Skip to content

Commit a437b06

Browse files
committed
Add aoti_torch_item_bool and aoti_torch_assign_tensors_out shims
Add two new shim implementations for the CUDA AOTI backend: 1. aoti_torch_item_bool: Extracts a boolean value from a 0D boolean tensor. Handles both CPU and CUDA tensors by using cudaPointerGetAttributes to determine the memory location and copying from device if needed. 2. aoti_torch_assign_tensors_out: Creates a new tensor view that shares the same underlying data as the source tensor. The new tensor has the same shape, strides, and dtype as the source. Also adds: - Declaration of aoti_torch_dtype_bool() in common_shims.h - Unit tests for both new functions - Update CMakeLists.txt with new test targets - Update targets.bzl with new test targets ghstack-source-id: 4aaf6d8 ghstack-comment-id: 3676249127 Pull-Request: #16345
1 parent 8f0b1c3 commit a437b06

File tree

7 files changed

+615
-6
lines changed

7 files changed

+615
-6
lines changed

backends/aoti/common_shims.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ AOTI_SHIM_EXPORT int32_t aoti_torch_device_type_cpu();
6464
AOTI_SHIM_EXPORT int32_t aoti_torch_layout_strided();
6565
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_float32();
6666
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_bfloat16();
67+
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_bool();
6768
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int8();
6869
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int16();
6970
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int32();

backends/cuda/runtime/shims/memory.cpp

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ namespace executorch::backends::cuda {
2424

2525
using executorch::aten::SizesType;
2626
using executorch::aten::StridesType;
27+
using executorch::backends::aoti::aoti_torch_dtype_bool;
2728
using executorch::backends::aoti::aoti_torch_get_device_index;
2829
using executorch::backends::aoti::aoti_torch_get_dtype;
2930
using executorch::backends::aoti::aoti_torch_get_sizes;
@@ -802,6 +803,126 @@ AOTITorchError aoti_torch_new_tensor_handle(
802803

803804
return Error::Ok;
804805
}
806+
807+
AOTITorchError aoti_torch_item_bool(Tensor* tensor, bool* ret_value) {
808+
// Validate input parameters
809+
ET_CHECK_OR_RETURN_ERROR(
810+
tensor != nullptr,
811+
InvalidArgument,
812+
"aoti_torch_item_bool failed: tensor is null");
813+
814+
ET_CHECK_OR_RETURN_ERROR(
815+
ret_value != nullptr,
816+
InvalidArgument,
817+
"aoti_torch_item_bool failed: ret_value is null");
818+
819+
// Validate that tensor dtype is bool
820+
int32_t dtype;
821+
ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_dtype(tensor, &dtype));
822+
823+
ET_CHECK_OR_RETURN_ERROR(
824+
dtype == aoti_torch_dtype_bool(),
825+
InvalidArgument,
826+
"aoti_torch_item_bool failed: tensor dtype is not bool (got %d)",
827+
dtype);
828+
829+
// Get the data pointer
830+
const void* data_ptr = tensor->const_data_ptr();
831+
ET_CHECK_OR_RETURN_ERROR(
832+
data_ptr != nullptr,
833+
InvalidArgument,
834+
"aoti_torch_item_bool failed: tensor data pointer is null");
835+
836+
// Check if tensor is on CUDA or CPU
837+
cudaPointerAttributes attributes{};
838+
ET_CUDA_CHECK_OR_RETURN_ERROR(
839+
cudaPointerGetAttributes(&attributes, data_ptr));
840+
841+
if (attributes.type == cudaMemoryTypeDevice) {
842+
// CUDA memory case: copy from device to host
843+
bool device_value;
844+
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpy(
845+
&device_value, data_ptr, sizeof(bool), cudaMemcpyDeviceToHost));
846+
*ret_value = device_value;
847+
} else {
848+
// CPU memory case: direct access
849+
const bool* bool_ptr = static_cast<const bool*>(data_ptr);
850+
*ret_value = *bool_ptr;
851+
}
852+
853+
return Error::Ok;
854+
}
855+
856+
AOTITorchError aoti_torch_assign_tensors_out(Tensor* src, Tensor** ret_dst) {
857+
// Validate input parameters
858+
ET_CHECK_OR_RETURN_ERROR(
859+
src != nullptr,
860+
InvalidArgument,
861+
"aoti_torch_assign_tensors_out failed: src is null");
862+
863+
ET_CHECK_OR_RETURN_ERROR(
864+
ret_dst != nullptr,
865+
InvalidArgument,
866+
"aoti_torch_assign_tensors_out failed: ret_dst is null");
867+
868+
// Get the data pointer from the source tensor
869+
void* data_ptr = src->mutable_data_ptr();
870+
ET_CHECK_OR_RETURN_ERROR(
871+
data_ptr != nullptr,
872+
InvalidArgument,
873+
"Source tensor has null data pointer");
874+
875+
// Check if the given memory is in the map, if not return error
876+
auto memory_it = memory_to_n_tensor.find(data_ptr);
877+
ET_CHECK_OR_RETURN_ERROR(
878+
memory_it != memory_to_n_tensor.end(),
879+
InvalidArgument,
880+
"Memory address %p is not being tracked by reference counting system",
881+
data_ptr);
882+
883+
// Get dtype from source tensor
884+
int32_t dtype = 0;
885+
ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_dtype(src, &dtype));
886+
887+
// Get sizes and strides from source tensor
888+
int64_t* sizes_ptr;
889+
int64_t* strides_ptr;
890+
ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_sizes(src, &sizes_ptr));
891+
ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_strides(src, &strides_ptr));
892+
893+
int64_t ndim = src->dim();
894+
895+
// Convert to vectors
896+
std::vector<SizesType> sizes = convert_sizes_to_vector(ndim, sizes_ptr);
897+
std::vector<StridesType> strides =
898+
convert_strides_to_vector(ndim, sizes_ptr, strides_ptr);
899+
900+
// Create new tensor view that shares the same memory as source tensor
901+
std::shared_ptr<Tensor> tensor = make_tensor(
902+
sizes,
903+
data_ptr, // Share the same memory from source tensor
904+
{}, // dim_order (empty, will be auto-generated)
905+
strides,
906+
dtype_to_scalar_type(dtype));
907+
908+
ET_CHECK_OR_RETURN_ERROR(
909+
tensor != nullptr,
910+
InvalidArgument,
911+
"Failed to create tensor view in aoti_torch_assign_tensors_out");
912+
913+
// Store the tensor so it doesn't get destroyed
914+
tensors.insert(tensor);
915+
916+
*ret_dst = tensor.get();
917+
918+
// Increment the reference count for this memory address only if it is owned
919+
// by tensor
920+
memory_to_n_tensor[data_ptr] = memory_to_n_tensor[data_ptr] == NOT_OWN
921+
? NOT_OWN
922+
: memory_to_n_tensor[data_ptr] + 1;
923+
924+
return Error::Ok;
925+
}
805926
} // extern "C"
806927

807928
} // namespace executorch::backends::cuda

backends/cuda/runtime/shims/memory.h

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,41 @@ aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking);
161161
* @return Error::Ok on success, appropriate error code on failure:
162162
* - Error::InvalidArgument: null pointers or invalid parameters
163163
*/
164-
AOTITorchError aoti_torch_new_tensor_handle(
165-
Tensor* orig_handle,
166-
Tensor** new_handle);
164+
AOTI_SHIM_EXPORT AOTITorchError
165+
aoti_torch_new_tensor_handle(Tensor* orig_handle, Tensor** new_handle);
166+
167+
/**
168+
* Retrieves a boolean value from a 0D boolean tensor.
169+
*
170+
* This function extracts the scalar boolean value from a tensor that contains
171+
* a single boolean element. The tensor can be on either CPU or CUDA device.
172+
* For CUDA tensors, the value is copied from device to host memory.
173+
*
174+
* @param tensor Pointer to a 0D boolean tensor (must not be null)
175+
* @param ret_value Output pointer to store the boolean value (must not be null)
176+
*
177+
* @return Error::Ok on success, appropriate error code on failure:
178+
* - Error::InvalidArgument: null pointers or tensor dtype is not bool
179+
*/
180+
AOTI_SHIM_EXPORT AOTITorchError
181+
aoti_torch_item_bool(Tensor* tensor, bool* ret_value);
182+
183+
/**
184+
* Creates a new tensor that shares the same underlying data as the source
185+
* tensor.
186+
*
187+
* This function creates a new tensor view with the same shape, strides, and
188+
* dtype as the source tensor, sharing the same underlying memory. The new
189+
* tensor handle will be stored in ret_dst.
190+
*
191+
* @param src The source tensor providing the data and metadata.
192+
* @param ret_dst On output, this will point to the new tensor view.
193+
*
194+
* @return Error::Ok on success, appropriate error code on failure:
195+
* - Error::InvalidArgument: null pointers or memory not tracked
196+
*/
197+
AOTI_SHIM_EXPORT AOTITorchError
198+
aoti_torch_assign_tensors_out(Tensor* src, Tensor** ret_dst);
167199

168200
// Function to clear all tensors from internal storage
169201
AOTI_SHIM_EXPORT void clear_all_tensors();

backends/cuda/runtime/shims/tests/CMakeLists.txt

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,14 @@ find_package(executorch CONFIG REQUIRED HINTS ${CMAKE_INSTALL_PREFIX})
3737

3838
# List of test files
3939
set(CUDA_SHIM_TESTS
40-
test_aoti_torch_create_tensor_from_blob_v2 test_aoti_torch_empty_strided
41-
test_aoti_torch_delete_tensor_object test_aoti_torch__reinterpret_tensor
42-
test_aoti_torch_copy_ test_aoti_torch_new_tensor_handle
40+
test_aoti_torch_create_tensor_from_blob_v2
41+
test_aoti_torch_empty_strided
42+
test_aoti_torch_delete_tensor_object
43+
test_aoti_torch__reinterpret_tensor
44+
test_aoti_torch_copy_
45+
test_aoti_torch_new_tensor_handle
46+
test_aoti_torch_item_bool
47+
test_aoti_torch_assign_tensors_out
4348
)
4449

4550
enable_testing()

backends/cuda/runtime/shims/tests/targets.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,5 @@ def define_common_targets():
3535
cuda_shim_cpp_unittest("aoti_torch_cuda_guard")
3636
cuda_shim_cpp_unittest("aoti_torch_cuda__weight_int4pack_mm")
3737
cuda_shim_cpp_unittest("aoti_torch_new_tensor_handle")
38+
cuda_shim_cpp_unittest("aoti_torch_item_bool")
39+
cuda_shim_cpp_unittest("aoti_torch_assign_tensors_out")

0 commit comments

Comments
 (0)