Skip to content
Open
1 change: 1 addition & 0 deletions backends/aoti/common_shims.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ AOTI_SHIM_EXPORT int32_t aoti_torch_device_type_cpu();
AOTI_SHIM_EXPORT int32_t aoti_torch_layout_strided();
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_float32();
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_bfloat16();
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_bool();
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int8();
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int16();
AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int32();
Expand Down
121 changes: 121 additions & 0 deletions backends/cuda/runtime/shims/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ namespace executorch::backends::cuda {

using executorch::aten::SizesType;
using executorch::aten::StridesType;
using executorch::backends::aoti::aoti_torch_dtype_bool;
using executorch::backends::aoti::aoti_torch_get_device_index;
using executorch::backends::aoti::aoti_torch_get_dtype;
using executorch::backends::aoti::aoti_torch_get_sizes;
Expand Down Expand Up @@ -800,6 +801,126 @@ AOTITorchError aoti_torch_new_tensor_handle(

return Error::Ok;
}

AOTITorchError aoti_torch_item_bool(Tensor* tensor, bool* ret_value) {
// Validate input parameters
ET_CHECK_OR_RETURN_ERROR(
tensor != nullptr,
InvalidArgument,
"aoti_torch_item_bool failed: tensor is null");

ET_CHECK_OR_RETURN_ERROR(
ret_value != nullptr,
InvalidArgument,
"aoti_torch_item_bool failed: ret_value is null");

// Validate that tensor dtype is bool
int32_t dtype;
ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_dtype(tensor, &dtype));

ET_CHECK_OR_RETURN_ERROR(
dtype == aoti_torch_dtype_bool(),
InvalidArgument,
"aoti_torch_item_bool failed: tensor dtype is not bool (got %d)",
dtype);

// Get the data pointer
const void* data_ptr = tensor->const_data_ptr();
ET_CHECK_OR_RETURN_ERROR(
data_ptr != nullptr,
InvalidArgument,
"aoti_torch_item_bool failed: tensor data pointer is null");

// Check if tensor is on CUDA or CPU
cudaPointerAttributes attributes{};
ET_CUDA_CHECK_OR_RETURN_ERROR(
cudaPointerGetAttributes(&attributes, data_ptr));

if (attributes.type == cudaMemoryTypeDevice) {
// CUDA memory case: copy from device to host
bool device_value;
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpy(
&device_value, data_ptr, sizeof(bool), cudaMemcpyDeviceToHost));
*ret_value = device_value;
} else {
// CPU memory case: direct access
const bool* bool_ptr = static_cast<const bool*>(data_ptr);
*ret_value = *bool_ptr;
}

return Error::Ok;
}

AOTITorchError aoti_torch_assign_tensors_out(Tensor* src, Tensor** ret_dst) {
// Validate input parameters
ET_CHECK_OR_RETURN_ERROR(
src != nullptr,
InvalidArgument,
"aoti_torch_assign_tensors_out failed: src is null");

ET_CHECK_OR_RETURN_ERROR(
ret_dst != nullptr,
InvalidArgument,
"aoti_torch_assign_tensors_out failed: ret_dst is null");

// Get the data pointer from the source tensor
void* data_ptr = src->mutable_data_ptr();
ET_CHECK_OR_RETURN_ERROR(
data_ptr != nullptr,
InvalidArgument,
"Source tensor has null data pointer");

// Check if the given memory is in the map, if not return error
auto memory_it = memory_to_n_tensor.find(data_ptr);
ET_CHECK_OR_RETURN_ERROR(
memory_it != memory_to_n_tensor.end(),
InvalidArgument,
"Memory address %p is not being tracked by reference counting system",
data_ptr);

// Get dtype from source tensor
int32_t dtype = 0;
ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_dtype(src, &dtype));

// Get sizes and strides from source tensor
int64_t* sizes_ptr;
int64_t* strides_ptr;
ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_sizes(src, &sizes_ptr));
ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_strides(src, &strides_ptr));

int64_t ndim = src->dim();

// Convert to vectors
std::vector<SizesType> sizes = convert_sizes_to_vector(ndim, sizes_ptr);
std::vector<StridesType> strides =
convert_strides_to_vector(ndim, sizes_ptr, strides_ptr);

// Create new tensor view that shares the same memory as source tensor
std::shared_ptr<Tensor> tensor = make_tensor(
sizes,
data_ptr, // Share the same memory from source tensor
{}, // dim_order (empty, will be auto-generated)
strides,
dtype_to_scalar_type(dtype));

ET_CHECK_OR_RETURN_ERROR(
tensor != nullptr,
InvalidArgument,
"Failed to create tensor view in aoti_torch_assign_tensors_out");

// Store the tensor so it doesn't get destroyed
tensors.insert(tensor);

*ret_dst = tensor.get();

// Increment the reference count for this memory address only if it is owned
// by tensor
memory_to_n_tensor[data_ptr] = memory_to_n_tensor[data_ptr] == NOT_OWN
? NOT_OWN
: memory_to_n_tensor[data_ptr] + 1;

return Error::Ok;
}
} // extern "C"

} // namespace executorch::backends::cuda
38 changes: 35 additions & 3 deletions backends/cuda/runtime/shims/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,41 @@ aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking);
* @return Error::Ok on success, appropriate error code on failure:
* - Error::InvalidArgument: null pointers or invalid parameters
*/
AOTITorchError aoti_torch_new_tensor_handle(
Tensor* orig_handle,
Tensor** new_handle);
AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_new_tensor_handle(Tensor* orig_handle, Tensor** new_handle);

/**
* Retrieves a boolean value from a 0D boolean tensor.
*
* This function extracts the scalar boolean value from a tensor that contains
* a single boolean element. The tensor can be on either CPU or CUDA device.
* For CUDA tensors, the value is copied from device to host memory.
*
* @param tensor Pointer to a 0D boolean tensor (must not be null)
* @param ret_value Output pointer to store the boolean value (must not be null)
*
* @return Error::Ok on success, appropriate error code on failure:
* - Error::InvalidArgument: null pointers or tensor dtype is not bool
*/
AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_item_bool(Tensor* tensor, bool* ret_value);

/**
* Creates a new tensor that shares the same underlying data as the source
* tensor.
*
* This function creates a new tensor view with the same shape, strides, and
* dtype as the source tensor, sharing the same underlying memory. The new
* tensor handle will be stored in ret_dst.
*
* @param src The source tensor providing the data and metadata.
* @param ret_dst On output, this will point to the new tensor view.
*
* @return Error::Ok on success, appropriate error code on failure:
* - Error::InvalidArgument: null pointers or memory not tracked
*/
AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_assign_tensors_out(Tensor* src, Tensor** ret_dst);

// Function to clear all tensors from internal storage
AOTI_SHIM_EXPORT void clear_all_tensors();
Expand Down
11 changes: 8 additions & 3 deletions backends/cuda/runtime/shims/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,14 @@ find_package(executorch CONFIG REQUIRED HINTS ${CMAKE_INSTALL_PREFIX})

# List of test files
set(CUDA_SHIM_TESTS
test_aoti_torch_create_tensor_from_blob_v2 test_aoti_torch_empty_strided
test_aoti_torch_delete_tensor_object test_aoti_torch__reinterpret_tensor
test_aoti_torch_copy_ test_aoti_torch_new_tensor_handle
test_aoti_torch_create_tensor_from_blob_v2
test_aoti_torch_empty_strided
test_aoti_torch_delete_tensor_object
test_aoti_torch__reinterpret_tensor
test_aoti_torch_copy_
test_aoti_torch_new_tensor_handle
test_aoti_torch_item_bool
test_aoti_torch_assign_tensors_out
)

enable_testing()
Expand Down
2 changes: 2 additions & 0 deletions backends/cuda/runtime/shims/tests/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,5 @@ def define_common_targets():
cuda_shim_cpp_unittest("aoti_torch_cuda_guard")
cuda_shim_cpp_unittest("aoti_torch_cuda__weight_int4pack_mm")
cuda_shim_cpp_unittest("aoti_torch_new_tensor_handle")
cuda_shim_cpp_unittest("aoti_torch_item_bool")
cuda_shim_cpp_unittest("aoti_torch_assign_tensors_out")
Loading
Loading