diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 4768e3dcd..3369c761d 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -6,77 +6,28 @@ on: jobs: build: - name: Build wheels on ${{ matrix.os }}-${{ matrix.accelerator }} - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - os: [ubuntu-latest, ubuntu-24.04-arm, windows-2022, macos-latest] - accelerator: [cpu, cu118, cu126] #, cu128] - exclude: - - os: ubuntu-24.04-arm - accelerator: cu118 - - os: ubuntu-24.04-arm - accelerator: cu126 - # - os: ubuntu-24.04-arm - # accelerator: cu128 - - os: macos-latest - accelerator: cu118 - - os: macos-latest - accelerator: cu126 - # - os: macos-latest - # accelerator: cu128 + name: Create source distribution + runs-on: ubuntu-latest steps: - - name: Free space of Github Runner (otherwise it will fail by running out of disk) - if: matrix.os == 'ubuntu-latest' - run: | - sudo rm -rf /usr/share/dotnet - sudo rm -rf /opt/ghc - sudo rm -rf "/usr/local/share/boost" - sudo rm -rf "/usr/local/.ghcup" - sudo rm -rf "/usr/local/julia1.9.2" - sudo rm -rf "/usr/local/lib/android" - sudo rm -rf "$AGENT_TOOLSDIRECTORY" - - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - - uses: actions/setup-python@v5 + - uses: actions/setup-python@v6 with: python-version: "3.13" - name: Install cibuildwheel - run: python -m pip install cibuildwheel==3.1.3 - - - name: Activate MSVC - uses: ilammy/msvc-dev-cmd@v1 - with: - toolset: 14.29 - if: matrix.os == 'windows-2022' + run: pip install build - - name: Build wheels - if: matrix.os != 'windows-2022' - shell: bash - run: python -m cibuildwheel --output-dir wheelhouse - env: - ACCELERATOR: ${{ matrix.accelerator }} - CPU_TRAIN: ${{ runner.os == 'macOS' && 'true' || 'false' }} - - - name: Build wheels - if: matrix.os == 'windows-2022' - shell: cmd # Use cmd on Windows to avoid bash environment taking priority over MSVC variables - run: python -m cibuildwheel --output-dir wheelhouse - env: - ACCELERATOR: ${{ matrix.accelerator }} - DISTUTILS_USE_SDK: "1" # Windows requires this to use vc for building - SKIP_TORCH_COMPILE: "true" + - name: Build pypi package + run: python -m build --sdist - uses: actions/upload-artifact@v4 with: - name: ${{ matrix.accelerator }}-cibw-wheels-${{ matrix.os }}-${{ strategy.job-index }} - path: ./wheelhouse/*.whl + name: source_dist + path: dist/*.tar.gz - publish-to-public-pypi: + publish-to-pypi: name: >- Publish Python 🐍 distribution 📦 to PyPI needs: @@ -84,18 +35,14 @@ jobs: runs-on: ubuntu-latest environment: name: pypi + url: https://pypi.org/p/torchmd-net permissions: id-token: write # IMPORTANT: mandatory for trusted publishing - strategy: - fail-fast: false - matrix: - accelerator: [cpu, cu118, cu126] #, hip, cu124, cu126, cu128] steps: - name: Download all the dists uses: actions/download-artifact@v4 with: - pattern: "${{ matrix.accelerator }}-cibw-wheels*" path: dist/ merge-multiple: true @@ -103,69 +50,7 @@ jobs: uses: pypa/gh-action-pypi-publish@release/v1 with: password: ${{ secrets.TMDNET_PYPI_API_TOKEN }} - skip-existing: true - - # publish-to-accelera-pypi: - # name: >- - # Publish Python 🐍 distribution 📦 to Acellera PyPI - # needs: - # - build - # runs-on: ubuntu-latest - # permissions: # Needed for GCP authentication - # contents: "read" - # id-token: "write" - # strategy: - # fail-fast: false - # matrix: - # accelerator: [cpu, cu118, cu126, cu128] - - # steps: - # - uses: actions/checkout@v4 # Needed for GCP authentication for some reason - - # - name: Set up Cloud SDK - # uses: google-github-actions/auth@v2 - # with: - # workload_identity_provider: ${{ secrets.GCP_WORKLOAD_IDENTITY_PROVIDER }} - # service_account: ${{ secrets.GCP_PYPI_SERVICE_ACCOUNT }} - - # - name: Download all the dists - # uses: actions/download-artifact@v4 - # with: - # pattern: "${{ matrix.accelerator }}-cibw-wheels*" - # path: dist/ - # merge-multiple: true - - # - name: Publish distribution 📦 to Acellera PyPI - # run: | - # pip install build twine keyring keyrings.google-artifactregistry-auth - # pip install -U packaging - # twine upload --repository-url https://us-central1-python.pkg.dev/pypi-packages-455608/${{ matrix.accelerator }} dist/* --verbose --skip-existing - - # publish-to-official-pypi: - # name: >- - # Publish Python 🐍 distribution 📦 to PyPI - # needs: - # - build - # runs-on: ubuntu-latest - # environment: - # name: pypi - # url: https://pypi.org/p/torchmd-net - # permissions: - # id-token: write # IMPORTANT: mandatory for trusted publishing - - # steps: - # - name: Download all the dists - # uses: actions/download-artifact@v4 - # with: - # pattern: "cu118-cibw-wheels*" - # path: dist/ - # merge-multiple: true - - # - name: Publish distribution 📦 to PyPI - # uses: pypa/gh-action-pypi-publish@release/v1 - # with: - # password: ${{ secrets.TMDNET_PYPI_API_TOKEN }} - # skip_existing: true + skip_existing: true github-release: name: >- diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 2644663ee..bb9a2101b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -17,93 +17,28 @@ jobs: ["ubuntu-latest", "ubuntu-22.04-arm", "macos-latest", "windows-2022"] python-version: ["3.13"] - defaults: # Needed for conda - run: - shell: bash -l {0} - steps: - name: Check out - uses: actions/checkout@v4 - - - uses: conda-incubator/setup-miniconda@v3 - with: - python-version: ${{ matrix.python-version }} - channels: conda-forge - conda-remove-defaults: "true" - if: matrix.os != 'macos-13' + uses: actions/checkout@v5 - - uses: conda-incubator/setup-miniconda@v3 + - name: Install uv + uses: astral-sh/setup-uv@v7 with: python-version: ${{ matrix.python-version }} - channels: conda-forge - mamba-version: "*" - conda-remove-defaults: "true" - if: matrix.os == 'macos-13' - - - name: Install OS-specific compilers - run: | - if [[ "${{ matrix.os }}" == "ubuntu-22.04-arm" ]]; then - conda install gxx --channel conda-forge --override-channels - elif [[ "${{ runner.os }}" == "Linux" ]]; then - conda install gxx --channel conda-forge --override-channels - elif [[ "${{ runner.os }}" == "macOS" ]]; then - conda install llvm-openmp pybind11 --channel conda-forge --override-channels - echo "CC=clang" >> $GITHUB_ENV - echo "CXX=clang++" >> $GITHUB_ENV - elif [[ "${{ runner.os }}" == "Windows" ]]; then - conda install vc vc14_runtime vs2015_runtime --channel conda-forge --override-channels - fi - - - name: List the conda environment - run: conda list - - name: Install testing packages - run: conda install -y -c conda-forge flake8 pytest psutil python-build + - name: Install the project + run: uv sync --all-extras --dev - name: Lint with flake8 run: | # stop the build if there are Python syntax errors or undefined names - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + uv run flake8 ./torchmdnet --count --select=E9,F63,F7,F82 --show-source --statistics # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - - - name: Set pytorch index - run: | - if [[ "${{ runner.os }}" == "Windows" ]]; then - mkdir -p "C:\ProgramData\pip" - echo "[global] - extra-index-url = https://download.pytorch.org/whl/cpu" > "C:\ProgramData\pip\pip.ini" - else - mkdir -p $HOME/.config/pip - echo "[global] - extra-index-url = https://download.pytorch.org/whl/cpu" > $HOME/.config/pip/pip.conf - fi - - - name: Build and install the package - run: | - if [[ "${{ runner.os }}" == "Windows" ]]; then - export LIB="C:/Miniconda/envs/test/Library/lib" - fi - python -m build - pip install dist/*.whl - env: - ACCELERATOR: "cpu" - - # - name: Install nnpops - # if: matrix.os == 'ubuntu-latest' || matrix.os == 'macos-latest' - # run: conda install nnpops --channel conda-forge --override-channels - - - name: List the conda environment - run: conda list + uv run flake8 ./torchmdnet --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - name: Run tests - run: pytest -v -s --durations=10 - env: - ACCELERATOR: "cpu" - SKIP_TORCH_COMPILE: ${{ runner.os == 'Windows' && 'true' || 'false' }} - OMP_PREFIX: ${{ runner.os == 'macOS' && '/Users/runner/miniconda3/envs/test' || '' }} - CPU_TRAIN: ${{ runner.os == 'macOS' && 'true' || 'false' }} - LONG_TRAIN: "true" + # For example, using `pytest` + run: uv run pytest tests - name: Test torchmd-train utility - run: torchmd-train --help + run: uv run torchmd-train --help diff --git a/README.md b/README.md index 0d1be37cb..0e3de23c7 100644 --- a/README.md +++ b/README.md @@ -21,19 +21,17 @@ Documentation is available at https://torchmd-net.readthedocs.io ## Installation -TorchMD-Net is available as a pip installable wheel as well as in [conda-forge](https://conda-forge.org/) +TorchMD-Net is available as a pip package as well as in [conda-forge](https://conda-forge.org/) -TorchMD-Net provides builds for CPU-only, CUDA 11 and CUDA 12. CPU versions are only provided as reference, -as the performance will be extremely limited. -Depending on which variant you wish to install, you can install it with one of the following commands: +As TorchMD-Net depends on PyTorch we need to add additional index URLs to the installation command as per [pytorch](https://pytorch.org/get-started/locally/) ```sh -# The following will install the CUDA 11.8 version -pip install torchmd-net-cu11 --extra-index-url https://download.pytorch.org/whl/cu118 -# The following will install the CUDA 12.4 version -pip install torchmd-net-cu12 --extra-index-url https://download.pytorch.org/whl/cu124 -# The following will install the CPU only version (not recommended) -pip install torchmd-net-cpu --extra-index-url https://download.pytorch.org/whl/cpu +# The following will install TorchMD-Net with PyTorch CUDA 11.8 version +pip install torchmd-net --extra-index-url https://download.pytorch.org/whl/cu118 +# The following will install TorchMD-Net with PyTorch CUDA 12.4 version +pip install torchmd-net --extra-index-url https://download.pytorch.org/whl/cu124 +# The following will install TorchMD-Net with PyTorch CPU only version (not recommended) +pip install torchmd-net --extra-index-url https://download.pytorch.org/whl/cpu ``` Alternatively it can be installed with conda or mamba with one of the following commands. @@ -46,7 +44,7 @@ mamba install torchmd-net cuda-version=12.4 ### Install from source -TorchMD-Net is installed using pip, but you will need to install some dependencies before. Check [this documentation page](https://torchmd-net.readthedocs.io/en/latest/installation.html#install-from-source). +TorchMD-Net is installed using pip with `pip install -e .` to create an editable install. ## Usage Specifying training arguments can either be done via a configuration yaml file or through command line arguments directly. Several examples of architectural and training specifications for some models and datasets can be found in [examples/](https://github.com/torchmd/torchmd-net/tree/main/examples). Note that if a parameter is present both in the yaml file and the command line, the command line version takes precedence. diff --git a/benchmarks/neighbors.py b/benchmarks/neighbors.py index 2db969d92..58209c4d7 100644 --- a/benchmarks/neighbors.py +++ b/benchmarks/neighbors.py @@ -178,17 +178,16 @@ def benchmark_neighbors( if __name__ == "__main__": + strategies = ["distance", "brute", "cell"] n_particles = 32767 mean_num_neighbors = min(n_particles, 64) density = 0.8 print( - "Benchmarking neighbor list generation for {} particles with {} neighbors on average".format( - n_particles, mean_num_neighbors - ) + f"Benchmarking neighbor list generation for {n_particles} particles with {mean_num_neighbors} neighbors on average" ) results = {} batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] - for strategy in ["shared", "brute", "cell", "distance"]: + for strategy in strategies: for n_batches in batch_sizes: time = benchmark_neighbors( device="cuda", @@ -200,33 +199,22 @@ def benchmark_neighbors( ) # Store results in a dictionary results[strategy, n_batches] = time + print("Summary") print("-------") - print( - "{:<10} {:<21} {:<21} {:<18} {:<10}".format( - "Batch size", "Shared(ms)", "Brute(ms)", "Cell(ms)", "Distance(ms)" - ) - ) - print( - "{:<10} {:<21} {:<21} {:<18} {:<10}".format( - "----------", "---------", "---------", "---------", "---------" - ) - ) + headers = "Batch size " + for strategy in strategies: + headers += f"{strategy.capitalize()+'(ms)':<21}" + print(headers) + print("-" * len(headers)) # Print a column per strategy, show speedup over Distance in parenthesis for n_batches in batch_sizes: base = results["distance", n_batches] - print( - "{:<10} {:<4.2f} x{:<14.2f} {:<4.2f} x{:<14.2f} {:<4.2f} x{:<14.2f} {:<10.2f}".format( - n_batches, - results["shared", n_batches], - base / results["shared", n_batches], - results["brute", n_batches], - base / results["brute", n_batches], - results["cell", n_batches], - base / results["cell", n_batches], - results["distance", n_batches], - ) - ) + row = f"{n_batches:<11}" + for strategy in strategies: + row += f"{results[strategy, n_batches]:<5.2f} x{base / results[strategy, n_batches]:<13.2f} " + print(row) + n_particles_list = np.power(2, np.arange(8, 18)) for n_batches in [1, 2, 32, 64]: @@ -236,7 +224,7 @@ def benchmark_neighbors( ) ) results = {} - for strategy in ["shared", "brute", "cell", "distance"]: + for strategy in strategies: for n_particles in n_particles_list: mean_num_neighbors = min(n_particles, 64) time = benchmark_neighbors( @@ -251,32 +239,19 @@ def benchmark_neighbors( results[strategy, n_particles] = time print("Summary") print("-------") - print( - "{:<10} {:<21} {:<21} {:<18} {:<10}".format( - "N Particles", "Shared(ms)", "Brute(ms)", "Cell(ms)", "Distance(ms)" - ) - ) - print( - "{:<10} {:<21} {:<21} {:<18} {:<10}".format( - "----------", "---------", "---------", "---------", "---------" - ) - ) + headers = "N Particles " + for strategy in strategies: + headers += f"{strategy.capitalize()+'(ms)':<21}" + print(headers) + print("-" * len(headers)) # Print a column per strategy, show speedup over Distance in parenthesis for n_particles in n_particles_list: base = results["distance", n_particles] brute_speedup = base / results["brute", n_particles] - if n_particles > 32000: - results["brute", n_particles] = 0 - brute_speedup = 0 - print( - "{:<10} {:<4.2f} x{:<14.2f} {:<4.2f} x{:<14.2f} {:<4.2f} x{:<14.2f} {:<10.2f}".format( - n_particles, - results["shared", n_particles], - base / results["shared", n_particles], - results["brute", n_particles], - brute_speedup, - results["cell", n_particles], - base / results["cell", n_particles], - results["distance", n_particles], - ) - ) + # if n_particles > 32000: + # results["brute", n_particles] = 0 + # brute_speedup = 0 + row = f"{n_particles:<12}" + for strategy in strategies: + row += f"{results[strategy, n_particles]:<5.2f} x{base / results[strategy, n_particles]:<13.2f} " + print(row) diff --git a/cibuildwheel_support/before_all_linux.sh b/cibuildwheel_support/before_all_linux.sh deleted file mode 100755 index 12c814b31..000000000 --- a/cibuildwheel_support/before_all_linux.sh +++ /dev/null @@ -1,81 +0,0 @@ -#! /bin/bash - -set -e -set -x - -# Configure pip to use PyTorch extra-index-url for CPU -mkdir -p $HOME/.config/pip -echo "[global] -extra-index-url = https://download.pytorch.org/whl/cpu" > $HOME/.config/pip/pip.conf - - -if [ "$ACCELERATOR" == "cu118" ]; then - - # Install CUDA 11.8: - dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo - - dnf install --setopt=obsoletes=0 -y \ - cuda-nvcc-11-8-11.8.89-1 \ - cuda-cudart-devel-11-8-11.8.89-1 \ - libcurand-devel-11-8-10.3.0.86-1 \ - libcudnn9-devel-cuda-11-9.8.0.87-1 \ - libcublas-devel-11-8-11.11.3.6-1 \ - libnccl-devel-2.15.5-1+cuda11.8 \ - libcusparse-devel-11-8-11.7.5.86-1 \ - libcusolver-devel-11-8-11.4.1.48-1 \ - gcc-toolset-11 - - ln -s cuda-11.8 /usr/local/cuda - ln -s /opt/rh/gcc-toolset-11/root/usr/bin/gcc /usr/local/cuda/bin/gcc - ln -s /opt/rh/gcc-toolset-11/root/usr/bin/g++ /usr/local/cuda/bin/g++ - - export CUDA_HOME="/usr/local/cuda" - - # Configure pip to use PyTorch extra-index-url for CUDA 11.8 - mkdir -p $HOME/.config/pip - echo "[global] -extra-index-url = https://download.pytorch.org/whl/cu118" > $HOME/.config/pip/pip.conf - -elif [ "$ACCELERATOR" == "cu126" ]; then - # Install CUDA 12.6 - dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo - - dnf install --setopt=obsoletes=0 -y \ - cuda-compiler-12-6-12.6.3-1 \ - cuda-libraries-12-6-12.6.3-1 \ - cuda-libraries-devel-12-6-12.6.3-1 \ - cuda-toolkit-12-6-12.6.3-1 \ - gcc-toolset-13 - - ln -s cuda-12.6 /usr/local/cuda - ln -s /opt/rh/gcc-toolset-13/root/usr/bin/gcc /usr/local/cuda/bin/gcc - ln -s /opt/rh/gcc-toolset-13/root/usr/bin/g++ /usr/local/cuda/bin/g++ - ln -s /usr/local/cuda/targets/x86_64-linux/lib/stubs/libcuda.so /usr/lib/libcuda.so.1 - - # Configure pip to use PyTorch extra-index-url for CUDA 12.6 - mkdir -p $HOME/.config/pip - echo "[global] -extra-index-url = https://download.pytorch.org/whl/cu126" > $HOME/.config/pip/pip.conf - -elif [ "$ACCELERATOR" == "cu128" ]; then - # Install CUDA 12.8 - dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo - - dnf install --setopt=obsoletes=0 -y \ - cuda-compiler-12-8-12.8.1-1 \ - cuda-libraries-12-8-12.8.1-1 \ - cuda-libraries-devel-12-8-12.8.1-1 \ - cuda-toolkit-12-8-12.8.1-1 \ - gcc-toolset-13 - - ln -s cuda-12.8 /usr/local/cuda - ln -s /opt/rh/gcc-toolset-13/root/usr/bin/gcc /usr/local/cuda/bin/gcc - ln -s /opt/rh/gcc-toolset-13/root/usr/bin/g++ /usr/local/cuda/bin/g++ - ln -s /usr/local/cuda/targets/x86_64-linux/lib/stubs/libcuda.so /usr/lib/libcuda.so.1 - - # Configure pip to use PyTorch extra-index-url for CUDA 12.8 - mkdir -p $HOME/.config/pip - echo "[global] -extra-index-url = https://download.pytorch.org/whl/cu128" > $HOME/.config/pip/pip.conf - -fi \ No newline at end of file diff --git a/cibuildwheel_support/before_all_windows.sh b/cibuildwheel_support/before_all_windows.sh deleted file mode 100755 index 9cf789159..000000000 --- a/cibuildwheel_support/before_all_windows.sh +++ /dev/null @@ -1,37 +0,0 @@ -#!/bin/bash - -set -e -set -x - -# Create pip directory if it doesn't exist -mkdir -p "C:\ProgramData\pip" - -# Create pip.ini file with PyTorch CPU index -echo "[global] -extra-index-url = https://download.pytorch.org/whl/cpu" > "C:\ProgramData\pip\pip.ini" - -if [ "$ACCELERATOR" == "cu118" ]; then - curl --netrc-optional -L -nv -o cuda.exe https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_522.06_windows.exe - ./cuda.exe -s nvcc_11.8 cudart_11.8 cublas_dev_11.8 curand_dev_11.8 cusparse_dev_11.8 cusolver_dev_11.8 thrust_11.8 - rm cuda.exe - - # Create pip.ini file with PyTorch CUDA 11.8 index - echo "[global] -extra-index-url = https://download.pytorch.org/whl/cu118" > "C:\ProgramData\pip\pip.ini" -elif [ "$ACCELERATOR" == "cu126" ]; then - curl --netrc-optional -L -nv -o cuda.exe https://developer.download.nvidia.com/compute/cuda/12.6.0/local_installers/cuda_12.6.0_560.76_windows.exe - ./cuda.exe -s nvcc_12.6 cudart_12.6 cublas_dev_12.6 curand_dev_12.6 cusparse_dev_12.6 cusolver_dev_12.6 thrust_12.6 - rm cuda.exe - - # Create pip.ini file with PyTorch CUDA 12.6 index - echo "[global] -extra-index-url = https://download.pytorch.org/whl/cu126" > "C:\ProgramData\pip\pip.ini" -elif [ "$ACCELERATOR" == "cu128" ]; then - curl --netrc-optional -L -nv -o cuda.exe https://developer.download.nvidia.com/compute/cuda/12.8.1/local_installers/cuda_12.8.1_572.61_windows.exe - ./cuda.exe -s nvcc_12.8 cudart_12.8 cublas_dev_12.8 curand_dev_12.8 cusparse_dev_12.8 cusolver_dev_12.8 thrust_12.8 - rm cuda.exe - - # Create pip.ini file with PyTorch CUDA 12.8 index - echo "[global] -extra-index-url = https://download.pytorch.org/whl/cu128" > "C:\ProgramData\pip\pip.ini" -fi \ No newline at end of file diff --git a/docs/source/installation.rst b/docs/source/installation.rst index 769bce14c..6e01ca096 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -1,20 +1,18 @@ Installation ============ -TorchMD-Net is available as a pip installable wheel as well as in `conda-forge `_ +TorchMD-Net is available as a pip package as well as in `conda-forge `_ -TorchMD-Net provides builds for CPU-only, CUDA 11 and CUDA 12. -CPU versions are only provided as reference, as the performance will be extremely limited. -Depending on which variant you wish to install, you can install it with one of the following commands: +As TorchMD-Net depends on PyTorch we need to add additional index URLs to the installation command as per `pytorch `_ .. code-block:: shell - # The following will install the CUDA 11.8 version - pip install torchmd-net-cu11 --extra-index-url https://download.pytorch.org/whl/cu118 - # The following will install the CUDA 12.4 version - pip install torchmd-net-cu12 --extra-index-url https://download.pytorch.org/whl/cu124 - # The following will install the CPU only version (not recommended) - pip install torchmd-net-cpu --extra-index-url https://download.pytorch.org/whl/cpu + # The following will install TorchMD-Net with PyTorch CUDA 11.8 version + pip install torchmd-net --extra-index-url https://download.pytorch.org/whl/cu118 + # The following will install TorchMD-Net with PyTorch CUDA 12.4 version + pip install torchmd-net --extra-index-url https://download.pytorch.org/whl/cu124 + # The following will install TorchMD-Net with PyTorch CPU only version (not recommended) + pip install torchmd-net --extra-index-url https://download.pytorch.org/whl/cpu Alternatively it can be installed with conda or mamba with one of the following commands. We recommend using `Miniforge `_ instead of anaconda. @@ -27,70 +25,14 @@ We recommend using `Miniforge `_ inst Install from source ------------------- +For development purposes, we recommend using `uv `_ to install the TorchMD-Net in editable mode. +After installing uv, run the following command to install the TorchMD-Net and its dependencies in editable mode. -1. Clone the repository: - -.. code-block:: shell - - git clone https://github.com/torchmd/torchmd-net.git - cd torchmd-net - -2. Install the dependencies in environment.yml. - -.. code-block:: shell - - conda env create -f environment.yml - conda activate torchmd-net - -3. CUDA enabled installation - -You can skip this section if you only need a CPU installation. - -You will need the CUDA compiler (nvcc) and the corresponding development libraries to build TorchMD-Net with CUDA support. You can install CUDA from the `official NVIDIA channel `_ or from conda-forge. - -The conda-forge channel `changed the way to install CUDA from versions 12 and above `_, thus the following instructions depend on whether you need CUDA < 12. If you have a GPU available, conda-forge probably installed the CUDA runtime (not the developer tools) on your system already, you can check with conda: - -.. code-block:: shell - - conda list | grep cuda - - -Or by asking pytorch: - .. code-block:: shell - - python -c "import torch; print(torch.version.cuda)" - - -It is recommended to install the same version as the one used by torch. - -.. warning:: At the time of writing there is a `bug in Mamba `_ (v1.5.6) that can cause trouble when installing CUDA on an already created environment. We thus recommend conda for this step. - -* CUDA>=12 - -.. code-block:: shell - - conda install -c conda-forge python=3.10 cuda-version=12.6 cuda-nvvm cuda-nvcc cuda-libraries-dev - - -* CUDA<12 - -The nvidia channel provides the developer tools for CUDA<12. - -.. code-block:: shell - - conda install -c nvidia "cuda-nvcc<12" "cuda-libraries-dev<12" "cuda-version<12" "gxx<12" pytorch=*=*cuda* - - -4. Install TorchMD-NET into the environment: - -.. code-block:: shell - - pip install -e . + uv sync -.. note:: Pip installation in CUDA mode requires compiling CUDA source codes, this can take a really long time and the process might appear as if it is "stuck". Run pip with `-vv` to see the compilation process. +This will install the package alongside PyTorch for CUDA 12.6. To change the CUDA version, +edit the `pyproject.toml `_ file and change the `torch` dependency to the desired version. -This will install TorchMD-NET in editable mode, so that changes to the source code are immediately available. -Besides making all python utilities available environment-wide, this will also install the ``torchmd-train`` command line utility. diff --git a/environment.yml b/environment.yml deleted file mode 100644 index a1b851e38..000000000 --- a/environment.yml +++ /dev/null @@ -1,19 +0,0 @@ -name: torchmd-net -channels: - - conda-forge -dependencies: - - h5py - - nnpops - - pip - - pytorch - - pytorch_geometric - - lightning - - torchmetrics - - tqdm - # Dev tools - - flake8 - - pytest - - psutil - - gxx - # optional - - ase diff --git a/examples/aceff_examples/ase_aceff.py b/examples/aceff_examples/ase_aceff.py index 9c9ceaf15..93e57c3cf 100644 --- a/examples/aceff_examples/ase_aceff.py +++ b/examples/aceff_examples/ase_aceff.py @@ -17,7 +17,7 @@ # We create the ASE calculator by supplying the path to the model and specifying the device and dtype calc = TMDNETCalculator(model_file_path, device='cuda') -atoms = read('caffiene.pdb') +atoms = read('caffeine.pdb') print(atoms) atoms.calc = calc @@ -81,9 +81,6 @@ atoms.calc = calc -# Single point calcuation to trigger compile -atoms.get_potential_energy() - # Run more dynamics t1 = time.perf_counter() dyn.run(steps=nsteps) diff --git a/pyproject.toml b/pyproject.toml index 0092a3cbc..5939d1b95 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,15 +1,27 @@ [project] -name = "PLACEHOLDER" +name = "torchmd-net" description = "TorchMD-NET provides state-of-the-art neural networks potentials for biomolecular systems" authors = [{ name = "Acellera", email = "info@acellera.com" }] readme = "README.md" license = "MIT" -requires-python = ">=3.8" -dynamic = ["version", "dependencies"] +requires-python = ">=3.9" +dynamic = ["version"] classifiers = [ "Programming Language :: Python :: 3", "Operating System :: POSIX :: Linux", ] +dependencies = [ + "h5py", + # "nnpops", + "torch", + "torch_geometric", + "lightning", + "tqdm", + "numpy", + "triton; sys_platform == 'linux' and platform_machine != 'aarch64'", + "triton-windows; sys_platform == 'win32'", + "ase", +] [project.urls] "Homepage" = "https://github.com/torchmd/torchmd-net" @@ -27,40 +39,21 @@ include = ["torchmdnet*"] "*" = ["*.c", "*.cpp", "*.h", "*.cuh", "*.cu", ".gitignore"] [build-system] -requires = ["setuptools>=78", "setuptools-scm>=8", "torch==2.7.1", "numpy"] +requires = ["setuptools>=78", "setuptools-scm>=8"] build-backend = "setuptools.build_meta" +[dependency-groups] +dev = ["flake8>=7.3.0", "ipython>=8.18.1", "pytest>=8.4.2"] -[tool.cibuildwheel] -# Disable builds which can't support CUDA and pytorch -skip = [ - "cp38-*", - "cp314-*", - "cp314t-*", - "pp*", - "*win32", - "*armv7l", - "*_i686", - "*_ppc64le", - "*_s390x", - "*_universal2", - "*-musllinux_*", -] -test-requires = ["pytest", "pytest-xdist"] -test-command = "pytest {project}/tests" -environment-pass = ["ACCELERATOR", "CI"] -# container-engine = "docker; create_args: --gpus all" - -[tool.cibuildwheel.linux] -before-all = "bash {project}/cibuildwheel_support/before_all_linux.sh" -repair-wheel-command = [ - "auditwheel repair --exclude libcudart.so.* --exclude libc10.so --exclude libc10_cuda.so --exclude libtorch.so --exclude libtorch_cuda.so --exclude libtorch_cpu.so --exclude libtorch_python.so -w {dest_dir} {wheel}", +[tool.uv.sources] +torch = [ + { index = "pytorch-cu126", marker = "(sys_platform == 'linux' or sys_platform == 'win32') and platform_machine != 'aarch64' and platform_machine != 'arm64'" }, ] - -[tool.cibuildwheel.macos] -repair-wheel-command = [ - "delocate-wheel --ignore-missing-dependencies --require-archs {delocate_archs} -w {dest_dir} -v {wheel}", +torchvision = [ + { index = "pytorch-cu126", marker = "(sys_platform == 'linux' or sys_platform == 'win32') and platform_machine != 'aarch64' and platform_machine != 'arm64'" }, ] -[tool.cibuildwheel.windows] -before-all = "bash {project}/cibuildwheel_support/before_all_windows.sh" +[[tool.uv.index]] +name = "pytorch-cu126" +url = "https://download.pytorch.org/whl/cu126" +explicit = true diff --git a/setup.py b/setup.py deleted file mode 100644 index b2ee351e9..000000000 --- a/setup.py +++ /dev/null @@ -1,113 +0,0 @@ -# Copyright Universitat Pompeu Fabra 2020-2023 https://www.compscience.org -# Distributed under the MIT License. -# (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) - -from setuptools import setup -import torch -from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension -import os -import platform - - -if os.environ.get("ACCELERATOR", None) is not None: - use_cuda = os.environ.get("ACCELERATOR", "").startswith("cu") -else: - use_cuda = torch.cuda._is_compiled() - - -def _replace_name(name): - import pathlib - - pyproject_path = pathlib.Path(__file__).parent / "pyproject.toml" - with open(pyproject_path, "r") as f: - pyproject_text = f.read() - pyproject_text = pyproject_text.replace("PLACEHOLDER", name) - with open(pyproject_path, "w") as f: - f.write(pyproject_text) - - -if os.getenv("ACCELERATOR", "").startswith("cpu"): - _replace_name("torchmd-net-cpu") -if os.getenv("ACCELERATOR", "").startswith("cu"): - cuda_ver = os.getenv("ACCELERATOR", "")[2:4] - _replace_name(f"torchmd-net-cu{cuda_ver}") - - -def set_torch_cuda_arch_list(): - """Set the CUDA arch list according to the architectures the current torch installation was compiled for. - This function is a no-op if the environment variable TORCH_CUDA_ARCH_LIST is already set or if torch was not compiled with CUDA support. - """ - if use_cuda and not os.environ.get("TORCH_CUDA_ARCH_LIST"): - arch_flags = torch._C._cuda_getArchFlags() - sm_versions = [x[3:] for x in arch_flags.split() if x.startswith("sm_")] - formatted_versions = ";".join([f"{y[:-1]}.{y[-1]}" for y in sm_versions]) - formatted_versions += "+PTX" - os.environ["TORCH_CUDA_ARCH_LIST"] = formatted_versions - - -set_torch_cuda_arch_list() - -extension_root = os.path.join("torchmdnet", "extensions") -neighbor_sources = ["neighbors_cpu.cpp"] -if use_cuda: - neighbor_sources.append("neighbors_cuda.cu") -neighbor_sources = [ - os.path.join(extension_root, "neighbors", source) for source in neighbor_sources -] - -runtime_library_dirs = None -if platform.system() == "Darwin": - runtime_library_dirs = [ - "@loader_path/../../torch/lib", - "@loader_path/../../nvidia/cuda_runtime/lib", - ] -elif platform.system() == "Linux": - runtime_library_dirs = [ - "$ORIGIN/../../torch/lib", - "$ORIGIN/../../nvidia/cuda_runtime/lib", - ] - -extra_deps = [] -if os.getenv("ACCELERATOR", "").startswith("cu"): - cuda_ver = os.getenv("ACCELERATOR")[2:4] - extra_deps = [f"nvidia-cuda-runtime-cu{cuda_ver}"] - -ExtensionType = CppExtension if not use_cuda else CUDAExtension -extensions = ExtensionType( - name="torchmdnet.extensions.torchmdnet_extensions", - sources=[os.path.join(extension_root, "torchmdnet_extensions.cpp")] - + neighbor_sources, - define_macros=[("WITH_CUDA", 1)] if use_cuda else [], - runtime_library_dirs=runtime_library_dirs, -) - -kwargs = {} -if "CI" in os.environ: - from setuptools_scm import get_version - - # Drop the dev version suffix because we modify pyproject.toml - # We do this only in CI because we need to upload to PyPI - - kwargs = {"version": ".".join(get_version().split(".")[:3])} - -if __name__ == "__main__": - setup( - ext_modules=[extensions], - cmdclass={ - "build_ext": BuildExtension.with_options( - no_python_abi_suffix=True, use_ninja=False - ) - }, - install_requires=[ - "h5py", - # "nnpops", - "torch==2.7.1", - "torch_geometric", - "lightning", - "tqdm", - "numpy", - "ase", - ] - + extra_deps, - **kwargs, - ) diff --git a/tests/caffeine.pdb b/tests/caffeine.pdb new file mode 100644 index 000000000..080ed8f97 --- /dev/null +++ b/tests/caffeine.pdb @@ -0,0 +1,51 @@ +MODEL 1 +HETATM 1 O UNL 1 0.470 2.569 0.001 1.00 0.00 O +HETATM 2 O UNL 2 -3.127 -0.444 -0.000 1.00 0.00 O +HETATM 3 N UNL 3 -0.969 -1.312 0.000 1.00 0.00 N +HETATM 4 N UNL 4 2.218 0.141 -0.000 1.00 0.00 N +HETATM 5 N UNL 5 -1.348 1.080 -0.000 1.00 0.00 N +HETATM 6 N UNL 6 1.412 -1.937 0.000 1.00 0.00 N +HETATM 7 C UNL 7 0.858 0.259 -0.001 1.00 0.00 C +HETATM 8 C UNL 8 0.390 -1.026 -0.000 1.00 0.00 C +HETATM 9 C UNL 9 0.031 1.422 -0.001 1.00 0.00 C +HETATM 10 C UNL 10 -1.906 -0.249 -0.000 1.00 0.00 C +HETATM 11 C UNL 11 2.503 -1.200 0.000 1.00 0.00 C +HETATM 12 C UNL 12 -1.428 -2.696 0.001 1.00 0.00 C +HETATM 13 C UNL 13 3.193 1.206 0.000 1.00 0.00 C +HETATM 14 C UNL 14 -2.297 2.188 0.001 1.00 0.00 C +HETATM 15 H UNL 15 3.516 -1.579 0.001 1.00 0.00 H +HETATM 16 H UNL 16 -1.045 -3.197 -0.894 1.00 0.00 H +HETATM 17 H UNL 17 -2.519 -2.760 0.001 1.00 0.00 H +HETATM 18 H UNL 18 -1.045 -3.196 0.896 1.00 0.00 H +HETATM 19 H UNL 19 4.199 0.780 0.000 1.00 0.00 H +HETATM 20 H UNL 20 3.047 1.809 -0.899 1.00 0.00 H +HETATM 21 H UNL 21 3.047 1.808 0.900 1.00 0.00 H +HETATM 22 H UNL 22 -1.809 3.165 -0.000 1.00 0.00 H +HETATM 23 H UNL 23 -2.932 2.103 0.888 1.00 0.00 H +HETATM 24 H UNL 24 -2.935 2.102 -0.885 1.00 0.00 H +CONECT 1 9 +CONECT 2 10 +CONECT 3 8 10 12 +CONECT 4 7 11 13 +CONECT 5 9 10 14 +CONECT 6 8 11 +CONECT 7 4 8 9 +CONECT 8 3 6 7 +CONECT 9 1 5 7 +CONECT 10 2 3 5 +CONECT 11 4 6 15 +CONECT 12 3 16 17 18 +CONECT 13 4 19 20 21 +CONECT 14 5 22 23 24 +CONECT 15 11 +CONECT 16 12 +CONECT 17 12 +CONECT 18 12 +CONECT 19 13 +CONECT 20 13 +CONECT 21 13 +CONECT 22 14 +CONECT 23 14 +CONECT 24 14 +MASTER 0 0 0 0 0 0 0 0 24 0 24 0 +END diff --git a/tests/example_tensornet.ckpt b/tests/example_tensornet.ckpt new file mode 100644 index 000000000..2c7f521ac Binary files /dev/null and b/tests/example_tensornet.ckpt differ diff --git a/tests/test_calculator.py b/tests/test_calculator.py index 90d552060..08b53f9c7 100644 --- a/tests/test_calculator.py +++ b/tests/test_calculator.py @@ -81,3 +81,76 @@ def test_compare_forward_multiple(): assert_close(e_calc, e_pred) assert_close(f_calc, f_pred.view(-1, len(z1), 3), rtol=3e-4, atol=1e-5) + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_ase_calculator(device): + import platform + + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + # Skip on Windows CI for now because we don't have cl.exe installed for torch.compile + if platform.system() == "Windows": + pytest.skip("Skipping test on Windows") + + from torchmdnet.calculators import TMDNETCalculator + from ase.io import read + from ase import units + from ase.md.langevin import Langevin + import os + import numpy as np + + ref_forces = np.array( + [ + [-7.76999056e-01, -6.89736724e-01, -9.31625906e-03], + [2.16895628e00, 5.29322922e-01, 9.77647374e-04], + [-6.50327325e-01, 1.11337602e00, 2.27598846e-03], + [1.19350255e00, -1.23314810e00, -7.83117674e-03], + [5.11314452e-01, -3.33878160e-01, -5.03035402e-03], + [-7.18148768e-01, 7.37230778e-02, 3.08941817e-03], + [-1.25317931e-01, -5.19263268e-01, 2.00758013e-03], + [-3.05806249e-01, -4.49415118e-01, -8.53991229e-03], + [5.10734320e-03, 2.49908626e-01, 2.04431713e-02], + [-3.65967184e-01, -1.57078415e-01, -1.55984145e-03], + [-6.44133329e-01, 1.16345167e00, 4.54566162e-03], + [2.05828249e-02, -2.64510632e-01, -1.38899162e-02], + [1.73451304e-02, 3.65104795e-01, 1.13833081e-02], + [-8.57830405e-01, -2.25283504e-01, -2.49589253e-02], + [-1.56955227e-01, 1.19012646e-01, -1.87584094e-03], + [-1.50042176e-02, 1.75106078e-02, 2.51995742e-01], + [3.01239967e-01, 3.67318511e-01, 4.64916229e-06], + [-9.57870483e-03, 1.21697336e-02, -2.39765823e-01], + [-2.48186022e-01, 2.74000764e-02, -1.08634552e-03], + [1.26295090e-01, 1.04473650e-01, 2.81187654e-01], + [1.28753006e-01, 1.03064716e-01, -2.88918495e-01], + [2.80321002e-01, -5.11180341e-01, -1.12308562e-03], + [5.06305993e-02, 6.65888190e-02, -2.11322665e-01], + [7.02065229e-02, 7.10679889e-02, 2.37307906e-01], + ] + ) + + curr_dir = os.path.dirname(__file__) + + checkpoint = join(curr_dir, "example_tensornet.ckpt") + calc = TMDNETCalculator(checkpoint, device=device) + + atoms = read(join(curr_dir, "caffeine.pdb")) + atoms.calc = calc + # The total molecular charge must be set + atoms.info["charge"] = 0 + assert np.allclose(atoms.get_potential_energy(), -113.6652, atol=1e-4) + assert np.allclose(atoms.get_forces(), ref_forces, atol=1e-4) + + # Molecular dynamics + temperature_K: float = 300 + timestep: float = 1.0 * units.fs + friction: float = 0.01 / units.fs + nsteps: int = 10 + dyn = Langevin(atoms, timestep, temperature_K=temperature_K, friction=friction) + dyn.run(steps=nsteps) + + # Now we can do the same but enabling torch.compile for increased speed + calc = TMDNETCalculator(checkpoint, device=device, compile=True) + atoms.calc = calc + # Run more dynamics + dyn.run(steps=nsteps) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 7402f8f31..736bc8a28 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -84,19 +84,17 @@ def test_custom(energy, forces, num_files, preload, tmpdir): # Assert shapes of whole dataset: for i in range(len(data)): n_atoms_i = n_atoms_per_sample[i] - assert np.array(data[i].z).shape == ( + assert data[i].z.shape == ( n_atoms_i, ), "Dataset has incorrect atom numbers shape" - assert np.array(data[i].pos).shape == ( + assert data[i].pos.shape == ( n_atoms_i, 3, ), "Dataset has incorrect coords shape" if energy: - assert np.array(data[i].y).shape == ( - 1, - ), "Dataset has incorrect energy shape" + assert data[i].y.shape == (1,), "Dataset has incorrect energy shape" if forces: - assert np.array(data[i].neg_dy).shape == ( + assert data[i].neg_dy.shape == ( n_atoms_i, 3, ), "Dataset has incorrect forces shape" @@ -190,19 +188,17 @@ def test_hdf5(preload, energy, forces, num_files, tmpdir): # Assert shapes of whole dataset: for i in range(len(data)): n_atoms_i = n_atoms_per_sample[i] - assert np.array(data[i].z).shape == ( + assert data[i].z.shape == ( n_atoms_i, ), "Dataset has incorrect atom numbers shape" - assert np.array(data[i].pos).shape == ( + assert data[i].pos.shape == ( n_atoms_i, 3, ), "Dataset has incorrect coords shape" if energy: - assert np.array(data[i].y).shape == ( - 1, - ), "Dataset has incorrect energy shape" + assert data[i].y.shape == (1,), "Dataset has incorrect energy shape" if forces: - assert np.array(data[i].neg_dy).shape == ( + assert data[i].neg_dy.shape == ( n_atoms_i, 3, ), "Dataset has incorrect forces shape" diff --git a/tests/test_model.py b/tests/test_model.py index 323c0fa67..65ca3863e 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -114,6 +114,40 @@ def test_torchscript_dynamic_shapes(model_name, device): )[0] +@mark.parametrize("model_name", models.__all_models__) +@mark.parametrize("device", ["cpu", "cuda"]) +def test_torchscript_then_compile(model_name, device): + """Test that a TorchScripted model can be torch.compiled afterwards""" + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + # Create and TorchScript the model + z, pos, batch = create_example_batch() + z = z.to(device) + pos = pos.to(device).requires_grad_(True) + batch = batch.to(device) + + scripted_model = torch.jit.script( + create_model(load_example_args(model_name, remove_prior=True, derivative=True)) + ).to(device=device) + + # Get baseline output from scripted model + y_scripted, neg_dy_scripted = scripted_model(z, pos, batch=batch) + + # Now try to torch.compile the scripted model + try: + compiled_model = torch.compile(scripted_model, backend="inductor") + y_compiled, neg_dy_compiled = compiled_model(z, pos, batch=batch) + + # Verify outputs match + torch.testing.assert_close(y_scripted, y_compiled, atol=1e-5, rtol=1e-5) + torch.testing.assert_close( + neg_dy_scripted, neg_dy_compiled, atol=1e-5, rtol=1e-5 + ) + except Exception as e: + pytest.fail(f"torch.compile failed on TorchScripted {model_name} model: {e}") + + # Currently only tensornet is CUDA graph compatible @mark.parametrize("model_name", ["tensornet"]) def test_cuda_graph_compatible(model_name): @@ -161,6 +195,65 @@ def test_cuda_graph_compatible(model_name): assert torch.allclose(neg_dy, neg_dy2, atol=1e-5, rtol=1e-5) +@mark.parametrize("model_name", ["tensornet"]) +def test_torchscript_cuda_graph_compatible(model_name): + """Test that a TorchScripted model is compatible with CUDA graphs. + + This is important when models are saved as TorchScript + and then CUDA graphs are used for performance. + """ + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + z, pos, batch = create_example_batch() + args = { + "model": model_name, + "embedding_dimension": 128, + "num_layers": 2, + "num_rbf": 32, + "rbf_type": "expnorm", + "trainable_rbf": False, + "activation": "silu", + "cutoff_lower": 0.0, + "cutoff_upper": 5.0, + "max_z": 100, + "max_num_neighbors": 128, + "equivariance_invariance_group": "O(3)", + "prior_model": None, + "atom_filter": -1, + "derivative": True, + "check_errors": False, + "static_shapes": True, + "output_model": "Scalar", + "reduce_op": "sum", + "precision": 32, + } + # Create the model + base_model = create_model(args) + # Setup for CUDA graphs before TorchScripting + z_cuda = z.to("cuda") + pos_cuda = pos.to("cuda").requires_grad_(True) + batch_cuda = batch.to("cuda") + + # Now TorchScript the model (like OpenMM-Torch does) + model = torch.jit.script(base_model).to(device="cuda") + model.eval() + + # Warm up the model + with torch.cuda.stream(torch.cuda.Stream()): + for _ in range(0, 15): + y, neg_dy = model(z_cuda, pos_cuda, batch=batch_cuda) + # Capture the model in a CUDA graph + g = torch.cuda.CUDAGraph() + y2, neg_dy2 = model(z_cuda, pos_cuda, batch=batch_cuda) + with torch.cuda.graph(g): + y, neg_dy = model(z_cuda, pos_cuda, batch=batch_cuda) + y.fill_(0.0) + neg_dy.fill_(0.0) + g.replay() + assert torch.allclose(y, y2) + assert torch.allclose(neg_dy, neg_dy2, atol=1e-5, rtol=1e-5) + + @mark.parametrize("model_name", models.__all_models__) def test_seed(model_name): args = load_example_args(model_name, remove_prior=True) diff --git a/tests/test_module.py b/tests/test_module.py index 021672472..7f55418f2 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -43,6 +43,10 @@ def test_train(model_name, use_atomref, precision, tmpdir): # OSX MPS backend runs out of memory on Github Actions torch.set_default_device("cpu") accelerator = "cpu" + elif precision == 64 and torch.backends.mps.is_available(): + # MPS backend doesn't support float64 + torch.set_default_device("cpu") + accelerator = "cpu" args = load_example_args( model_name, @@ -95,6 +99,10 @@ def test_dummy_train(model_name, use_atomref, precision, tmpdir): # OSX MPS backend runs out of memory on Github Actions torch.set_default_device("cpu") accelerator = "cpu" + elif precision == 64 and torch.backends.mps.is_available(): + # MPS backend doesn't support float64 + torch.set_default_device("cpu") + accelerator = "cpu" extra_args = {} if model_name != "tensornet": diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py index fa3e06046..574d9b611 100644 --- a/tests/test_neighbors.py +++ b/tests/test_neighbors.py @@ -73,7 +73,7 @@ def compute_ref_neighbors(pos, batch, loop, include_transpose, cutoff, box_vecto @pytest.mark.parametrize( ("device", "strategy"), - [("cpu", "brute"), ("cuda", "brute"), ("cuda", "shared"), ("cuda", "cell")], + [("cpu", "brute"), ("cuda", "brute"), ("cuda", "cell")], ) @pytest.mark.parametrize("n_batches", [1, 2, 3, 4, 128]) @pytest.mark.parametrize("cutoff", [0.1, 1.0, 3.0, 4.9]) @@ -130,6 +130,7 @@ def test_neighbors( return_vecs=True, include_transpose=include_transpose, ) + nl.to(device) batch.to(device) neighbors, distances, distance_vecs = nl(pos, batch) neighbors = neighbors.cpu().detach().numpy() @@ -149,7 +150,7 @@ def test_neighbors( @pytest.mark.parametrize( ("device", "strategy"), - [("cpu", "brute"), ("cuda", "brute"), ("cuda", "shared"), ("cuda", "cell")], + [("cpu", "brute"), ("cuda", "brute"), ("cuda", "cell")], ) @pytest.mark.parametrize("loop", [True, False]) @pytest.mark.parametrize("include_transpose", [True, False]) @@ -225,6 +226,7 @@ def test_neighbor_grads( resize_to_fit=True, box=box, ) + nl.to(device) neighbors, distances, deltas = nl(positions) # Check neighbor pairs are correct ref_neighbors_sort, _, _ = sort_neighbors( @@ -261,7 +263,7 @@ def test_neighbor_grads( @pytest.mark.parametrize( ("device", "strategy"), - [("cpu", "brute"), ("cuda", "brute"), ("cuda", "shared"), ("cuda", "cell")], + [("cpu", "brute"), ("cuda", "brute"), ("cuda", "cell")], ) @pytest.mark.parametrize("loop", [True, False]) @pytest.mark.parametrize("include_transpose", [True, False]) @@ -300,6 +302,7 @@ def test_neighbor_autograds( resize_to_fit=True, box=box, ) + nl.to(device) positions = 0.25 * lbox * torch.rand(num_atoms, 3, device=device, dtype=dtype) positions.requires_grad_(True) batch = torch.zeros((num_atoms,), dtype=torch.long, device=device) @@ -314,7 +317,7 @@ def test_neighbor_autograds( ) -@pytest.mark.parametrize("strategy", ["brute", "cell", "shared"]) +@pytest.mark.parametrize("strategy", ["brute", "cell"]) @pytest.mark.parametrize("n_batches", [1, 2, 3, 4]) def test_large_size(strategy, n_batches): device = "cuda" @@ -359,6 +362,7 @@ def test_large_size(strategy, n_batches): include_transpose=True, resize_to_fit=True, ) + nl.to(device) neighbors, distances, distance_vecs = nl(pos, batch) neighbors = neighbors.cpu().detach().numpy() distance_vecs = distance_vecs.cpu().detach().numpy() @@ -373,7 +377,7 @@ def test_large_size(strategy, n_batches): @pytest.mark.parametrize( ("device", "strategy"), - [("cpu", "brute"), ("cuda", "brute"), ("cuda", "shared"), ("cuda", "cell")], + [("cpu", "brute"), ("cuda", "brute"), ("cuda", "cell")], ) @pytest.mark.parametrize("n_batches", [1, 128]) @pytest.mark.parametrize("cutoff", [1.0]) @@ -424,6 +428,7 @@ def test_jit_script_compatible( return_vecs=True, include_transpose=include_transpose, ) + nl.to(device) batch.to(device) nl = torch.jit.script(nl) @@ -444,7 +449,7 @@ def test_jit_script_compatible( @pytest.mark.parametrize("device", ["cuda"]) -@pytest.mark.parametrize("strategy", ["brute", "shared", "cell"]) +@pytest.mark.parametrize("strategy", ["brute", "cell"]) @pytest.mark.parametrize("n_batches", [1, 128]) @pytest.mark.parametrize("cutoff", [1.0]) @pytest.mark.parametrize("loop", [True, False]) @@ -494,6 +499,7 @@ def test_cuda_graph_compatible_forward( check_errors=False, resize_to_fit=False, ) + nl.to(device) batch.to(device) graph = torch.cuda.CUDAGraph() @@ -527,7 +533,7 @@ def test_cuda_graph_compatible_forward( @pytest.mark.parametrize("device", ["cuda"]) -@pytest.mark.parametrize("strategy", ["brute", "shared", "cell"]) +@pytest.mark.parametrize("strategy", ["brute", "cell"]) @pytest.mark.parametrize("n_batches", [1, 128]) @pytest.mark.parametrize("cutoff", [1.0]) @pytest.mark.parametrize("loop", [True, False]) @@ -580,6 +586,7 @@ def test_cuda_graph_compatible_backward( check_errors=False, resize_to_fit=False, ) + nl.to(device) batch.to(device) graph = torch.cuda.CUDAGraph() @@ -600,9 +607,7 @@ def test_cuda_graph_compatible_backward( torch.cuda.synchronize() -@pytest.mark.parametrize( - ("device", "strategy"), [("cpu", "brute"), ("cuda", "brute"), ("cuda", "shared")] -) +@pytest.mark.parametrize(("device", "strategy"), [("cpu", "brute"), ("cuda", "brute")]) @pytest.mark.parametrize("n_batches", [1, 128]) @pytest.mark.parametrize("use_forward", [True, False]) def test_per_batch_box(device, strategy, n_batches, use_forward): @@ -646,6 +651,7 @@ def test_per_batch_box(device, strategy, n_batches, use_forward): return_vecs=True, include_transpose=include_transpose, ) + nl.to(device) batch.to(device) neighbors, distances, distance_vecs = nl( pos, batch, box=box if use_forward else None @@ -694,8 +700,6 @@ def test_torch_compile(device, dtype, loop, include_transpose): if sys.version_info >= (3, 12): pytest.skip("Not available in this version") - if torch.__version__ < "2.0.0": - pytest.skip("Not available in this version") if device == "cuda" and not torch.cuda.is_available(): pytest.skip("CUDA not available") np.random.seed(123456) diff --git a/torchmdnet/calculators.py b/torchmdnet/calculators.py index 662f7b7a4..dbf1e6b27 100644 --- a/torchmdnet/calculators.py +++ b/torchmdnet/calculators.py @@ -268,10 +268,12 @@ def calculate( else: print("atomic numbers changed, re-compiling...") - - self.model.representation_model.setup_for_compile_cudagraphs(batch) - self.model.output_model.setup_for_compile_cudagraphs(batch) + # Warmup pass to set dim_size before compilation + # This is needed because torch.compile doesn't support .item() calls self.model.to(self.device) + with torch.no_grad(): + _ = self.model(numbers, positions, batch=batch, q=total_charge) + self.compiled_model = torch.compile(self.model, backend='inductor', dynamic=False, fullgraph=True, mode='reduce-overhead') self.compiled = True diff --git a/torchmdnet/extensions/__init__.py b/torchmdnet/extensions/__init__.py index f474bff89..f163ff7d7 100644 --- a/torchmdnet/extensions/__init__.py +++ b/torchmdnet/extensions/__init__.py @@ -1,12 +1,3 @@ # Copyright Universitat Pompeu Fabra 2020-2023 https://www.compscience.org # Distributed under the MIT License. # (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) - -# Place here any short extensions to torch that you want to use in your code. -# The extensions present in extensions.cpp will be automatically compiled in setup.py and loaded here. -# The extensions will be available under torch.ops.torchmdnet_extensions, but you can add wrappers here to make them more convenient to use. -# Place here too any meta registrations for your extensions if required. - -import torch -from pathlib import Path -from . import torchmdnet_extensions, ops diff --git a/torchmdnet/extensions/neighbors.py b/torchmdnet/extensions/neighbors.py new file mode 100644 index 000000000..f495f4b4c --- /dev/null +++ b/torchmdnet/extensions/neighbors.py @@ -0,0 +1,174 @@ +from torch import Tensor +from typing import Tuple +import torch + + +def _round_nearest(x: Tensor) -> Tensor: + # Equivalent to torch.round but works for both float32/float64 and keeps TorchScript happy. + return torch.where(x >= 0, torch.floor(x + 0.5), torch.ceil(x - 0.5)) + + +def _apply_pbc_torch(deltas: Tensor, box_for_pairs: Tensor) -> Tensor: + # box_for_pairs: (num_pairs, 3, 3) + scale3 = _round_nearest(deltas[:, 2] / box_for_pairs[:, 2, 2]) + deltas = deltas - scale3.unsqueeze(1) * box_for_pairs[:, 2] + scale2 = _round_nearest(deltas[:, 1] / box_for_pairs[:, 1, 1]) + deltas = deltas - scale2.unsqueeze(1) * box_for_pairs[:, 1] + scale1 = _round_nearest(deltas[:, 0] / box_for_pairs[:, 0, 0]) + deltas = deltas - scale1.unsqueeze(1) * box_for_pairs[:, 0] + return deltas + + +def torch_neighbor_bruteforce( + positions: Tensor, + batch: Tensor, + box_vectors: Tensor, + use_periodic: bool, + cutoff_lower: float, + cutoff_upper: float, + max_num_pairs: int, + loop: bool, + include_transpose: bool, +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Optimized brute-force neighbor list using pure PyTorch. + + This implementation avoids nonzero() to be torch.compile compatible. + Uses triangular indexing to reduce memory usage and computation from O(n^2) to O(n^2/2). + It uses argsort with infinity for invalid pairs to achieve fixed output shapes. + """ + if positions.dim() != 2 or positions.size(1) != 3: + raise ValueError('Expected "positions" to have shape (N, 3)') + if batch.dim() != 1 or batch.size(0) != positions.size(0): + raise ValueError('Expected "batch" to have shape (N,)') + if max_num_pairs <= 0: + raise ValueError('Expected "max_num_pairs" to be positive') + + device = positions.device + dtype = positions.dtype + n_atoms = positions.size(0) + + if use_periodic: + if box_vectors.dim() == 2: + box_vectors = box_vectors.unsqueeze(0) + elif box_vectors.dim() != 3: + raise ValueError('Expected "box_vectors" to have shape (n_batch, 3, 3)') + box_vectors = box_vectors.to(device=device, dtype=dtype) + + # Generate base pairs + if loop: + # loop=True: i >= j (lower triangle including diagonal) + tril_indices = torch.tril_indices(n_atoms, n_atoms, device=device) + i_indices = tril_indices[0] + j_indices = tril_indices[1] + else: + # loop=False: i > j (lower triangle excluding diagonal) + tril_indices = torch.tril_indices(n_atoms, n_atoms, offset=-1, device=device) + i_indices = tril_indices[0] + j_indices = tril_indices[1] + + # If include_transpose, add the flipped pairs (j,i) + if include_transpose: + if loop: + # For loop=True, base pairs are i >= j, so add i < j transposes + triu_indices = torch.triu_indices(n_atoms, n_atoms, offset=1, device=device) + i_transpose = triu_indices[0] + j_transpose = triu_indices[1] + else: + # For loop=False, base pairs are i > j, so add all transposes (j,i) + i_transpose = j_indices + j_transpose = i_indices + # Combine base and transpose pairs + i_indices = torch.cat([i_indices, i_transpose]) + j_indices = torch.cat([j_indices, j_transpose]) + + # Compute deltas for all pairs + deltas = positions[i_indices] - positions[j_indices] + + # Apply PBC if needed + if use_periodic: + batch_i = batch[i_indices] + if box_vectors.size(0) == 1: + # Single box for all - use the same box + box_for_pairs = box_vectors.expand(len(deltas), 3, 3) + else: + # Per-batch boxes - index by batch of atom i + box_for_pairs = box_vectors[batch_i] + # Apply PBC to pairs + deltas = _apply_pbc_torch(deltas, box_for_pairs) + + # Compute distances for all pairs + dist_sq = (deltas * deltas).sum(dim=-1) + zero_mask = dist_sq == 0 + distances = torch.where( + zero_mask, + torch.zeros_like(dist_sq), + torch.sqrt(dist_sq.clamp(min=1e-32)), + ) + + # Build validity mask for all pairs + valid_mask = torch.ones(len(distances), device=device, dtype=torch.bool) + + # Apply batch constraint + if batch.numel() > 0: + same_batch = batch[i_indices] == batch[j_indices] + valid_mask = valid_mask & same_batch + + # Apply cutoff constraints + # Self-loops (i == j) are exempt from cutoff_lower since they have distance 0 + is_self_loop = i_indices == j_indices + valid_mask = ( + valid_mask + & (distances < cutoff_upper) + & ((distances >= cutoff_lower) | is_self_loop) + ) + + # Sort key: valid pairs by distance, invalid pairs get infinity (sorted last) + sort_key = torch.where( + valid_mask, + distances, + torch.full_like(distances, float("inf")), + ) + + # Sort and take top max_num_pairs (fixed output shape) + # For include_transpose + loop case, we may have more pairs than expected + num_candidates = min(sort_key.size(0), max_num_pairs) + order = torch.argsort(sort_key)[:num_candidates] + + # Gather results using the sorted indices + i_out = i_indices.index_select(0, order) + j_out = j_indices.index_select(0, order) + deltas_out = deltas.index_select(0, order) + distances_out = distances.index_select(0, order) + valid_out = valid_mask.index_select(0, order) + + # Pad to max_num_pairs if needed + if num_candidates < max_num_pairs: + pad_size = max_num_pairs - num_candidates + pad_indices = torch.full((pad_size,), -1, device=device, dtype=torch.long) + pad_deltas = torch.zeros((pad_size, 3), device=device, dtype=dtype) + pad_distances = torch.zeros((pad_size,), device=device, dtype=dtype) + pad_valid = torch.zeros((pad_size,), device=device, dtype=torch.bool) + + i_out = torch.cat([i_out, pad_indices]) + j_out = torch.cat([j_out, pad_indices]) + deltas_out = torch.cat([deltas_out, pad_deltas]) + distances_out = torch.cat([distances_out, pad_distances]) + valid_out = torch.cat([valid_out, pad_valid]) + + # Replace invalid entries with -1 (neighbors) and 0 (deltas/distances) + # Ensure gradients flow through valid entries + neighbors_out = torch.stack( + [ + torch.where(valid_out, i_out, torch.full_like(i_out, -1)), + torch.where(valid_out, j_out, torch.full_like(j_out, -1)), + ] + ) + # For deltas/distances, use where to preserve gradients for valid entries + zero_deltas = deltas_out.detach() * 0 + zero_distances = distances_out.detach() * 0 + deltas_out = torch.where(valid_out.unsqueeze(1), deltas_out, zero_deltas) + distances_out = torch.where(valid_out, distances_out, zero_distances) + + # Count valid pairs (before slicing) to detect overflow + num_pairs_tensor = valid_mask.sum().view(1).to(torch.long) + return neighbors_out, deltas_out, distances_out, num_pairs_tensor diff --git a/torchmdnet/extensions/neighbors/common.cuh b/torchmdnet/extensions/neighbors/common.cuh deleted file mode 100644 index bf875af5b..000000000 --- a/torchmdnet/extensions/neighbors/common.cuh +++ /dev/null @@ -1,206 +0,0 @@ -/* Copyright Universitat Pompeu Fabra 2020-2023 https://www.compscience.org - * Distributed under the MIT License. - *(See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) - * Raul P. Pelaez 2023. Common utilities for the CUDA neighbor operation. - */ -#ifndef NEIGHBORS_COMMON_CUH -#define NEIGHBORS_COMMON_CUH -#include -#include -#include -#include -#include - -using c10::cuda::getCurrentCUDAStream; -using c10::cuda::CUDAStreamGuard; -using at::Tensor; -using at::TensorOptions; -using at::TensorAccessor; -using torch::Scalar; -using torch::empty; -using torch::full; - -template -using Accessor = torch::PackedTensorAccessor32; - -template -using KernelAccessor = TensorAccessor; - -template -inline Accessor get_accessor(const Tensor& tensor) { - return tensor.packed_accessor32(); -}; - -template __device__ __forceinline__ scalar_t sqrt_(scalar_t x){return ::sqrt(x);}; -template <> __device__ __forceinline__ float sqrt_(float x) { - return ::sqrtf(x); -}; -template <> __device__ __forceinline__ double sqrt_(double x) { - return ::sqrt(x); -}; - -template struct vec3 { - using type = void; -}; - -template <> struct vec3 { - using type = float3; -}; - -template <> struct vec3 { - using type = double3; -}; - -template using scalar3 = typename vec3::type; - -/* - * @brief Get the position of the i'th particle - * @param positions The positions tensor - * @param i The index of the particle - * @return The position of the i'th particle - */ -template -__device__ scalar3 fetchPosition(const Accessor positions, const int i) { - return {positions[i][0], positions[i][1], positions[i][2]}; -} - -struct PairList { - Tensor i_curr_pair; - Tensor neighbors; - Tensor deltas; - Tensor distances; - const bool loop, include_transpose, use_periodic; - PairList(int max_num_pairs, torch::TensorOptions options, bool loop, bool include_transpose, - bool use_periodic) - : i_curr_pair(torch::zeros({1}, options.dtype(at::kInt))), - neighbors(torch::full({2, max_num_pairs}, -1, options.dtype(at::kInt))), - deltas(torch::full({max_num_pairs, 3}, 0, options)), - distances(torch::full({max_num_pairs}, 0, options)), loop(loop), - include_transpose(include_transpose), use_periodic(use_periodic) { - } -}; - -template struct PairListAccessor { - Accessor i_curr_pair; - Accessor neighbors; - Accessor deltas; - Accessor distances; - bool loop, include_transpose, use_periodic; - explicit PairListAccessor(const PairList& pl) - : i_curr_pair(get_accessor(pl.i_curr_pair)), - neighbors(get_accessor(pl.neighbors)), - deltas(get_accessor(pl.deltas)), - distances(get_accessor(pl.distances)), loop(pl.loop), - include_transpose(pl.include_transpose), use_periodic(pl.use_periodic) { - } -}; - -template -__device__ void writeAtomPair(PairListAccessor& list, int i, int j, - scalar3 delta, scalar_t distance, int i_pair) { - if (i_pair < list.neighbors.size(1)) { - list.neighbors[0][i_pair] = i; - list.neighbors[1][i_pair] = j; - list.deltas[i_pair][0] = delta.x; - list.deltas[i_pair][1] = delta.y; - list.deltas[i_pair][2] = delta.z; - list.distances[i_pair] = distance; - } -} - -template -__device__ void addAtomPairToList(PairListAccessor& list, int i, int j, - scalar3 delta, scalar_t distance, bool add_transpose) { - const int32_t i_pair = atomicAdd(&list.i_curr_pair[0], add_transpose ? 2 : 1); - // Neighbors after the max number of pairs are ignored, although the pair is counted - writeAtomPair(list, i, j, delta, distance, i_pair); - if (add_transpose) { - writeAtomPair(list, j, i, {-delta.x, -delta.y, -delta.z}, distance, i_pair + 1); - } -} - -static void checkInput(const Tensor& positions, const Tensor& batch) { - TORCH_CHECK(positions.dim() == 2, "Expected \"positions\" to have two dimensions"); - TORCH_CHECK(positions.size(0) > 0, - "Expected the 1nd dimension size of \"positions\" to be more than 0"); - TORCH_CHECK(positions.size(1) == 3, "Expected the 2nd dimension size of \"positions\" to be 3"); - TORCH_CHECK(positions.is_contiguous(), "Expected \"positions\" to be contiguous"); - - TORCH_CHECK(batch.dim() == 1, "Expected \"batch\" to have one dimension"); - TORCH_CHECK(batch.size(0) == positions.size(0), - "Expected the 1st dimension size of \"batch\" to be the same as the 1st dimension " - "size of \"positions\""); - TORCH_CHECK(batch.is_contiguous(), "Expected \"batch\" to be contiguous"); - TORCH_CHECK(batch.dtype() == torch::kInt64, "Expected \"batch\" to be of type at::kLong"); -} - -namespace rect { - -/* - * @brief Takes a point to the unit cell in the range [-0.5, 0.5]*box_size using Minimum Image - * Convention - * @param p The point position - * @param box_size The box size - * @return The point in the unit cell - */ -template -__device__ auto apply_pbc(scalar3 p, scalar3 box_size) { - p.x = p.x - floorf(p.x / box_size.x + scalar_t(0.5)) * box_size.x; - p.y = p.y - floorf(p.y / box_size.y + scalar_t(0.5)) * box_size.y; - p.z = p.z - floorf(p.z / box_size.z + scalar_t(0.5)) * box_size.z; - return p; -} - -template -__device__ auto compute_distance(scalar3 pos_i, scalar3 pos_j, - bool use_periodic, scalar3 box_size) { - scalar3 delta = {pos_i.x - pos_j.x, pos_i.y - pos_j.y, pos_i.z - pos_j.z}; - if (use_periodic) { - delta = apply_pbc(delta, box_size); - } - return delta; -} - -} // namespace rect - -namespace triclinic { -template using BoxAccessor = Accessor; -template -BoxAccessor get_box_accessor(const Tensor& box_vectors, bool use_periodic) { - return get_accessor(box_vectors); -} - -/* - * @brief Takes a point to the unit cell using Minimum Image - * Convention - * @param p The point position - * @param box_vectors The box vectors (3x3 matrix) - * @return The point in the unit cell - */ -template -__device__ auto apply_pbc(scalar3 delta, const KernelAccessor& box) { - scalar_t scale3 = round(delta.z / box[2][2]); - delta.x -= scale3 * box[2][0]; - delta.y -= scale3 * box[2][1]; - delta.z -= scale3 * box[2][2]; - scalar_t scale2 = round(delta.y / box[1][1]); - delta.x -= scale2 * box[1][0]; - delta.y -= scale2 * box[1][1]; - scalar_t scale1 = round(delta.x / box[0][0]); - delta.x -= scale1 * box[0][0]; - return delta; -} - - template -__device__ auto compute_distance(scalar3 pos_i, scalar3 pos_j, - bool use_periodic, const KernelAccessor& box) { - scalar3 delta = {pos_i.x - pos_j.x, pos_i.y - pos_j.y, pos_i.z - pos_j.z}; - if (use_periodic) { - delta = apply_pbc(delta, box); - } - return delta; -} - -} // namespace triclinic - -#endif diff --git a/torchmdnet/extensions/neighbors/neighbors_cpu.cpp b/torchmdnet/extensions/neighbors/neighbors_cpu.cpp deleted file mode 100644 index 350bbd26f..000000000 --- a/torchmdnet/extensions/neighbors/neighbors_cpu.cpp +++ /dev/null @@ -1,273 +0,0 @@ -/* Copyright Universitat Pompeu Fabra 2020-2023 https://www.compscience.org - * Distributed under the MIT License. - *(See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) - */ -#include -#include -#include -#include -#include - -using std::tuple; -using torch::arange; -using torch::div; -using torch::linalg_vector_norm; -using torch::full; -using torch::hstack; -using torch::index_select; -using torch::kInt32; -using torch::round; -using torch::Scalar; -using torch::vstack; -using torch::autograd::AutogradContext; -using torch::autograd::Function; -using torch::autograd::tensor_list; -using torch::indexing::Slice; -using at::Tensor; - -static tuple -forward_impl(const std::string& strategy, const Tensor& positions, const Tensor& batch, - const Tensor& in_box_vectors, bool use_periodic, const Scalar& cutoff_lower, - const Scalar& cutoff_upper, const Scalar& max_num_pairs, bool loop, - bool include_transpose) { - TORCH_CHECK(positions.dim() == 2, "Expected \"positions\" to have two dimensions"); - TORCH_CHECK(positions.size(0) > 0, - "Expected the 1nd dimension size of \"positions\" to be more than 0"); - TORCH_CHECK(positions.size(1) == 3, "Expected the 2nd dimension size of \"positions\" to be 3"); - TORCH_CHECK(positions.is_contiguous(), "Expected \"positions\" to be contiguous"); - TORCH_CHECK(cutoff_upper.to() > 0, "Expected \"cutoff\" to be positive"); - Tensor box_vectors = in_box_vectors; - const int n_batch = batch.max().item() + 1; - if (use_periodic) { - if (box_vectors.dim() == 2) { - box_vectors = box_vectors.unsqueeze(0).expand({n_batch, 3, 3}); - } - TORCH_CHECK(box_vectors.dim() == 3, "Expected \"box_vectors\" to have two dimensions"); - TORCH_CHECK(box_vectors.size(1) == 3 && box_vectors.size(2) == 3, - "Expected \"box_vectors\" to have shape (n_batch, 3, 3)"); - // Ensure the box first dimension has size max(batch) + 1 - TORCH_CHECK(box_vectors.size(0) == n_batch, - "Expected \"box_vectors\" to have shape (n_batch, 3, 3)"); - // Check that the box is a valid triclinic box, in the case of a box per sample we only - // check the first one - double v[3][3]; - for (int i = 0; i < 3; i++) - for (int j = 0; j < 3; j++) - v[i][j] = box_vectors[0][i][j].item(); - double c = cutoff_upper.to(); - TORCH_CHECK(v[0][1] == 0, "Invalid box vectors: box_vectors[0][1] != 0"); - TORCH_CHECK(v[0][2] == 0, "Invalid box vectors: box_vectors[0][2] != 0"); - TORCH_CHECK(v[1][2] == 0, "Invalid box vectors: box_vectors[1][2] != 0"); - TORCH_CHECK(v[0][0] >= 2 * c, "Invalid box vectors: box_vectors[0][0] < 2*cutoff"); - TORCH_CHECK(v[1][1] >= 2 * c, "Invalid box vectors: box_vectors[1][1] < 2*cutoff"); - TORCH_CHECK(v[2][2] >= 2 * c, "Invalid box vectors: box_vectors[2][2] < 2*cutoff"); - TORCH_CHECK(v[0][0] >= 2 * v[1][0], - "Invalid box vectors: box_vectors[0][0] < 2*box_vectors[1][0]"); - TORCH_CHECK(v[0][0] >= 2 * v[2][0], - "Invalid box vectors: box_vectors[0][0] < 2*box_vectors[1][0]"); - TORCH_CHECK(v[1][1] >= 2 * v[2][1], - "Invalid box vectors: box_vectors[1][1] < 2*box_vectors[2][1]"); - } - TORCH_CHECK(max_num_pairs.toLong() > 0, "Expected \"max_num_neighbors\" to be positive"); - const int n_atoms = positions.size(0); - Tensor neighbors = torch::empty({0}, positions.options().dtype(kInt32)); - Tensor distances = torch::empty({0}, positions.options()); - Tensor deltas = torch::empty({0}, positions.options()); - neighbors = torch::vstack((torch::tril_indices(n_atoms, n_atoms, -1, neighbors.options()))); - auto mask = index_select(batch, 0, neighbors.index({0, Slice()})) == - index_select(batch, 0, neighbors.index({1, Slice()})); - neighbors = neighbors.index({Slice(), mask}).to(kInt32); - deltas = index_select(positions, 0, neighbors.index({0, Slice()})) - - index_select(positions, 0, neighbors.index({1, Slice()})); - if (use_periodic) { - const auto pair_batch = batch.index({neighbors.index({0, Slice()})}); - const auto scale3 = - round(deltas.index({Slice(), 2}) / box_vectors.index({pair_batch, 2, 2})); - deltas.index_put_({Slice(), 0}, deltas.index({Slice(), 0}) - - scale3 * box_vectors.index({pair_batch, 2, 0})); - deltas.index_put_({Slice(), 1}, deltas.index({Slice(), 1}) - - scale3 * box_vectors.index({pair_batch, 2, 1})); - deltas.index_put_({Slice(), 2}, deltas.index({Slice(), 2}) - - scale3 * box_vectors.index({pair_batch, 2, 2})); - const auto scale2 = - round(deltas.index({Slice(), 1}) / box_vectors.index({pair_batch, 1, 1})); - deltas.index_put_({Slice(), 0}, deltas.index({Slice(), 0}) - - scale2 * box_vectors.index({pair_batch, 1, 0})); - deltas.index_put_({Slice(), 1}, deltas.index({Slice(), 1}) - - scale2 * box_vectors.index({pair_batch, 1, 1})); - const auto scale1 = - round(deltas.index({Slice(), 0}) / box_vectors.index({pair_batch, 0, 0})); - deltas.index_put_({Slice(), 0}, deltas.index({Slice(), 0}) - - scale1 * box_vectors.index({pair_batch, 0, 0})); - } - distances = linalg_vector_norm(deltas, 2, 1); - mask = (distances < cutoff_upper) * (distances >= cutoff_lower); - neighbors = neighbors.index({Slice(), mask}); - deltas = deltas.index({mask, Slice()}); - distances = distances.index({mask}); - if (include_transpose) { - neighbors = torch::hstack({neighbors, torch::stack({neighbors[1], neighbors[0]})}); - distances = torch::hstack({distances, distances}); - deltas = torch::vstack({deltas, -deltas}); - } - if (loop) { - const Tensor range = torch::arange(0, n_atoms, torch::kInt32); - neighbors = torch::hstack({neighbors, torch::stack({range, range})}); - distances = torch::hstack({distances, torch::zeros_like(range)}); - deltas = torch::vstack({deltas, torch::zeros({n_atoms, 3}, deltas.options())}); - } - Tensor num_pairs_found = torch::empty(1, distances.options().dtype(kInt32)); - num_pairs_found[0] = distances.size(0); - // This seems wasteful, but it allows to enable torch.compile by guaranteeing that the output of - // this operator has a predictable size Resize to max_num_pairs, add zeros if necessary - int64_t extension = std::max(max_num_pairs.toLong() - distances.size(0), (int64_t)0); - if (extension > 0) { - deltas = torch::vstack({deltas, torch::zeros({extension, 3}, deltas.options())}); - distances = torch::hstack({distances, torch::zeros({extension}, distances.options())}); - // For the neighbors add (-1,-1) pairs to fill the tensor - neighbors = torch::hstack( - {neighbors, torch::full({2, extension}, -1, neighbors.options().dtype(kInt32))}); - } - return {neighbors, deltas, distances, num_pairs_found}; -} - -// The backwards operation is implemented fully in pytorch so that it can be differentiated a second -// time automatically via Autograd. -static Tensor backward_impl(const Tensor& grad_edge_vec, const Tensor& grad_edge_weight, - const Tensor& edge_index, const Tensor& edge_vec, - const Tensor& edge_weight, const int64_t num_atoms) { - auto zero_mask = edge_weight.eq(0); - auto zero_mask3 = zero_mask.unsqueeze(-1).expand_as(grad_edge_vec); - // We need to avoid dividing by 0. Otherwise Autograd fills the gradient with NaNs in the - // case of a double backwards. This is why we index_select like this. - auto grad_distances_ = edge_vec / edge_weight.masked_fill(zero_mask, 1).unsqueeze(-1) * - grad_edge_weight.masked_fill(zero_mask, 0).unsqueeze(-1); - auto result = grad_edge_vec.masked_fill(zero_mask3, 0) + grad_distances_; - // Avoid out-of-bounds indices under compiler fusion by mapping masked indices to 0 - // and ensuring their contributions are exactly zero. - // This removes the need for allocating an extra dummy row (num_atoms + 1). - auto grad_positions = torch::zeros({num_atoms, 3}, edge_vec.options()); - auto edge_index_ = - edge_index.masked_fill(zero_mask.unsqueeze(0).expand_as(edge_index), 0); - grad_positions.index_add_(0, edge_index_[0], result); - grad_positions.index_add_(0, edge_index_[1], -result); - return grad_positions; -} - -// This is the autograd function that is called when the user calls get_neighbor_pairs. -// It dispatches the required strategy for the forward function and implements the backward -// function. The backward function is written in full pytorch so that it can be differentiated a -// second time automatically via Autograd. -class NeighborAutograd : public torch::autograd::Function { -public: - static tensor_list forward(AutogradContext* ctx, const std::string& strategy, - const Tensor& positions, const Tensor& batch, - const Tensor& box_vectors, bool use_periodic, - const Scalar& cutoff_lower, const Scalar& cutoff_upper, - const Scalar& max_num_pairs, bool loop, bool include_transpose) { - static auto fwd = - torch::Dispatcher::singleton() - .findSchemaOrThrow("torchmdnet_extensions::get_neighbor_pairs_fwd", "") - .typed(); - Tensor neighbors, deltas, distances, i_curr_pair; - std::tie(neighbors, deltas, distances, i_curr_pair) = - fwd.call(strategy, positions, batch, box_vectors, use_periodic, cutoff_lower, - cutoff_upper, max_num_pairs, loop, include_transpose); - ctx->save_for_backward({neighbors, deltas, distances}); - ctx->saved_data["num_atoms"] = positions.size(0); - return {neighbors, deltas, distances, i_curr_pair}; - } - - using Slice = torch::indexing::Slice; - - static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs) { - auto saved = ctx->get_saved_variables(); - auto edge_index = saved[0]; - auto edge_vec = saved[1]; - auto edge_weight = saved[2]; - auto num_atoms = ctx->saved_data["num_atoms"].toInt(); - auto grad_edge_vec = grad_outputs[1]; - auto grad_edge_weight = grad_outputs[2]; - static auto backward = - torch::Dispatcher::singleton() - .findSchemaOrThrow("torchmdnet_extensions::get_neighbor_pairs_bkwd", "") - .typed(); - auto grad_positions = backward.call(grad_edge_vec, grad_edge_weight, edge_index, edge_vec, - edge_weight, num_atoms); - Tensor ignore; - return {ignore, grad_positions, ignore, ignore, ignore, ignore, - ignore, ignore, ignore, ignore, ignore}; - } -}; - -// By registering as a CompositeImplicitAutograd we can torch.compile this Autograd function. -// This mode will generate meta registration for NeighborAutograd automatically, in this case using -// the -// meta registrations provided for the forward and backward functions python side. -// We provide meta registrations python side because it is the recommended way to do it. -TORCH_LIBRARY_IMPL(torchmdnet_extensions, CompositeImplicitAutograd, m) { - m.impl("get_neighbor_pairs", - [](const std::string& strategy, const Tensor& positions, const Tensor& batch, - const Tensor& box_vectors, bool use_periodic, const Scalar& cutoff_lower, - const Scalar& cutoff_upper, const Scalar& max_num_pairs, bool loop, - bool include_transpose) { - auto result = NeighborAutograd::apply(strategy, positions, batch, box_vectors, - use_periodic, cutoff_lower, cutoff_upper, - max_num_pairs, loop, include_transpose); - return std::make_tuple(result[0], result[1], result[2], result[3]); - }); -} - -// // Explicit device backend registrations for PyTorch versions that do not -// // automatically fall back to CompositeImplicitAutograd for device dispatch. -// TORCH_LIBRARY_IMPL(torchmdnet_extensions, CPU, m) { -// m.impl("get_neighbor_pairs", -// [](const std::string& strategy, const Tensor& positions, const Tensor& batch, -// const Tensor& box_vectors, bool use_periodic, const Scalar& cutoff_lower, -// const Scalar& cutoff_upper, const Scalar& max_num_pairs, bool loop, -// bool include_transpose) { -// auto result = NeighborAutograd::apply(strategy, positions, batch, box_vectors, -// use_periodic, cutoff_lower, cutoff_upper, -// max_num_pairs, loop, include_transpose); -// return std::make_tuple(result[0], result[1], result[2], result[3]); -// }); -// } - -// TORCH_LIBRARY_IMPL(torchmdnet_extensions, CUDA, m) { -// m.impl("get_neighbor_pairs", -// [](const std::string& strategy, const Tensor& positions, const Tensor& batch, -// const Tensor& box_vectors, bool use_periodic, const Scalar& cutoff_lower, -// const Scalar& cutoff_upper, const Scalar& max_num_pairs, bool loop, -// bool include_transpose) { -// auto result = NeighborAutograd::apply(strategy, positions, batch, box_vectors, -// use_periodic, cutoff_lower, cutoff_upper, -// max_num_pairs, loop, include_transpose); -// return std::make_tuple(result[0], result[1], result[2], result[3]); -// }); -// } - -// The registration for the CUDA version of the forward function is done in a separate .cu file. -TORCH_LIBRARY_IMPL(torchmdnet_extensions, CPU, m) { - m.impl("get_neighbor_pairs_fwd", forward_impl); -} - -// Ideally we would register this just once using CompositeImplicitAutograd, but this causes a -// segfault -// when trying to torch.compile this function. -// Doing it this way prints a message about Autograd not being provided a backward function of the -// backward function. It gets it from the implementation just fine now, but warns that in the future -// this will be deprecated. -TORCH_LIBRARY_IMPL(torchmdnet_extensions, CPU, m) { - m.impl("get_neighbor_pairs_bkwd", backward_impl); -} - -TORCH_LIBRARY_IMPL(torchmdnet_extensions, CUDA, m) { - m.impl("get_neighbor_pairs_bkwd", backward_impl); -} - -// // Register explicit Autograd fallthroughs to avoid deprecated autograd fallback warnings -// // when backpropagating through these custom operators on CUDA/CPU. -// TORCH_LIBRARY_IMPL(torchmdnet_extensions, Autograd, m) { -// m.impl("get_neighbor_pairs", torch::CppFunction::makeFallthrough()); -// m.impl("get_neighbor_pairs_bkwd", torch::CppFunction::makeFallthrough()); -// } diff --git a/torchmdnet/extensions/neighbors/neighbors_cuda.cu b/torchmdnet/extensions/neighbors/neighbors_cuda.cu deleted file mode 100644 index 68f775485..000000000 --- a/torchmdnet/extensions/neighbors/neighbors_cuda.cu +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright Universitat Pompeu Fabra 2020-2023 https://www.compscience.org - * Distributed under the MIT License. - * (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) - * Raul P. Pelaez 2023 - */ -#include -#include -#include - -#include -#include -#include - -#include "neighbors_cuda_brute.cuh" -#include "neighbors_cuda_cell.cuh" -#include "neighbors_cuda_shared.cuh" - -static std::tuple -forward_impl_cuda(const std::string& strategy, const at::Tensor& positions, const at::Tensor& batch, - const at::Tensor& in_box_vectors, bool use_periodic, const at::Scalar& cutoff_lower, - const at::Scalar& cutoff_upper, const at::Scalar& max_num_pairs, bool loop, - bool include_transpose) { - auto kernel = forward_brute; - if (positions.size(0) >= 32768 && strategy == "brute") { - kernel = forward_shared; - } - if (strategy == "brute") { - } else if (strategy == "cell") { - kernel = forward_cell; - } else if (strategy == "shared") { - kernel = forward_shared; - } else { - throw std::runtime_error("Unknown kernel name"); - } - return kernel(positions, batch, in_box_vectors, use_periodic, cutoff_lower, cutoff_upper, - max_num_pairs, loop, include_transpose); -} - -// We only need to register the CUDA version of the forward function here. This way we can avoid -// compiling this file in CPU-only mode The rest of the registrations take place in -// neighbors_cpu.cpp -TORCH_LIBRARY_IMPL(torchmdnet_extensions, CUDA, m) { - m.impl("get_neighbor_pairs_fwd", forward_impl_cuda); -} diff --git a/torchmdnet/extensions/neighbors/neighbors_cuda_brute.cuh b/torchmdnet/extensions/neighbors/neighbors_cuda_brute.cuh deleted file mode 100644 index 4d33979f3..000000000 --- a/torchmdnet/extensions/neighbors/neighbors_cuda_brute.cuh +++ /dev/null @@ -1,122 +0,0 @@ -/* Copyright Universitat Pompeu Fabra 2020-2023 https://www.compscience.org - * Distributed under the MIT License. - *(See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) - * - * Raul P. Pelaez 2023. Brute force neighbor list construction in CUDA. - * - * A brute force approach that assigns a thread per each possible pair of particles in the system. - * Based on an implementation by Raimondas Galvelis. - * Works fantastically for small (less than 10K atoms) systems, but cannot handle more than 32K - * atoms. - */ -#ifndef NEIGHBORS_BRUTE_CUH -#define NEIGHBORS_BRUTE_CUH -#include "common.cuh" -#include - -__device__ uint32_t get_row(uint32_t index) { - uint32_t row = floor((sqrtf(8 * index + 1) + 1) / 2); - if (row * (row - 1) > 2 * index) - row--; - return row; -} - -template -__global__ void forward_kernel_brute(uint32_t num_all_pairs, const Accessor positions, - const Accessor batch, scalar_t cutoff_lower2, - scalar_t cutoff_upper2, PairListAccessor list, - triclinic::BoxAccessor box) { - const uint32_t index = blockIdx.x * blockDim.x + threadIdx.x; - if (index >= num_all_pairs) - return; - const uint32_t row = get_row(index); - const uint32_t column = (index - row * (row - 1) / 2); - const auto batch_row = batch[row]; - if (batch_row == batch[column]) { - const auto pos_i = fetchPosition(positions, row); - const auto pos_j = fetchPosition(positions, column); - const auto box_row = box[batch_row]; - const auto delta = triclinic::compute_distance(pos_i, pos_j, list.use_periodic, box_row); - const scalar_t distance2 = delta.x * delta.x + delta.y * delta.y + delta.z * delta.z; - if (distance2 < cutoff_upper2 && distance2 >= cutoff_lower2) { - const scalar_t r2 = sqrt_(distance2); - addAtomPairToList(list, row, column, delta, r2, list.include_transpose); - } - } -} - -template -__global__ void add_self_kernel(const int num_atoms, Accessor positions, - PairListAccessor list) { - const int32_t i_atom = blockIdx.x * blockDim.x + threadIdx.x; - if (i_atom >= num_atoms) - return; - __shared__ int i_pair; - if (threadIdx.x == 0) { // Each block adds blockDim.x pairs to the list. - // Handle the last block, so that only num_atoms are added in total - i_pair = - atomicAdd(&list.i_curr_pair[0], min(blockDim.x, num_atoms - blockIdx.x * blockDim.x)); - } - __syncthreads(); - scalar3 delta{}; - scalar_t distance = 0; - writeAtomPair(list, i_atom, i_atom, delta, distance, i_pair + threadIdx.x); -} - -static std::tuple -forward_brute(const Tensor& positions, const Tensor& batch, const Tensor& in_box_vectors, - bool use_periodic, const Scalar& cutoff_lower, const Scalar& cutoff_upper, - const Scalar& max_num_pairs, bool loop, bool include_transpose) { - checkInput(positions, batch); - const auto max_num_pairs_ = max_num_pairs.toLong(); - TORCH_CHECK(max_num_pairs_ > 0, "Expected \"max_num_neighbors\" to be positive"); - auto box_vectors = in_box_vectors.to(positions.device()).clone(); - if (box_vectors.dim() == 2) { - // If the box is a 3x3 tensor it is assumed every sample has the same box - if (use_periodic) { - TORCH_CHECK(box_vectors.size(0) == 3 && box_vectors.size(1) == 3, - "Expected \"box_vectors\" to have shape (3, 3)"); - } - // Make the box (None,3,3), expand artificially to positions.size(0) - box_vectors = box_vectors.unsqueeze(0); - if (use_periodic) { - // I use positions.size(0) because the batch dimension is not available here - box_vectors = box_vectors.expand({positions.size(0), 3, 3}); - } - } - if (use_periodic) { - TORCH_CHECK(box_vectors.dim() == 3, "Expected \"box_vectors\" to have three dimensions"); - TORCH_CHECK(box_vectors.size(1) == 3 && box_vectors.size(2) == 3, - "Expected \"box_vectors\" to have shape (n_batch, 3, 3)"); - } - const int num_atoms = positions.size(0); - TORCH_CHECK(num_atoms < 32768, "The brute strategy fails with \"num_atoms\" larger than 32768"); - const int num_pairs = max_num_pairs_; - const TensorOptions options = positions.options(); - const auto stream = getCurrentCUDAStream(positions.get_device()); - PairList list(num_pairs, positions.options(), loop, include_transpose, use_periodic); - const CUDAStreamGuard guard(stream); - const uint64_t num_all_pairs = num_atoms * (num_atoms - 1UL) / 2UL; - const uint64_t num_threads = 128; - const uint64_t num_blocks = std::max(static_cast((num_all_pairs + num_threads - 1UL) / num_threads), static_cast(1UL)); - AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "get_neighbor_pairs_forward", [&]() { - PairListAccessor list_accessor(list); - auto box = triclinic::get_box_accessor(box_vectors, use_periodic); - const scalar_t cutoff_upper_ = cutoff_upper.to(); - const scalar_t cutoff_lower_ = cutoff_lower.to(); - TORCH_CHECK(cutoff_upper_ > 0, "Expected \"cutoff\" to be positive"); - forward_kernel_brute<<>>( - num_all_pairs, get_accessor(positions), get_accessor(batch), - cutoff_lower_ * cutoff_lower_, cutoff_upper_ * cutoff_upper_, list_accessor, box); - if (loop) { - const uint32_t num_threads_self = 256; - const uint32_t num_blocks_self = - std::max((num_atoms + num_threads_self - 1U) / num_threads_self, 1U); - add_self_kernel<<>>( - num_atoms, get_accessor(positions), list_accessor); - } - }); - return {list.neighbors, list.deltas, list.distances, list.i_curr_pair}; -} - -#endif diff --git a/torchmdnet/extensions/neighbors/neighbors_cuda_cell.cuh b/torchmdnet/extensions/neighbors/neighbors_cuda_cell.cuh deleted file mode 100644 index b54ab1c13..000000000 --- a/torchmdnet/extensions/neighbors/neighbors_cuda_cell.cuh +++ /dev/null @@ -1,390 +0,0 @@ -/* Copyright Universitat Pompeu Fabra 2020-2023 https://www.compscience.org - * Distributed under the MIT License. - *(See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) - * Raul P. Pelaez 2023. Batched cell list neighbor list implementation for CUDA. - */ -#ifndef NEIGHBOR_CUDA_CELL_H -#define NEIGHBOR_CUDA_CELL_H -#include "common.cuh" - -/* - * @brief Calculates the cell dimensions for a given box size and cutoff - * @param box_size The box size - * @param cutoff The cutoff - * @return The cell dimensions - */ -template -__host__ __device__ int3 getCellDimensions(scalar3 box_size, scalar_t cutoff) { - int3 cell_dim = make_int3(box_size.x / cutoff, box_size.y / cutoff, box_size.z / cutoff); - // Minimum 3 cells in each dimension - cell_dim.x = max(cell_dim.x, 3); - cell_dim.y = max(cell_dim.y, 3); - cell_dim.z = max(cell_dim.z, 3); -// In the host, throw if there are more than 1024 cells in any dimension -#ifndef __CUDA_ARCH__ - if (cell_dim.x > 1024 || cell_dim.y > 1024 || cell_dim.z > 1024) { - throw std::runtime_error("Too many cells in one dimension. Maximum is 1024"); - } -#endif - return cell_dim; -} - -/* - * @brief Get the cell coordinates of a point - * @param p The point position - * @param box_size The size of the box in each dimension - * @param cutoff The cutoff - * @param cell_dim The number of cells in each dimension - * @return The cell coordinates - */ -template -__device__ int3 getCell(scalar3 p, scalar3 box_size, scalar_t cutoff, - int3 cell_dim) { - p = rect::apply_pbc(p, box_size); - // Take to the [0, box_size] range and divide by cutoff (which is the cell size) - int cx = floorf((p.x + scalar_t(0.5) * box_size.x) / cutoff); - int cy = floorf((p.y + scalar_t(0.5) * box_size.y) / cutoff); - int cz = floorf((p.z + scalar_t(0.5) * box_size.z) / cutoff); - // Wrap around. If the position of a particle is exactly box_size, it will be in the last cell, - // which results in an illegal access down the line. - if (cx == cell_dim.x) - cx = 0; - if (cy == cell_dim.y) - cy = 0; - if (cz == cell_dim.z) - cz = 0; - return make_int3(cx, cy, cz); -} - -/* - * @brief Get the index of a cell in a 1D array of cells. - * @param cell The cell coordinates, assumed to be in the range [0, cell_dim]. - * @param cell_dim The number of cells in each dimension - */ -__device__ int getCellIndex(int3 cell, int3 cell_dim) { - return cell.x + cell_dim.x * (cell.y + cell_dim.y * cell.z); -} - -/* - @brief Fold a cell coordinate to the range [0, cell_dim) - @param cell The cell coordinate - @param cell_dim The dimensions of the grid - @return The folded cell coordinate -*/ -__device__ int3 getPeriodicCell(int3 cell, int3 cell_dim) { - int3 periodic_cell = cell; - if (cell.x < 0) - periodic_cell.x += cell_dim.x; - if (cell.x >= cell_dim.x) - periodic_cell.x -= cell_dim.x; - if (cell.y < 0) - periodic_cell.y += cell_dim.y; - if (cell.y >= cell_dim.y) - periodic_cell.y -= cell_dim.y; - if (cell.z < 0) - periodic_cell.z += cell_dim.z; - if (cell.z >= cell_dim.z) - periodic_cell.z -= cell_dim.z; - return periodic_cell; -} - -// Computes and stores the cell index of each atom. -template -__global__ void assignCellIndex(const Accessor positions, - Accessor cell_indices, scalar3 box_size, - scalar_t cutoff) { - const int32_t i_atom = blockIdx.x * blockDim.x + threadIdx.x; - if (i_atom >= positions.size(0)) - return; - const auto pi = fetchPosition(positions, i_atom); - const auto cell_dim = getCellDimensions(box_size, cutoff); - const auto ci = getCell(pi, box_size, cutoff, cell_dim); - cell_indices[i_atom] = getCellIndex(ci, cell_dim); -} - -/* - * @brief Sort the positions by cell index - * @param positions The positions of the atoms - * @param box_size The box vectors - * @param cutoff The cutoff - * @return A tuple of the sorted indices and cell indices - */ -static auto sortAtomsByCellIndex(const Tensor& positions, const Tensor& box_size, - const Scalar& cutoff) { - const int num_atoms = positions.size(0); - Tensor cell_index = empty({num_atoms}, positions.options().dtype(torch::kInt32)); - const int threads = 128; - const int blocks = (num_atoms + threads - 1) / threads; - auto stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "assignHash", [&] { - scalar_t cutoff_ = cutoff.to(); - scalar3 box_size_ = {box_size[0][0].item(), - box_size[1][1].item(), - box_size[2][2].item()}; - assignCellIndex<<>>(get_accessor(positions), - get_accessor(cell_index), - box_size_, cutoff_); - }); - // Sort the atom indices by cell index - Tensor sorted_atom_index; - Tensor sorted_cell_index; - std::tie(sorted_cell_index, sorted_atom_index) = torch::sort(cell_index); - return std::make_tuple(sorted_atom_index.to(torch::kInt32), sorted_cell_index); -} - -__global__ void fillCellOffsetsD(const Accessor sorted_cell_indices, - Accessor cell_start, Accessor cell_end) { - // Since positions are sorted by cell, for a given atom, if the previous atom is in a different - // cell, then the current atom is the first atom in its cell We use this fact to fill the - // cell_start and cell_end arrays - const int32_t i_atom = blockIdx.x * blockDim.x + threadIdx.x; - if (i_atom >= sorted_cell_indices.size(0)) - return; - const int icell = sorted_cell_indices[i_atom]; - int im1_cell; - if (i_atom > 0) { - const int im1 = i_atom - 1; - im1_cell = sorted_cell_indices[im1]; - } else { - im1_cell = 0; - } - if (icell != im1_cell || i_atom == 0) { - cell_start[icell] = i_atom; - if (i_atom > 0) { - cell_end[im1_cell] = i_atom; - } - } - if (i_atom == sorted_cell_indices.size(0) - 1) { - cell_end[icell] = i_atom + 1; - } -} - -/* - @brief Fills the cell_start and cell_end arrays, identifying the first and last atom in each cell - @param sorted_cell_indices The cell indices of each position - @param cell_dim The dimensions of the cell grid - @return A tuple of cell_start and cell_end arrays -*/ -static auto fillCellOffsets(const Tensor& sorted_cell_indices, int3 cell_dim) { - const TensorOptions options = sorted_cell_indices.options(); - const int num_cells = cell_dim.x * cell_dim.y * cell_dim.z; - const Tensor cell_start = full({num_cells}, -1, options.dtype(torch::kInt)); - const Tensor cell_end = empty({num_cells}, options.dtype(torch::kInt)); - const int threads = 128; - const int blocks = (sorted_cell_indices.size(0) + threads - 1) / threads; - auto stream = at::cuda::getCurrentCUDAStream(); - fillCellOffsetsD<<>>(get_accessor(sorted_cell_indices), - get_accessor(cell_start), - get_accessor(cell_end)); - return std::make_tuple(cell_start, cell_end); -} - -/* - @brief Get the cell index of the i'th neighboring cell for a given cell - @param cell_i The cell coordinates - @param i The index of the neighboring cell, from 0 to 26 - @param cell_dim The dimensions of the cell grid - @return The cell index of the i'th neighboring cell -*/ -__device__ int getNeighborCellIndex(int3 cell_i, int i, int3 cell_dim) { - auto cell_j = cell_i; - cell_j.x += i % 3 - 1; - cell_j.y += (i / 3) % 3 - 1; - cell_j.z += i / 9 - 1; - cell_j = getPeriodicCell(cell_j, cell_dim); - const int icellj = getCellIndex(cell_j, cell_dim); - return icellj; -} - -template struct Particle { - int index; // Index in the sorted arrays - int original_index; // Index in the original arrays - int batch; - scalar3 position; - scalar_t cutoff_upper2, cutoff_lower2; -}; - -struct CellList { - Tensor cell_start, cell_end; - Tensor sorted_indices; - Tensor sorted_positions, sorted_batch; -}; - -CellList constructCellList(const Tensor& positions, const Tensor& batch, const Tensor& box_size, - const Scalar& cutoff) { - // The algorithm for the cell list construction can be summarized in three separate steps: - // 1. Label the particles according to the cell (bin) they lie in. - // 2. Sort the particles using the cell index as the ordering label - // (technically this is known as sorting by key). So that particles with positions - // lying in the same cell become contiguous in memory. - // 3. Identify where each cell starts and ends in the sorted particle positions - // array. - const TensorOptions options = positions.options(); - CellList cl; - Tensor sorted_cell_indices; - // Steps 1 and 2 - std::tie(cl.sorted_indices, sorted_cell_indices) = - sortAtomsByCellIndex(positions, box_size, cutoff); - cl.sorted_positions = positions.index_select(0, cl.sorted_indices); - cl.sorted_batch = batch.index_select(0, cl.sorted_indices); - // Step 3 - int3 cell_dim; - AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "computeCellDim", [&] { - scalar_t cutoff_ = cutoff.to(); - scalar3 box_size_ = {box_size[0][0].item(), - box_size[1][1].item(), - box_size[2][2].item()}; - cell_dim = getCellDimensions(box_size_, cutoff_); - }); - std::tie(cl.cell_start, cl.cell_end) = fillCellOffsets(sorted_cell_indices, cell_dim); - return cl; -} - -template struct CellListAccessor { - Accessor cell_start, cell_end; - Accessor sorted_indices; - Accessor sorted_positions; - Accessor sorted_batch; - - explicit CellListAccessor(const CellList& cl) - : cell_start(get_accessor(cl.cell_start)), - cell_end(get_accessor(cl.cell_end)), - sorted_indices(get_accessor(cl.sorted_indices)), - sorted_positions(get_accessor(cl.sorted_positions)), - sorted_batch(get_accessor(cl.sorted_batch)) { - } -}; - -/* - * @brief Add a pair of particles to the pair list. If necessary, also add the transpose pair. - * @param list The pair list - * @param i The index of the first particle - * @param j The index of the second particle - * @param distance2 The squared distance between the particles - * @param delta The vector between the particles - */ -template -__device__ void addNeighborPair(PairListAccessor& list, const int i, const int j, - scalar_t distance2, scalar3 delta) { - const bool requires_transpose = list.include_transpose && (j != i); - const int ni = max(i, j); - const int nj = min(i, j); - const scalar_t delta_sign = (ni == i) ? scalar_t(1.0) : scalar_t(-1.0); - const scalar_t distance = sqrt_(distance2); - delta = {delta_sign * delta.x, delta_sign * delta.y, delta_sign * delta.z}; - addAtomPairToList(list, ni, nj, delta, distance, requires_transpose); -} - -/* - * @brief Add to the pair list all neighbors of particle i_atom in cell j_cell - * @param i_atom The Information of the particle for which we are adding neighbors - * @param j_cell The index of the cell in which we are looking for neighbors - * @param cl The cell list - * @param box_size The box size - * @param list The pair list - */ -template -__device__ void addNeighborsForCell(const Particle& i_atom, int j_cell, - const CellListAccessor& cl, - scalar3 box_size, PairListAccessor& list) { - const auto first_particle = cl.cell_start[j_cell]; - if (first_particle != -1) { // Continue only if there are particles in this cell - const auto last_particle = cl.cell_end[j_cell]; - for (int cur_j = first_particle; cur_j < last_particle; cur_j++) { - const auto j_batch = cl.sorted_batch[cur_j]; - if ((j_batch == i_atom.batch) && - ((cur_j < i_atom.index) || (list.loop && cur_j == i_atom.index))) { - const auto position_j = fetchPosition(cl.sorted_positions, cur_j); - const auto delta = rect::compute_distance(i_atom.position, position_j, - list.use_periodic, box_size); - const scalar_t distance2 = - delta.x * delta.x + delta.y * delta.y + delta.z * delta.z; - if ( ((distance2 < i_atom.cutoff_upper2 && distance2 >= i_atom.cutoff_lower2)) || - (list.loop && cur_j == i_atom.index)) { - const int orj = cl.sorted_indices[cur_j]; - addNeighborPair(list, i_atom.original_index, orj, distance2, delta); - } // endif - } // endif - } // endfor - } // endif -} - -// Traverse the cell list for each atom and find the neighbors -template -__global__ void traverseCellList(const CellListAccessor cell_list, - PairListAccessor list, int num_atoms, - scalar3 box_size, scalar_t cutoff_lower, - scalar_t cutoff_upper) { - // Each atom traverses the cells around it and finds the neighbors - // Atoms for all batches are placed in the same cell list, but other batches are ignored while - // traversing - Particle i_atom; - i_atom.index = blockIdx.x * blockDim.x + threadIdx.x; - if (i_atom.index >= num_atoms) { - return; - } - i_atom.original_index = cell_list.sorted_indices[i_atom.index]; - i_atom.batch = cell_list.sorted_batch[i_atom.index]; - i_atom.position = fetchPosition(cell_list.sorted_positions, i_atom.index); - i_atom.cutoff_lower2 = cutoff_lower * cutoff_lower; - i_atom.cutoff_upper2 = cutoff_upper * cutoff_upper; - const int3 cell_dim = getCellDimensions(box_size, cutoff_upper); - const int3 cell_i = getCell(i_atom.position, box_size, cutoff_upper, cell_dim); - // Loop over the 27 cells around the current cell - for (int i = 0; i < 27; i++) { - const int neighbor_cell = getNeighborCellIndex(cell_i, i, cell_dim); - addNeighborsForCell(i_atom, neighbor_cell, cell_list, box_size, list); - } -} - -static std::tuple -forward_cell(const Tensor& positions, const Tensor& batch, const Tensor& in_box_size, - bool use_periodic, const Scalar& cutoff_lower, const Scalar& cutoff_upper, - const Scalar& max_num_pairs, bool loop, bool include_transpose) { - // This module computes the pair list for a given set of particles, which may be in multiple - // batches. The strategy is to first compute a cell list for all particles, and then - // traverse the cell list for each particle to construct a pair list. - checkInput(positions, batch); - auto box_size = in_box_size.to("cpu"); - // If the box has dimensions (1, 3,3) squeeze it - if (box_size.dim() == 3) { - TORCH_CHECK(box_size.size(0) == 1 && box_size.size(1) == 3 && box_size.size(2) == 3, - "Cell list does not support a box per sample. Expected \"box_size\" to have shape (1, 3, 3) or (3, 3)"); - box_size = box_size.squeeze(0); - } - - TORCH_CHECK(box_size.dim() == 2, "Expected \"box_size\" to have two dimensions"); - TORCH_CHECK(box_size.size(0) == 3 && box_size.size(1) == 3, - "Expected \"box_size\" to have shape (3, 3)"); - TORCH_CHECK(box_size[0][1].item() == 0 && box_size[0][2].item() == 0 && - box_size[1][0].item() == 0 && box_size[1][2].item() == 0 && - box_size[2][0].item() == 0 && box_size[2][1].item() == 0, - "Expected \"box_size\" to be diagonal"); - const auto max_num_pairs_ = max_num_pairs.toInt(); - TORCH_CHECK(max_num_pairs_ > 0, "Expected \"max_num_neighbors\" to be positive"); - const int num_atoms = positions.size(0); - const auto cell_list = constructCellList(positions, batch, box_size, cutoff_upper); - PairList list(max_num_pairs_, positions.options(), loop, include_transpose, use_periodic); - const auto stream = getCurrentCUDAStream(positions.get_device()); - { // Traverse the cell list to find the neighbors - const CUDAStreamGuard guard(stream); - AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "forward", [&] { - const scalar_t cutoff_upper_ = cutoff_upper.to(); - TORCH_CHECK(cutoff_upper_ > 0, "Expected cutoff_upper to be positive"); - const scalar_t cutoff_lower_ = cutoff_lower.to(); - const scalar3 box_size_ = {box_size[0][0].item(), - box_size[1][1].item(), - box_size[2][2].item()}; - PairListAccessor list_accessor(list); - CellListAccessor cell_list_accessor(cell_list); - const int threads = 128; - const int blocks = (num_atoms + threads - 1) / threads; - traverseCellList<<>>(cell_list_accessor, list_accessor, - num_atoms, box_size_, cutoff_lower_, - cutoff_upper_); - }); - } - return {list.neighbors, list.deltas, list.distances, list.i_curr_pair}; -} - -#endif diff --git a/torchmdnet/extensions/neighbors/neighbors_cuda_shared.cuh b/torchmdnet/extensions/neighbors/neighbors_cuda_shared.cuh deleted file mode 100644 index 950a3bb6b..000000000 --- a/torchmdnet/extensions/neighbors/neighbors_cuda_shared.cuh +++ /dev/null @@ -1,124 +0,0 @@ -/* Copyright Universitat Pompeu Fabra 2020-2023 https://www.compscience.org - * Distributed under the MIT License. - *(See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) - * Raul P. Pelaez 2023. Shared memory neighbor list construction for CUDA. - * This brute force approach checks all pairs of atoms by collaborativelly loading and processing - * tiles of atoms into shared memory. - * This approach is tipically slower than the brute force approach, but can handle an arbitrarily - * large number of atoms. - */ -#ifndef NEIGHBORS_SHARED_CUH -#define NEIGHBORS_SHARED_CUH -#include "common.cuh" -#include - -template -__global__ void forward_kernel_shared(uint32_t num_atoms, const Accessor positions, - const Accessor batch, scalar_t cutoff_lower2, - scalar_t cutoff_upper2, PairListAccessor list, - int32_t num_tiles, triclinic::BoxAccessor box) { - // A thread per atom - const int id = blockIdx.x * blockDim.x + threadIdx.x; - // All threads must pass through __syncthreads, - // but when N is not a multiple of 32 some threads are assigned a particle i>N. - // This threads cant return, so they are masked to not do any work - const bool active = id < num_atoms; - __shared__ scalar3 sh_pos[BLOCKSIZE]; - __shared__ int64_t sh_batch[BLOCKSIZE]; - scalar3 pos_i; - int64_t batch_i; - if (active) { - pos_i = fetchPosition(positions, id); - batch_i = batch[id]; - } - // Distribute the N particles in a group of tiles. Storing in each tile blockDim.x values in - // shared memory. This way all threads are accesing the same memory addresses at the same time - for (int tile = 0; tile < num_tiles; tile++) { - // Load this tiles particles values to shared memory - const int i_load = tile * blockDim.x + threadIdx.x; - if (i_load < num_atoms) { // Even if im not active, my thread may load a value each tile to - // shared memory. - sh_pos[threadIdx.x] = fetchPosition(positions, i_load); - sh_batch[threadIdx.x] = batch[i_load]; - } - // Wait for all threads to arrive - __syncthreads(); - // Go through all the particles in the current tile -#pragma unroll 8 - for (int counter = 0; counter < blockDim.x; counter++) { - if (!active) - break; // An out of bounds thread must be masked - const int cur_j = tile * blockDim.x + counter; - const bool testPair = (cur_j < num_atoms) && (cur_j < id || (list.loop && cur_j == id)); - if (testPair) { - const auto batch_j = sh_batch[counter]; - if (batch_i == batch_j) { - const auto pos_j = sh_pos[counter]; - const auto box_i = box[batch_i]; - const auto delta = - triclinic::compute_distance(pos_i, pos_j, list.use_periodic, box_i); - const scalar_t distance2 = - delta.x * delta.x + delta.y * delta.y + delta.z * delta.z; - if (distance2 < cutoff_upper2 && distance2 >= cutoff_lower2) { - const bool requires_transpose = list.include_transpose && !(cur_j == id); - const auto distance = sqrt_(distance2); - addAtomPairToList(list, id, cur_j, delta, distance, requires_transpose); - } - } - } - } - __syncthreads(); - } -} - -static std::tuple -forward_shared(const Tensor& positions, const Tensor& batch, const Tensor& in_box_vectors, - bool use_periodic, const Scalar& cutoff_lower, const Scalar& cutoff_upper, - const Scalar& max_num_pairs, bool loop, bool include_transpose) { - checkInput(positions, batch); - const auto max_num_pairs_ = max_num_pairs.toLong(); - auto box_vectors = in_box_vectors.to(positions.device()); - if (box_vectors.dim() == 2) { - // If the box is a 3x3 tensor it is assumed every sample has the same box - if (use_periodic) { - TORCH_CHECK(box_vectors.size(0) == 3 && box_vectors.size(1) == 3, - "Expected \"box_vectors\" to have shape (n_batch, 3, 3)"); - } - // Make the box (None,3,3), expand artificially to positions.size(0) - box_vectors = box_vectors.unsqueeze(0); - if (use_periodic) { - // I use positions.size(0) because the batch dimension is not available here - box_vectors = box_vectors.expand({positions.size(0), 3, 3}); - } - } - TORCH_CHECK(max_num_pairs_ > 0, "Expected \"max_num_neighbors\" to be positive"); - if (use_periodic) { - TORCH_CHECK(box_vectors.dim() == 3, "Expected \"box_vectors\" to have three dimensions"); - TORCH_CHECK(box_vectors.size(1) == 3 && box_vectors.size(2) == 3, - "Expected \"box_vectors\" to have shape (n_batch, 3, 3)"); - } - const int num_atoms = positions.size(0); - const int num_pairs = max_num_pairs_; - const TensorOptions options = positions.options(); - const auto stream = getCurrentCUDAStream(positions.get_device()); - PairList list(num_pairs, positions.options(), loop, include_transpose, use_periodic); - const CUDAStreamGuard guard(stream); - AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "get_neighbor_pairs_shared_forward", [&]() { - const scalar_t cutoff_upper_ = cutoff_upper.to(); - const scalar_t cutoff_lower_ = cutoff_lower.to(); - auto box = triclinic::get_box_accessor(box_vectors, use_periodic); - TORCH_CHECK(cutoff_upper_ > 0, "Expected \"cutoff\" to be positive"); - constexpr int BLOCKSIZE = 64; - const int num_blocks = std::max((num_atoms + BLOCKSIZE - 1) / BLOCKSIZE, 1); - const int num_threads = BLOCKSIZE; - const int num_tiles = num_blocks; - PairListAccessor list_accessor(list); - forward_kernel_shared<<>>( - num_atoms, get_accessor(positions), get_accessor(batch), - cutoff_lower_ * cutoff_lower_, cutoff_upper_ * cutoff_upper_, list_accessor, num_tiles, - box); - }); - return {list.neighbors, list.deltas, list.distances, list.i_curr_pair}; -} - -#endif diff --git a/torchmdnet/extensions/ops.py b/torchmdnet/extensions/ops.py index 8a03c87b6..a9ebbce24 100644 --- a/torchmdnet/extensions/ops.py +++ b/torchmdnet/extensions/ops.py @@ -10,23 +10,21 @@ import torch from torch import Tensor from typing import Tuple -from torch.library import register_fake +import logging +from torchmdnet.extensions.neighbors import torch_neighbor_bruteforce +try: + import triton + from torchmdnet.extensions.triton_neighbors import triton_neighbor_pairs -__all__ = ["is_current_stream_capturing", "get_neighbor_pairs_kernel"] + HAS_TRITON = True +except ImportError: + HAS_TRITON = False -def is_current_stream_capturing(): - """Returns True if the current CUDA stream is capturing. +logger = logging.getLogger(__name__) - Returns False if CUDA is not available or the current stream is not capturing. - - This utility is required because the builtin torch function that does this is not scriptable. - """ - _is_current_stream_capturing = ( - torch.ops.torchmdnet_extensions.is_current_stream_capturing - ) - return _is_current_stream_capturing() +__all__ = ["get_neighbor_pairs_kernel"] def get_neighbor_pairs_kernel( @@ -40,6 +38,7 @@ def get_neighbor_pairs_kernel( max_num_pairs: int, loop: bool, include_transpose: bool, + num_cells: int, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """Computes the neighbor pairs for a given set of atomic positions. The list is generated as a list of pairs (i,j) without any enforced ordering. @@ -48,7 +47,7 @@ def get_neighbor_pairs_kernel( Parameters ---------- strategy : str - Strategy to use for computing the neighbor list. Can be one of :code:`["shared", "brute", "cell"]`. + Strategy to use for computing the neighbor list. Can be one of :code:`["brute", "cell"]`. positions : Tensor A tensor with shape (N, 3) representing the atomic positions. batch : Tensor @@ -67,7 +66,8 @@ def get_neighbor_pairs_kernel( Whether to include self-interactions. include_transpose : bool Whether to include the transpose of the neighbor list (pair i,j and pair j,i). - + num_cells : int + The number of cells in the grid if using the cell strategy. Returns ------- neighbors : Tensor @@ -79,53 +79,46 @@ def get_neighbor_pairs_kernel( num_pairs : Tensor The number of pairs found. """ - return torch.ops.torchmdnet_extensions.get_neighbor_pairs( + if torch.jit.is_scripting() or not positions.is_cuda: + + return torch_neighbor_bruteforce( + positions, + batch=batch, + box_vectors=box_vectors, + use_periodic=use_periodic, + cutoff_lower=cutoff_lower, + cutoff_upper=cutoff_upper, + max_num_pairs=max_num_pairs, + loop=loop, + include_transpose=include_transpose, + ) + + if not HAS_TRITON: + logger.warning( + "Triton is not available, using torch version of the neighbor pairs kernel." + ) + return torch_neighbor_bruteforce( + positions, + batch=batch, + box_vectors=box_vectors, + use_periodic=use_periodic, + cutoff_lower=cutoff_lower, + cutoff_upper=cutoff_upper, + max_num_pairs=max_num_pairs, + loop=loop, + include_transpose=include_transpose, + ) + + return triton_neighbor_pairs( strategy, - positions, - batch, - box_vectors, - use_periodic, - cutoff_lower, - cutoff_upper, - max_num_pairs, - loop, - include_transpose, + positions=positions, + batch=batch, + box_vectors=box_vectors, + use_periodic=use_periodic, + cutoff_lower=cutoff_lower, + cutoff_upper=cutoff_upper, + max_num_pairs=max_num_pairs, + loop=loop, + include_transpose=include_transpose, + num_cells=num_cells, ) - - -# Registers a FakeTensor kernel (aka "meta kernel", "abstract impl") -# that describes what the properties of the output Tensor are given -# the properties of the input Tensor. The FakeTensor kernel is necessary -# for the op to work performantly with torch.compile. -@torch.library.register_fake("torchmdnet_extensions::get_neighbor_pairs_bkwd") -def _( - grad_edge_vec: Tensor, - grad_edge_weight: Tensor, - edge_index: Tensor, - edge_vec: Tensor, - edge_weight: Tensor, - num_atoms: int, -): - return torch.zeros((num_atoms, 3), dtype=edge_vec.dtype, device=edge_vec.device) - - -@torch.library.register_fake("torchmdnet_extensions::get_neighbor_pairs_fwd") -def _( - strategy: str, - positions: Tensor, - batch: Tensor, - box_vectors: Tensor, - use_periodic: bool, - cutoff_lower: float, - cutoff_upper: float, - max_num_pairs: int, - loop: bool, - include_transpose: bool, -) -> Tuple[Tensor, Tensor, Tensor, Tensor]: - """Returns empty vectors with the correct shape for the output of get_neighbor_pairs_kernel.""" - size = max_num_pairs - edge_index = torch.empty((2, size), dtype=torch.int32, device=positions.device) - edge_distance = torch.empty((size,), dtype=positions.dtype, device=positions.device) - edge_vec = torch.empty((size, 3), dtype=positions.dtype, device=positions.device) - num_pairs = torch.empty((1,), dtype=torch.int32, device=positions.device) - return edge_index, edge_vec, edge_distance, num_pairs diff --git a/torchmdnet/extensions/torchmdnet_extensions.cpp b/torchmdnet/extensions/torchmdnet_extensions.cpp deleted file mode 100644 index 992a3b178..000000000 --- a/torchmdnet/extensions/torchmdnet_extensions.cpp +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Copyright Universitat Pompeu Fabra 2020-2023 https://www.compscience.org - * Distributed under the MIT License. - *(See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) - * Raul P. Pelaez 2023. Torch extensions to the torchmdnet library. - * You can expose functions to python here which will be compatible with TorchScript. - * Add your exports to the TORCH_LIBRARY macro below, see __init__.py to see how to access them from python. - * The WITH_CUDA macro will be defined when compiling with CUDA support. - */ - - -#include -#include -#include -#include - -#if defined(WITH_CUDA) -#include -#include -#endif - - -extern "C" { - /* Creates a dummy empty torchmdnet_extensions module that can be imported from Python. - The import from Python will load the .so consisting of this file - in this extension, so that the TORCH_LIBRARY static initializers - below are run. */ - PyObject* PyInit_torchmdnet_extensions(void) - { - static struct PyModuleDef module_def = { - PyModuleDef_HEAD_INIT, - "torchmdnet_extensions", /* name of module */ - NULL, /* module documentation, may be NULL */ - -1, /* size of per-interpreter state of the module, - or -1 if the module keeps state in global variables. */ - NULL, /* methods */ - }; - return PyModule_Create(&module_def); - } -} - -/* @brief Returns true if the current torch CUDA stream is capturing. - * This function is required because the one available in torch is not compatible with TorchScript. - * @return True if the current torch CUDA stream is capturing. - */ -bool is_current_stream_capturing() { -#if defined(WITH_CUDA) - auto current_stream = at::cuda::getCurrentCUDAStream().stream(); - cudaStreamCaptureStatus capture_status; - cudaError_t err = cudaStreamGetCaptureInfo(current_stream, &capture_status, nullptr); - if (err != cudaSuccess) { - throw std::runtime_error(cudaGetErrorString(err)); - } - return capture_status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive; -#else - return false; -#endif -} - -TORCH_LIBRARY(torchmdnet_extensions, m) { - m.def("is_current_stream_capturing", is_current_stream_capturing); - m.def("get_neighbor_pairs(str strategy, Tensor positions, Tensor batch, Tensor box_vectors, " - "bool use_periodic, Scalar cutoff_lower, Scalar cutoff_upper, Scalar max_num_pairs, bool " - "loop, bool include_transpose) -> (Tensor neighbors, Tensor distances, Tensor " - "distance_vecs, Tensor num_pairs)"); - //The individual fwd and bkwd functions must be exposed in order to register their meta implementations python side. - m.def("get_neighbor_pairs_fwd(str strategy, Tensor positions, Tensor batch, Tensor box_vectors, " - "bool use_periodic, Scalar cutoff_lower, Scalar cutoff_upper, Scalar max_num_pairs, bool " - "loop, bool include_transpose) -> (Tensor neighbors, Tensor distances, Tensor " - "distance_vecs, Tensor num_pairs)"); - m.def("get_neighbor_pairs_bkwd(Tensor grad_edge_vec, Tensor grad_edge_weight, Tensor edge_index, " - "Tensor edge_vec, Tensor edge_weight, int num_atoms) -> Tensor"); -} diff --git a/torchmdnet/extensions/triton_brute.py b/torchmdnet/extensions/triton_brute.py new file mode 100644 index 000000000..6a6e19019 --- /dev/null +++ b/torchmdnet/extensions/triton_brute.py @@ -0,0 +1,255 @@ +# Copyright Universitat Pompeu Fabra 2020-2023 https://www.compscience.org +# Distributed under the MIT License. +# (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) +import triton +import triton.language as tl +from torch import Tensor +import torch +from torchmdnet.extensions.triton_neighbors import _tl_round, TritonNeighborAutograd +from torch.library import triton_op, wrap_triton +from typing import Tuple + + +@triton.jit +def _neighbor_brute_kernel( + pos_ptr, + batch_ptr, + box_ptr, + neighbors0_ptr, + neighbors1_ptr, + deltas_ptr, + distances_ptr, + counter_ptr, + box_batch_stride, + n_atoms, + num_all_pairs, + use_periodic: tl.constexpr, + include_transpose: tl.constexpr, + loop: tl.constexpr, + max_pairs, + cutoff_lower, + cutoff_upper, + BLOCK: tl.constexpr, +): + """Brute-force neighbor list kernel using triangular indexing and atomic compaction. + + Uses triangular indexing to iterate over only n*(n-1)/2 pairs (or n*(n+1)/2 with loop), + achieving 100% thread utilization while maintaining block-level atomic compaction. + """ + pid = tl.program_id(axis=0) + start = pid * BLOCK + idx = start + tl.arange(0, BLOCK) + + valid = idx < num_all_pairs + + # Convert linear index to (i, j) using triangular formula (same as CUDA get_row) + # Do integer arithmetic first, only convert to float for sqrt + if loop: + # With self-loops: j <= i, num_pairs = n*(n+1)/2 + # row = floor((-1 + sqrt(1 + 8k)) / 2) + sqrt_arg = (1 + 8 * idx).to(tl.float32) + row_f = tl.math.floor((-1.0 + tl.math.sqrt(sqrt_arg)) * 0.5) + i = row_f.to(tl.int32) + # Handle floating-point edge case: if i*(i+1)/2 > idx, decrement i + i = tl.where(i * (i + 1) > 2 * idx, i - 1, i) + # col = k - row*(row+1)/2 + j = idx - (i * (i + 1)) // 2 + else: + # Without self-loops: j < i, num_pairs = n*(n-1)/2 + # row = floor((1 + sqrt(1 + 8k)) / 2) + sqrt_arg = (1 + 8 * idx).to(tl.float32) + row_f = tl.math.floor((1.0 + tl.math.sqrt(sqrt_arg)) * 0.5) + i = row_f.to(tl.int32) + # Handle floating-point edge case (same correction as CUDA) + i = tl.where(i * (i - 1) > 2 * idx, i - 1, i) + # col = k - row*(row-1)/2 + j = idx - (i * (i - 1)) // 2 + + # Validate indices: check bounds and triangular constraint + # Due to float precision, we may get invalid (i, j) pairs + valid = valid & (i >= 0) & (i < n_atoms) & (j >= 0) & (j <= i) + if not loop: + # For non-loop case, also require j < i (no self-loops) + valid = valid & (j < i) + + batch_i = tl.load(batch_ptr + i, mask=valid, other=0) + batch_j = tl.load(batch_ptr + j, mask=valid, other=0) + valid = valid & (batch_i == batch_j) + + pos_ix = tl.load(pos_ptr + i * 3 + 0, mask=valid, other=0.0) + pos_iy = tl.load(pos_ptr + i * 3 + 1, mask=valid, other=0.0) + pos_iz = tl.load(pos_ptr + i * 3 + 2, mask=valid, other=0.0) + pos_jx = tl.load(pos_ptr + j * 3 + 0, mask=valid, other=0.0) + pos_jy = tl.load(pos_ptr + j * 3 + 1, mask=valid, other=0.0) + pos_jz = tl.load(pos_ptr + j * 3 + 2, mask=valid, other=0.0) + + dx = pos_ix - pos_jx + dy = pos_iy - pos_jy + dz = pos_iz - pos_jz + + if use_periodic: + box_base = box_ptr + batch_i * box_batch_stride + + b20 = tl.load(box_base + 2 * 3 + 0, mask=valid, other=0.0) + b21 = tl.load(box_base + 2 * 3 + 1, mask=valid, other=0.0) + b22 = tl.load(box_base + 2 * 3 + 2, mask=valid, other=1.0) + b10 = tl.load(box_base + 1 * 3 + 0, mask=valid, other=0.0) + b11 = tl.load(box_base + 1 * 3 + 1, mask=valid, other=1.0) + b00 = tl.load(box_base + 0 * 3 + 0, mask=valid, other=1.0) + + scale3 = _tl_round(dz / b22) + dx = dx - scale3 * b20 + dy = dy - scale3 * b21 + dz = dz - scale3 * b22 + scale2 = _tl_round(dy / b11) + dx = dx - scale2 * b10 + dy = dy - scale2 * b11 + scale1 = _tl_round(dx / b00) + dx = dx - scale1 * b00 + + dist2 = dx * dx + dy * dy + dz * dz + dist = tl.sqrt(dist2) + # Self-loops (i == j) are exempt from cutoff_lower since they have distance 0 + is_self_loop = i == j + valid = valid & (dist < cutoff_upper) & ((dist >= cutoff_lower) | is_self_loop) + + valid_int = valid.to(tl.int32) + local_idx = tl.cumsum(valid_int, axis=0) - 1 + total_pairs = tl.sum(valid_int, axis=0) + + # For transpose, don't count self-loops (i == j) + if include_transpose: + is_not_self_loop = i != j + valid_transpose = valid & is_not_self_loop + total_transpose = tl.sum(valid_transpose.to(tl.int32), axis=0) + total_out = total_pairs + total_transpose + else: + total_out = total_pairs + + has_work = total_out > 0 + start_idx = tl.atomic_add(counter_ptr, total_out, mask=has_work) + start_idx = tl.where(has_work, start_idx, 0) + start_idx = tl.broadcast_to(start_idx, local_idx.shape) + + write_idx = start_idx + local_idx + mask_store = valid & (write_idx < max_pairs) + tl.store(neighbors0_ptr + write_idx, i, mask=mask_store) + tl.store(neighbors1_ptr + write_idx, j, mask=mask_store) + tl.store(deltas_ptr + write_idx * 3 + 0, dx, mask=mask_store) + tl.store(deltas_ptr + write_idx * 3 + 1, dy, mask=mask_store) + tl.store(deltas_ptr + write_idx * 3 + 2, dz, mask=mask_store) + tl.store(distances_ptr + write_idx, dist, mask=mask_store) + + if include_transpose: + # Don't add transpose for self-loops (i == j) + is_not_self_loop = i != j + valid_transpose = valid & is_not_self_loop + valid_t_int = valid_transpose.to(tl.int32) + local_idx_t = tl.cumsum(valid_t_int, axis=0) - 1 + total_pairs_t = tl.sum(valid_t_int, axis=0) + + write_idx_t = start_idx + total_pairs + local_idx_t + mask_store_t = valid_transpose & (write_idx_t < max_pairs) + tl.store(neighbors0_ptr + write_idx_t, j, mask=mask_store_t) + tl.store(neighbors1_ptr + write_idx_t, i, mask=mask_store_t) + tl.store(deltas_ptr + write_idx_t * 3 + 0, -dx, mask=mask_store_t) + tl.store(deltas_ptr + write_idx_t * 3 + 1, -dy, mask=mask_store_t) + tl.store(deltas_ptr + write_idx_t * 3 + 2, -dz, mask=mask_store_t) + tl.store(distances_ptr + write_idx_t, dist, mask=mask_store_t) + + +@triton_op("torchmdnet::triton_neighbor_bruteforce", mutates_args={}) +def triton_neighbor_bruteforce( + positions: Tensor, + batch: Tensor, + box_vectors: Tensor, + use_periodic: bool, + cutoff_lower: float, + cutoff_upper: float, + max_num_pairs: int, + loop: bool, + include_transpose: bool, +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + device = positions.device + dtype = positions.dtype + n_atoms = positions.size(0) + + batch = batch.contiguous() + positions = positions.contiguous() + if use_periodic: + if box_vectors.dim() == 2: + box_vectors = box_vectors.unsqueeze(0) + elif box_vectors.dim() != 3: + raise ValueError('Expected "box_vectors" to have shape (n_batch, 3, 3)') + box_vectors = box_vectors.to(device=device, dtype=dtype) + box_vectors = box_vectors.contiguous() + # Use stride 0 to broadcast single box to all batches (avoids CPU sync) + box_batch_stride = 0 if box_vectors.size(0) == 1 else 9 + + neighbors = torch.full((2, max_num_pairs), -1, device=device, dtype=torch.long) + deltas = torch.zeros((max_num_pairs, 3), device=device, dtype=dtype) + distances = torch.zeros((max_num_pairs,), device=device, dtype=dtype) + num_pairs = torch.zeros((1,), device=device, dtype=torch.int32) + + # Compute triangular pair count: n*(n-1)/2 without self-loops, n*(n+1)/2 with + if loop: + num_all_pairs = n_atoms * (n_atoms + 1) // 2 + else: + num_all_pairs = n_atoms * (n_atoms - 1) // 2 + + # Grid covers only triangular pairs (not n*n) + grid = lambda meta: (triton.cdiv(num_all_pairs, meta["BLOCK"]),) + + wrap_triton(_neighbor_brute_kernel)[grid]( + positions, + batch, + box_vectors if use_periodic else positions, # dummy pointer if not periodic + neighbors[0], + neighbors[1], + deltas, + distances, + num_pairs, + box_batch_stride if use_periodic else 0, + n_atoms, + num_all_pairs, + use_periodic, + include_transpose, + loop, + max_num_pairs, + cutoff_lower, + cutoff_upper, + BLOCK=256, + ) + + return neighbors, deltas, distances, num_pairs + + +class TritonBruteNeighborAutograd(TritonNeighborAutograd): + @staticmethod + def forward( # type: ignore[override] + ctx, + positions: Tensor, + batch: Tensor, + box_vectors: Tensor, + use_periodic: bool, + cutoff_lower: float, + cutoff_upper: float, + max_num_pairs: int, + loop: bool, + include_transpose: bool, + ): + neighbors, deltas, distances, num_pairs = triton_neighbor_bruteforce( + positions, + batch, + box_vectors, + use_periodic, + cutoff_lower, + cutoff_upper, + max_num_pairs, + loop, + include_transpose, + ) + + ctx.save_for_backward(neighbors, deltas, distances) + ctx.num_atoms = positions.size(0) + return neighbors, deltas, distances, num_pairs diff --git a/torchmdnet/extensions/triton_cell.py b/torchmdnet/extensions/triton_cell.py new file mode 100644 index 000000000..2586fdfcb --- /dev/null +++ b/torchmdnet/extensions/triton_cell.py @@ -0,0 +1,474 @@ +# Copyright Universitat Pompeu Fabra 2020-2023 https://www.compscience.org +# Distributed under the MIT License. +# (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) +import triton +import triton.language as tl +import torch +from torch import Tensor +from typing import Tuple +from torchmdnet.extensions.triton_neighbors import TritonNeighborAutograd +from torch.library import triton_op, wrap_triton + + +def _get_cell_dimensions( + box_x: torch.float32, + box_y: torch.float32, + box_z: torch.float32, + cutoff_upper: torch.float32, +) -> int: + nx = torch.floor(box_x / cutoff_upper).clamp(min=3).long() + ny = torch.floor(box_y / cutoff_upper).clamp(min=3).long() + nz = torch.floor(box_z / cutoff_upper).clamp(min=3).long() + return torch.stack([nx, ny, nz]) + + +@triton.jit +def _tl_round(x): + return tl.where(x >= 0, tl.math.floor(x + 0.5), tl.math.ceil(x - 0.5)) + + +@triton.jit +def cell_neighbor_kernel( + # Cell data structure (1D sorted approach) + SortedIndices, # [n_atoms] - original atom indices, sorted by cell + SortedPositions, # [n_atoms, 3] - positions sorted by cell (for coalesced access) + SortedBatch, # [n_atoms] - batch indices sorted by cell + CellStart, # [num_cells] - start index in sorted arrays for each cell + CellEnd, # [num_cells] - end index (exclusive) in sorted arrays for each cell + # Box parameters + BoxSizes, # [3] - box dimensions + CellDims, # [3] - number of cells in each dimension (int32) + # Output + OutPairs, # [2, max_pairs] + OutDeltas, # [max_pairs, 3] + OutDists, # [max_pairs] + GlobalCounter, # [1] + # Scalar parameters + max_pairs, + cutoff_lower_sq, + cutoff_upper_sq, + # Flags + use_periodic: tl.constexpr, + loop: tl.constexpr, + include_transpose: tl.constexpr, + # Batch size for vectorized processing + BATCH_SIZE: tl.constexpr, # e.g., 32 -> processes 32×32=1024 pairs per iteration +): + """ + Each program processes one cell (the "home" cell). + Uses 1D sorted array with cell_start/cell_end pointers. + + Vectorized batched processing: + - Loads BATCH_SIZE atoms at a time for both home and neighbor + - Computes BATCH_SIZE × BATCH_SIZE distance matrix per iteration + - Uses while loops to iterate only over actual atoms + - Minimal waste: only last partial batch may have masked elements + + To avoid double-counting: + - For half list (include_transpose=False): only emit pairs where home_atom > neighbor_atom + - For full list (include_transpose=True): emit both directions + """ + home_cell_id = tl.program_id(0) + + # Load box and cell dimensions + box_x = tl.load(BoxSizes + 0) + box_y = tl.load(BoxSizes + 1) + box_z = tl.load(BoxSizes + 2) + + num_cells_x = tl.load(CellDims + 0) + num_cells_y = tl.load(CellDims + 1) + num_cells_z = tl.load(CellDims + 2) + + # Decompose home cell ID into 3D coordinates + cells_yz = num_cells_y * num_cells_z + home_cx = home_cell_id // cells_yz + home_cy = (home_cell_id % cells_yz) // num_cells_z + home_cz = home_cell_id % num_cells_z + + # Load home cell boundaries + home_start = tl.load(CellStart + home_cell_id) + home_end = tl.load(CellEnd + home_cell_id) + + # Loop over 27 neighbor cells + for neighbor_offset in tl.range(0, 27): + # Decompose neighbor_offset into di, dj, dk (each in {-1, 0, 1}) + di = (neighbor_offset % 3) - 1 + dj = ((neighbor_offset // 3) % 3) - 1 + dk = (neighbor_offset // 9) - 1 + + # Compute neighbor cell coordinates + ni = home_cx + di + nj = home_cy + dj + nk = home_cz + dk + + # Handle boundary conditions + if use_periodic: + ni = (ni + num_cells_x) % num_cells_x + nj = (nj + num_cells_y) % num_cells_y + nk = (nk + num_cells_z) % num_cells_z + cell_valid = True + else: + cell_valid = ( + (ni >= 0) + & (ni < num_cells_x) + & (nj >= 0) + & (nj < num_cells_y) + & (nk >= 0) + & (nk < num_cells_z) + ) + + neighbor_cell_id = ni * cells_yz + nj * num_cells_z + nk + + # Load neighbor cell boundaries + neighbor_start = tl.load(CellStart + neighbor_cell_id) + neighbor_end = tl.load(CellEnd + neighbor_cell_id) + + # If cell is invalid (non-periodic boundary), make it empty + neighbor_start = tl.where(cell_valid, neighbor_start, 0) + neighbor_end = tl.where(cell_valid, neighbor_end, 0) + + # Batched iteration over home atoms + home_batch_start = home_start + while home_batch_start < home_end: + # Load BATCH_SIZE home atoms + home_offsets = tl.arange(0, BATCH_SIZE) + home_global_idx = home_batch_start + home_offsets + home_mask = home_global_idx < home_end + + # Load home atom original indices (for output pair indices) + home_atoms = tl.load( + SortedIndices + home_global_idx, mask=home_mask, other=0 + ) + + # Load home atom positions (sequential access - coalesced!) + home_x = tl.load( + SortedPositions + home_global_idx * 3 + 0, mask=home_mask, other=0.0 + ) + home_y = tl.load( + SortedPositions + home_global_idx * 3 + 1, mask=home_mask, other=0.0 + ) + home_z = tl.load( + SortedPositions + home_global_idx * 3 + 2, mask=home_mask, other=0.0 + ) + home_batch = tl.load( + SortedBatch + home_global_idx, mask=home_mask, other=-1 + ) + + # Batched iteration over neighbor atoms + neighbor_batch_start = neighbor_start + while neighbor_batch_start < neighbor_end: + # Load BATCH_SIZE neighbor atoms + neighbor_offsets = tl.arange(0, BATCH_SIZE) + neighbor_global_idx = neighbor_batch_start + neighbor_offsets + neighbor_mask = neighbor_global_idx < neighbor_end + + # Load neighbor atom original indices (for output pair indices) + neighbor_atoms = tl.load( + SortedIndices + neighbor_global_idx, mask=neighbor_mask, other=0 + ) + + # Load neighbor atom positions (sequential access - coalesced!) + neighbor_x = tl.load( + SortedPositions + neighbor_global_idx * 3 + 0, + mask=neighbor_mask, + other=0.0, + ) + neighbor_y = tl.load( + SortedPositions + neighbor_global_idx * 3 + 1, + mask=neighbor_mask, + other=0.0, + ) + neighbor_z = tl.load( + SortedPositions + neighbor_global_idx * 3 + 2, + mask=neighbor_mask, + other=0.0, + ) + neighbor_batch_vals = tl.load( + SortedBatch + neighbor_global_idx, mask=neighbor_mask, other=-2 + ) + + # Compute pairwise distances: [BATCH_SIZE, BATCH_SIZE] + dx = home_x[:, None] - neighbor_x[None, :] + dy = home_y[:, None] - neighbor_y[None, :] + dz = home_z[:, None] - neighbor_z[None, :] + + # Apply PBC + if use_periodic: + dx = dx - box_x * _tl_round(dx / box_x) + dy = dy - box_y * _tl_round(dy / box_y) + dz = dz - box_z * _tl_round(dz / box_z) + + dist_sq = dx * dx + dy * dy + dz * dz + + # Build validity mask + # 1. Distance within cutoff + cond_dist = (dist_sq < cutoff_upper_sq) & (dist_sq >= cutoff_lower_sq) + + # 2. Same batch + cond_batch = home_batch[:, None] == neighbor_batch_vals[None, :] + + # 3. Index ordering to avoid double-counting + home_atoms_bc = home_atoms[:, None] + neighbor_atoms_bc = neighbor_atoms[None, :] + + if include_transpose: + if loop: + cond_idx = True + else: + cond_idx = home_atoms_bc != neighbor_atoms_bc + else: + if loop: + cond_idx = home_atoms_bc >= neighbor_atoms_bc + else: + cond_idx = home_atoms_bc > neighbor_atoms_bc + + # 4. Both atoms must be valid (within actual cell bounds) + cond_valid = home_mask[:, None] & neighbor_mask[None, :] + + # Combined validity + valid_mask = cond_dist & cond_batch & cond_idx & cond_valid + + # Count and store valid pairs + num_found = tl.sum(valid_mask.to(tl.int32)) + + if num_found > 0: + # Atomically reserve space in output + current_offset = tl.atomic_add(GlobalCounter, num_found) + + if current_offset + num_found <= max_pairs: + # Compute storage indices using cumsum + flat_mask = tl.ravel(valid_mask) + csum = tl.cumsum(flat_mask.to(tl.int32), axis=0) + store_idx = current_offset + csum - 1 + + # Prepare flattened data + flat_home = tl.ravel( + tl.broadcast_to( + home_atoms[:, None], (BATCH_SIZE, BATCH_SIZE) + ) + ) + flat_neighbor = tl.ravel( + tl.broadcast_to( + neighbor_atoms[None, :], (BATCH_SIZE, BATCH_SIZE) + ) + ) + flat_dx = tl.ravel(dx) + flat_dy = tl.ravel(dy) + flat_dz = tl.ravel(dz) + flat_dist = tl.sqrt(tl.ravel(dist_sq)) + + # Store pairs + tl.store( + OutPairs + 0 * max_pairs + store_idx, + flat_home, + mask=flat_mask, + ) + tl.store( + OutPairs + 1 * max_pairs + store_idx, + flat_neighbor, + mask=flat_mask, + ) + + # Store deltas + tl.store(OutDeltas + store_idx * 3 + 0, flat_dx, mask=flat_mask) + tl.store(OutDeltas + store_idx * 3 + 1, flat_dy, mask=flat_mask) + tl.store(OutDeltas + store_idx * 3 + 2, flat_dz, mask=flat_mask) + + # Store distances + tl.store(OutDists + store_idx, flat_dist, mask=flat_mask) + + neighbor_batch_start += BATCH_SIZE + home_batch_start += BATCH_SIZE + + +def build_cell_list( + positions: Tensor, + batch: Tensor, + box_sizes: Tensor, # [3] diagonal elements + use_periodic: bool, + cell_dims: Tensor, # [3] number of cells in each dimension + num_cells: int, # total number of cells (fixed for CUDA graphs) +) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + """ + Build the cell list data structure using 1D sorted arrays. + + Args: + positions: [N, 3] atom positions + batch: [N] batch indices + box_sizes: [3] box diagonal elements + use_periodic: whether to use periodic boundary conditions + cell_dims: [3] number of cells in each dimension (pre-computed) + num_cells: total number of cells (pre-computed, fixed for CUDA graphs) + + Returns: + sorted_indices: [n_atoms] - original atom indices, sorted by cell + sorted_positions: [n_atoms, 3] - positions sorted by cell (for coalesced access) + sorted_batch: [n_atoms] - batch indices sorted by cell + cell_start: [num_cells] - start index in sorted arrays for each cell + cell_end: [num_cells] - end index (exclusive) in sorted arrays for each cell + """ + device = positions.device + n_atoms = positions.size(0) + + # Compute cell index for each atom + if use_periodic: + # Wrap to [0, box) + inv_box = 1.0 / box_sizes + wrapped = positions - torch.floor(positions * inv_box) * box_sizes + else: + # Shift by half box (like CUDA implementation) + wrapped = positions + 0.5 * box_sizes + + # Cell coordinates + cell_size = box_sizes / cell_dims.float() + cell_coords = (wrapped / cell_size).long() + cell_coords = torch.clamp( + cell_coords, min=torch.zeros(3, device=device), max=cell_dims - 1 + ) + + # Flat cell index + cell_idx = ( + cell_coords[:, 0] * (cell_dims[1] * cell_dims[2]) + + cell_coords[:, 1] * cell_dims[2] + + cell_coords[:, 2] + ).long() + + # Sort atoms by cell index + sorted_cell_idx, sort_order = torch.sort(cell_idx) + sorted_indices = sort_order.int() # Original atom indices, now sorted by cell + + # Create sorted positions and batch for coalesced memory access + sorted_positions = positions.index_select(0, sort_order).contiguous() + sorted_batch = batch.index_select(0, sort_order).contiguous() + + # Count atoms per cell + cell_counts = torch.zeros(num_cells, dtype=torch.int32, device=device) + cell_counts.scatter_add_( + 0, cell_idx, torch.ones(n_atoms, dtype=torch.int32, device=device) + ) + + # Compute cell_start and cell_end using cumsum + cell_end = torch.cumsum(cell_counts, dim=0).int() + cell_start = torch.zeros(num_cells, dtype=torch.int32, device=device) + cell_start[1:] = cell_end[:-1] + + return sorted_indices, sorted_positions, sorted_batch, cell_start, cell_end + + +@triton_op("torchmdnet::triton_neighbor_cell", mutates_args={}) +def triton_neighbor_cell( + positions: Tensor, + batch: Tensor, + box_vectors: Tensor, + use_periodic: bool, + cutoff_lower: float, + cutoff_upper: float, + max_num_pairs: int, + loop: bool, + include_transpose: bool, + num_cells: int, +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + device = positions.device + dtype = positions.dtype + n_atoms = positions.size(0) + + # Validate inputs + if positions.dim() != 2 or positions.size(1) != 3: + raise ValueError('Expected "positions" to have shape (N, 3)') + if batch.dim() != 1 or batch.size(0) != n_atoms: + raise ValueError('Expected "batch" to have shape (N,)') + + # Extract box diagonal + box_vectors = box_vectors.contiguous() + if box_vectors.dim() == 3: + box_diag = box_vectors[0] + else: + box_diag = box_vectors + box_sizes = torch.stack( + [box_diag[0, 0], box_diag[1, 1], box_diag[2, 2]] + ).contiguous() + + # Compute cell dimensions using shared utility (stays on GPU) + cell_dims = _get_cell_dimensions( + box_sizes[0], box_sizes[1], box_sizes[2], cutoff_upper + ) + + # Build cell list (1D sorted approach with sorted positions for coalesced access) + sorted_indices, sorted_positions, sorted_batch, cell_start, cell_end = ( + build_cell_list(positions, batch, box_sizes, use_periodic, cell_dims, num_cells) + ) + + # Allocate outputs + neighbors = torch.full((2, max_num_pairs), -1, device=device, dtype=torch.long) + deltas = torch.zeros((max_num_pairs, 3), device=device, dtype=dtype) + distances = torch.zeros((max_num_pairs,), device=device, dtype=dtype) + num_pairs = torch.zeros((1,), device=device, dtype=torch.int32) + + # Launch kernel: one program per cell + # BATCH_SIZE: process atoms in batches for vectorized compute + # 32 is a good balance: 32×32=1024 elements fits in registers, minimal waste on partial batches + BATCH_SIZE = 32 + + grid = (num_cells,) + wrap_triton(cell_neighbor_kernel)[grid]( + sorted_indices, + sorted_positions, + sorted_batch, + cell_start, + cell_end, + box_sizes, + cell_dims, + neighbors, + deltas, + distances, + num_pairs, + max_num_pairs, + cutoff_lower**2, + cutoff_upper**2, + use_periodic=use_periodic, + loop=loop, + include_transpose=include_transpose, + BATCH_SIZE=BATCH_SIZE, + ) + return neighbors, deltas, distances, num_pairs + + +class TritonCellNeighborAutograd(TritonNeighborAutograd): + @staticmethod + def forward( + ctx, + positions: Tensor, + batch: Tensor, + box_vectors: Tensor, + use_periodic: bool, + cutoff_lower: float, + cutoff_upper: float, + max_num_pairs: int, + loop: bool, + include_transpose: bool, + num_cells: int, + ): + neighbors, deltas, distances, num_pairs = triton_neighbor_cell( + positions, + batch, + box_vectors, + use_periodic, + cutoff_lower, + cutoff_upper, + max_num_pairs, + loop, + include_transpose, + num_cells, + ) + + ctx.save_for_backward(neighbors, deltas, distances) + ctx.num_atoms = positions.size(0) + return neighbors, deltas, distances, num_pairs + + @staticmethod + def backward(ctx, grad_neighbors, grad_deltas, grad_distances, grad_num_pairs): # type: ignore[override] + # Call parent backward (returns 9 values) and add None for num_cells + parent_grads = TritonNeighborAutograd.backward( + ctx, grad_neighbors, grad_deltas, grad_distances, grad_num_pairs + ) + return (*parent_grads, None) diff --git a/torchmdnet/extensions/triton_neighbors.py b/torchmdnet/extensions/triton_neighbors.py new file mode 100644 index 000000000..5d2df81af --- /dev/null +++ b/torchmdnet/extensions/triton_neighbors.py @@ -0,0 +1,108 @@ +# Copyright Universitat Pompeu Fabra 2020-2023 https://www.compscience.org +# Distributed under the MIT License. +# (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) + +from typing import Tuple +import torch +from torch import Tensor +import triton +import triton.language as tl + + +@triton.jit +def _tl_round(x): + return tl.where(x >= 0, tl.math.floor(x + 0.5), tl.math.ceil(x - 0.5)) + + +class TritonNeighborAutograd(torch.autograd.Function): + @staticmethod + def backward(ctx, grad_neighbors, grad_deltas, grad_distances, grad_num_pairs): # type: ignore[override] + neighbors, edge_vec, edge_weight = ctx.saved_tensors + num_atoms = ctx.num_atoms + + if grad_deltas is None: + grad_deltas = torch.zeros_like(edge_vec) + if grad_distances is None: + grad_distances = torch.zeros_like(edge_weight) + + zero_mask = edge_weight.eq(0) + zero_mask3 = zero_mask.unsqueeze(-1).expand_as(grad_deltas) + + grad_distances_term = edge_vec / edge_weight.masked_fill( + zero_mask, 1 + ).unsqueeze(-1) + grad_distances_term = grad_distances_term * grad_distances.masked_fill( + zero_mask, 0 + ).unsqueeze(-1) + + grad_positions = torch.zeros( + (num_atoms, 3), device=edge_vec.device, dtype=edge_vec.dtype + ) + edge_index_safe = neighbors.masked_fill( + zero_mask.unsqueeze(0).expand_as(neighbors), 0 + ) + grad_vec = grad_deltas.masked_fill(zero_mask3, 0) + grad_distances_term + grad_positions.index_add_(0, edge_index_safe[0], grad_vec) + grad_positions.index_add_(0, edge_index_safe[1], -grad_vec) + + return ( + grad_positions, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +def triton_neighbor_pairs( + strategy: str, + positions: Tensor, + batch: Tensor, + box_vectors: Tensor, + use_periodic: bool, + cutoff_lower: float, + cutoff_upper: float, + max_num_pairs: int, + loop: bool, + include_transpose: bool, + num_cells: int, +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + from torchmdnet.extensions.triton_cell import TritonCellNeighborAutograd + from torchmdnet.extensions.triton_brute import TritonBruteNeighborAutograd + + if positions.device.type != "cuda": + raise RuntimeError("Triton neighbor list requires CUDA tensors") + if positions.dtype not in (torch.float32, torch.float64): + raise RuntimeError("Unsupported dtype for Triton neighbor list") + + if strategy == "brute": + return TritonBruteNeighborAutograd.apply( + positions, + batch, + box_vectors, + use_periodic, + float(cutoff_lower), + float(cutoff_upper), + int(max_num_pairs), + bool(loop), + bool(include_transpose), + ) + elif strategy == "cell": + return TritonCellNeighborAutograd.apply( + positions, + batch, + box_vectors, + use_periodic, + cutoff_lower, + cutoff_upper, + int(max_num_pairs), + bool(loop), + bool(include_transpose), + num_cells, + ) + else: + raise ValueError(f"Unsupported strategy {strategy}") diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index b7c8398cc..942f7932b 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -127,6 +127,7 @@ def create_model(args, prior_model=None, mean=None, std=None): activation=args["activation"], reduce_op=args["reduce_op"], dtype=dtype, + static_shapes=args.get("static_shapes", False), num_hidden_layers=args.get("output_mlp_num_layers", 0), ) @@ -251,6 +252,13 @@ def load_model(filepath, args=None, device="cpu", return_std=False, **kwargs): for p in patterns: state_dict = {re.sub(p[0], p[1], k): v for k, v in state_dict.items()} + # Backward compatibility: box was changed from class attribute to registered buffer + # Old checkpoints don't have the box buffer, so add it with default value + if "representation_model.distance.box" not in state_dict: + state_dict["representation_model.distance.box"] = torch.zeros( + (3, 3), device="cpu" + ) + model.load_state_dict(state_dict) return model.to(device) diff --git a/torchmdnet/models/output_modules.py b/torchmdnet/models/output_modules.py index 64f9cebfc..c9f933b9b 100644 --- a/torchmdnet/models/output_modules.py +++ b/torchmdnet/models/output_modules.py @@ -6,14 +6,8 @@ from typing import Optional import torch from torch import nn -from torchmdnet.models.utils import ( - act_class_mapping, - GatedEquivariantBlock, - scatter, - MLP, -) +from torchmdnet.models.utils import GatedEquivariantBlock, scatter, MLP from torchmdnet.utils import atomic_masses -from torchmdnet.extensions.ops import is_current_stream_capturing from warnings import warn __all__ = ["Scalar", "DipoleMoment", "ElectronicSpatialExtent"] @@ -26,27 +20,34 @@ class OutputModel(nn.Module, metaclass=ABCMeta): As an example, have a look at the :py:mod:`torchmdnet.output_modules.Scalar` output model. """ - def __init__(self, allow_prior_model, reduce_op): + def __init__(self, allow_prior_model, reduce_op, static_shapes=False): super(OutputModel, self).__init__() self.allow_prior_model = allow_prior_model self.reduce_op = reduce_op + self.static_shapes = static_shapes self.dim_size = 0 - self.setup_for_compile = False def reset_parameters(self): pass - def setup_for_compile_cudagraphs(self, batch): - self.dim_size = int(batch.max().item() + 1) - self.setup_for_compile = True - @abstractmethod def pre_reduce(self, x, v, z, pos, batch): return def reduce(self, x, batch): - if not self.setup_for_compile: - is_capturing = x.is_cuda and is_current_stream_capturing() + # torch.compile and torch.export don't support .item() calls during tracing + # The model should be warmed up before compilation to set the correct dim_size + if torch.compiler.is_compiling(): + pass + elif torch.jit.is_scripting(): + # TorchScript doesn't support torch.cuda.is_current_stream_capturing() + # For CPU, always update dim_size (no CUDA graphs on CPU) + # For CUDA with static_shapes, only update once (first call sets dim_size for CUDA graph capture) + # For CUDA without static_shapes, always update (dynamic batch sizes) + if not x.is_cuda or not self.static_shapes or self.dim_size == 0: + self.dim_size = int(batch.max().item() + 1) + else: + is_capturing = x.is_cuda and torch.cuda.is_current_stream_capturing() if not x.is_cuda or not is_capturing: self.dim_size = int(batch.max().item() + 1) if is_capturing: @@ -72,10 +73,13 @@ def __init__( allow_prior_model=True, reduce_op="sum", dtype=torch.float, + static_shapes=False, **kwargs, ): super(Scalar, self).__init__( - allow_prior_model=allow_prior_model, reduce_op=reduce_op + allow_prior_model=allow_prior_model, + reduce_op=reduce_op, + static_shapes=static_shapes, ) self.output_network = MLP( in_channels=hidden_channels, @@ -102,10 +106,13 @@ def __init__( allow_prior_model=True, reduce_op="sum", dtype=torch.float, + static_shapes=False, **kwargs, ): super(EquivariantScalar, self).__init__( - allow_prior_model=allow_prior_model, reduce_op=reduce_op + allow_prior_model=allow_prior_model, + reduce_op=reduce_op, + static_shapes=static_shapes, ) if kwargs.get("num_layers", 0) > 0: warn("num_layers is not used in EquivariantScalar") @@ -144,6 +151,7 @@ def __init__( activation="silu", reduce_op="sum", dtype=torch.float, + static_shapes=False, **kwargs, ): super(DipoleMoment, self).__init__( @@ -152,6 +160,7 @@ def __init__( allow_prior_model=False, reduce_op=reduce_op, dtype=dtype, + static_shapes=static_shapes, **kwargs, ) atomic_mass = torch.from_numpy(atomic_masses).to(dtype) @@ -177,6 +186,7 @@ def __init__( activation="silu", reduce_op="sum", dtype=torch.float, + static_shapes=False, **kwargs, ): super(EquivariantDipoleMoment, self).__init__( @@ -185,6 +195,7 @@ def __init__( allow_prior_model=False, reduce_op=reduce_op, dtype=dtype, + static_shapes=static_shapes, **kwargs, ) atomic_mass = torch.from_numpy(atomic_masses).to(dtype) @@ -211,10 +222,11 @@ def __init__( activation="silu", reduce_op="sum", dtype=torch.float, + static_shapes=False, **kwargs, ): super(ElectronicSpatialExtent, self).__init__( - allow_prior_model=False, reduce_op=reduce_op + allow_prior_model=False, reduce_op=reduce_op, static_shapes=static_shapes ) self.output_network = MLP( in_channels=hidden_channels, @@ -254,6 +266,7 @@ def __init__( activation="silu", reduce_op="sum", dtype=torch.float, + static_shapes=False, **kwargs, ): super(EquivariantVectorOutput, self).__init__( @@ -262,6 +275,7 @@ def __init__( allow_prior_model=False, reduce_op="sum", dtype=dtype, + static_shapes=static_shapes, **kwargs, ) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index 23a3c2683..9aec148ca 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -218,9 +218,6 @@ def reset_parameters(self): self.linear.reset_parameters() self.out_norm.reset_parameters() - def setup_for_compile_cudagraphs(self, batch): - self.distance.setup_for_compile_cudagraphs() - def forward( self, z: Tensor, diff --git a/torchmdnet/models/utils.py b/torchmdnet/models/utils.py index 55cd89982..031da9a5c 100644 --- a/torchmdnet/models/utils.py +++ b/torchmdnet/models/utils.py @@ -152,7 +152,7 @@ class OptimizedDistance(torch.nn.Module): If the number of pairs found is larger than this, the pairs are randomly sampled. When check_errors is True, an exception is raised in this case. If negative, it is interpreted as (minus) the maximum number of neighbors per atom. strategy : str - Strategy to use for computing the neighbor list. Can be one of :code:`["shared", "brute", "cell"]`. + Strategy to use for computing the neighbor list. Can be one of :code:`["brute", "cell"]`. 1. *Shared*: An O(N^2) algorithm that leverages CUDA shared memory, best for large number of particles. 2. *Brute*: A brute force O(N^2) algorithm, best for small number of particles. @@ -201,34 +201,46 @@ def __init__( self.cutoff_lower = cutoff_lower self.max_num_pairs = max_num_pairs self.strategy = strategy - self.box: Optional[Tensor] = box self.loop = loop self.return_vecs = return_vecs self.include_transpose = include_transpose self.resize_to_fit = resize_to_fit self.use_periodic = True - if self.box is None: + self.num_cells = 0 + + # Use register_buffer for box to make it export-compatible and handle device movement + if box is None: self.use_periodic = False - self.box = torch.empty((0, 0)) - if self.strategy == "cell": + if strategy == "cell": # Default the box to 3 times the cutoff, really inefficient for the cell list lbox = cutoff_upper * 3.0 - self.box = torch.tensor( + box = torch.tensor( [[lbox, 0, 0], [0, lbox, 0], [0, 0, lbox]], device="cpu" ) + else: + # Use a placeholder box instead of empty (0,0) to avoid shape issues in torch.export + # This won't be used when use_periodic=False, but prevents export tracing issues + box = torch.zeros((3, 3), device="cpu") + + # Register box as a buffer so it moves with the module and is export-compatible + self.register_buffer("box", box, persistent=True) + if self.strategy == "cell": - self.box = self.box.cpu() + from torchmdnet.extensions.triton_cell import _get_cell_dimensions + + cell_dims = _get_cell_dimensions( + self.box[0, 0], self.box[1, 1], self.box[2, 2], cutoff_upper + ) + self.num_cells = int(cell_dims.prod()) + if self.num_cells > 1024**3: + raise RuntimeError( + f"Too many cells: {self.num_cells}. Maximum is 1024^3. " + f"Reduce box size or increase cutoff." + ) + self.check_errors = check_errors self.long_edge_index = long_edge_index - def setup_for_compile_cudagraphs(self): - # box needs to be a buffer to it moves to correct device, otherwise we can end - # up with torch.compile failing to use cuda graphs with "skipping cudagraphs due to skipping cudagraphs due to cpu device (primals_3)." - # we cant just make it a buffer in the constructor above because then old state dicts will fail to load - _box = self.box - del self.box - self.register_buffer('box', _box) - def forward( self, pos: Tensor, batch: Optional[Tensor] = None, box: Optional[Tensor] = None ) -> Tuple[Tensor, Tensor, Optional[Tensor]]: @@ -263,9 +275,18 @@ def forward( use_periodic = self.use_periodic if not use_periodic: use_periodic = box is not None - box = self.box if box is None else box + + # Use the registered buffer box if no box is provided + # The buffer is automatically moved to the correct device with the module + if box is None: + box = self.box + assert box is not None, "Box must be provided" - box = box.to(pos.dtype) + + # Ensure box has correct dtype (device is already correct if using self.box) + if box.dtype != pos.dtype: + box = box.to(dtype=pos.dtype) + max_pairs: int = self.max_num_pairs if self.max_num_pairs < 0: max_pairs = -self.max_num_pairs * pos.shape[0] @@ -282,6 +303,7 @@ def forward( include_transpose=self.include_transpose, box_vectors=box, use_periodic=use_periodic, + num_cells=self.num_cells, ) if self.check_errors: assert (