diff --git a/.github/workflows/cuda.yml b/.github/workflows/cuda.yml index 265bc252ee8..dbce874cfc0 100644 --- a/.github/workflows/cuda.yml +++ b/.github/workflows/cuda.yml @@ -87,6 +87,33 @@ jobs: export LD_LIBRARY_PATH=/opt/conda/lib:$LD_LIBRARY_PATH PYTHON_EXECUTABLE=python source .ci/scripts/test_model.sh "${{ matrix.model }}" cmake cuda + test-cuda-shims: + name: test-cuda-shims + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + permissions: + id-token: write + contents: read + with: + timeout: 90 + runner: linux.g5.4xlarge.nvidia.gpu + gpu-arch-type: cuda + gpu-arch-version: 12.6 + use-custom-docker-registry: false + submodules: recursive + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + script: | + set -eux + # Install requirements + bash ./install_requirements.sh + + # Build ExecuTorch with CUDA support + cmake --workflow --preset llm-release-cuda + + # Build and run CUDA shim tests + pushd backends/cuda/runtime/shims/tests + cmake --workflow --preset default + popd + export-model-cuda-artifact: name: export-model-cuda-artifact # Skip this job if the pull request is from a fork (HuggingFace secrets are not available) diff --git a/backends/cuda/runtime/shims/memory.cpp b/backends/cuda/runtime/shims/memory.cpp index aaaf3913381..ecb1ded2f39 100644 --- a/backends/cuda/runtime/shims/memory.cpp +++ b/backends/cuda/runtime/shims/memory.cpp @@ -119,8 +119,6 @@ AOTITorchError aoti_torch_create_tensor_from_blob_v2( int32_t layout, const uint8_t* opaque_metadata, int64_t opaque_metadata_size) { - // TODO(gasoonjia): verify given data is on the target device - (void)device_type; (void)opaque_metadata; (void)layout; (void)opaque_metadata_size; @@ -154,6 +152,34 @@ AOTITorchError aoti_torch_create_tensor_from_blob_v2( // Storage offset must be 0 since from_blob cannot handle different offsets ET_CHECK_OK_OR_RETURN_ERROR(validate_storage_offset(storage_offset)); + // Verify that data pointer location matches the requested device_type + cudaPointerAttributes data_attributes{}; + ET_CUDA_CHECK_OR_RETURN_ERROR( + cudaPointerGetAttributes(&data_attributes, data)); + + bool data_is_on_device = data_attributes.type == cudaMemoryTypeDevice; + bool data_is_on_host = data_attributes.type == cudaMemoryTypeHost || + data_attributes.type == cudaMemoryTypeUnregistered; + bool requested_device = + device_type == static_cast(SupportedDevices::CUDA); + bool requested_cpu = + device_type == static_cast(SupportedDevices::CPU); + + // Error if data location doesn't match requested device type + ET_CHECK_OR_RETURN_ERROR( + !(data_is_on_device && requested_cpu), + InvalidArgument, + "aoti_torch_create_tensor_from_blob_v2 failed: data pointer %p is on CUDA " + "but device_type is CPU. Data must be on CPU for CPU tensors.", + data); + + ET_CHECK_OR_RETURN_ERROR( + !(data_is_on_host && requested_device), + InvalidArgument, + "aoti_torch_create_tensor_from_blob_v2 failed: data pointer %p is on CPU " + "but device_type is CUDA. Data must be on GPU for CUDA tensors.", + data); + // Convert sizes to the format expected by ExecutorTorch using SizesType std::vector sizes = convert_sizes_to_vector(ndim, sizes_ptr); @@ -305,7 +331,10 @@ void clear_all_tensors() { // tensors set should now be empty, but ensure it's cleared tensors.clear(); - ET_LOG(Info, "Cleared all tensors"); + // Clear memory tracking map (includes leftover NOT_OWN entries) + memory_to_n_tensor.clear(); + + ET_LOG(Info, "Cleared all tensors and memory tracking"); } AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor) { diff --git a/backends/cuda/runtime/shims/memory.h b/backends/cuda/runtime/shims/memory.h index 1a89d8b782c..935df853748 100644 --- a/backends/cuda/runtime/shims/memory.h +++ b/backends/cuda/runtime/shims/memory.h @@ -167,6 +167,9 @@ AOTITorchError aoti_torch_new_tensor_handle( // Function to clear all tensors from internal storage AOTI_SHIM_EXPORT void clear_all_tensors(); + +// Function to clear memory tracking map (for test cleanup) +AOTI_SHIM_EXPORT void clear_memory_tracking(); } // extern "C" } // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/tests/CMakeLists.txt b/backends/cuda/runtime/shims/tests/CMakeLists.txt new file mode 100644 index 00000000000..a7df6075c37 --- /dev/null +++ b/backends/cuda/runtime/shims/tests/CMakeLists.txt @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +cmake_minimum_required(VERSION 3.19) +project(aoti_cuda_shim_tests LANGUAGES CXX CUDA) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +# Find required packages +find_package(CUDAToolkit REQUIRED) + +# Fetch GoogleTest +include(FetchContent) +FetchContent_Declare( + googletest + GIT_REPOSITORY https://github.com/google/googletest.git + GIT_TAG v1.14.0 +) +# For Windows: Prevent overriding the parent project's compiler/linker settings +set(gtest_force_shared_crt + ON + CACHE BOOL "" FORCE +) +FetchContent_MakeAvailable(googletest) + +# Get EXECUTORCH_ROOT +if(NOT EXECUTORCH_ROOT) + set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../../..) +endif() + +# Find installed ExecuTorch +find_package(executorch CONFIG REQUIRED HINTS ${CMAKE_INSTALL_PREFIX}) + +# List of test files +set(CUDA_SHIM_TESTS + test_aoti_torch_create_tensor_from_blob_v2 test_aoti_torch_empty_strided + test_aoti_torch_delete_tensor_object test_aoti_torch__reinterpret_tensor + test_aoti_torch_copy_ test_aoti_torch_new_tensor_handle +) + +enable_testing() + +foreach(test_name ${CUDA_SHIM_TESTS}) + add_executable(${test_name} ${test_name}.cpp) + + target_include_directories( + ${test_name} PRIVATE ${EXECUTORCH_ROOT}/.. ${EXECUTORCH_ROOT} + ${CUDAToolkit_INCLUDE_DIRS} + ) + + target_link_libraries( + ${test_name} + PRIVATE GTest::gtest + GTest::gtest_main + aoti_cuda_shims + aoti_cuda_backend + cuda_tensor_maker + cuda_platform + executorch_core + extension_tensor + CUDA::cudart + ) + + add_test(NAME ${test_name} COMMAND ${test_name}) +endforeach() diff --git a/backends/cuda/runtime/shims/tests/CMakePresets.json b/backends/cuda/runtime/shims/tests/CMakePresets.json new file mode 100644 index 00000000000..33b448f1538 --- /dev/null +++ b/backends/cuda/runtime/shims/tests/CMakePresets.json @@ -0,0 +1,95 @@ +{ + "version": 6, + "configurePresets": [ + { + "name": "default", + "displayName": "CUDA Shim Tests", + "binaryDir": "${sourceDir}/../../../../../cmake-out/backends/cuda/runtime/shims/tests", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Release", + "CMAKE_PREFIX_PATH": "${sourceDir}/../../../../../cmake-out" + }, + "condition": { + "type": "inList", + "string": "${hostSystemName}", + "list": ["Linux", "Windows"] + } + }, + { + "name": "debug", + "displayName": "CUDA Shim Tests (Debug)", + "inherits": ["default"], + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Debug" + } + } + ], + "buildPresets": [ + { + "name": "default", + "displayName": "Build CUDA Shim Tests", + "configurePreset": "default" + }, + { + "name": "debug", + "displayName": "Build CUDA Shim Tests (Debug)", + "configurePreset": "debug" + } + ], + "workflowPresets": [ + { + "name": "default", + "displayName": "Configure, build, and test CUDA Shim Tests", + "steps": [ + { + "type": "configure", + "name": "default" + }, + { + "type": "build", + "name": "default" + }, + { + "type": "test", + "name": "default" + } + ] + }, + { + "name": "debug", + "displayName": "Configure, build, and test CUDA Shim Tests (Debug)", + "steps": [ + { + "type": "configure", + "name": "debug" + }, + { + "type": "build", + "name": "debug" + }, + { + "type": "test", + "name": "debug" + } + ] + } + ], + "testPresets": [ + { + "name": "default", + "displayName": "Run all CUDA Shim Tests", + "configurePreset": "default", + "output": { + "outputOnFailure": true + } + }, + { + "name": "debug", + "displayName": "Run all CUDA Shim Tests (Debug)", + "configurePreset": "debug", + "output": { + "outputOnFailure": true + } + } + ] +} diff --git a/backends/cuda/runtime/shims/tests/README.md b/backends/cuda/runtime/shims/tests/README.md new file mode 100644 index 00000000000..3e95280811b --- /dev/null +++ b/backends/cuda/runtime/shims/tests/README.md @@ -0,0 +1,94 @@ +# CUDA AOTI Shim Tests + +Unit tests for the CUDA AOTI (Ahead-Of-Time Inductor) shim functions used by the ExecuTorch CUDA backend. + +## Prerequisites + +1. **CUDA Toolkit**: Ensure CUDA is installed and available +2. **ExecuTorch with CUDA**: Build and install ExecuTorch with CUDA support first + +## Building ExecuTorch with CUDA + +From the ExecuTorch root directory: + +```bash +# Release build +cmake --workflow --preset llm-release-cuda + +# Or debug build (recommended for debugging test failures) +cmake --workflow --preset llm-debug-cuda +``` + +## Building and Run the Tests + +### Option 1: Using CMake Presets (Recommended) + +From this directory (`backends/cuda/runtime/shims/tests/`): + +```bash +# Release build +cmake --workflow --preset default + +# Debug build +cmake --workflow --preset debug +``` + +### Option 2: Manual CMake Commands + +From the ExecuTorch root directory: + +```bash +# Configure +cmake -B cmake-out/backends/cuda/runtime/shims/tests \ + -S backends/cuda/runtime/shims/tests \ + -DCMAKE_PREFIX_PATH=$(pwd)/cmake-out \ + -DCMAKE_BUILD_TYPE=Debug + +# Build +cmake --build cmake-out/backends/cuda/runtime/shims/tests -j$(nproc) +``` + +### Run Specific Test Cases + +Use Google Test filters to run specific test cases: + +```bash +# From the build directory +cd cmake-out/backends/cuda/runtime/shims/tests +# Run only device mismatch tests +./test_aoti_torch_create_tensor_from_blob_v2 --gtest_filter="*DeviceMismatch*" + +# Run a single test +./test_aoti_torch_create_tensor_from_blob_v2 --gtest_filter="AOTITorchCreateTensorFromBlobV2Test.BasicFunctionalityCUDA" + +# List all available tests +./test_aoti_torch_create_tensor_from_blob_v2 --gtest_list_tests +``` + +## Troubleshooting + +### CUDA Not Available + +If tests are skipped with "CUDA not available", ensure: +- CUDA drivers are installed +- A CUDA-capable GPU is present +- `nvidia-smi` shows the GPU + +### Link Errors + +If you get link errors, ensure ExecuTorch was built with CUDA support: +```bash +cmake --workflow --preset llm-release-cuda +``` + +### Test Failures + +For debugging test failures, build with debug mode: +```bash +cmake --workflow --preset debug +``` + +Then run with verbose output: +```bash +./test_aoti_torch_create_tensor_from_blob_v2 --gtest_break_on_failure +``` diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_create_tensor_from_blob_v2.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_create_tensor_from_blob_v2.cpp index d9b785a5a78..db0ab84970d 100644 --- a/backends/cuda/runtime/shims/tests/test_aoti_torch_create_tensor_from_blob_v2.cpp +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_create_tensor_from_blob_v2.cpp @@ -752,3 +752,68 @@ TEST_F(AOTITorchCreateTensorFromBlobV2Test, StressTestManySmallTensors) { EXPECT_EQ(error, Error::Ok); } } + +// Test device type mismatch: CPU data with CUDA device request should fail +TEST_F(AOTITorchCreateTensorFromBlobV2Test, DeviceMismatchCPUDataCUDADevice) { + std::vector sizes = {2, 3}; + std::vector strides = calculate_contiguous_strides(sizes); + + // Allocate CPU memory + size_t bytes = calculate_numel(sizes) * sizeof(float); + void* cpu_data = allocate_cpu_memory(bytes); + ASSERT_NE(cpu_data, nullptr); + + Tensor* tensor; + // Request CUDA device but provide CPU memory - should fail + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + cpu_data, + sizes.size(), + sizes.data(), + strides.data(), + 0, // storage_offset + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA), // Request CUDA + 0, // device index + &tensor, + 0, // layout + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + EXPECT_EQ(error, Error::InvalidArgument) + << "Should fail when CPU data is provided but CUDA device is requested"; +} + +// Test device type mismatch: CUDA data with CPU device request should fail +TEST_F(AOTITorchCreateTensorFromBlobV2Test, DeviceMismatchCUDADataCPUDevice) { + std::vector sizes = {2, 3}; + std::vector strides = calculate_contiguous_strides(sizes); + + // Allocate CUDA memory (device memory, not managed) + size_t bytes = calculate_numel(sizes) * sizeof(float); + void* cuda_data = nullptr; + cudaError_t cuda_err = cudaMalloc(&cuda_data, bytes); + ASSERT_EQ(cuda_err, cudaSuccess); + ASSERT_NE(cuda_data, nullptr); + + Tensor* tensor; + // Request CPU device but provide CUDA memory - should fail + AOTITorchError error = aoti_torch_create_tensor_from_blob_v2( + cuda_data, + sizes.size(), + sizes.data(), + strides.data(), + 0, // storage_offset + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CPU), // Request CPU + 0, // device index + &tensor, + 0, // layout + nullptr, // opaque_metadata + 0); // opaque_metadata_size + + EXPECT_EQ(error, Error::InvalidArgument) + << "Should fail when CUDA data is provided but CPU device is requested"; + + // Clean up the CUDA memory we allocated directly + cudaFree(cuda_data); +}