From a489b8de76ad1400f7dcc287230c3a376896fbf2 Mon Sep 17 00:00:00 2001 From: Zheng Qi Date: Wed, 17 Dec 2025 10:41:20 -0800 Subject: [PATCH] Move the prefetched info to preallocated buffers Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2219 This change improves the performance of tracking the deltas in TBE, mainly by replacing DtoH copy with {F1984231816} with DtoD copy with async DtoH under stream_callback {F1984231839} To achieve this, the following is added - the pre-registered UVA buffer that's accessible from both GPU and CPU are reused every iteration - makes the lifetime of tensors the same to TBE makes it safe to async copy. - reuse the same buffer to avoid repeating allocation. - trigger the CPU thread to async copy in raw_embedding_streamer.stream() - GPU ops don't wait on the D2H - To avoid the D2D copy overlaps with D2H copy - A GPU event to track the finish of the D2D copy, make the CPU thread to wait for the D2D copy finish - join_stream_tensor_copy_thread to trigger a blocking wait for the copy in the next iteration in case of CPU copies take too long before overwriting the pre-registered buffer. Differential Revision: D86888586 --- ...t_table_batched_embeddings_ops_training.py | 180 ++++++++++++++---- .../raw_embedding_streamer.h | 8 +- .../raw_embedding_streamer.cpp | 42 +++- .../split_embeddings_cache_ops.cpp | 6 +- .../tests/raw_embedding_streamer_test.cpp | 63 +++++- 5 files changed, 254 insertions(+), 45 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index d0e326b843..ad353c23ea 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -1502,6 +1502,7 @@ def __init__( # noqa C901 self.prefetched_info_list: list[PrefetchedInfo] = torch.jit.annotate( list[PrefetchedInfo], [] ) + if self.enable_raw_embedding_streaming: self.res_params: RESParams = res_params or RESParams() self.res_params.table_sizes = [0] + list(accumulate(rows)) @@ -1509,6 +1510,11 @@ def __init__( # noqa C901 self.res_params.res_server_port = ( int(res_port_from_env) if res_port_from_env else 0 ) + self._res_copy_event: torch.cuda.Event = torch.cuda.Event() + self._res_require_copy: bool = True + self._res_sync_copy: bool = False + self._register_res_buffers() + # pyre-fixme[4]: Attribute must be annotated. self._raw_embedding_streamer = torch.classes.fbgemm.RawEmbeddingStreamer( self.uuid, @@ -1520,8 +1526,110 @@ def __init__( # noqa C901 self.res_params.table_sizes, ) logging.info( - f"{self.uuid} raw embedding streaming enabled with {self.res_params=}" + f"{self.uuid} raw embedding streaming enabled with {self.res_params=}, {self._res_require_copy=}, {self._res_sync_copy=}" ) + else: + self._register_empty_res_buffers() + + def _register_res_buffers(self) -> None: + assert ( + self.enable_raw_embedding_streaming + ), "Should not register res buffers when raw embedding streaming is not enabled" + cache_size = self.lxu_cache_weights.size(0) + if cache_size == 0: + self.log("Registering empty res buffers when there is no cache") + self._register_empty_res_buffers() + return + self.register_buffer( + "res_indices", + torch.ops.fbgemm.new_unified_tensor( + torch.zeros( + 1, + device=self.current_device, + dtype=torch.long, + ), + (cache_size,), + is_host_mapped=self.uvm_host_mapped, + ), + ) + self.register_buffer( + "res_weights", + torch.ops.fbgemm.new_unified_tensor( + torch.zeros( + 1, + device=self.current_device, + # pyre-ignore Incompatible parameter type [6]: In call `torch._C._VariableFunctions.empty`, for argument `dtype`, expected `Optional[dtype]` but got `Union[Module, dtype, Tensor]` + dtype=self.lxu_cache_weights.dtype, + ), + self.lxu_cache_weights.shape, + is_host_mapped=self.uvm_host_mapped, + ), + ) + self.register_buffer( + "res_identities", + torch.ops.fbgemm.new_unified_tensor( + torch.zeros( + 1, + device=self.current_device, + dtype=torch.long, + ), + (cache_size, 1), + is_host_mapped=self.uvm_host_mapped, + ), + ) + self.register_buffer( + "res_runtime_meta", + torch.ops.fbgemm.new_unified_tensor( + torch.zeros( + 1, + device=self.current_device, + dtype=torch.long, + ), + (cache_size, 1), + is_host_mapped=self.uvm_host_mapped, + ), + ) + self.register_buffer( + "res_count", + torch.ops.fbgemm.new_unified_tensor( + torch.zeros( + 1, + device=self.current_device, + dtype=torch.int, + ), + (1,), + is_host_mapped=self.uvm_host_mapped, + ), + ) + + def _register_empty_res_buffers(self) -> None: + """Register empty res buffers for TorchScript compatibility when streaming is disabled or cache is empty""" + self.register_buffer( + "res_indices", + torch.zeros(0, device=self.current_device, dtype=torch.long), + ) + self.register_buffer( + "res_weights", + torch.zeros( + 0, + 0, + device=self.current_device, + # pyre-ignore Incompatible parameter type [6]: In call `torch._C._VariableFunctions.zeros`, for argument `dtype`, expected `Optional[dtype]` but got `Union[Module, dtype, Tensor]` + dtype=self.lxu_cache_weights.dtype, + ), + ) + self.register_buffer( + "res_identities", + torch.zeros(0, 1, device=self.current_device, dtype=torch.long), + ) + self.register_buffer( + "res_runtime_meta", + torch.zeros(0, 1, device=self.current_device, dtype=torch.long), + ) + self.register_buffer( + "res_count", + torch.zeros(1, device=self.current_device, dtype=torch.int), + ) @torch.jit.ignore def log(self, msg: str) -> None: @@ -4152,6 +4260,8 @@ def raw_embedding_stream(self) -> None: with record_function( "## uvm_lookup_prefetched_rows {} {} ##".format(self.timestep, self.uuid) ): + if not self._res_sync_copy and self._res_require_copy: + self._raw_embedding_streamer.join_stream_tensor_copy_thread() prefetched_info = self.prefetched_info_list.pop(0) updated_locations = torch.ops.fbgemm.lxu_cache_lookup( prefetched_info.linear_unique_cache_indices, @@ -4160,60 +4270,52 @@ def raw_embedding_stream(self) -> None: gather_cache_stats=False, # not collecting cache stats num_uniq_cache_indices=prefetched_info.linear_unique_indices_length, ) - updated_weights = torch.empty( - [ - prefetched_info.linear_unique_cache_indices.size()[0], - self.max_D_cache, - ], - # pyre-ignore Incompatible parameter type [6]: In call `torch._C._VariableFunctions.empty`, for argument `dtype`, expected `Optional[dtype]` but got `Union[Module, dtype, Tensor]` - dtype=self.lxu_cache_weights.dtype, - # pyre-ignore Incompatible parameter type [6]: In call `torch._C._VariableFunctions.empty`, for argument `device`, expected `Union[None, int, str, device]` but got `Union[Module, device, Tensor]` - device=self.lxu_cache_weights.device, - ) + # only found items in cache will be filled torch.ops.fbgemm.masked_index_select( - updated_weights, + self.res_weights, updated_locations, self.lxu_cache_weights, prefetched_info.linear_unique_indices_length, ) - # TODO: this statement triggers a sync - # added here to make this diff self-contained - # will remove in later change - cache_hit_mask_index = ( - updated_locations.narrow( - 0, 0, prefetched_info.linear_unique_indices_length.item() - ) - .not_equal(-1) - .nonzero() - .flatten() + + # fill cache miss rows with -1 + linear_unique_hit_indices = torch.where( + torch.where(updated_locations != -1, True, False), + prefetched_info.linear_unique_indices, + -1, + ) + # pyre-ignore[29]: `Union[...]` is not a function. + self.res_indices[: prefetched_info.linear_unique_indices.size(0)].copy_( + linear_unique_hit_indices ) + # pyre-ignore[29]: `Union[...]` is not a function. + self.res_count[:1].copy_(prefetched_info.linear_unique_indices_length) + if prefetched_info.hash_zch_identities is not None: + # pyre-ignore[29]: `Union[...]` is not a function. + self.res_identities[ + : prefetched_info.hash_zch_identities.size(0) + ].copy_(prefetched_info.hash_zch_identities) + + self._res_copy_event.record() + # stream weights self._raw_embedding_streamer.stream( - prefetched_info.linear_unique_indices.index_select( - dim=0, index=cache_hit_mask_index - ).to(device=torch.device("cpu")), - updated_weights.index_select(dim=0, index=cache_hit_mask_index).to( - device=torch.device("cpu") - ), + self.res_indices, + self.res_weights, ( - prefetched_info.hash_zch_identities.index_select( - dim=0, index=cache_hit_mask_index - ).to(device=torch.device("cpu")) + self.res_identities if prefetched_info.hash_zch_identities is not None else None ), ( - prefetched_info.hash_zch_runtime_meta.index_select( - dim=0, index=cache_hit_mask_index - ).to(device=torch.device("cpu")) + self.res_runtime_meta if prefetched_info.hash_zch_runtime_meta is not None else None ), - prefetched_info.linear_unique_indices_length.to( - device=torch.device("cpu") - ), - False, # require_tensor_copy - False, # blocking_tensor_copy + self.res_count, + self._res_require_copy, # require_tensor_copy + self._res_sync_copy, # blocking_tensor_copy + self._res_copy_event.cuda_event, # event_ptr_to_wait - pass the raw CUDA event pointer ) @staticmethod diff --git a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/raw_embedding_streamer.h b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/raw_embedding_streamer.h index 9ee379045e..336653887b 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/raw_embedding_streamer.h +++ b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/raw_embedding_streamer.h @@ -8,6 +8,11 @@ #pragma once #include +#ifdef USE_ROCM +#include +#else +#include +#endif #ifdef FBGEMM_FBCODE #include #endif @@ -77,7 +82,8 @@ class RawEmbeddingStreamer : public torch::jit::CustomClassHolder { std::optional runtime_meta, const at::Tensor& count, bool require_tensor_copy, - bool blocking_tensor_copy = true); + bool blocking_tensor_copy = true, + std::optional event_ptr_to_wait = std::nullopt); #ifdef FBGEMM_FBCODE folly::coro::Task tensor_stream( diff --git a/fbgemm_gpu/src/split_embeddings_cache/raw_embedding_streamer.cpp b/fbgemm_gpu/src/split_embeddings_cache/raw_embedding_streamer.cpp index e1b54d8c82..0d8b03fb94 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/raw_embedding_streamer.cpp +++ b/fbgemm_gpu/src/split_embeddings_cache/raw_embedding_streamer.cpp @@ -205,7 +205,8 @@ void RawEmbeddingStreamer::stream( std::optional runtime_meta, const at::Tensor& count, bool require_tensor_copy, - bool blocking_tensor_copy) { + bool blocking_tensor_copy, + std::optional event_ptr_to_wait) { if (!enable_raw_embedding_streaming_) { return; } @@ -223,6 +224,17 @@ void RawEmbeddingStreamer::stream( return; } if (blocking_tensor_copy) { + if (event_ptr_to_wait.has_value() && event_ptr_to_wait.value() != 0) { +#ifdef USE_ROCM + hipEvent_t cuda_event = + reinterpret_cast(event_ptr_to_wait.value()); + AT_CUDA_CHECK(hipEventSynchronize(cuda_event)); +#else + cudaEvent_t cuda_event = + reinterpret_cast(event_ptr_to_wait.value()); + AT_CUDA_CHECK(cudaEventSynchronize(cuda_event)); +#endif + } copy_and_enqueue_stream_tensors( indices, weights, @@ -237,7 +249,33 @@ void RawEmbeddingStreamer::stream( // callbacks don't need to be serialized. // So, We need to spin up a new thread to unblock the CUDA stream, so the CUDA // can continue executing other host callbacks, eg. get/evict. - stream_tensor_copy_thread_ = std::make_unique([=, this]() { + stream_tensor_copy_thread_ = std::make_unique([this, + event_ptr_to_wait, + indices, + weights, + identities, + runtime_meta, + count]() { + // Set the CUDA device for this thread - this is critical! + // Without this, cudaEventSynchronize may fail with illegal memory + // access if the thread doesn't have the correct CUDA context. + if (event_ptr_to_wait.has_value() && event_ptr_to_wait.value() != 0) { +#ifdef USE_ROCM + if (weights.is_cuda()) { + AT_CUDA_CHECK(hipSetDevice(weights.device().index())); + } + hipEvent_t cuda_event = + reinterpret_cast(event_ptr_to_wait.value()); + AT_CUDA_CHECK(hipEventSynchronize(cuda_event)); +#else + if (weights.is_cuda()) { + AT_CUDA_CHECK(cudaSetDevice(weights.device().index())); + } + cudaEvent_t cuda_event = + reinterpret_cast(event_ptr_to_wait.value()); + AT_CUDA_CHECK(cudaEventSynchronize(cuda_event)); +#endif + } copy_and_enqueue_stream_tensors( indices, weights, identities, runtime_meta, count); }); diff --git a/fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cpp b/fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cpp index 6bbd1dd7a6..b56ad90520 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cpp +++ b/fbgemm_gpu/src/split_embeddings_cache/split_embeddings_cache_ops.cpp @@ -112,6 +112,10 @@ static auto raw_embedding_streamer = torch::arg("count"), torch::arg("require_tensor_copy"), torch::arg("blocking_tensor_copy"), - }); + torch::arg("event_ptr_to_wait") = std::nullopt, + }) + .def( + "join_stream_tensor_copy_thread", + &fbgemm_gpu::RawEmbeddingStreamer::join_stream_tensor_copy_thread); } // namespace diff --git a/fbgemm_gpu/src/split_embeddings_cache/tests/raw_embedding_streamer_test.cpp b/fbgemm_gpu/src/split_embeddings_cache/tests/raw_embedding_streamer_test.cpp index d6567c6e93..00760af6b7 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/tests/raw_embedding_streamer_test.cpp +++ b/fbgemm_gpu/src/split_embeddings_cache/tests/raw_embedding_streamer_test.cpp @@ -300,11 +300,11 @@ TEST(RawEmbeddingStreamerTest, TestStreamWithIdentities) { auto identities = at::tensor( {1001, 1002, 1003, 1004, 1005, 1006, 1007}, at::TensorOptions().device(at::kCPU).dtype(at::kLong)) - .reshape({7, 1}); + .view({-1, 1}); auto runtime_meta = at::tensor( {101, 102, 103, 104, 105, 106, 107}, at::TensorOptions().device(at::kCPU).dtype(at::kLong)) - .reshape({7, 1}); + .view({-1, 1}); auto count = at::tensor( {indices.size(0)}, at::TensorOptions().device(at::kCPU).dtype(at::kLong)); @@ -333,4 +333,63 @@ TEST(RawEmbeddingStreamerTest, TestStreamWithIdentities) { indices, weights, identities, runtime_meta, count, true, true); EXPECT_EQ(streamer->get_weights_to_stream_queue_size(), 1); } + +TEST(RawEmbeddingStreamerTest, TestStreamWithEventPtrNonBlockingCopy) { + if (!at::cuda::is_available()) { + GTEST_SKIP() << "Skipping test because CUDA/HIP is not available"; + } + + std::vector table_names = {"tb1", "tb2", "tb3"}; + std::vector table_offsets = {0, 100, 300}; + std::vector table_sizes = {0, 50, 200, 300}; + + auto streamer = getRawEmbeddingStreamer( + "test_event_nonblocking", true, table_names, table_offsets, table_sizes); + + // Create CPU tensors + auto indices = at::tensor( + {10, 2, 1, 150, 170, 230, 280}, + at::TensorOptions().device(at::kCPU).dtype(at::kLong)); + auto weights = at::randn( + {indices.size(0), EMBEDDING_DIMENSION}, + at::TensorOptions().device(at::kCPU).dtype(c10::kFloat)); + auto count = at::tensor( + {indices.size(0)}, at::TensorOptions().device(at::kCPU).dtype(at::kLong)); + + // Create and record a GPU event (compatible with both CUDA and HIP) +#ifdef USE_ROCM + hipEvent_t event; + AT_CUDA_CHECK(hipEventCreate(&event)); + AT_CUDA_CHECK(hipEventRecord(event, at::cuda::getCurrentHIPStream())); +#else + cudaEvent_t event; + AT_CUDA_CHECK(cudaEventCreate(&event)); + AT_CUDA_CHECK(cudaEventRecord(event, at::cuda::getCurrentCUDAStream())); +#endif + + // Stop the dequeue thread to get accurate queue size + streamer->join_weights_stream_thread(); + + // Test with valid event_ptr_to_wait in non-blocking mode + streamer->stream( + indices, + weights, + std::nullopt, + std::nullopt, + count, + true, + false, + reinterpret_cast(event)); + + // Wait for the async thread to complete + streamer->join_stream_tensor_copy_thread(); + EXPECT_EQ(streamer->get_weights_to_stream_queue_size(), 1); + + // Cleanup +#ifdef USE_ROCM + AT_CUDA_CHECK(hipEventDestroy(event)); +#else + AT_CUDA_CHECK(cudaEventDestroy(event)); +#endif +} #endif