From 63a2766fb8aab7219938408047a9731049f00311 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Fri, 19 Dec 2025 11:21:00 -0800 Subject: [PATCH 1/8] Update [ghstack-poisoned] --- Makefile | 36 +++++++++++++++-------- examples/models/whisper/CMakePresets.json | 34 +++++++++++++++++++++ 2 files changed, 57 insertions(+), 13 deletions(-) diff --git a/Makefile b/Makefile index 13fc941e135..0dcfb7a5ee3 100644 --- a/Makefile +++ b/Makefile @@ -87,21 +87,22 @@ # # ============================================================================== -.PHONY: voxtral-cuda voxtral-cpu voxtral-metal whisper-cuda whisper-cpu whisper-metal llama-cpu llava-cpu gemma3-cuda gemma3-cpu clean help +.PHONY: voxtral-cuda voxtral-cpu voxtral-metal whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal llama-cpu llava-cpu gemma3-cuda gemma3-cpu clean help help: - @echo "This Makefile adds targets to build runners for various models on various backends. Run using `make `. Available targets:" - @echo " voxtral-cuda - Build Voxtral runner with CUDA backend" - @echo " voxtral-cpu - Build Voxtral runner with CPU backend" - @echo " voxtral-metal - Build Voxtral runner with Metal backend (macOS only)" - @echo " whisper-cuda - Build Whisper runner with CUDA backend" - @echo " whisper-cpu - Build Whisper runner with CPU backend" - @echo " whisper-metal - Build Whisper runner with Metal backend (macOS only)" - @echo " llama-cpu - Build Llama runner with CPU backend" - @echo " llava-cpu - Build Llava runner with CPU backend" - @echo " gemma3-cuda - Build Gemma3 runner with CUDA backend" - @echo " gemma3-cpu - Build Gemma3 runner with CPU backend" - @echo " clean - Clean build artifacts" + @echo "This Makefile adds targets to build runners for various models on various backends. Run using \`make \`. Available targets:" + @echo " voxtral-cuda - Build Voxtral runner with CUDA backend" + @echo " voxtral-cpu - Build Voxtral runner with CPU backend" + @echo " voxtral-metal - Build Voxtral runner with Metal backend (macOS only)" + @echo " whisper-cuda - Build Whisper runner with CUDA backend" + @echo " whisper-cuda-debug - Build Whisper runner with CUDA backend (debug mode)" + @echo " whisper-cpu - Build Whisper runner with CPU backend" + @echo " whisper-metal - Build Whisper runner with Metal backend (macOS only)" + @echo " llama-cpu - Build Llama runner with CPU backend" + @echo " llava-cpu - Build Llava runner with CPU backend" + @echo " gemma3-cuda - Build Gemma3 runner with CUDA backend" + @echo " gemma3-cpu - Build Gemma3 runner with CPU backend" + @echo " clean - Clean build artifacts" voxtral-cuda: @echo "==> Building and installing ExecuTorch with CUDA..." @@ -139,6 +140,15 @@ whisper-cuda: @echo "✓ Build complete!" @echo " Binary: cmake-out/examples/models/whisper/whisper_runner" +whisper-cuda-debug: + @echo "==> Building and installing ExecuTorch with CUDA (debug mode)..." + cmake --workflow --preset llm-debug-cuda + @echo "==> Building Whisper runner with CUDA (debug mode)..." + cd examples/models/whisper && cmake --workflow --preset whisper-cuda-debug + @echo "" + @echo "✓ Build complete!" + @echo " Binary: cmake-out/examples/models/whisper/whisper_runner" + whisper-cpu: @echo "==> Building and installing ExecuTorch..." cmake --workflow --preset llm-release diff --git a/examples/models/whisper/CMakePresets.json b/examples/models/whisper/CMakePresets.json index 1419788cb0c..a5bfc2be5a4 100644 --- a/examples/models/whisper/CMakePresets.json +++ b/examples/models/whisper/CMakePresets.json @@ -29,6 +29,20 @@ "list": ["Linux", "Windows"] } }, + { + "name": "whisper-cuda-debug", + "displayName": "Whisper runner (CUDA Debug)", + "inherits": ["whisper-base"], + "cacheVariables": { + "EXECUTORCH_BUILD_CUDA": "ON", + "CMAKE_BUILD_TYPE": "Debug" + }, + "condition": { + "type": "inList", + "string": "${hostSystemName}", + "list": ["Linux", "Windows"] + } + }, { "name": "whisper-metal", "displayName": "Whisper runner (Metal)", @@ -56,6 +70,12 @@ "configurePreset": "whisper-cuda", "targets": ["whisper_runner"] }, + { + "name": "whisper-cuda-debug", + "displayName": "Build Whisper runner (CUDA Debug)", + "configurePreset": "whisper-cuda-debug", + "targets": ["whisper_runner"] + }, { "name": "whisper-metal", "displayName": "Build Whisper runner (Metal)", @@ -92,6 +112,20 @@ } ] }, + { + "name": "whisper-cuda-debug", + "displayName": "Configure and build Whisper runner (CUDA Debug)", + "steps": [ + { + "type": "configure", + "name": "whisper-cuda-debug" + }, + { + "type": "build", + "name": "whisper-cuda-debug" + } + ] + }, { "name": "whisper-metal", "displayName": "Configure and build Whisper runner (Metal)", From f02dbe10d2d2fcf5f52e79660fc0819087e47e4c Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Fri, 19 Dec 2025 11:21:04 -0800 Subject: [PATCH 2/8] Update [ghstack-poisoned] --- .github/workflows/cuda.yml | 26 ++++ backends/cuda/runtime/shims/memory.cpp | 30 +++- .../cuda/runtime/shims/tests/CMakeLists.txt | 60 ++++++++ .../runtime/shims/tests/CMakePresets.json | 88 ++++++++++++ backends/cuda/runtime/shims/tests/README.md | 132 ++++++++++++++++++ ..._aoti_torch_create_tensor_from_blob_v2.cpp | 65 +++++++++ 6 files changed, 399 insertions(+), 2 deletions(-) create mode 100644 backends/cuda/runtime/shims/tests/CMakeLists.txt create mode 100644 backends/cuda/runtime/shims/tests/CMakePresets.json create mode 100644 backends/cuda/runtime/shims/tests/README.md diff --git a/.github/workflows/cuda.yml b/.github/workflows/cuda.yml index 265bc252ee8..904efb4e52f 100644 --- a/.github/workflows/cuda.yml +++ b/.github/workflows/cuda.yml @@ -87,6 +87,32 @@ jobs: export LD_LIBRARY_PATH=/opt/conda/lib:$LD_LIBRARY_PATH PYTHON_EXECUTABLE=python source .ci/scripts/test_model.sh "${{ matrix.model }}" cmake cuda + test-cuda-shims: + name: test-cuda-shims + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + permissions: + id-token: write + contents: read + with: + timeout: 90 + runner: linux.g5.4xlarge.nvidia.gpu + gpu-arch-type: cuda + gpu-arch-version: 12.6 + use-custom-docker-registry: false + submodules: recursive + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + script: | + set -eux + + # Build ExecuTorch with CUDA support + ./install_executorch.sh + + # Build and run CUDA shim tests + pushd backends/cuda/runtime/shims/tests + cmake --workflow --preset default + ctest --preset default --output-on-failure + popd + export-model-cuda-artifact: name: export-model-cuda-artifact # Skip this job if the pull request is from a fork (HuggingFace secrets are not available) diff --git a/backends/cuda/runtime/shims/memory.cpp b/backends/cuda/runtime/shims/memory.cpp index aaaf3913381..b6310b0ee5a 100644 --- a/backends/cuda/runtime/shims/memory.cpp +++ b/backends/cuda/runtime/shims/memory.cpp @@ -119,8 +119,6 @@ AOTITorchError aoti_torch_create_tensor_from_blob_v2( int32_t layout, const uint8_t* opaque_metadata, int64_t opaque_metadata_size) { - // TODO(gasoonjia): verify given data is on the target device - (void)device_type; (void)opaque_metadata; (void)layout; (void)opaque_metadata_size; @@ -154,6 +152,34 @@ AOTITorchError aoti_torch_create_tensor_from_blob_v2( // Storage offset must be 0 since from_blob cannot handle different offsets ET_CHECK_OK_OR_RETURN_ERROR(validate_storage_offset(storage_offset)); + // Verify that data pointer location matches the requested device_type + cudaPointerAttributes data_attributes{}; + ET_CUDA_CHECK_OR_RETURN_ERROR( + cudaPointerGetAttributes(&data_attributes, data)); + + bool data_is_on_device = data_attributes.type == cudaMemoryTypeDevice; + bool data_is_on_host = data_attributes.type == cudaMemoryTypeHost || + data_attributes.type == cudaMemoryTypeUnregistered; + bool requested_device = + device_type == static_cast(SupportedDevices::CUDA); + bool requested_cpu = + device_type == static_cast(SupportedDevices::CPU); + + // Error if data location doesn't match requested device type + ET_CHECK_OR_RETURN_ERROR( + !(data_is_on_device && requested_cpu), + InvalidArgument, + "aoti_torch_create_tensor_from_blob_v2 failed: data pointer %p is on CUDA " + "but device_type is CPU. Data must be on CPU for CPU tensors.", + data); + + ET_CHECK_OR_RETURN_ERROR( + !(data_is_on_host && requested_device), + InvalidArgument, + "aoti_torch_create_tensor_from_blob_v2 failed: data pointer %p is on CPU " + "but device_type is CUDA. Data must be on GPU for CUDA tensors.", + data); + // Convert sizes to the format expected by ExecutorTorch using SizesType std::vector sizes = convert_sizes_to_vector(ndim, sizes_ptr); diff --git a/backends/cuda/runtime/shims/tests/CMakeLists.txt b/backends/cuda/runtime/shims/tests/CMakeLists.txt new file mode 100644 index 00000000000..f626e067627 --- /dev/null +++ b/backends/cuda/runtime/shims/tests/CMakeLists.txt @@ -0,0 +1,60 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +cmake_minimum_required(VERSION 3.19) +project(aoti_cuda_shim_tests LANGUAGES CXX CUDA) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +# Find required packages +find_package(CUDAToolkit REQUIRED) +find_package(GTest REQUIRED) + +# Get EXECUTORCH_ROOT +if(NOT EXECUTORCH_ROOT) + set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../../..) +endif() + +# Find installed ExecuTorch +find_package(executorch CONFIG REQUIRED HINTS ${CMAKE_INSTALL_PREFIX}) + +# List of test files +set(CUDA_SHIM_TESTS + test_aoti_torch_create_tensor_from_blob_v2 + test_aoti_torch_empty_strided + test_aoti_torch_delete_tensor_object + test_aoti_torch__reinterpret_tensor + test_aoti_torch_copy_ + test_aoti_torch_new_tensor_handle +) + +enable_testing() + +foreach(test_name ${CUDA_SHIM_TESTS}) + add_executable(${test_name} ${test_name}.cpp) + + target_include_directories( + ${test_name} + PRIVATE ${EXECUTORCH_ROOT} + ${CUDAToolkit_INCLUDE_DIRS} + ) + + target_link_libraries( + ${test_name} + PRIVATE GTest::gtest + GTest::gtest_main + aoti_cuda_shims + aoti_cuda_backend + cuda_tensor_maker + cuda_platform + executorch_core + extension_tensor + CUDA::cudart + ) + + add_test(NAME ${test_name} COMMAND ${test_name}) +endforeach() diff --git a/backends/cuda/runtime/shims/tests/CMakePresets.json b/backends/cuda/runtime/shims/tests/CMakePresets.json new file mode 100644 index 00000000000..f4e4ee7253e --- /dev/null +++ b/backends/cuda/runtime/shims/tests/CMakePresets.json @@ -0,0 +1,88 @@ +{ + "version": 6, + "configurePresets": [ + { + "name": "default", + "displayName": "CUDA Shim Tests", + "binaryDir": "${sourceDir}/../../../../../cmake-out/backends/cuda/runtime/shims/tests", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Release", + "CMAKE_PREFIX_PATH": "${sourceDir}/../../../../../cmake-out" + }, + "condition": { + "type": "inList", + "string": "${hostSystemName}", + "list": ["Linux", "Windows"] + } + }, + { + "name": "debug", + "displayName": "CUDA Shim Tests (Debug)", + "inherits": ["default"], + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Debug" + } + } + ], + "buildPresets": [ + { + "name": "default", + "displayName": "Build CUDA Shim Tests", + "configurePreset": "default" + }, + { + "name": "debug", + "displayName": "Build CUDA Shim Tests (Debug)", + "configurePreset": "debug" + } + ], + "workflowPresets": [ + { + "name": "default", + "displayName": "Configure and build CUDA Shim Tests", + "steps": [ + { + "type": "configure", + "name": "default" + }, + { + "type": "build", + "name": "default" + } + ] + }, + { + "name": "debug", + "displayName": "Configure and build CUDA Shim Tests (Debug)", + "steps": [ + { + "type": "configure", + "name": "debug" + }, + { + "type": "build", + "name": "debug" + } + ] + } + ], + "testPresets": [ + { + "name": "default", + "displayName": "Run all CUDA Shim Tests", + "configurePreset": "default", + "output": { + "outputOnFailure": true + } + }, + { + "name": "debug", + "displayName": "Run all CUDA Shim Tests (Debug)", + "configurePreset": "debug", + "output": { + "outputOnFailure": true + } + } + ] +} + diff --git a/backends/cuda/runtime/shims/tests/README.md b/backends/cuda/runtime/shims/tests/README.md new file mode 100644 index 00000000000..2c4b91570b0 --- /dev/null +++ b/backends/cuda/runtime/shims/tests/README.md @@ -0,0 +1,132 @@ +# CUDA AOTI Shim Tests + +Unit tests for the CUDA AOTI (Ahead-Of-Time Inductor) shim functions used by the ExecuTorch CUDA backend. + +## Prerequisites + +1. **CUDA Toolkit**: Ensure CUDA is installed and available +2. **ExecuTorch with CUDA**: Build and install ExecuTorch with CUDA support first + +## Building ExecuTorch with CUDA + +From the ExecuTorch root directory: + +```bash +# Release build +cmake --workflow --preset llm-release-cuda + +# Or debug build (recommended for debugging test failures) +cmake --workflow --preset llm-debug-cuda +``` + +## Building the Tests + +### Option 1: Using CMake Presets (Recommended) + +From this directory (`backends/cuda/runtime/shims/tests/`): + +```bash +# Release build +cmake --workflow --preset default + +# Debug build +cmake --workflow --preset debug +``` + +### Option 2: Manual CMake Commands + +From the ExecuTorch root directory: + +```bash +# Configure +cmake -B cmake-out/backends/cuda/runtime/shims/tests \ + -S backends/cuda/runtime/shims/tests \ + -DCMAKE_PREFIX_PATH=$(pwd)/cmake-out \ + -DCMAKE_BUILD_TYPE=Debug + +# Build +cmake --build cmake-out/backends/cuda/runtime/shims/tests -j$(nproc) +``` + +## Running the Tests + +### Run All Tests + +```bash +# Using ctest (from the build directory) +cd cmake-out/backends/cuda/runtime/shims/tests +ctest --output-on-failure + +# Or using the test preset (from this directory) +ctest --preset default +``` + +### Run a Specific Test + +```bash +# From the build directory +./test_aoti_torch_create_tensor_from_blob_v2 +./test_aoti_torch_empty_strided +./test_aoti_torch_delete_tensor_object +./test_aoti_torch_copy_ +./test_aoti_torch_new_tensor_handle +./test_aoti_torch_item_bool +./test_aoti_torch_assign_tensors_out +``` + +### Run Specific Test Cases + +Use Google Test filters to run specific test cases: + +```bash +# Run only device mismatch tests +./test_aoti_torch_create_tensor_from_blob_v2 --gtest_filter="*DeviceMismatch*" + +# Run a single test +./test_aoti_torch_create_tensor_from_blob_v2 --gtest_filter="AOTITorchCreateTensorFromBlobV2Test.BasicFunctionalityCUDA" + +# List all available tests +./test_aoti_torch_create_tensor_from_blob_v2 --gtest_list_tests +``` + +## Test Descriptions + +| Test File | Description | +|-----------|-------------| +| `test_aoti_torch_create_tensor_from_blob_v2` | Tests tensor creation from existing memory blobs, including device type validation | +| `test_aoti_torch_empty_strided` | Tests creation of uninitialized tensors with specified strides | +| `test_aoti_torch_delete_tensor_object` | Tests proper tensor deletion and memory management | +| `test_aoti_torch__reinterpret_tensor` | Tests tensor view reinterpretation with different shapes/strides | +| `test_aoti_torch_copy_` | Tests tensor copy operations between CPU and CUDA | +| `test_aoti_torch_new_tensor_handle` | Tests creating new tensor handles that share memory | +| `test_aoti_torch_item_bool` | Tests extracting boolean values from scalar tensors | +| `test_aoti_torch_assign_tensors_out` | Tests creating tensor views that share underlying data | + +## Troubleshooting + +### CUDA Not Available + +If tests are skipped with "CUDA not available", ensure: +- CUDA drivers are installed +- A CUDA-capable GPU is present +- `nvidia-smi` shows the GPU + +### Link Errors + +If you get link errors, ensure ExecuTorch was built with CUDA support: +```bash +cmake --workflow --preset llm-release-cuda +``` + +### Test Failures + +For debugging test failures, build with debug mode: +```bash +cmake --workflow --preset debug +``` + +Then run with verbose output: +```bash +./test_aoti_torch_create_tensor_from_blob_v2 --gtest_break_on_failure +``` + diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_create_tensor_from_blob_v2.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_create_tensor_from_blob_v2.cpp index d9b785a5a78..db0ab84970d 100644 --- a/backends/cuda/runtime/shims/tests/test_aoti_torch_create_tensor_from_blob_v2.cpp +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_create_tensor_from_blob_v2.cpp @@ -752,3 +752,68 @@ TEST_F(AOTITorchCreateTensorFromBlobV2Test, StressTestManySmallTensors) { EXPECT_EQ(error, Error::Ok); } } + +// Test device type mismatch: CPU data with CUDA device request should fail +TEST_F(AOTITorchCreateTensorFromBlobV2Test, DeviceMismatchCPUDataCUDADevice) { + std::vector sizes = {2, 3}; + std::vector strides = calculate_contiguous_strides(sizes); + + // Allocate CPU memory + size_t bytes = calculate_numel(sizes) * sizeof(float); + void* cpu_data = allocate_cpu_memory(bytes); + ASSERT_NE(cpu_data, nullptr); + + Tensor* tensor; + // Request CUDA device but provide CPU memory - should fail + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + cpu_data, + sizes.size(), + sizes.data(), + strides.data(), + 0, // storage_offset + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), // Request CUDA + 0, // device index + &tensor, + 0, // layout + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + EXPECT_EQ(error, Error::InvalidArgument) + << "Should fail when CPU data is provided but CUDA device is requested"; +} + +// Test device type mismatch: CUDA data with CPU device request should fail +TEST_F(AOTITorchCreateTensorFromBlobV2Test, DeviceMismatchCUDADataCPUDevice) { + std::vector sizes = {2, 3}; + std::vector strides = calculate_contiguous_strides(sizes); + + // Allocate CUDA memory (device memory, not managed) + size_t bytes = calculate_numel(sizes) * sizeof(float); + void* cuda_data = nullptr; + cudaError_t cuda_err = cudaMalloc(&cuda_data, bytes); + ASSERT_EQ(cuda_err, cudaSuccess); + ASSERT_NE(cuda_data, nullptr); + + Tensor* tensor; + // Request CPU device but provide CUDA memory - should fail + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + cuda_data, + sizes.size(), + sizes.data(), + strides.data(), + 0, // storage_offset + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CPU), // Request CPU + 0, // device index + &tensor, + 0, // layout + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + EXPECT_EQ(error, Error::InvalidArgument) + << "Should fail when CUDA data is provided but CPU device is requested"; + + // Clean up the CUDA memory we allocated directly + cudaFree(cuda_data); +} From 9a7aa9143e0fecb99fc07118b110a0ad66cb9200 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Fri, 19 Dec 2025 11:21:09 -0800 Subject: [PATCH 3/8] Update [ghstack-poisoned] --- backends/aoti/common_shims.h | 1 + backends/cuda/runtime/shims/memory.cpp | 121 +++++++++ backends/cuda/runtime/shims/memory.h | 38 ++- .../cuda/runtime/shims/tests/CMakeLists.txt | 2 + backends/cuda/runtime/shims/tests/targets.bzl | 2 + .../test_aoti_torch_assign_tensors_out.cpp | 237 ++++++++++++++++++ .../shims/tests/test_aoti_torch_item_bool.cpp | 204 +++++++++++++++ 7 files changed, 602 insertions(+), 3 deletions(-) create mode 100644 backends/cuda/runtime/shims/tests/test_aoti_torch_assign_tensors_out.cpp create mode 100644 backends/cuda/runtime/shims/tests/test_aoti_torch_item_bool.cpp diff --git a/backends/aoti/common_shims.h b/backends/aoti/common_shims.h index 675a9864e74..3fc414fb669 100644 --- a/backends/aoti/common_shims.h +++ b/backends/aoti/common_shims.h @@ -64,6 +64,7 @@ AOTI_SHIM_EXPORT int32_t aoti_torch_device_type_cpu(); AOTI_SHIM_EXPORT int32_t aoti_torch_layout_strided(); AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_float32(); AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_bfloat16(); +AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_bool(); AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int8(); AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int16(); AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int32(); diff --git a/backends/cuda/runtime/shims/memory.cpp b/backends/cuda/runtime/shims/memory.cpp index b6310b0ee5a..13125807103 100644 --- a/backends/cuda/runtime/shims/memory.cpp +++ b/backends/cuda/runtime/shims/memory.cpp @@ -24,6 +24,7 @@ namespace executorch::backends::cuda { using executorch::aten::SizesType; using executorch::aten::StridesType; +using executorch::backends::aoti::aoti_torch_dtype_bool; using executorch::backends::aoti::aoti_torch_get_device_index; using executorch::backends::aoti::aoti_torch_get_dtype; using executorch::backends::aoti::aoti_torch_get_sizes; @@ -797,6 +798,126 @@ AOTITorchError aoti_torch_new_tensor_handle( return Error::Ok; } + +AOTITorchError aoti_torch_item_bool(Tensor* tensor, bool* ret_value) { + // Validate input parameters + ET_CHECK_OR_RETURN_ERROR( + tensor != nullptr, + InvalidArgument, + "aoti_torch_item_bool failed: tensor is null"); + + ET_CHECK_OR_RETURN_ERROR( + ret_value != nullptr, + InvalidArgument, + "aoti_torch_item_bool failed: ret_value is null"); + + // Validate that tensor dtype is bool + int32_t dtype; + ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_dtype(tensor, &dtype)); + + ET_CHECK_OR_RETURN_ERROR( + dtype == aoti_torch_dtype_bool(), + InvalidArgument, + "aoti_torch_item_bool failed: tensor dtype is not bool (got %d)", + dtype); + + // Get the data pointer + const void* data_ptr = tensor->const_data_ptr(); + ET_CHECK_OR_RETURN_ERROR( + data_ptr != nullptr, + InvalidArgument, + "aoti_torch_item_bool failed: tensor data pointer is null"); + + // Check if tensor is on CUDA or CPU + cudaPointerAttributes attributes{}; + ET_CUDA_CHECK_OR_RETURN_ERROR( + cudaPointerGetAttributes(&attributes, data_ptr)); + + if (attributes.type == cudaMemoryTypeDevice) { + // CUDA memory case: copy from device to host + bool device_value; + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpy( + &device_value, data_ptr, sizeof(bool), cudaMemcpyDeviceToHost)); + *ret_value = device_value; + } else { + // CPU memory case: direct access + const bool* bool_ptr = static_cast(data_ptr); + *ret_value = *bool_ptr; + } + + return Error::Ok; +} + +AOTITorchError aoti_torch_assign_tensors_out(Tensor* src, Tensor** ret_dst) { + // Validate input parameters + ET_CHECK_OR_RETURN_ERROR( + src != nullptr, + InvalidArgument, + "aoti_torch_assign_tensors_out failed: src is null"); + + ET_CHECK_OR_RETURN_ERROR( + ret_dst != nullptr, + InvalidArgument, + "aoti_torch_assign_tensors_out failed: ret_dst is null"); + + // Get the data pointer from the source tensor + void* data_ptr = src->mutable_data_ptr(); + ET_CHECK_OR_RETURN_ERROR( + data_ptr != nullptr, + InvalidArgument, + "Source tensor has null data pointer"); + + // Check if the given memory is in the map, if not return error + auto memory_it = memory_to_n_tensor.find(data_ptr); + ET_CHECK_OR_RETURN_ERROR( + memory_it != memory_to_n_tensor.end(), + InvalidArgument, + "Memory address %p is not being tracked by reference counting system", + data_ptr); + + // Get dtype from source tensor + int32_t dtype = 0; + ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_dtype(src, &dtype)); + + // Get sizes and strides from source tensor + int64_t* sizes_ptr; + int64_t* strides_ptr; + ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_sizes(src, &sizes_ptr)); + ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_strides(src, &strides_ptr)); + + int64_t ndim = src->dim(); + + // Convert to vectors + std::vector sizes = convert_sizes_to_vector(ndim, sizes_ptr); + std::vector strides = + convert_strides_to_vector(ndim, sizes_ptr, strides_ptr); + + // Create new tensor view that shares the same memory as source tensor + std::shared_ptr tensor = make_tensor( + sizes, + data_ptr, // Share the same memory from source tensor + {}, // dim_order (empty, will be auto-generated) + strides, + dtype_to_scalar_type(dtype)); + + ET_CHECK_OR_RETURN_ERROR( + tensor != nullptr, + InvalidArgument, + "Failed to create tensor view in aoti_torch_assign_tensors_out"); + + // Store the tensor so it doesn't get destroyed + tensors.insert(tensor); + + *ret_dst = tensor.get(); + + // Increment the reference count for this memory address only if it is owned + // by tensor + memory_to_n_tensor[data_ptr] = memory_to_n_tensor[data_ptr] == NOT_OWN + ? NOT_OWN + : memory_to_n_tensor[data_ptr] + 1; + + return Error::Ok; +} } // extern "C" } // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/memory.h b/backends/cuda/runtime/shims/memory.h index 1a89d8b782c..e7801358086 100644 --- a/backends/cuda/runtime/shims/memory.h +++ b/backends/cuda/runtime/shims/memory.h @@ -161,9 +161,41 @@ aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking); * @return Error::Ok on success, appropriate error code on failure: * - Error::InvalidArgument: null pointers or invalid parameters */ -AOTITorchError aoti_torch_new_tensor_handle( - Tensor* orig_handle, - Tensor** new_handle); +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_new_tensor_handle(Tensor* orig_handle, Tensor** new_handle); + +/** + * Retrieves a boolean value from a 0D boolean tensor. + * + * This function extracts the scalar boolean value from a tensor that contains + * a single boolean element. The tensor can be on either CPU or CUDA device. + * For CUDA tensors, the value is copied from device to host memory. + * + * @param tensor Pointer to a 0D boolean tensor (must not be null) + * @param ret_value Output pointer to store the boolean value (must not be null) + * + * @return Error::Ok on success, appropriate error code on failure: + * - Error::InvalidArgument: null pointers or tensor dtype is not bool + */ +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_item_bool(Tensor* tensor, bool* ret_value); + +/** + * Creates a new tensor that shares the same underlying data as the source + * tensor. + * + * This function creates a new tensor view with the same shape, strides, and + * dtype as the source tensor, sharing the same underlying memory. The new + * tensor handle will be stored in ret_dst. + * + * @param src The source tensor providing the data and metadata. + * @param ret_dst On output, this will point to the new tensor view. + * + * @return Error::Ok on success, appropriate error code on failure: + * - Error::InvalidArgument: null pointers or memory not tracked + */ +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_assign_tensors_out(Tensor* src, Tensor** ret_dst); // Function to clear all tensors from internal storage AOTI_SHIM_EXPORT void clear_all_tensors(); diff --git a/backends/cuda/runtime/shims/tests/CMakeLists.txt b/backends/cuda/runtime/shims/tests/CMakeLists.txt index f626e067627..343c3487ba3 100644 --- a/backends/cuda/runtime/shims/tests/CMakeLists.txt +++ b/backends/cuda/runtime/shims/tests/CMakeLists.txt @@ -30,6 +30,8 @@ set(CUDA_SHIM_TESTS test_aoti_torch__reinterpret_tensor test_aoti_torch_copy_ test_aoti_torch_new_tensor_handle + test_aoti_torch_item_bool + test_aoti_torch_assign_tensors_out ) enable_testing() diff --git a/backends/cuda/runtime/shims/tests/targets.bzl b/backends/cuda/runtime/shims/tests/targets.bzl index b274ecf3675..7736624c02a 100644 --- a/backends/cuda/runtime/shims/tests/targets.bzl +++ b/backends/cuda/runtime/shims/tests/targets.bzl @@ -35,3 +35,5 @@ def define_common_targets(): cuda_shim_cpp_unittest("aoti_torch_cuda_guard") cuda_shim_cpp_unittest("aoti_torch_cuda__weight_int4pack_mm") cuda_shim_cpp_unittest("aoti_torch_new_tensor_handle") + cuda_shim_cpp_unittest("aoti_torch_item_bool") + cuda_shim_cpp_unittest("aoti_torch_assign_tensors_out") diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_assign_tensors_out.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_assign_tensors_out.cpp new file mode 100644 index 00000000000..b883bed2856 --- /dev/null +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_assign_tensors_out.cpp @@ -0,0 +1,237 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace executorch::backends::aoti; +using namespace executorch::backends::cuda; +using namespace executorch::runtime; +using executorch::runtime::etensor::Tensor; + +// Test fixture for aoti_torch_assign_tensors_out tests +class AOTITorchAssignTensorsOutTest : public ::testing::Test { + protected: + void SetUp() override { + // Initialize ExecuTorch Platform Abstraction Layer + et_pal_init(); + + // Check if CUDA is available + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "CUDA not available, skipping CUDA tests"; + } + + // Clean up any existing cached metadata before each test + cleanup_tensor_metadata(); + + // Clear any remaining tensors from previous tests + clear_all_tensors(); + } + + void TearDown() override { + // Clean up metadata + cleanup_tensor_metadata(); + + // Clear the global tensor storage using the provided function + clear_all_tensors(); + } + + // Helper to create a test tensor + Tensor* create_test_tensor( + const std::vector& sizes, + int32_t dtype = static_cast(SupportedDTypes::FLOAT32), + int32_t device_type = static_cast(SupportedDevices::CUDA)) { + std::vector strides; + // Calculate contiguous strides + if (!sizes.empty()) { + strides.resize(sizes.size()); + strides[sizes.size() - 1] = 1; + for (int64_t i = static_cast(sizes.size()) - 2; i >= 0; i--) { + strides[i] = strides[i + 1] * sizes[i + 1]; + } + } + + Tensor* tensor; + const int64_t* strides_ptr = strides.empty() ? nullptr : strides.data(); + + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), sizes.data(), strides_ptr, dtype, device_type, 0, &tensor); + + return (error == Error::Ok) ? tensor : nullptr; + } +}; + +// Test basic functionality +TEST_F(AOTITorchAssignTensorsOutTest, BasicFunctionality) { + // Create a source tensor + std::vector sizes = {2, 3}; + Tensor* src = create_test_tensor(sizes); + ASSERT_NE(src, nullptr); + + // Create output tensor handle + Tensor* dst = nullptr; + AOTITorchError error = aoti_torch_assign_tensors_out(src, &dst); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(dst, nullptr); + + // Verify the output tensor has the same properties as source + EXPECT_EQ(dst->dim(), src->dim()); + EXPECT_EQ(dst->size(0), src->size(0)); + EXPECT_EQ(dst->size(1), src->size(1)); + EXPECT_EQ(dst->numel(), src->numel()); + + // Verify they share the same memory + EXPECT_EQ(dst->mutable_data_ptr(), src->mutable_data_ptr()); +} + +// Test with 1D tensor +TEST_F(AOTITorchAssignTensorsOutTest, OneDimensionalTensor) { + std::vector sizes = {10}; + Tensor* src = create_test_tensor(sizes); + ASSERT_NE(src, nullptr); + + Tensor* dst = nullptr; + AOTITorchError error = aoti_torch_assign_tensors_out(src, &dst); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(dst, nullptr); + EXPECT_EQ(dst->dim(), 1); + EXPECT_EQ(dst->size(0), 10); + EXPECT_EQ(dst->mutable_data_ptr(), src->mutable_data_ptr()); +} + +// Test with 3D tensor +TEST_F(AOTITorchAssignTensorsOutTest, ThreeDimensionalTensor) { + std::vector sizes = {2, 3, 4}; + Tensor* src = create_test_tensor(sizes); + ASSERT_NE(src, nullptr); + + Tensor* dst = nullptr; + AOTITorchError error = aoti_torch_assign_tensors_out(src, &dst); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(dst, nullptr); + EXPECT_EQ(dst->dim(), 3); + EXPECT_EQ(dst->size(0), 2); + EXPECT_EQ(dst->size(1), 3); + EXPECT_EQ(dst->size(2), 4); + EXPECT_EQ(dst->mutable_data_ptr(), src->mutable_data_ptr()); +} + +// Test with scalar (0D) tensor +TEST_F(AOTITorchAssignTensorsOutTest, ScalarTensor) { + std::vector sizes = {}; + Tensor* src = create_test_tensor(sizes); + ASSERT_NE(src, nullptr); + + Tensor* dst = nullptr; + AOTITorchError error = aoti_torch_assign_tensors_out(src, &dst); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(dst, nullptr); + EXPECT_EQ(dst->dim(), 0); + EXPECT_EQ(dst->mutable_data_ptr(), src->mutable_data_ptr()); +} + +// Test with null source pointer +TEST_F(AOTITorchAssignTensorsOutTest, NullSourcePointer) { + Tensor* dst = nullptr; + AOTITorchError error = aoti_torch_assign_tensors_out(nullptr, &dst); + EXPECT_EQ(error, Error::InvalidArgument); +} + +// Test with null destination pointer +TEST_F(AOTITorchAssignTensorsOutTest, NullDestinationPointer) { + std::vector sizes = {2, 3}; + Tensor* src = create_test_tensor(sizes); + ASSERT_NE(src, nullptr); + + AOTITorchError error = aoti_torch_assign_tensors_out(src, nullptr); + EXPECT_EQ(error, Error::InvalidArgument); +} + +// Test that strides are preserved +TEST_F(AOTITorchAssignTensorsOutTest, StridesPreserved) { + std::vector sizes = {2, 3}; + Tensor* src = create_test_tensor(sizes); + ASSERT_NE(src, nullptr); + + Tensor* dst = nullptr; + AOTITorchError error = aoti_torch_assign_tensors_out(src, &dst); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(dst, nullptr); + + // Get strides from both tensors + int64_t* src_strides; + int64_t* dst_strides; + aoti_torch_get_strides(src, &src_strides); + aoti_torch_get_strides(dst, &dst_strides); + + // Verify strides match + for (int64_t i = 0; i < src->dim(); i++) { + EXPECT_EQ(src_strides[i], dst_strides[i]); + } +} + +// Test with CPU tensor +TEST_F(AOTITorchAssignTensorsOutTest, CPUTensor) { + std::vector sizes = {2, 3}; + Tensor* src = + create_test_tensor(sizes, static_cast(SupportedDTypes::FLOAT32), static_cast(SupportedDevices::CPU)); + ASSERT_NE(src, nullptr); + + Tensor* dst = nullptr; + AOTITorchError error = aoti_torch_assign_tensors_out(src, &dst); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(dst, nullptr); + EXPECT_EQ(dst->mutable_data_ptr(), src->mutable_data_ptr()); +} + +// Test dtype is preserved +TEST_F(AOTITorchAssignTensorsOutTest, DtypePreserved) { + // Test with different dtypes + std::vector dtypes = { + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDTypes::INT32), + static_cast(SupportedDTypes::INT64), + }; + + for (int32_t dtype : dtypes) { + cleanup_tensor_metadata(); + clear_all_tensors(); + + std::vector sizes = {2, 3}; + Tensor* src = create_test_tensor(sizes, dtype); + ASSERT_NE(src, nullptr); + + Tensor* dst = nullptr; + AOTITorchError error = aoti_torch_assign_tensors_out(src, &dst); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(dst, nullptr); + + // Verify dtype is preserved + int32_t src_dtype, dst_dtype; + aoti_torch_get_dtype(src, &src_dtype); + aoti_torch_get_dtype(dst, &dst_dtype); + EXPECT_EQ(src_dtype, dst_dtype) << "Dtype mismatch for dtype code: " << dtype; + } +} + diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_item_bool.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_item_bool.cpp new file mode 100644 index 00000000000..7e53ac7003a --- /dev/null +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_item_bool.cpp @@ -0,0 +1,204 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace executorch::backends::aoti; +using namespace executorch::backends::cuda; +using namespace executorch::runtime; +using executorch::runtime::etensor::Tensor; + +// Test fixture for aoti_torch_item_bool tests +class AOTITorchItemBoolTest : public ::testing::Test { + protected: + void SetUp() override { + // Initialize ExecuTorch Platform Abstraction Layer + et_pal_init(); + + // Check if CUDA is available + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "CUDA not available, skipping CUDA tests"; + } + + // Clean up any existing cached metadata before each test + cleanup_tensor_metadata(); + + // Clear any remaining tensors from previous tests + clear_all_tensors(); + } + + void TearDown() override { + // Clean up metadata + cleanup_tensor_metadata(); + + // Clear the global tensor storage using the provided function + clear_all_tensors(); + } + + // Helper to create a bool tensor on CUDA with a specific value + Tensor* create_cuda_bool_tensor(bool value) { + // Create a 0D (scalar) bool tensor + std::vector sizes = {}; // 0D tensor + std::vector strides = {}; // Empty strides for scalar + Tensor* tensor; + + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + strides.data(), + static_cast(SupportedDTypes::BOOL), + static_cast(SupportedDevices::CUDA), + 0, + &tensor); + + if (error != Error::Ok || tensor == nullptr) { + return nullptr; + } + + // Set the value + bool host_value = value; + cudaError_t cuda_err = cudaMemcpy( + tensor->mutable_data_ptr(), + &host_value, + sizeof(bool), + cudaMemcpyHostToDevice); + + if (cuda_err != cudaSuccess) { + aoti_torch_delete_tensor_object(tensor); + return nullptr; + } + + return tensor; + } + + // Helper to create a bool tensor on CPU with a specific value + Tensor* create_cpu_bool_tensor(bool value) { + // Create a 0D (scalar) bool tensor + std::vector sizes = {}; // 0D tensor + std::vector strides = {}; // Empty strides for scalar + Tensor* tensor; + + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + strides.data(), + static_cast(SupportedDTypes::BOOL), + static_cast(SupportedDevices::CPU), + 0, + &tensor); + + if (error != Error::Ok || tensor == nullptr) { + return nullptr; + } + + // Set the value directly + bool* data_ptr = static_cast(tensor->mutable_data_ptr()); + *data_ptr = value; + + return tensor; + } +}; + +// Test extracting true value from CUDA bool tensor +TEST_F(AOTITorchItemBoolTest, CUDATensorTrueValue) { + Tensor* tensor = create_cuda_bool_tensor(true); + ASSERT_NE(tensor, nullptr); + + bool result = false; + AOTITorchError error = aoti_torch_item_bool(tensor, &result); + + EXPECT_EQ(error, Error::Ok); + EXPECT_TRUE(result); +} + +// Test extracting false value from CUDA bool tensor +TEST_F(AOTITorchItemBoolTest, CUDATensorFalseValue) { + Tensor* tensor = create_cuda_bool_tensor(false); + ASSERT_NE(tensor, nullptr); + + bool result = true; + AOTITorchError error = aoti_torch_item_bool(tensor, &result); + + EXPECT_EQ(error, Error::Ok); + EXPECT_FALSE(result); +} + +// Test extracting true value from CPU bool tensor +TEST_F(AOTITorchItemBoolTest, CPUTensorTrueValue) { + Tensor* tensor = create_cpu_bool_tensor(true); + ASSERT_NE(tensor, nullptr); + + bool result = false; + AOTITorchError error = aoti_torch_item_bool(tensor, &result); + + EXPECT_EQ(error, Error::Ok); + EXPECT_TRUE(result); +} + +// Test extracting false value from CPU bool tensor +TEST_F(AOTITorchItemBoolTest, CPUTensorFalseValue) { + Tensor* tensor = create_cpu_bool_tensor(false); + ASSERT_NE(tensor, nullptr); + + bool result = true; + AOTITorchError error = aoti_torch_item_bool(tensor, &result); + + EXPECT_EQ(error, Error::Ok); + EXPECT_FALSE(result); +} + +// Test with null tensor pointer +TEST_F(AOTITorchItemBoolTest, NullTensorPointer) { + bool result; + AOTITorchError error = aoti_torch_item_bool(nullptr, &result); + EXPECT_EQ(error, Error::InvalidArgument); +} + +// Test with null result pointer +TEST_F(AOTITorchItemBoolTest, NullResultPointer) { + Tensor* tensor = create_cuda_bool_tensor(true); + ASSERT_NE(tensor, nullptr); + + AOTITorchError error = aoti_torch_item_bool(tensor, nullptr); + EXPECT_EQ(error, Error::InvalidArgument); +} + +// Test with non-bool dtype (should fail) +TEST_F(AOTITorchItemBoolTest, NonBoolDtype) { + // Create a float tensor + std::vector sizes = {}; + std::vector strides = {}; + Tensor* tensor; + + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + strides.data(), + static_cast(SupportedDTypes::FLOAT32), // Not bool + static_cast(SupportedDevices::CUDA), + 0, + &tensor); + + ASSERT_EQ(error, Error::Ok); + ASSERT_NE(tensor, nullptr); + + bool result; + error = aoti_torch_item_bool(tensor, &result); + EXPECT_EQ(error, Error::InvalidArgument); +} + From a97933b8acc62a5f8fa6ca837a1b4675228c7806 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Mon, 22 Dec 2025 08:00:16 +0000 Subject: [PATCH 4/8] Update [ghstack-poisoned] --- .../cuda/runtime/shims/tests/CMakeLists.txt | 20 ++++--- .../runtime/shims/tests/CMakePresets.json | 1 - backends/cuda/runtime/shims/tests/README.md | 54 +++---------------- 3 files changed, 17 insertions(+), 58 deletions(-) diff --git a/backends/cuda/runtime/shims/tests/CMakeLists.txt b/backends/cuda/runtime/shims/tests/CMakeLists.txt index 9c367ab8e1a..a7df6075c37 100644 --- a/backends/cuda/runtime/shims/tests/CMakeLists.txt +++ b/backends/cuda/runtime/shims/tests/CMakeLists.txt @@ -21,7 +21,10 @@ FetchContent_Declare( GIT_TAG v1.14.0 ) # For Windows: Prevent overriding the parent project's compiler/linker settings -set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) +set(gtest_force_shared_crt + ON + CACHE BOOL "" FORCE +) FetchContent_MakeAvailable(googletest) # Get EXECUTORCH_ROOT @@ -34,12 +37,9 @@ find_package(executorch CONFIG REQUIRED HINTS ${CMAKE_INSTALL_PREFIX}) # List of test files set(CUDA_SHIM_TESTS - test_aoti_torch_create_tensor_from_blob_v2 - test_aoti_torch_empty_strided - test_aoti_torch_delete_tensor_object - test_aoti_torch__reinterpret_tensor - test_aoti_torch_copy_ - test_aoti_torch_new_tensor_handle + test_aoti_torch_create_tensor_from_blob_v2 test_aoti_torch_empty_strided + test_aoti_torch_delete_tensor_object test_aoti_torch__reinterpret_tensor + test_aoti_torch_copy_ test_aoti_torch_new_tensor_handle ) enable_testing() @@ -48,10 +48,8 @@ foreach(test_name ${CUDA_SHIM_TESTS}) add_executable(${test_name} ${test_name}.cpp) target_include_directories( - ${test_name} - PRIVATE ${EXECUTORCH_ROOT}/.. - ${EXECUTORCH_ROOT} - ${CUDAToolkit_INCLUDE_DIRS} + ${test_name} PRIVATE ${EXECUTORCH_ROOT}/.. ${EXECUTORCH_ROOT} + ${CUDAToolkit_INCLUDE_DIRS} ) target_link_libraries( diff --git a/backends/cuda/runtime/shims/tests/CMakePresets.json b/backends/cuda/runtime/shims/tests/CMakePresets.json index 40e0dbbd698..33b448f1538 100644 --- a/backends/cuda/runtime/shims/tests/CMakePresets.json +++ b/backends/cuda/runtime/shims/tests/CMakePresets.json @@ -93,4 +93,3 @@ } ] } - diff --git a/backends/cuda/runtime/shims/tests/README.md b/backends/cuda/runtime/shims/tests/README.md index 2c4b91570b0..1456447eeba 100644 --- a/backends/cuda/runtime/shims/tests/README.md +++ b/backends/cuda/runtime/shims/tests/README.md @@ -13,13 +13,13 @@ From the ExecuTorch root directory: ```bash # Release build -cmake --workflow --preset llm-release-cuda +cmake --workflow llm-release-cuda # Or debug build (recommended for debugging test failures) -cmake --workflow --preset llm-debug-cuda +cmake --workflow llm-debug-cuda ``` -## Building the Tests +## Building and Run the Tests ### Option 1: Using CMake Presets (Recommended) @@ -27,10 +27,10 @@ From this directory (`backends/cuda/runtime/shims/tests/`): ```bash # Release build -cmake --workflow --preset default +cmake --workflow default # Debug build -cmake --workflow --preset debug +cmake --workflow debug ``` ### Option 2: Manual CMake Commands @@ -48,37 +48,13 @@ cmake -B cmake-out/backends/cuda/runtime/shims/tests \ cmake --build cmake-out/backends/cuda/runtime/shims/tests -j$(nproc) ``` -## Running the Tests - -### Run All Tests - -```bash -# Using ctest (from the build directory) -cd cmake-out/backends/cuda/runtime/shims/tests -ctest --output-on-failure - -# Or using the test preset (from this directory) -ctest --preset default -``` - -### Run a Specific Test - -```bash -# From the build directory -./test_aoti_torch_create_tensor_from_blob_v2 -./test_aoti_torch_empty_strided -./test_aoti_torch_delete_tensor_object -./test_aoti_torch_copy_ -./test_aoti_torch_new_tensor_handle -./test_aoti_torch_item_bool -./test_aoti_torch_assign_tensors_out -``` - ### Run Specific Test Cases Use Google Test filters to run specific test cases: ```bash +# From the build directory +cd cmake-out/backends/cuda/runtime/shims/tests # Run only device mismatch tests ./test_aoti_torch_create_tensor_from_blob_v2 --gtest_filter="*DeviceMismatch*" @@ -89,19 +65,6 @@ Use Google Test filters to run specific test cases: ./test_aoti_torch_create_tensor_from_blob_v2 --gtest_list_tests ``` -## Test Descriptions - -| Test File | Description | -|-----------|-------------| -| `test_aoti_torch_create_tensor_from_blob_v2` | Tests tensor creation from existing memory blobs, including device type validation | -| `test_aoti_torch_empty_strided` | Tests creation of uninitialized tensors with specified strides | -| `test_aoti_torch_delete_tensor_object` | Tests proper tensor deletion and memory management | -| `test_aoti_torch__reinterpret_tensor` | Tests tensor view reinterpretation with different shapes/strides | -| `test_aoti_torch_copy_` | Tests tensor copy operations between CPU and CUDA | -| `test_aoti_torch_new_tensor_handle` | Tests creating new tensor handles that share memory | -| `test_aoti_torch_item_bool` | Tests extracting boolean values from scalar tensors | -| `test_aoti_torch_assign_tensors_out` | Tests creating tensor views that share underlying data | - ## Troubleshooting ### CUDA Not Available @@ -122,11 +85,10 @@ cmake --workflow --preset llm-release-cuda For debugging test failures, build with debug mode: ```bash -cmake --workflow --preset debug +cmake --workflow debug ``` Then run with verbose output: ```bash ./test_aoti_torch_create_tensor_from_blob_v2 --gtest_break_on_failure ``` - From e1bb6c241cc008b96e63689b43c194521022fe80 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Mon, 22 Dec 2025 17:40:55 +0000 Subject: [PATCH 5/8] Update [ghstack-poisoned] --- backends/cuda/runtime/shims/tests/README.md | 10 +-- extension/llm/custom_ops/custom_ops.py | 86 +++++++++++++++++++++ 2 files changed, 91 insertions(+), 5 deletions(-) diff --git a/backends/cuda/runtime/shims/tests/README.md b/backends/cuda/runtime/shims/tests/README.md index 1456447eeba..3e95280811b 100644 --- a/backends/cuda/runtime/shims/tests/README.md +++ b/backends/cuda/runtime/shims/tests/README.md @@ -13,10 +13,10 @@ From the ExecuTorch root directory: ```bash # Release build -cmake --workflow llm-release-cuda +cmake --workflow --preset llm-release-cuda # Or debug build (recommended for debugging test failures) -cmake --workflow llm-debug-cuda +cmake --workflow --preset llm-debug-cuda ``` ## Building and Run the Tests @@ -27,10 +27,10 @@ From this directory (`backends/cuda/runtime/shims/tests/`): ```bash # Release build -cmake --workflow default +cmake --workflow --preset default # Debug build -cmake --workflow debug +cmake --workflow --preset debug ``` ### Option 2: Manual CMake Commands @@ -85,7 +85,7 @@ cmake --workflow --preset llm-release-cuda For debugging test failures, build with debug mode: ```bash -cmake --workflow debug +cmake --workflow --preset debug ``` Then run with verbose output: diff --git a/extension/llm/custom_ops/custom_ops.py b/extension/llm/custom_ops/custom_ops.py index dfa357fe356..51c9601454a 100644 --- a/extension/llm/custom_ops/custom_ops.py +++ b/extension/llm/custom_ops/custom_ops.py @@ -16,6 +16,12 @@ from torch.library import impl +from typing import Tuple + +from torch._inductor.lowering import lowerings as L, register_lowering + +aten = torch.ops.aten + try: op = torch.ops.llama.sdpa_with_kv_cache.default assert op is not None @@ -387,3 +393,83 @@ def custom_quantized_sdpa_meta( ) return torch.empty(query.size(), dtype=torch.float32, device="meta") + +# 1) Define the custom op in the "executorch" namespace with name "alias" +@torch.library.custom_op("executorch::alias", mutates_args=()) +def custom_alias(x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # no copies, just pass-through + return x, y + + +# 2) FakeTensor kernel: describes output metadata for compile-time +@custom_alias.register_fake +def _(x, y): + # For this op, outputs have exactly the same shape/dtype/device as inputs. + # We just need *dummy* tensors with that metadata. + out_x = torch.empty_like(x) + out_y = torch.empty_like(y) + return out_x, out_y + + +@register_lowering(torch.ops.executorch.alias.default) +def lowering_custom_alias(x, y): + # x, y here are IR values (Inductor's internal representation). + # Alias is logically a no-op – just pass them through. + return x, y + + +# Expecting cache shape: (B, H, S_max, D), value shape (B, H, S, D) where S <= S_max +def _validate_cross_attn_cache_params(value: torch.Tensor, cache: torch.Tensor): + torch._assert(value.dim() == 4, "value must be 4D") + torch._assert(cache.dim() == 4, "cache must be 4D") + # Cache shape: (B, H, S_max, D) + # Value shape: (B, H, S, D) + torch._assert( + value.size(2) <= cache.size(2), + f"value sequence length {value.size(2)} exceeds cache size {cache.size(2)}", + ) + torch._assert(value.size(0) == cache.size(0), "batch size mismatch") + torch._assert(value.size(1) == cache.size(1), "num heads mismatch") + torch._assert(value.size(3) == cache.size(3), "head dim mismatch") + torch._assert(value.dtype == cache.dtype, "dtype mismatch") + + +# This is cheating: we delibrately NOT mark `cache` to be mutating so that this +# custom op can be used in HOP such as `torch.cond`, where `torch.compile` requires +# no aliasing or mutation in the branches. This is fine because we only care about inference. +@torch.library.custom_op("executorch::update_cross_attn_cache", mutates_args=[]) +def _update_cross_attn_cache(value: torch.Tensor, cache: torch.Tensor) -> torch.Tensor: + # Eager implementation + _validate_cross_attn_cache_params(value, cache) + + # Slice the cache to match value's sequence length and copy + # cache shape: [B, H, S_max, D] + # value shape: [B, H, S, D] + cache[:, :, : value.size(2), :].copy_(value) + return cache + + +# Register the fake (meta) kernel +@_update_cross_attn_cache.register_fake +def _update_cross_attn_cache_fake( + value: torch.Tensor, cache: torch.Tensor +) -> torch.Tensor: + _validate_cross_attn_cache_params(value, cache) + return torch.empty_like(cache) + + +# Register Inductor lowering +@register_lowering(torch.ops.executorch.update_cross_attn_cache) +def _update_cross_attn_cache_lowering(value, cache): + # cache shape: [B, H, S_max, D] + # value shape: [B, H, S, D] + + # We need to slice the cache along dim 2 (sequence length) + # slice(self, dim, start, end, step=1) + seq_len = value.get_size()[2] + cache_slice = L[aten.slice.Tensor](cache, 2, 0, seq_len, 1) + + # Copy value into the slice + L[aten.copy_.default](cache_slice, value) + + return cache \ No newline at end of file From 2a7a9f0d96c974adfcf9dd552ec88b92540736f1 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Mon, 22 Dec 2025 17:43:43 +0000 Subject: [PATCH 6/8] Update [ghstack-poisoned] --- extension/llm/custom_ops/custom_ops.py | 86 -------------------------- 1 file changed, 86 deletions(-) diff --git a/extension/llm/custom_ops/custom_ops.py b/extension/llm/custom_ops/custom_ops.py index 51c9601454a..dfa357fe356 100644 --- a/extension/llm/custom_ops/custom_ops.py +++ b/extension/llm/custom_ops/custom_ops.py @@ -16,12 +16,6 @@ from torch.library import impl -from typing import Tuple - -from torch._inductor.lowering import lowerings as L, register_lowering - -aten = torch.ops.aten - try: op = torch.ops.llama.sdpa_with_kv_cache.default assert op is not None @@ -393,83 +387,3 @@ def custom_quantized_sdpa_meta( ) return torch.empty(query.size(), dtype=torch.float32, device="meta") - -# 1) Define the custom op in the "executorch" namespace with name "alias" -@torch.library.custom_op("executorch::alias", mutates_args=()) -def custom_alias(x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - # no copies, just pass-through - return x, y - - -# 2) FakeTensor kernel: describes output metadata for compile-time -@custom_alias.register_fake -def _(x, y): - # For this op, outputs have exactly the same shape/dtype/device as inputs. - # We just need *dummy* tensors with that metadata. - out_x = torch.empty_like(x) - out_y = torch.empty_like(y) - return out_x, out_y - - -@register_lowering(torch.ops.executorch.alias.default) -def lowering_custom_alias(x, y): - # x, y here are IR values (Inductor's internal representation). - # Alias is logically a no-op – just pass them through. - return x, y - - -# Expecting cache shape: (B, H, S_max, D), value shape (B, H, S, D) where S <= S_max -def _validate_cross_attn_cache_params(value: torch.Tensor, cache: torch.Tensor): - torch._assert(value.dim() == 4, "value must be 4D") - torch._assert(cache.dim() == 4, "cache must be 4D") - # Cache shape: (B, H, S_max, D) - # Value shape: (B, H, S, D) - torch._assert( - value.size(2) <= cache.size(2), - f"value sequence length {value.size(2)} exceeds cache size {cache.size(2)}", - ) - torch._assert(value.size(0) == cache.size(0), "batch size mismatch") - torch._assert(value.size(1) == cache.size(1), "num heads mismatch") - torch._assert(value.size(3) == cache.size(3), "head dim mismatch") - torch._assert(value.dtype == cache.dtype, "dtype mismatch") - - -# This is cheating: we delibrately NOT mark `cache` to be mutating so that this -# custom op can be used in HOP such as `torch.cond`, where `torch.compile` requires -# no aliasing or mutation in the branches. This is fine because we only care about inference. -@torch.library.custom_op("executorch::update_cross_attn_cache", mutates_args=[]) -def _update_cross_attn_cache(value: torch.Tensor, cache: torch.Tensor) -> torch.Tensor: - # Eager implementation - _validate_cross_attn_cache_params(value, cache) - - # Slice the cache to match value's sequence length and copy - # cache shape: [B, H, S_max, D] - # value shape: [B, H, S, D] - cache[:, :, : value.size(2), :].copy_(value) - return cache - - -# Register the fake (meta) kernel -@_update_cross_attn_cache.register_fake -def _update_cross_attn_cache_fake( - value: torch.Tensor, cache: torch.Tensor -) -> torch.Tensor: - _validate_cross_attn_cache_params(value, cache) - return torch.empty_like(cache) - - -# Register Inductor lowering -@register_lowering(torch.ops.executorch.update_cross_attn_cache) -def _update_cross_attn_cache_lowering(value, cache): - # cache shape: [B, H, S_max, D] - # value shape: [B, H, S, D] - - # We need to slice the cache along dim 2 (sequence length) - # slice(self, dim, start, end, step=1) - seq_len = value.get_size()[2] - cache_slice = L[aten.slice.Tensor](cache, 2, 0, seq_len, 1) - - # Copy value into the slice - L[aten.copy_.default](cache_slice, value) - - return cache \ No newline at end of file From 8b9408775df14d8adee82a8a0f9e8fbd849dee52 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Mon, 22 Dec 2025 19:54:40 +0000 Subject: [PATCH 7/8] Update [ghstack-poisoned] --- .github/workflows/cuda.yml | 4 ++-- backends/cuda/runtime/shims/memory.cpp | 8 +++----- .../shims/tests/test_aoti_torch__reinterpret_tensor.cpp | 6 ------ .../tests/test_aoti_torch_create_tensor_from_blob_v2.cpp | 6 ------ 4 files changed, 5 insertions(+), 19 deletions(-) diff --git a/.github/workflows/cuda.yml b/.github/workflows/cuda.yml index 0dbee174bd3..d12ab2df73e 100644 --- a/.github/workflows/cuda.yml +++ b/.github/workflows/cuda.yml @@ -105,11 +105,11 @@ jobs: set -eux # Build ExecuTorch with CUDA support - cmake --workflow llm-release-cuda + cmake --workflow --preset llm-release-cuda # Build and run CUDA shim tests pushd backends/cuda/runtime/shims/tests - cmake --workflow default + cmake --workflow --preset default popd export-model-cuda-artifact: diff --git a/backends/cuda/runtime/shims/memory.cpp b/backends/cuda/runtime/shims/memory.cpp index 7599bd19b05..ecb1ded2f39 100644 --- a/backends/cuda/runtime/shims/memory.cpp +++ b/backends/cuda/runtime/shims/memory.cpp @@ -331,12 +331,10 @@ void clear_all_tensors() { // tensors set should now be empty, but ensure it's cleared tensors.clear(); - ET_LOG(Info, "Cleared all tensors"); -} - -void clear_memory_tracking() { + // Clear memory tracking map (includes leftover NOT_OWN entries) memory_to_n_tensor.clear(); - ET_LOG(Info, "Cleared memory tracking map"); + + ET_LOG(Info, "Cleared all tensors and memory tracking"); } AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor) { diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch__reinterpret_tensor.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch__reinterpret_tensor.cpp index 2e646fb9a66..d3044810b15 100644 --- a/backends/cuda/runtime/shims/tests/test_aoti_torch__reinterpret_tensor.cpp +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch__reinterpret_tensor.cpp @@ -42,9 +42,6 @@ class AOTITorchReinterpretTensorTest : public ::testing::Test { // Clear any remaining tensors from previous tests clear_all_tensors(); - - // Clear memory tracking map - clear_memory_tracking(); } void TearDown() override { @@ -53,9 +50,6 @@ class AOTITorchReinterpretTensorTest : public ::testing::Test { // Clear the global tensor storage using the provided function clear_all_tensors(); - - // Clear memory tracking map - clear_memory_tracking(); } // Helper to calculate number of elements from sizes diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_create_tensor_from_blob_v2.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_create_tensor_from_blob_v2.cpp index ce9c2bb633d..db0ab84970d 100644 --- a/backends/cuda/runtime/shims/tests/test_aoti_torch_create_tensor_from_blob_v2.cpp +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_create_tensor_from_blob_v2.cpp @@ -40,9 +40,6 @@ class AOTITorchCreateTensorFromBlobV2Test : public ::testing::Test { // Clear any remaining tensors from previous tests clear_all_tensors(); - - // Clear memory tracking map - clear_memory_tracking(); } void TearDown() override { @@ -52,9 +49,6 @@ class AOTITorchCreateTensorFromBlobV2Test : public ::testing::Test { // Clear the global tensor storage using the provided function clear_all_tensors(); - // Clear memory tracking map - clear_memory_tracking(); - // Clean up any allocated memory buffers for (void* ptr : cuda_memory_buffers_) { if (ptr) { From 73efe126558e91f2c784847d1f7318a74b235b0a Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Mon, 22 Dec 2025 23:36:29 +0000 Subject: [PATCH 8/8] Update [ghstack-poisoned] --- .github/workflows/cuda.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/cuda.yml b/.github/workflows/cuda.yml index d12ab2df73e..dbce874cfc0 100644 --- a/.github/workflows/cuda.yml +++ b/.github/workflows/cuda.yml @@ -103,6 +103,8 @@ jobs: ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} script: | set -eux + # Install requirements + bash ./install_requirements.sh # Build ExecuTorch with CUDA support cmake --workflow --preset llm-release-cuda