Skip to content

Commit a5799b4

Browse files
authored
Add aoti_torch_item_bool and aoti_torch_assign_tensors_out shims (#16345)
Add two new shim implementations for the CUDA AOTI backend: 1. aoti_torch_item_bool: Extracts a boolean value from a 0D boolean tensor. Handles both CPU and CUDA tensors by using cudaPointerGetAttributes to determine the memory location and copying from device if needed. 2. aoti_torch_assign_tensors_out: Creates a new tensor view that shares the same underlying data as the source tensor. The new tensor has the same shape, strides, and dtype as the source. Also adds: - Declaration of aoti_torch_dtype_bool() in common_shims.h - Unit tests for both new functions - Update CMakeLists.txt with new test targets - Update targets.bzl with new test targets
1 parent 210047b commit a5799b4

File tree

7 files changed

+615
-6
lines changed

7 files changed

+615
-6
lines changed

backends/aoti/common_shims.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ AOTI_SHIM_EXPORT int32_t aoti_torch_device_type_cpu();
6464
AOTI_SHIM_EXPORT int32_t aoti_torch_layout_strided();
6565
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_float32();
6666
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_bfloat16();
67+
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_bool();
6768
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int8();
6869
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int16();
6970
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int32();

backends/cuda/runtime/shims/memory.cpp

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ namespace executorch::backends::cuda {
2424

2525
using executorch::aten::SizesType;
2626
using executorch::aten::StridesType;
27+
using executorch::backends::aoti::aoti_torch_dtype_bool;
2728
using executorch::backends::aoti::aoti_torch_get_device_index;
2829
using executorch::backends::aoti::aoti_torch_get_dtype;
2930
using executorch::backends::aoti::aoti_torch_get_sizes;
@@ -800,6 +801,126 @@ AOTITorchError aoti_torch_new_tensor_handle(
800801

801802
return Error::Ok;
802803
}
804+
805+
AOTITorchError aoti_torch_item_bool(Tensor* tensor, bool* ret_value) {
806+
// Validate input parameters
807+
ET_CHECK_OR_RETURN_ERROR(
808+
tensor != nullptr,
809+
InvalidArgument,
810+
"aoti_torch_item_bool failed: tensor is null");
811+
812+
ET_CHECK_OR_RETURN_ERROR(
813+
ret_value != nullptr,
814+
InvalidArgument,
815+
"aoti_torch_item_bool failed: ret_value is null");
816+
817+
// Validate that tensor dtype is bool
818+
int32_t dtype;
819+
ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_dtype(tensor, &dtype));
820+
821+
ET_CHECK_OR_RETURN_ERROR(
822+
dtype == aoti_torch_dtype_bool(),
823+
InvalidArgument,
824+
"aoti_torch_item_bool failed: tensor dtype is not bool (got %d)",
825+
dtype);
826+
827+
// Get the data pointer
828+
const void* data_ptr = tensor->const_data_ptr();
829+
ET_CHECK_OR_RETURN_ERROR(
830+
data_ptr != nullptr,
831+
InvalidArgument,
832+
"aoti_torch_item_bool failed: tensor data pointer is null");
833+
834+
// Check if tensor is on CUDA or CPU
835+
cudaPointerAttributes attributes{};
836+
ET_CUDA_CHECK_OR_RETURN_ERROR(
837+
cudaPointerGetAttributes(&attributes, data_ptr));
838+
839+
if (attributes.type == cudaMemoryTypeDevice) {
840+
// CUDA memory case: copy from device to host
841+
bool device_value;
842+
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpy(
843+
&device_value, data_ptr, sizeof(bool), cudaMemcpyDeviceToHost));
844+
*ret_value = device_value;
845+
} else {
846+
// CPU memory case: direct access
847+
const bool* bool_ptr = static_cast<const bool*>(data_ptr);
848+
*ret_value = *bool_ptr;
849+
}
850+
851+
return Error::Ok;
852+
}
853+
854+
AOTITorchError aoti_torch_assign_tensors_out(Tensor* src, Tensor** ret_dst) {
855+
// Validate input parameters
856+
ET_CHECK_OR_RETURN_ERROR(
857+
src != nullptr,
858+
InvalidArgument,
859+
"aoti_torch_assign_tensors_out failed: src is null");
860+
861+
ET_CHECK_OR_RETURN_ERROR(
862+
ret_dst != nullptr,
863+
InvalidArgument,
864+
"aoti_torch_assign_tensors_out failed: ret_dst is null");
865+
866+
// Get the data pointer from the source tensor
867+
void* data_ptr = src->mutable_data_ptr();
868+
ET_CHECK_OR_RETURN_ERROR(
869+
data_ptr != nullptr,
870+
InvalidArgument,
871+
"Source tensor has null data pointer");
872+
873+
// Check if the given memory is in the map, if not return error
874+
auto memory_it = memory_to_n_tensor.find(data_ptr);
875+
ET_CHECK_OR_RETURN_ERROR(
876+
memory_it != memory_to_n_tensor.end(),
877+
InvalidArgument,
878+
"Memory address %p is not being tracked by reference counting system",
879+
data_ptr);
880+
881+
// Get dtype from source tensor
882+
int32_t dtype = 0;
883+
ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_dtype(src, &dtype));
884+
885+
// Get sizes and strides from source tensor
886+
int64_t* sizes_ptr;
887+
int64_t* strides_ptr;
888+
ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_sizes(src, &sizes_ptr));
889+
ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_strides(src, &strides_ptr));
890+
891+
int64_t ndim = src->dim();
892+
893+
// Convert to vectors
894+
std::vector<SizesType> sizes = convert_sizes_to_vector(ndim, sizes_ptr);
895+
std::vector<StridesType> strides =
896+
convert_strides_to_vector(ndim, sizes_ptr, strides_ptr);
897+
898+
// Create new tensor view that shares the same memory as source tensor
899+
std::shared_ptr<Tensor> tensor = make_tensor(
900+
sizes,
901+
data_ptr, // Share the same memory from source tensor
902+
{}, // dim_order (empty, will be auto-generated)
903+
strides,
904+
dtype_to_scalar_type(dtype));
905+
906+
ET_CHECK_OR_RETURN_ERROR(
907+
tensor != nullptr,
908+
InvalidArgument,
909+
"Failed to create tensor view in aoti_torch_assign_tensors_out");
910+
911+
// Store the tensor so it doesn't get destroyed
912+
tensors.insert(tensor);
913+
914+
*ret_dst = tensor.get();
915+
916+
// Increment the reference count for this memory address only if it is owned
917+
// by tensor
918+
memory_to_n_tensor[data_ptr] = memory_to_n_tensor[data_ptr] == NOT_OWN
919+
? NOT_OWN
920+
: memory_to_n_tensor[data_ptr] + 1;
921+
922+
return Error::Ok;
923+
}
803924
} // extern "C"
804925

805926
} // namespace executorch::backends::cuda

backends/cuda/runtime/shims/memory.h

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,41 @@ aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking);
161161
* @return Error::Ok on success, appropriate error code on failure:
162162
* - Error::InvalidArgument: null pointers or invalid parameters
163163
*/
164-
AOTITorchError aoti_torch_new_tensor_handle(
165-
Tensor* orig_handle,
166-
Tensor** new_handle);
164+
AOTI_SHIM_EXPORT AOTITorchError
165+
aoti_torch_new_tensor_handle(Tensor* orig_handle, Tensor** new_handle);
166+
167+
/**
168+
* Retrieves a boolean value from a 0D boolean tensor.
169+
*
170+
* This function extracts the scalar boolean value from a tensor that contains
171+
* a single boolean element. The tensor can be on either CPU or CUDA device.
172+
* For CUDA tensors, the value is copied from device to host memory.
173+
*
174+
* @param tensor Pointer to a 0D boolean tensor (must not be null)
175+
* @param ret_value Output pointer to store the boolean value (must not be null)
176+
*
177+
* @return Error::Ok on success, appropriate error code on failure:
178+
* - Error::InvalidArgument: null pointers or tensor dtype is not bool
179+
*/
180+
AOTI_SHIM_EXPORT AOTITorchError
181+
aoti_torch_item_bool(Tensor* tensor, bool* ret_value);
182+
183+
/**
184+
* Creates a new tensor that shares the same underlying data as the source
185+
* tensor.
186+
*
187+
* This function creates a new tensor view with the same shape, strides, and
188+
* dtype as the source tensor, sharing the same underlying memory. The new
189+
* tensor handle will be stored in ret_dst.
190+
*
191+
* @param src The source tensor providing the data and metadata.
192+
* @param ret_dst On output, this will point to the new tensor view.
193+
*
194+
* @return Error::Ok on success, appropriate error code on failure:
195+
* - Error::InvalidArgument: null pointers or memory not tracked
196+
*/
197+
AOTI_SHIM_EXPORT AOTITorchError
198+
aoti_torch_assign_tensors_out(Tensor* src, Tensor** ret_dst);
167199

168200
// Function to clear all tensors from internal storage
169201
AOTI_SHIM_EXPORT void clear_all_tensors();

backends/cuda/runtime/shims/tests/CMakeLists.txt

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,14 @@ find_package(executorch CONFIG REQUIRED HINTS ${CMAKE_INSTALL_PREFIX})
3737

3838
# List of test files
3939
set(CUDA_SHIM_TESTS
40-
test_aoti_torch_create_tensor_from_blob_v2 test_aoti_torch_empty_strided
41-
test_aoti_torch_delete_tensor_object test_aoti_torch__reinterpret_tensor
42-
test_aoti_torch_copy_ test_aoti_torch_new_tensor_handle
40+
test_aoti_torch_create_tensor_from_blob_v2
41+
test_aoti_torch_empty_strided
42+
test_aoti_torch_delete_tensor_object
43+
test_aoti_torch__reinterpret_tensor
44+
test_aoti_torch_copy_
45+
test_aoti_torch_new_tensor_handle
46+
test_aoti_torch_item_bool
47+
test_aoti_torch_assign_tensors_out
4348
)
4449

4550
enable_testing()

backends/cuda/runtime/shims/tests/targets.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,5 @@ def define_common_targets():
3535
cuda_shim_cpp_unittest("aoti_torch_cuda_guard")
3636
cuda_shim_cpp_unittest("aoti_torch_cuda__weight_int4pack_mm")
3737
cuda_shim_cpp_unittest("aoti_torch_new_tensor_handle")
38+
cuda_shim_cpp_unittest("aoti_torch_item_bool")
39+
cuda_shim_cpp_unittest("aoti_torch_assign_tensors_out")

0 commit comments

Comments
 (0)