Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 141 additions & 39 deletions fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1502,13 +1502,19 @@ def __init__( # noqa C901
self.prefetched_info_list: list[PrefetchedInfo] = torch.jit.annotate(
list[PrefetchedInfo], []
)

if self.enable_raw_embedding_streaming:
self.res_params: RESParams = res_params or RESParams()
self.res_params.table_sizes = [0] + list(accumulate(rows))
res_port_from_env = os.getenv("LOCAL_RES_PORT")
self.res_params.res_server_port = (
int(res_port_from_env) if res_port_from_env else 0
)
self._res_copy_event: torch.cuda.Event = torch.cuda.Event()
self._res_require_copy: bool = True
self._res_sync_copy: bool = False
self._register_res_buffers()

# pyre-fixme[4]: Attribute must be annotated.
self._raw_embedding_streamer = torch.classes.fbgemm.RawEmbeddingStreamer(
self.uuid,
Expand All @@ -1520,8 +1526,110 @@ def __init__( # noqa C901
self.res_params.table_sizes,
)
logging.info(
f"{self.uuid} raw embedding streaming enabled with {self.res_params=}"
f"{self.uuid} raw embedding streaming enabled with {self.res_params=}, {self._res_require_copy=}, {self._res_sync_copy=}"
)
else:
self._register_empty_res_buffers()

def _register_res_buffers(self) -> None:
assert (
self.enable_raw_embedding_streaming
), "Should not register res buffers when raw embedding streaming is not enabled"
cache_size = self.lxu_cache_weights.size(0)
if cache_size == 0:
self.log("Registering empty res buffers when there is no cache")
self._register_empty_res_buffers()
return
self.register_buffer(
"res_indices",
torch.ops.fbgemm.new_unified_tensor(
torch.zeros(
1,
device=self.current_device,
dtype=torch.long,
),
(cache_size,),
is_host_mapped=self.uvm_host_mapped,
),
)
self.register_buffer(
"res_weights",
torch.ops.fbgemm.new_unified_tensor(
torch.zeros(
1,
device=self.current_device,
# pyre-ignore Incompatible parameter type [6]: In call `torch._C._VariableFunctions.empty`, for argument `dtype`, expected `Optional[dtype]` but got `Union[Module, dtype, Tensor]`
dtype=self.lxu_cache_weights.dtype,
),
self.lxu_cache_weights.shape,
is_host_mapped=self.uvm_host_mapped,
),
)
self.register_buffer(
"res_identities",
torch.ops.fbgemm.new_unified_tensor(
torch.zeros(
1,
device=self.current_device,
dtype=torch.long,
),
(cache_size, 1),
is_host_mapped=self.uvm_host_mapped,
),
)
self.register_buffer(
"res_runtime_meta",
torch.ops.fbgemm.new_unified_tensor(
torch.zeros(
1,
device=self.current_device,
dtype=torch.long,
),
(cache_size, 1),
is_host_mapped=self.uvm_host_mapped,
),
)
self.register_buffer(
"res_count",
torch.ops.fbgemm.new_unified_tensor(
torch.zeros(
1,
device=self.current_device,
dtype=torch.int,
),
(1,),
is_host_mapped=self.uvm_host_mapped,
),
)

def _register_empty_res_buffers(self) -> None:
"""Register empty res buffers for TorchScript compatibility when streaming is disabled or cache is empty"""
self.register_buffer(
"res_indices",
torch.zeros(0, device=self.current_device, dtype=torch.long),
)
self.register_buffer(
"res_weights",
torch.zeros(
0,
0,
device=self.current_device,
# pyre-ignore Incompatible parameter type [6]: In call `torch._C._VariableFunctions.zeros`, for argument `dtype`, expected `Optional[dtype]` but got `Union[Module, dtype, Tensor]`
dtype=self.lxu_cache_weights.dtype,
),
)
self.register_buffer(
"res_identities",
torch.zeros(0, 1, device=self.current_device, dtype=torch.long),
)
self.register_buffer(
"res_runtime_meta",
torch.zeros(0, 1, device=self.current_device, dtype=torch.long),
)
self.register_buffer(
"res_count",
torch.zeros(1, device=self.current_device, dtype=torch.int),
)

@torch.jit.ignore
def log(self, msg: str) -> None:
Expand Down Expand Up @@ -4152,6 +4260,8 @@ def raw_embedding_stream(self) -> None:
with record_function(
"## uvm_lookup_prefetched_rows {} {} ##".format(self.timestep, self.uuid)
):
if not self._res_sync_copy and self._res_require_copy:
self._raw_embedding_streamer.join_stream_tensor_copy_thread()
prefetched_info = self.prefetched_info_list.pop(0)
updated_locations = torch.ops.fbgemm.lxu_cache_lookup(
prefetched_info.linear_unique_cache_indices,
Expand All @@ -4160,60 +4270,52 @@ def raw_embedding_stream(self) -> None:
gather_cache_stats=False, # not collecting cache stats
num_uniq_cache_indices=prefetched_info.linear_unique_indices_length,
)
updated_weights = torch.empty(
[
prefetched_info.linear_unique_cache_indices.size()[0],
self.max_D_cache,
],
# pyre-ignore Incompatible parameter type [6]: In call `torch._C._VariableFunctions.empty`, for argument `dtype`, expected `Optional[dtype]` but got `Union[Module, dtype, Tensor]`
dtype=self.lxu_cache_weights.dtype,
# pyre-ignore Incompatible parameter type [6]: In call `torch._C._VariableFunctions.empty`, for argument `device`, expected `Union[None, int, str, device]` but got `Union[Module, device, Tensor]`
device=self.lxu_cache_weights.device,
)
# only found items in cache will be filled
torch.ops.fbgemm.masked_index_select(
updated_weights,
self.res_weights,
updated_locations,
self.lxu_cache_weights,
prefetched_info.linear_unique_indices_length,
)
# TODO: this statement triggers a sync
# added here to make this diff self-contained
# will remove in later change
cache_hit_mask_index = (
updated_locations.narrow(
0, 0, prefetched_info.linear_unique_indices_length.item()
)
.not_equal(-1)
.nonzero()
.flatten()

# fill cache miss rows with -1
linear_unique_hit_indices = torch.where(
torch.where(updated_locations != -1, True, False),
prefetched_info.linear_unique_indices,
-1,
)
# pyre-ignore[29]: `Union[...]` is not a function.
self.res_indices[: prefetched_info.linear_unique_indices.size(0)].copy_(
linear_unique_hit_indices
)
# pyre-ignore[29]: `Union[...]` is not a function.
self.res_count[:1].copy_(prefetched_info.linear_unique_indices_length)
if prefetched_info.hash_zch_identities is not None:
# pyre-ignore[29]: `Union[...]` is not a function.
self.res_identities[
: prefetched_info.hash_zch_identities.size(0)
].copy_(prefetched_info.hash_zch_identities)

self._res_copy_event.record()

# stream weights
self._raw_embedding_streamer.stream(
prefetched_info.linear_unique_indices.index_select(
dim=0, index=cache_hit_mask_index
).to(device=torch.device("cpu")),
updated_weights.index_select(dim=0, index=cache_hit_mask_index).to(
device=torch.device("cpu")
),
self.res_indices,
self.res_weights,
(
prefetched_info.hash_zch_identities.index_select(
dim=0, index=cache_hit_mask_index
).to(device=torch.device("cpu"))
self.res_identities
if prefetched_info.hash_zch_identities is not None
else None
),
(
prefetched_info.hash_zch_runtime_meta.index_select(
dim=0, index=cache_hit_mask_index
).to(device=torch.device("cpu"))
self.res_runtime_meta
if prefetched_info.hash_zch_runtime_meta is not None
else None
),
prefetched_info.linear_unique_indices_length.to(
device=torch.device("cpu")
),
False, # require_tensor_copy
False, # blocking_tensor_copy
self.res_count,
self._res_require_copy, # require_tensor_copy
self._res_sync_copy, # blocking_tensor_copy
self._res_copy_event.cuda_event, # event_ptr_to_wait - pass the raw CUDA event pointer
)

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@

#pragma once
#include <ATen/ATen.h>
#ifdef USE_ROCM
#include <ATen/hip/HIPEvent.h>
#else
#include <ATen/cuda/CUDAEvent.h>
#endif
#ifdef FBGEMM_FBCODE
#include <folly/coro/Task.h>
#endif
Expand Down Expand Up @@ -77,7 +82,8 @@ class RawEmbeddingStreamer : public torch::jit::CustomClassHolder {
std::optional<at::Tensor> runtime_meta,
const at::Tensor& count,
bool require_tensor_copy,
bool blocking_tensor_copy = true);
bool blocking_tensor_copy = true,
std::optional<int64_t> event_ptr_to_wait = std::nullopt);

#ifdef FBGEMM_FBCODE
folly::coro::Task<void> tensor_stream(
Expand Down
42 changes: 40 additions & 2 deletions fbgemm_gpu/src/split_embeddings_cache/raw_embedding_streamer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,8 @@ void RawEmbeddingStreamer::stream(
std::optional<at::Tensor> runtime_meta,
const at::Tensor& count,
bool require_tensor_copy,
bool blocking_tensor_copy) {
bool blocking_tensor_copy,
std::optional<int64_t> event_ptr_to_wait) {
if (!enable_raw_embedding_streaming_) {
return;
}
Expand All @@ -223,6 +224,17 @@ void RawEmbeddingStreamer::stream(
return;
}
if (blocking_tensor_copy) {
if (event_ptr_to_wait.has_value() && event_ptr_to_wait.value() != 0) {
#ifdef USE_ROCM
hipEvent_t cuda_event =
reinterpret_cast<hipEvent_t>(event_ptr_to_wait.value());
AT_CUDA_CHECK(hipEventSynchronize(cuda_event));
#else
cudaEvent_t cuda_event =
reinterpret_cast<cudaEvent_t>(event_ptr_to_wait.value());
AT_CUDA_CHECK(cudaEventSynchronize(cuda_event));
#endif
}
copy_and_enqueue_stream_tensors(
indices,
weights,
Expand All @@ -237,7 +249,33 @@ void RawEmbeddingStreamer::stream(
// callbacks don't need to be serialized.
// So, We need to spin up a new thread to unblock the CUDA stream, so the CUDA
// can continue executing other host callbacks, eg. get/evict.
stream_tensor_copy_thread_ = std::make_unique<std::thread>([=, this]() {
stream_tensor_copy_thread_ = std::make_unique<std::thread>([this,
event_ptr_to_wait,
indices,
weights,
identities,
runtime_meta,
count]() {
// Set the CUDA device for this thread - this is critical!
// Without this, cudaEventSynchronize may fail with illegal memory
// access if the thread doesn't have the correct CUDA context.
if (event_ptr_to_wait.has_value() && event_ptr_to_wait.value() != 0) {
#ifdef USE_ROCM
if (weights.is_cuda()) {
AT_CUDA_CHECK(hipSetDevice(weights.device().index()));
}
hipEvent_t cuda_event =
reinterpret_cast<hipEvent_t>(event_ptr_to_wait.value());
AT_CUDA_CHECK(hipEventSynchronize(cuda_event));
#else
if (weights.is_cuda()) {
AT_CUDA_CHECK(cudaSetDevice(weights.device().index()));
}
cudaEvent_t cuda_event =
reinterpret_cast<cudaEvent_t>(event_ptr_to_wait.value());
AT_CUDA_CHECK(cudaEventSynchronize(cuda_event));
#endif
}
copy_and_enqueue_stream_tensors(
indices, weights, identities, runtime_meta, count);
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ static auto raw_embedding_streamer =
torch::arg("count"),
torch::arg("require_tensor_copy"),
torch::arg("blocking_tensor_copy"),
});
torch::arg("event_ptr_to_wait") = std::nullopt,
})
.def(
"join_stream_tensor_copy_thread",
&fbgemm_gpu::RawEmbeddingStreamer::join_stream_tensor_copy_thread);

} // namespace
Loading
Loading