diff --git a/backends/aoti/common_shims.h b/backends/aoti/common_shims.h index 675a9864e74..3fc414fb669 100644 --- a/backends/aoti/common_shims.h +++ b/backends/aoti/common_shims.h @@ -64,6 +64,7 @@ AOTI_SHIM_EXPORT int32_t aoti_torch_device_type_cpu(); AOTI_SHIM_EXPORT int32_t aoti_torch_layout_strided(); AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_float32(); AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_bfloat16(); +AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_bool(); AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int8(); AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int16(); AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int32(); diff --git a/backends/cuda/runtime/shims/memory.cpp b/backends/cuda/runtime/shims/memory.cpp index ecb1ded2f39..86f6cdd6396 100644 --- a/backends/cuda/runtime/shims/memory.cpp +++ b/backends/cuda/runtime/shims/memory.cpp @@ -24,6 +24,7 @@ namespace executorch::backends::cuda { using executorch::aten::SizesType; using executorch::aten::StridesType; +using executorch::backends::aoti::aoti_torch_dtype_bool; using executorch::backends::aoti::aoti_torch_get_device_index; using executorch::backends::aoti::aoti_torch_get_dtype; using executorch::backends::aoti::aoti_torch_get_sizes; @@ -800,6 +801,126 @@ AOTITorchError aoti_torch_new_tensor_handle( return Error::Ok; } + +AOTITorchError aoti_torch_item_bool(Tensor* tensor, bool* ret_value) { + // Validate input parameters + ET_CHECK_OR_RETURN_ERROR( + tensor != nullptr, + InvalidArgument, + "aoti_torch_item_bool failed: tensor is null"); + + ET_CHECK_OR_RETURN_ERROR( + ret_value != nullptr, + InvalidArgument, + "aoti_torch_item_bool failed: ret_value is null"); + + // Validate that tensor dtype is bool + int32_t dtype; + ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_dtype(tensor, &dtype)); + + ET_CHECK_OR_RETURN_ERROR( + dtype == aoti_torch_dtype_bool(), + InvalidArgument, + "aoti_torch_item_bool failed: tensor dtype is not bool (got %d)", + dtype); + + // Get the data pointer + const void* data_ptr = tensor->const_data_ptr(); + ET_CHECK_OR_RETURN_ERROR( + data_ptr != nullptr, + InvalidArgument, + "aoti_torch_item_bool failed: tensor data pointer is null"); + + // Check if tensor is on CUDA or CPU + cudaPointerAttributes attributes{}; + ET_CUDA_CHECK_OR_RETURN_ERROR( + cudaPointerGetAttributes(&attributes, data_ptr)); + + if (attributes.type == cudaMemoryTypeDevice) { + // CUDA memory case: copy from device to host + bool device_value; + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpy( + &device_value, data_ptr, sizeof(bool), cudaMemcpyDeviceToHost)); + *ret_value = device_value; + } else { + // CPU memory case: direct access + const bool* bool_ptr = static_cast(data_ptr); + *ret_value = *bool_ptr; + } + + return Error::Ok; +} + +AOTITorchError aoti_torch_assign_tensors_out(Tensor* src, Tensor** ret_dst) { + // Validate input parameters + ET_CHECK_OR_RETURN_ERROR( + src != nullptr, + InvalidArgument, + "aoti_torch_assign_tensors_out failed: src is null"); + + ET_CHECK_OR_RETURN_ERROR( + ret_dst != nullptr, + InvalidArgument, + "aoti_torch_assign_tensors_out failed: ret_dst is null"); + + // Get the data pointer from the source tensor + void* data_ptr = src->mutable_data_ptr(); + ET_CHECK_OR_RETURN_ERROR( + data_ptr != nullptr, + InvalidArgument, + "Source tensor has null data pointer"); + + // Check if the given memory is in the map, if not return error + auto memory_it = memory_to_n_tensor.find(data_ptr); + ET_CHECK_OR_RETURN_ERROR( + memory_it != memory_to_n_tensor.end(), + InvalidArgument, + "Memory address %p is not being tracked by reference counting system", + data_ptr); + + // Get dtype from source tensor + int32_t dtype = 0; + ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_dtype(src, &dtype)); + + // Get sizes and strides from source tensor + int64_t* sizes_ptr; + int64_t* strides_ptr; + ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_sizes(src, &sizes_ptr)); + ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_strides(src, &strides_ptr)); + + int64_t ndim = src->dim(); + + // Convert to vectors + std::vector sizes = convert_sizes_to_vector(ndim, sizes_ptr); + std::vector strides = + convert_strides_to_vector(ndim, sizes_ptr, strides_ptr); + + // Create new tensor view that shares the same memory as source tensor + std::shared_ptr tensor = make_tensor( + sizes, + data_ptr, // Share the same memory from source tensor + {}, // dim_order (empty, will be auto-generated) + strides, + dtype_to_scalar_type(dtype)); + + ET_CHECK_OR_RETURN_ERROR( + tensor != nullptr, + InvalidArgument, + "Failed to create tensor view in aoti_torch_assign_tensors_out"); + + // Store the tensor so it doesn't get destroyed + tensors.insert(tensor); + + *ret_dst = tensor.get(); + + // Increment the reference count for this memory address only if it is owned + // by tensor + memory_to_n_tensor[data_ptr] = memory_to_n_tensor[data_ptr] == NOT_OWN + ? NOT_OWN + : memory_to_n_tensor[data_ptr] + 1; + + return Error::Ok; +} } // extern "C" } // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/memory.h b/backends/cuda/runtime/shims/memory.h index 935df853748..34b781a5270 100644 --- a/backends/cuda/runtime/shims/memory.h +++ b/backends/cuda/runtime/shims/memory.h @@ -161,9 +161,41 @@ aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking); * @return Error::Ok on success, appropriate error code on failure: * - Error::InvalidArgument: null pointers or invalid parameters */ -AOTITorchError aoti_torch_new_tensor_handle( - Tensor* orig_handle, - Tensor** new_handle); +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_new_tensor_handle(Tensor* orig_handle, Tensor** new_handle); + +/** + * Retrieves a boolean value from a 0D boolean tensor. + * + * This function extracts the scalar boolean value from a tensor that contains + * a single boolean element. The tensor can be on either CPU or CUDA device. + * For CUDA tensors, the value is copied from device to host memory. + * + * @param tensor Pointer to a 0D boolean tensor (must not be null) + * @param ret_value Output pointer to store the boolean value (must not be null) + * + * @return Error::Ok on success, appropriate error code on failure: + * - Error::InvalidArgument: null pointers or tensor dtype is not bool + */ +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_item_bool(Tensor* tensor, bool* ret_value); + +/** + * Creates a new tensor that shares the same underlying data as the source + * tensor. + * + * This function creates a new tensor view with the same shape, strides, and + * dtype as the source tensor, sharing the same underlying memory. The new + * tensor handle will be stored in ret_dst. + * + * @param src The source tensor providing the data and metadata. + * @param ret_dst On output, this will point to the new tensor view. + * + * @return Error::Ok on success, appropriate error code on failure: + * - Error::InvalidArgument: null pointers or memory not tracked + */ +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_assign_tensors_out(Tensor* src, Tensor** ret_dst); // Function to clear all tensors from internal storage AOTI_SHIM_EXPORT void clear_all_tensors(); diff --git a/backends/cuda/runtime/shims/tests/CMakeLists.txt b/backends/cuda/runtime/shims/tests/CMakeLists.txt index a7df6075c37..204c08688c4 100644 --- a/backends/cuda/runtime/shims/tests/CMakeLists.txt +++ b/backends/cuda/runtime/shims/tests/CMakeLists.txt @@ -37,9 +37,14 @@ find_package(executorch CONFIG REQUIRED HINTS ${CMAKE_INSTALL_PREFIX}) # List of test files set(CUDA_SHIM_TESTS - test_aoti_torch_create_tensor_from_blob_v2 test_aoti_torch_empty_strided - test_aoti_torch_delete_tensor_object test_aoti_torch__reinterpret_tensor - test_aoti_torch_copy_ test_aoti_torch_new_tensor_handle + test_aoti_torch_create_tensor_from_blob_v2 + test_aoti_torch_empty_strided + test_aoti_torch_delete_tensor_object + test_aoti_torch__reinterpret_tensor + test_aoti_torch_copy_ + test_aoti_torch_new_tensor_handle + test_aoti_torch_item_bool + test_aoti_torch_assign_tensors_out ) enable_testing() diff --git a/backends/cuda/runtime/shims/tests/targets.bzl b/backends/cuda/runtime/shims/tests/targets.bzl index b274ecf3675..7736624c02a 100644 --- a/backends/cuda/runtime/shims/tests/targets.bzl +++ b/backends/cuda/runtime/shims/tests/targets.bzl @@ -35,3 +35,5 @@ def define_common_targets(): cuda_shim_cpp_unittest("aoti_torch_cuda_guard") cuda_shim_cpp_unittest("aoti_torch_cuda__weight_int4pack_mm") cuda_shim_cpp_unittest("aoti_torch_new_tensor_handle") + cuda_shim_cpp_unittest("aoti_torch_item_bool") + cuda_shim_cpp_unittest("aoti_torch_assign_tensors_out") diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_assign_tensors_out.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_assign_tensors_out.cpp new file mode 100644 index 00000000000..d5e1bcb2547 --- /dev/null +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_assign_tensors_out.cpp @@ -0,0 +1,245 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace executorch::backends::aoti; +using namespace executorch::backends::cuda; +using namespace executorch::runtime; +using executorch::runtime::etensor::Tensor; + +// Test fixture for aoti_torch_assign_tensors_out tests +class AOTITorchAssignTensorsOutTest : public ::testing::Test { + protected: + void SetUp() override { + // Initialize ExecuTorch Platform Abstraction Layer + et_pal_init(); + + // Check if CUDA is available + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "CUDA not available, skipping CUDA tests"; + } + + // Clean up any existing cached metadata before each test + cleanup_tensor_metadata(); + + // Clear any remaining tensors from previous tests + clear_all_tensors(); + } + + void TearDown() override { + // Clean up metadata + cleanup_tensor_metadata(); + + // Clear the global tensor storage using the provided function + clear_all_tensors(); + } + + // Helper to create a test tensor + Tensor* create_test_tensor( + const std::vector& sizes, + int32_t dtype = static_cast(SupportedDTypes::FLOAT32), + int32_t device_type = static_cast(SupportedDevices::CUDA)) { + std::vector strides; + // Calculate contiguous strides + if (!sizes.empty()) { + strides.resize(sizes.size()); + strides[sizes.size() - 1] = 1; + for (int64_t i = static_cast(sizes.size()) - 2; i >= 0; i--) { + strides[i] = strides[i + 1] * sizes[i + 1]; + } + } + + Tensor* tensor; + const int64_t* strides_ptr = strides.empty() ? nullptr : strides.data(); + + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + strides_ptr, + dtype, + device_type, + 0, + &tensor); + + return (error == Error::Ok) ? tensor : nullptr; + } +}; + +// Test basic functionality +TEST_F(AOTITorchAssignTensorsOutTest, BasicFunctionality) { + // Create a source tensor + std::vector sizes = {2, 3}; + Tensor* src = create_test_tensor(sizes); + ASSERT_NE(src, nullptr); + + // Create output tensor handle + Tensor* dst = nullptr; + AOTITorchError error = aoti_torch_assign_tensors_out(src, &dst); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(dst, nullptr); + + // Verify the output tensor has the same properties as source + EXPECT_EQ(dst->dim(), src->dim()); + EXPECT_EQ(dst->size(0), src->size(0)); + EXPECT_EQ(dst->size(1), src->size(1)); + EXPECT_EQ(dst->numel(), src->numel()); + + // Verify they share the same memory + EXPECT_EQ(dst->mutable_data_ptr(), src->mutable_data_ptr()); +} + +// Test with 1D tensor +TEST_F(AOTITorchAssignTensorsOutTest, OneDimensionalTensor) { + std::vector sizes = {10}; + Tensor* src = create_test_tensor(sizes); + ASSERT_NE(src, nullptr); + + Tensor* dst = nullptr; + AOTITorchError error = aoti_torch_assign_tensors_out(src, &dst); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(dst, nullptr); + EXPECT_EQ(dst->dim(), 1); + EXPECT_EQ(dst->size(0), 10); + EXPECT_EQ(dst->mutable_data_ptr(), src->mutable_data_ptr()); +} + +// Test with 3D tensor +TEST_F(AOTITorchAssignTensorsOutTest, ThreeDimensionalTensor) { + std::vector sizes = {2, 3, 4}; + Tensor* src = create_test_tensor(sizes); + ASSERT_NE(src, nullptr); + + Tensor* dst = nullptr; + AOTITorchError error = aoti_torch_assign_tensors_out(src, &dst); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(dst, nullptr); + EXPECT_EQ(dst->dim(), 3); + EXPECT_EQ(dst->size(0), 2); + EXPECT_EQ(dst->size(1), 3); + EXPECT_EQ(dst->size(2), 4); + EXPECT_EQ(dst->mutable_data_ptr(), src->mutable_data_ptr()); +} + +// Test with scalar (0D) tensor +TEST_F(AOTITorchAssignTensorsOutTest, ScalarTensor) { + std::vector sizes = {}; + Tensor* src = create_test_tensor(sizes); + ASSERT_NE(src, nullptr); + + Tensor* dst = nullptr; + AOTITorchError error = aoti_torch_assign_tensors_out(src, &dst); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(dst, nullptr); + EXPECT_EQ(dst->dim(), 0); + EXPECT_EQ(dst->mutable_data_ptr(), src->mutable_data_ptr()); +} + +// Test with null source pointer +TEST_F(AOTITorchAssignTensorsOutTest, NullSourcePointer) { + Tensor* dst = nullptr; + AOTITorchError error = aoti_torch_assign_tensors_out(nullptr, &dst); + EXPECT_EQ(error, Error::InvalidArgument); +} + +// Test with null destination pointer +TEST_F(AOTITorchAssignTensorsOutTest, NullDestinationPointer) { + std::vector sizes = {2, 3}; + Tensor* src = create_test_tensor(sizes); + ASSERT_NE(src, nullptr); + + AOTITorchError error = aoti_torch_assign_tensors_out(src, nullptr); + EXPECT_EQ(error, Error::InvalidArgument); +} + +// Test that strides are preserved +TEST_F(AOTITorchAssignTensorsOutTest, StridesPreserved) { + std::vector sizes = {2, 3}; + Tensor* src = create_test_tensor(sizes); + ASSERT_NE(src, nullptr); + + Tensor* dst = nullptr; + AOTITorchError error = aoti_torch_assign_tensors_out(src, &dst); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(dst, nullptr); + + // Get strides from both tensors + int64_t* src_strides; + int64_t* dst_strides; + aoti_torch_get_strides(src, &src_strides); + aoti_torch_get_strides(dst, &dst_strides); + + // Verify strides match + for (int64_t i = 0; i < src->dim(); i++) { + EXPECT_EQ(src_strides[i], dst_strides[i]); + } +} + +// Test with CPU tensor +TEST_F(AOTITorchAssignTensorsOutTest, CPUTensor) { + std::vector sizes = {2, 3}; + Tensor* src = create_test_tensor( + sizes, + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CPU)); + ASSERT_NE(src, nullptr); + + Tensor* dst = nullptr; + AOTITorchError error = aoti_torch_assign_tensors_out(src, &dst); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(dst, nullptr); + EXPECT_EQ(dst->mutable_data_ptr(), src->mutable_data_ptr()); +} + +// Test dtype is preserved +TEST_F(AOTITorchAssignTensorsOutTest, DtypePreserved) { + // Test with different dtypes + std::vector dtypes = { + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDTypes::INT32), + static_cast(SupportedDTypes::INT64), + }; + + for (int32_t dtype : dtypes) { + cleanup_tensor_metadata(); + clear_all_tensors(); + + std::vector sizes = {2, 3}; + Tensor* src = create_test_tensor(sizes, dtype); + ASSERT_NE(src, nullptr); + + Tensor* dst = nullptr; + AOTITorchError error = aoti_torch_assign_tensors_out(src, &dst); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(dst, nullptr); + + // Verify dtype is preserved + int32_t src_dtype, dst_dtype; + aoti_torch_get_dtype(src, &src_dtype); + aoti_torch_get_dtype(dst, &dst_dtype); + EXPECT_EQ(src_dtype, dst_dtype) + << "Dtype mismatch for dtype code: " << dtype; + } +} diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_item_bool.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_item_bool.cpp new file mode 100644 index 00000000000..8e6bcbbfad6 --- /dev/null +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_item_bool.cpp @@ -0,0 +1,203 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace executorch::backends::aoti; +using namespace executorch::backends::cuda; +using namespace executorch::runtime; +using executorch::runtime::etensor::Tensor; + +// Test fixture for aoti_torch_item_bool tests +class AOTITorchItemBoolTest : public ::testing::Test { + protected: + void SetUp() override { + // Initialize ExecuTorch Platform Abstraction Layer + et_pal_init(); + + // Check if CUDA is available + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "CUDA not available, skipping CUDA tests"; + } + + // Clean up any existing cached metadata before each test + cleanup_tensor_metadata(); + + // Clear any remaining tensors from previous tests + clear_all_tensors(); + } + + void TearDown() override { + // Clean up metadata + cleanup_tensor_metadata(); + + // Clear the global tensor storage using the provided function + clear_all_tensors(); + } + + // Helper to create a bool tensor on CUDA with a specific value + Tensor* create_cuda_bool_tensor(bool value) { + // Create a 0D (scalar) bool tensor + std::vector sizes = {}; // 0D tensor + std::vector strides = {}; // Empty strides for scalar + Tensor* tensor; + + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + strides.data(), + static_cast(SupportedDTypes::BOOL), + static_cast(SupportedDevices::CUDA), + 0, + &tensor); + + if (error != Error::Ok || tensor == nullptr) { + return nullptr; + } + + // Set the value + bool host_value = value; + cudaError_t cuda_err = cudaMemcpy( + tensor->mutable_data_ptr(), + &host_value, + sizeof(bool), + cudaMemcpyHostToDevice); + + if (cuda_err != cudaSuccess) { + aoti_torch_delete_tensor_object(tensor); + return nullptr; + } + + return tensor; + } + + // Helper to create a bool tensor on CPU with a specific value + Tensor* create_cpu_bool_tensor(bool value) { + // Create a 0D (scalar) bool tensor + std::vector sizes = {}; // 0D tensor + std::vector strides = {}; // Empty strides for scalar + Tensor* tensor; + + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + strides.data(), + static_cast(SupportedDTypes::BOOL), + static_cast(SupportedDevices::CPU), + 0, + &tensor); + + if (error != Error::Ok || tensor == nullptr) { + return nullptr; + } + + // Set the value directly + bool* data_ptr = static_cast(tensor->mutable_data_ptr()); + *data_ptr = value; + + return tensor; + } +}; + +// Test extracting true value from CUDA bool tensor +TEST_F(AOTITorchItemBoolTest, CUDATensorTrueValue) { + Tensor* tensor = create_cuda_bool_tensor(true); + ASSERT_NE(tensor, nullptr); + + bool result = false; + AOTITorchError error = aoti_torch_item_bool(tensor, &result); + + EXPECT_EQ(error, Error::Ok); + EXPECT_TRUE(result); +} + +// Test extracting false value from CUDA bool tensor +TEST_F(AOTITorchItemBoolTest, CUDATensorFalseValue) { + Tensor* tensor = create_cuda_bool_tensor(false); + ASSERT_NE(tensor, nullptr); + + bool result = true; + AOTITorchError error = aoti_torch_item_bool(tensor, &result); + + EXPECT_EQ(error, Error::Ok); + EXPECT_FALSE(result); +} + +// Test extracting true value from CPU bool tensor +TEST_F(AOTITorchItemBoolTest, CPUTensorTrueValue) { + Tensor* tensor = create_cpu_bool_tensor(true); + ASSERT_NE(tensor, nullptr); + + bool result = false; + AOTITorchError error = aoti_torch_item_bool(tensor, &result); + + EXPECT_EQ(error, Error::Ok); + EXPECT_TRUE(result); +} + +// Test extracting false value from CPU bool tensor +TEST_F(AOTITorchItemBoolTest, CPUTensorFalseValue) { + Tensor* tensor = create_cpu_bool_tensor(false); + ASSERT_NE(tensor, nullptr); + + bool result = true; + AOTITorchError error = aoti_torch_item_bool(tensor, &result); + + EXPECT_EQ(error, Error::Ok); + EXPECT_FALSE(result); +} + +// Test with null tensor pointer +TEST_F(AOTITorchItemBoolTest, NullTensorPointer) { + bool result; + AOTITorchError error = aoti_torch_item_bool(nullptr, &result); + EXPECT_EQ(error, Error::InvalidArgument); +} + +// Test with null result pointer +TEST_F(AOTITorchItemBoolTest, NullResultPointer) { + Tensor* tensor = create_cuda_bool_tensor(true); + ASSERT_NE(tensor, nullptr); + + AOTITorchError error = aoti_torch_item_bool(tensor, nullptr); + EXPECT_EQ(error, Error::InvalidArgument); +} + +// Test with non-bool dtype (should fail) +TEST_F(AOTITorchItemBoolTest, NonBoolDtype) { + // Create a float tensor + std::vector sizes = {}; + std::vector strides = {}; + Tensor* tensor; + + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + strides.data(), + static_cast(SupportedDTypes::FLOAT32), // Not bool + static_cast(SupportedDevices::CUDA), + 0, + &tensor); + + ASSERT_EQ(error, Error::Ok); + ASSERT_NE(tensor, nullptr); + + bool result; + error = aoti_torch_item_bool(tensor, &result); + EXPECT_EQ(error, Error::InvalidArgument); +} diff --git a/extension/llm/custom_ops/custom_ops.py b/extension/llm/custom_ops/custom_ops.py index dfa357fe356..f71f97d6d83 100644 --- a/extension/llm/custom_ops/custom_ops.py +++ b/extension/llm/custom_ops/custom_ops.py @@ -12,10 +12,16 @@ import logging +from typing import Tuple + import torch +from torch._inductor.lowering import lowerings as L, register_lowering + from torch.library import impl +aten = torch.ops.aten + try: op = torch.ops.llama.sdpa_with_kv_cache.default assert op is not None @@ -387,3 +393,85 @@ def custom_quantized_sdpa_meta( ) return torch.empty(query.size(), dtype=torch.float32, device="meta") + + +# 1) Define the custom op in the "executorch" namespace with name "alias" +@torch.library.custom_op("executorch::alias", mutates_args=()) +def custom_alias(x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # no copies, just pass-through + return x, y + + +# 2) FakeTensor kernel: describes output metadata for compile-time +@custom_alias.register_fake +def _(x, y): + # For this op, outputs have exactly the same shape/dtype/device as inputs. + # We just need *dummy* tensors with that metadata. + out_x = torch.empty_like(x) + out_y = torch.empty_like(y) + return out_x, out_y + + +@register_lowering(torch.ops.executorch.alias.default) +def lowering_custom_alias(x, y): + # x, y here are IR values (Inductor's internal representation). + # Alias is logically a no-op – just pass them through. + return x, y + + +# Expecting cache shape: (B, H, S_max, D), value shape (B, H, S, D) where S <= S_max +def _validate_cross_attn_cache_params(value: torch.Tensor, cache: torch.Tensor): + torch._assert(value.dim() == 4, "value must be 4D") + torch._assert(cache.dim() == 4, "cache must be 4D") + # Cache shape: (B, H, S_max, D) + # Value shape: (B, H, S, D) + torch._assert( + value.size(2) <= cache.size(2), + f"value sequence length {value.size(2)} exceeds cache size {cache.size(2)}", + ) + torch._assert(value.size(0) == cache.size(0), "batch size mismatch") + torch._assert(value.size(1) == cache.size(1), "num heads mismatch") + torch._assert(value.size(3) == cache.size(3), "head dim mismatch") + torch._assert(value.dtype == cache.dtype, "dtype mismatch") + + +# This is cheating: we delibrately NOT mark `cache` to be mutating so that this +# custom op can be used in HOP such as `torch.cond`, where `torch.compile` requires +# no aliasing or mutation in the branches. This is fine because we only care about inference. +@torch.library.custom_op("executorch::update_cross_attn_cache", mutates_args=[]) +def _update_cross_attn_cache(value: torch.Tensor, cache: torch.Tensor) -> torch.Tensor: + # Eager implementation + _validate_cross_attn_cache_params(value, cache) + + # Slice the cache to match value's sequence length and copy + # cache shape: [B, H, S_max, D] + # value shape: [B, H, S, D] + cache[:, :, : value.size(2), :].copy_(value) + # Return a clone of the cache to avoid aliasing with the input cache, so that we can still run exported program. + return cache.clone() + + +# Register the fake (meta) kernel +@_update_cross_attn_cache.register_fake +def _update_cross_attn_cache_fake( + value: torch.Tensor, cache: torch.Tensor +) -> torch.Tensor: + _validate_cross_attn_cache_params(value, cache) + return torch.empty_like(cache) + + +# Register Inductor lowering +@register_lowering(torch.ops.executorch.update_cross_attn_cache) +def _update_cross_attn_cache_lowering(value, cache): + # cache shape: [B, H, S_max, D] + # value shape: [B, H, S, D] + + # We need to slice the cache along dim 2 (sequence length) + # slice(self, dim, start, end, step=1) + seq_len = value.get_size()[2] + cache_slice = L[aten.slice.Tensor](cache, 2, 0, seq_len, 1) + + # Copy value into the slice + L[aten.copy_.default](cache_slice, value) + + return cache diff --git a/extension/llm/custom_ops/test_update_cross_attn_cache.py b/extension/llm/custom_ops/test_update_cross_attn_cache.py new file mode 100644 index 00000000000..7b1831c9779 --- /dev/null +++ b/extension/llm/custom_ops/test_update_cross_attn_cache.py @@ -0,0 +1,279 @@ +import unittest + +import torch + +# Import the custom ops to ensure they are registered +from executorch.extension.llm.custom_ops import custom_ops # noqa: F401 + +# Check CUDA availability once at module level +CUDA_AVAILABLE = torch.cuda.is_available() + + +class TestUpdateCrossAttnCache(unittest.TestCase): + def test_update_cross_attn_cache(self): + + # Create tensors + # Cache: [B=2, H=1, S_max=4, D=4] + cache = torch.zeros(2, 1, 4, 4, dtype=torch.float32) + # Value: [B=2, H=1, S=2, D=4] (S < S_max) + value = torch.randn(2, 1, 2, 4, dtype=torch.float32) + + # Compile a function that uses the op + @torch.compile + def fn(v, c): + return torch.ops.executorch.update_cross_attn_cache(v, c) + + # Run it + out = fn(value, cache) + + # Check correctness + # The first 2 elements in dim 2 (sequence dim) should match value + torch.testing.assert_close( + cache[:, :, :2, :], value, msg="Cache slice not updated correctly" + ) + + # Make sure out and cache are close. In eager they are the same objects. + torch.testing.assert_close( + out, cache, msg="Output and cache are different objects" + ) + + # The rest should be zeros + torch.testing.assert_close( + cache[:, :, 2:, :], + torch.zeros_like(cache[:, :, 2:, :]), + msg="Rest of cache was modified", + ) + + def test_update_cross_attn_cache_in_cond(self): + # Create tensors + + # Value: [B=2, H=1, S=2, D=4] + value = torch.randn(2, 1, 2, 4, dtype=torch.float32) + # Alternative value for false branch + value_alt = torch.randn(2, 1, 2, 4, dtype=torch.float32) + + # Define a function that uses the op inside torch.cond + def fn_with_cond(pred, v1, v2, c): + def true_fn(v1, v2, cache): + return torch.ops.executorch.update_cross_attn_cache(v1, cache) + + def false_fn(v1, v2, cache): + return torch.ops.executorch.update_cross_attn_cache(v2, cache) + + return torch.cond(pred, true_fn, false_fn, (v1, v2, c)) + + # Test with true condition + pred_true = torch.tensor(True) + cache_true = torch.zeros(2, 1, 4, 4, dtype=torch.float32) + + # Compile the function + @torch.compile + def compiled_fn(pred, v1, v2, c): + return fn_with_cond(pred, v1, v2, c) + + # Run with true condition + compiled_fn(pred_true, value, value_alt, cache_true) + + # Check that the true branch was executed (value was used) + torch.testing.assert_close( + cache_true[:, :, :2, :], + value, + msg="Cache not updated correctly in true branch", + ) + + # Test with false condition + pred_false = torch.tensor(False) + cache_false = torch.zeros(2, 1, 4, 4, dtype=torch.float32) + + compiled_fn(pred_false, value, value_alt, cache_false) + + # Check that the false branch was executed (value_alt was used) + torch.testing.assert_close( + cache_false[:, :, :2, :], + value_alt, + msg="Cache not updated correctly in false branch", + ) + + def test_update_cross_attn_cache_export(self): + + # Create tensors + # Cache: [B=2, H=1, S_max=4, D=4] + cache = torch.zeros(2, 1, 4, 4, dtype=torch.float32) + # Value: [B=2, H=1, S=2, D=4] + value = torch.randn(2, 1, 2, 4, dtype=torch.float32) + # Alternative value for false branch + value_alt = torch.randn(2, 1, 2, 4, dtype=torch.float32) + + # Define a module that uses torch.cond with the op + class UpdateCacheCondModule(torch.nn.Module): + def forward(self, pred, v1, v2, c): + def true_fn(v1, v2, cache): + return torch.ops.executorch.update_cross_attn_cache(v1, cache) + + def false_fn(v1, v2, cache): + return torch.ops.executorch.update_cross_attn_cache(v2, cache) + + return torch.cond(pred, true_fn, false_fn, (v1, v2, c)) + + module = UpdateCacheCondModule() + + # Export the module with true condition + pred_true = torch.tensor(True) + exported_program = torch.export.export( + module, + (pred_true, value, value_alt, cache), + ) + + # Run the exported program with true condition + cache_true = torch.zeros(2, 1, 4, 4, dtype=torch.float32) + exported_program.module()(pred_true, value, value_alt, cache_true) + + # Check that the true branch was executed (value was used) + torch.testing.assert_close( + cache_true[:, :, :2, :], + value, + msg="Cache not updated correctly in true branch after export", + ) + + # Run the exported program with false condition + pred_false = torch.tensor(False) + cache_false = torch.zeros(2, 1, 4, 4, dtype=torch.float32) + exported_program.module()(pred_false, value, value_alt, cache_false) + + # Check that the false branch was executed (value_alt was used) + torch.testing.assert_close( + cache_false[:, :, :2, :], + value_alt, + msg="Cache not updated correctly in false branch after export", + ) + + def test_update_cross_attn_cache_different_shapes(self): + print("Testing executorch::update_cross_attn_cache with different shapes...") + + # Test with different batch sizes and sequence lengths + test_cases = [ + # (B, H, S_max, S, D) + (1, 2, 10, 5, 8), + (4, 4, 8, 3, 16), + (2, 1, 16, 10, 32), + ] + + for B, H, S_max, S, D in test_cases: + # Cache: [B, H, S_max, D], Value: [B, H, S, D] + cache = torch.zeros(B, H, S_max, D, dtype=torch.float32) + value = torch.randn(B, H, S, D, dtype=torch.float32) + + @torch.compile + def fn(v, c): + return torch.ops.executorch.update_cross_attn_cache(v, c) + + fn(value, cache) + + # Check that the first S positions in dim 2 are updated + torch.testing.assert_close( + cache[:, :, :S, :], + value, + msg=f"Failed for shape B={B}, H={H}, S_max={S_max}, S={S}, D={D}", + ) + + # Check that the rest remain zeros + if S < S_max: + torch.testing.assert_close( + cache[:, :, S:, :], + torch.zeros_like(cache[:, :, S:, :]), + msg=f"Remaining cache modified for shape B={B}, H={H}, S_max={S_max}, S={S}, D={D}", + ) + + def test_update_cross_attn_cache_full_sequence(self): + + # Cache: [B=2, H=1, S_max=4, D=4] + cache = torch.zeros(2, 1, 4, 4, dtype=torch.float32) + # Value: [B=2, H=1, S=4, D=4] (S == S_max) + value = torch.randn(2, 1, 4, 4, dtype=torch.float32) + + @torch.compile + def fn(v, c): + return torch.ops.executorch.update_cross_attn_cache(v, c) + + fn(value, cache) + + # The entire cache should match value + torch.testing.assert_close( + cache, value, msg="Cache not fully updated when S == S_max" + ) + + @unittest.skipUnless(CUDA_AVAILABLE, "CUDA not available") + def test_alias_and_update_cross_attn_cache_with_cond_triton(self): + """Test combining alias and update_cross_attn_cache ops with torch.cond, + lowered to Triton on CUDA. True branch uses alias, false branch uses + update_cross_attn_cache.""" + + # Create CUDA tensors + # Value: [B=2, H=1, S=2, D=4] + value = torch.randn(2, 1, 2, 4, dtype=torch.float32, device="cuda") + # Extra tensor for alias op + extra = torch.randn(2, 1, 4, 4, dtype=torch.float32, device="cuda") + + # Define a function that uses different ops in each branch + def fn_with_cond(pred, v, extra_tensor, c): + def true_fn(v, extra_tensor, cache): + # True branch: use alias op only + aliased_cache, aliased_extra = torch.ops.executorch.alias( + cache, extra_tensor + ) + # Return sum of aliased tensors (no cache mutation) + return aliased_cache + aliased_extra + + def false_fn(v, extra_tensor, cache): + # False branch: use update_cross_attn_cache op only + updated = torch.ops.executorch.update_cross_attn_cache(v, cache) + return updated + + return torch.cond(pred, true_fn, false_fn, (v, extra_tensor, c)) + + # Compile the function with Triton backend + @torch.compile(backend="inductor") + def compiled_fn(pred, v, extra_tensor, c): + return fn_with_cond(pred, v, extra_tensor, c) + + # Test with true condition (alias branch) + pred_true = torch.tensor(True, device="cuda") + cache_true = torch.zeros(2, 1, 4, 4, dtype=torch.float32, device="cuda") + + result_true = compiled_fn(pred_true, value, extra, cache_true) + + # Check that the true branch was executed (alias: cache + extra) + expected_true = cache_true + extra + torch.testing.assert_close( + result_true, + expected_true, + msg="Result incorrect in true branch (alias) with CUDA/Triton", + ) + + # Cache should remain unchanged in true branch (alias doesn't mutate) + torch.testing.assert_close( + cache_true, + torch.zeros(2, 1, 4, 4, dtype=torch.float32, device="cuda"), + msg="Cache should not be mutated in true branch (alias)", + ) + + # Test with false condition (update_cross_attn_cache branch) + pred_false = torch.tensor(False, device="cuda") + cache_false = torch.zeros(2, 1, 4, 4, dtype=torch.float32, device="cuda") + + compiled_fn(pred_false, value, extra, cache_false) + + # Check that the false branch was executed (update_cross_attn_cache) + # The cache should be updated with value in the first S positions + torch.testing.assert_close( + cache_false[:, :, :2, :], + value, + msg="Cache not updated correctly in false branch with CUDA/Triton", + ) + + # The rest of the cache should remain zeros + torch.testing.assert_close( + cache_false[:, :, 2:, :], + torch.zeros(2, 1, 2, 4, dtype=torch.float32, device="cuda"), + msg="Rest of cache was modified in false branch", + ) diff --git a/torch_pin.py b/torch_pin.py index e934463cb70..62a2572fd78 100644 --- a/torch_pin.py +++ b/torch_pin.py @@ -1,2 +1,2 @@ -TORCH_VERSION = "2.10.0" -NIGHTLY_VERSION = "dev20251120" +TORCH_VERSION = "2.11.0" +NIGHTLY_VERSION = "dev20251222"