diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml
index 4768e3dcd..3369c761d 100644
--- a/.github/workflows/publish.yml
+++ b/.github/workflows/publish.yml
@@ -6,77 +6,28 @@ on:
jobs:
build:
- name: Build wheels on ${{ matrix.os }}-${{ matrix.accelerator }}
- runs-on: ${{ matrix.os }}
- strategy:
- fail-fast: false
- matrix:
- os: [ubuntu-latest, ubuntu-24.04-arm, windows-2022, macos-latest]
- accelerator: [cpu, cu118, cu126] #, cu128]
- exclude:
- - os: ubuntu-24.04-arm
- accelerator: cu118
- - os: ubuntu-24.04-arm
- accelerator: cu126
- # - os: ubuntu-24.04-arm
- # accelerator: cu128
- - os: macos-latest
- accelerator: cu118
- - os: macos-latest
- accelerator: cu126
- # - os: macos-latest
- # accelerator: cu128
+ name: Create source distribution
+ runs-on: ubuntu-latest
steps:
- - name: Free space of Github Runner (otherwise it will fail by running out of disk)
- if: matrix.os == 'ubuntu-latest'
- run: |
- sudo rm -rf /usr/share/dotnet
- sudo rm -rf /opt/ghc
- sudo rm -rf "/usr/local/share/boost"
- sudo rm -rf "/usr/local/.ghcup"
- sudo rm -rf "/usr/local/julia1.9.2"
- sudo rm -rf "/usr/local/lib/android"
- sudo rm -rf "$AGENT_TOOLSDIRECTORY"
-
- - uses: actions/checkout@v4
+ - uses: actions/checkout@v5
- - uses: actions/setup-python@v5
+ - uses: actions/setup-python@v6
with:
python-version: "3.13"
- name: Install cibuildwheel
- run: python -m pip install cibuildwheel==3.1.3
-
- - name: Activate MSVC
- uses: ilammy/msvc-dev-cmd@v1
- with:
- toolset: 14.29
- if: matrix.os == 'windows-2022'
+ run: pip install build
- - name: Build wheels
- if: matrix.os != 'windows-2022'
- shell: bash
- run: python -m cibuildwheel --output-dir wheelhouse
- env:
- ACCELERATOR: ${{ matrix.accelerator }}
- CPU_TRAIN: ${{ runner.os == 'macOS' && 'true' || 'false' }}
-
- - name: Build wheels
- if: matrix.os == 'windows-2022'
- shell: cmd # Use cmd on Windows to avoid bash environment taking priority over MSVC variables
- run: python -m cibuildwheel --output-dir wheelhouse
- env:
- ACCELERATOR: ${{ matrix.accelerator }}
- DISTUTILS_USE_SDK: "1" # Windows requires this to use vc for building
- SKIP_TORCH_COMPILE: "true"
+ - name: Build pypi package
+ run: python -m build --sdist
- uses: actions/upload-artifact@v4
with:
- name: ${{ matrix.accelerator }}-cibw-wheels-${{ matrix.os }}-${{ strategy.job-index }}
- path: ./wheelhouse/*.whl
+ name: source_dist
+ path: dist/*.tar.gz
- publish-to-public-pypi:
+ publish-to-pypi:
name: >-
Publish Python 🐍 distribution 📦 to PyPI
needs:
@@ -84,18 +35,14 @@ jobs:
runs-on: ubuntu-latest
environment:
name: pypi
+ url: https://pypi.org/p/torchmd-net
permissions:
id-token: write # IMPORTANT: mandatory for trusted publishing
- strategy:
- fail-fast: false
- matrix:
- accelerator: [cpu, cu118, cu126] #, hip, cu124, cu126, cu128]
steps:
- name: Download all the dists
uses: actions/download-artifact@v4
with:
- pattern: "${{ matrix.accelerator }}-cibw-wheels*"
path: dist/
merge-multiple: true
@@ -103,69 +50,7 @@ jobs:
uses: pypa/gh-action-pypi-publish@release/v1
with:
password: ${{ secrets.TMDNET_PYPI_API_TOKEN }}
- skip-existing: true
-
- # publish-to-accelera-pypi:
- # name: >-
- # Publish Python 🐍 distribution 📦 to Acellera PyPI
- # needs:
- # - build
- # runs-on: ubuntu-latest
- # permissions: # Needed for GCP authentication
- # contents: "read"
- # id-token: "write"
- # strategy:
- # fail-fast: false
- # matrix:
- # accelerator: [cpu, cu118, cu126, cu128]
-
- # steps:
- # - uses: actions/checkout@v4 # Needed for GCP authentication for some reason
-
- # - name: Set up Cloud SDK
- # uses: google-github-actions/auth@v2
- # with:
- # workload_identity_provider: ${{ secrets.GCP_WORKLOAD_IDENTITY_PROVIDER }}
- # service_account: ${{ secrets.GCP_PYPI_SERVICE_ACCOUNT }}
-
- # - name: Download all the dists
- # uses: actions/download-artifact@v4
- # with:
- # pattern: "${{ matrix.accelerator }}-cibw-wheels*"
- # path: dist/
- # merge-multiple: true
-
- # - name: Publish distribution 📦 to Acellera PyPI
- # run: |
- # pip install build twine keyring keyrings.google-artifactregistry-auth
- # pip install -U packaging
- # twine upload --repository-url https://us-central1-python.pkg.dev/pypi-packages-455608/${{ matrix.accelerator }} dist/* --verbose --skip-existing
-
- # publish-to-official-pypi:
- # name: >-
- # Publish Python 🐍 distribution 📦 to PyPI
- # needs:
- # - build
- # runs-on: ubuntu-latest
- # environment:
- # name: pypi
- # url: https://pypi.org/p/torchmd-net
- # permissions:
- # id-token: write # IMPORTANT: mandatory for trusted publishing
-
- # steps:
- # - name: Download all the dists
- # uses: actions/download-artifact@v4
- # with:
- # pattern: "cu118-cibw-wheels*"
- # path: dist/
- # merge-multiple: true
-
- # - name: Publish distribution 📦 to PyPI
- # uses: pypa/gh-action-pypi-publish@release/v1
- # with:
- # password: ${{ secrets.TMDNET_PYPI_API_TOKEN }}
- # skip_existing: true
+ skip_existing: true
github-release:
name: >-
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index 2644663ee..bb9a2101b 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -17,93 +17,28 @@ jobs:
["ubuntu-latest", "ubuntu-22.04-arm", "macos-latest", "windows-2022"]
python-version: ["3.13"]
- defaults: # Needed for conda
- run:
- shell: bash -l {0}
-
steps:
- name: Check out
- uses: actions/checkout@v4
-
- - uses: conda-incubator/setup-miniconda@v3
- with:
- python-version: ${{ matrix.python-version }}
- channels: conda-forge
- conda-remove-defaults: "true"
- if: matrix.os != 'macos-13'
+ uses: actions/checkout@v5
- - uses: conda-incubator/setup-miniconda@v3
+ - name: Install uv
+ uses: astral-sh/setup-uv@v7
with:
python-version: ${{ matrix.python-version }}
- channels: conda-forge
- mamba-version: "*"
- conda-remove-defaults: "true"
- if: matrix.os == 'macos-13'
-
- - name: Install OS-specific compilers
- run: |
- if [[ "${{ matrix.os }}" == "ubuntu-22.04-arm" ]]; then
- conda install gxx --channel conda-forge --override-channels
- elif [[ "${{ runner.os }}" == "Linux" ]]; then
- conda install gxx --channel conda-forge --override-channels
- elif [[ "${{ runner.os }}" == "macOS" ]]; then
- conda install llvm-openmp pybind11 --channel conda-forge --override-channels
- echo "CC=clang" >> $GITHUB_ENV
- echo "CXX=clang++" >> $GITHUB_ENV
- elif [[ "${{ runner.os }}" == "Windows" ]]; then
- conda install vc vc14_runtime vs2015_runtime --channel conda-forge --override-channels
- fi
-
- - name: List the conda environment
- run: conda list
- - name: Install testing packages
- run: conda install -y -c conda-forge flake8 pytest psutil python-build
+ - name: Install the project
+ run: uv sync --all-extras --dev
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
- flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
+ uv run flake8 ./torchmdnet --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
- flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
-
- - name: Set pytorch index
- run: |
- if [[ "${{ runner.os }}" == "Windows" ]]; then
- mkdir -p "C:\ProgramData\pip"
- echo "[global]
- extra-index-url = https://download.pytorch.org/whl/cpu" > "C:\ProgramData\pip\pip.ini"
- else
- mkdir -p $HOME/.config/pip
- echo "[global]
- extra-index-url = https://download.pytorch.org/whl/cpu" > $HOME/.config/pip/pip.conf
- fi
-
- - name: Build and install the package
- run: |
- if [[ "${{ runner.os }}" == "Windows" ]]; then
- export LIB="C:/Miniconda/envs/test/Library/lib"
- fi
- python -m build
- pip install dist/*.whl
- env:
- ACCELERATOR: "cpu"
-
- # - name: Install nnpops
- # if: matrix.os == 'ubuntu-latest' || matrix.os == 'macos-latest'
- # run: conda install nnpops --channel conda-forge --override-channels
-
- - name: List the conda environment
- run: conda list
+ uv run flake8 ./torchmdnet --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Run tests
- run: pytest -v -s --durations=10
- env:
- ACCELERATOR: "cpu"
- SKIP_TORCH_COMPILE: ${{ runner.os == 'Windows' && 'true' || 'false' }}
- OMP_PREFIX: ${{ runner.os == 'macOS' && '/Users/runner/miniconda3/envs/test' || '' }}
- CPU_TRAIN: ${{ runner.os == 'macOS' && 'true' || 'false' }}
- LONG_TRAIN: "true"
+ # For example, using `pytest`
+ run: uv run pytest tests
- name: Test torchmd-train utility
- run: torchmd-train --help
+ run: uv run torchmd-train --help
diff --git a/README.md b/README.md
index 0d1be37cb..0e3de23c7 100644
--- a/README.md
+++ b/README.md
@@ -21,19 +21,17 @@ Documentation is available at https://torchmd-net.readthedocs.io
## Installation
-TorchMD-Net is available as a pip installable wheel as well as in [conda-forge](https://conda-forge.org/)
+TorchMD-Net is available as a pip package as well as in [conda-forge](https://conda-forge.org/)
-TorchMD-Net provides builds for CPU-only, CUDA 11 and CUDA 12. CPU versions are only provided as reference,
-as the performance will be extremely limited.
-Depending on which variant you wish to install, you can install it with one of the following commands:
+As TorchMD-Net depends on PyTorch we need to add additional index URLs to the installation command as per [pytorch](https://pytorch.org/get-started/locally/)
```sh
-# The following will install the CUDA 11.8 version
-pip install torchmd-net-cu11 --extra-index-url https://download.pytorch.org/whl/cu118
-# The following will install the CUDA 12.4 version
-pip install torchmd-net-cu12 --extra-index-url https://download.pytorch.org/whl/cu124
-# The following will install the CPU only version (not recommended)
-pip install torchmd-net-cpu --extra-index-url https://download.pytorch.org/whl/cpu
+# The following will install TorchMD-Net with PyTorch CUDA 11.8 version
+pip install torchmd-net --extra-index-url https://download.pytorch.org/whl/cu118
+# The following will install TorchMD-Net with PyTorch CUDA 12.4 version
+pip install torchmd-net --extra-index-url https://download.pytorch.org/whl/cu124
+# The following will install TorchMD-Net with PyTorch CPU only version (not recommended)
+pip install torchmd-net --extra-index-url https://download.pytorch.org/whl/cpu
```
Alternatively it can be installed with conda or mamba with one of the following commands.
@@ -46,7 +44,7 @@ mamba install torchmd-net cuda-version=12.4
### Install from source
-TorchMD-Net is installed using pip, but you will need to install some dependencies before. Check [this documentation page](https://torchmd-net.readthedocs.io/en/latest/installation.html#install-from-source).
+TorchMD-Net is installed using pip with `pip install -e .` to create an editable install.
## Usage
Specifying training arguments can either be done via a configuration yaml file or through command line arguments directly. Several examples of architectural and training specifications for some models and datasets can be found in [examples/](https://github.com/torchmd/torchmd-net/tree/main/examples). Note that if a parameter is present both in the yaml file and the command line, the command line version takes precedence.
diff --git a/benchmarks/neighbors.py b/benchmarks/neighbors.py
index 2db969d92..58209c4d7 100644
--- a/benchmarks/neighbors.py
+++ b/benchmarks/neighbors.py
@@ -178,17 +178,16 @@ def benchmark_neighbors(
if __name__ == "__main__":
+ strategies = ["distance", "brute", "cell"]
n_particles = 32767
mean_num_neighbors = min(n_particles, 64)
density = 0.8
print(
- "Benchmarking neighbor list generation for {} particles with {} neighbors on average".format(
- n_particles, mean_num_neighbors
- )
+ f"Benchmarking neighbor list generation for {n_particles} particles with {mean_num_neighbors} neighbors on average"
)
results = {}
batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
- for strategy in ["shared", "brute", "cell", "distance"]:
+ for strategy in strategies:
for n_batches in batch_sizes:
time = benchmark_neighbors(
device="cuda",
@@ -200,33 +199,22 @@ def benchmark_neighbors(
)
# Store results in a dictionary
results[strategy, n_batches] = time
+
print("Summary")
print("-------")
- print(
- "{:<10} {:<21} {:<21} {:<18} {:<10}".format(
- "Batch size", "Shared(ms)", "Brute(ms)", "Cell(ms)", "Distance(ms)"
- )
- )
- print(
- "{:<10} {:<21} {:<21} {:<18} {:<10}".format(
- "----------", "---------", "---------", "---------", "---------"
- )
- )
+ headers = "Batch size "
+ for strategy in strategies:
+ headers += f"{strategy.capitalize()+'(ms)':<21}"
+ print(headers)
+ print("-" * len(headers))
# Print a column per strategy, show speedup over Distance in parenthesis
for n_batches in batch_sizes:
base = results["distance", n_batches]
- print(
- "{:<10} {:<4.2f} x{:<14.2f} {:<4.2f} x{:<14.2f} {:<4.2f} x{:<14.2f} {:<10.2f}".format(
- n_batches,
- results["shared", n_batches],
- base / results["shared", n_batches],
- results["brute", n_batches],
- base / results["brute", n_batches],
- results["cell", n_batches],
- base / results["cell", n_batches],
- results["distance", n_batches],
- )
- )
+ row = f"{n_batches:<11}"
+ for strategy in strategies:
+ row += f"{results[strategy, n_batches]:<5.2f} x{base / results[strategy, n_batches]:<13.2f} "
+ print(row)
+
n_particles_list = np.power(2, np.arange(8, 18))
for n_batches in [1, 2, 32, 64]:
@@ -236,7 +224,7 @@ def benchmark_neighbors(
)
)
results = {}
- for strategy in ["shared", "brute", "cell", "distance"]:
+ for strategy in strategies:
for n_particles in n_particles_list:
mean_num_neighbors = min(n_particles, 64)
time = benchmark_neighbors(
@@ -251,32 +239,19 @@ def benchmark_neighbors(
results[strategy, n_particles] = time
print("Summary")
print("-------")
- print(
- "{:<10} {:<21} {:<21} {:<18} {:<10}".format(
- "N Particles", "Shared(ms)", "Brute(ms)", "Cell(ms)", "Distance(ms)"
- )
- )
- print(
- "{:<10} {:<21} {:<21} {:<18} {:<10}".format(
- "----------", "---------", "---------", "---------", "---------"
- )
- )
+ headers = "N Particles "
+ for strategy in strategies:
+ headers += f"{strategy.capitalize()+'(ms)':<21}"
+ print(headers)
+ print("-" * len(headers))
# Print a column per strategy, show speedup over Distance in parenthesis
for n_particles in n_particles_list:
base = results["distance", n_particles]
brute_speedup = base / results["brute", n_particles]
- if n_particles > 32000:
- results["brute", n_particles] = 0
- brute_speedup = 0
- print(
- "{:<10} {:<4.2f} x{:<14.2f} {:<4.2f} x{:<14.2f} {:<4.2f} x{:<14.2f} {:<10.2f}".format(
- n_particles,
- results["shared", n_particles],
- base / results["shared", n_particles],
- results["brute", n_particles],
- brute_speedup,
- results["cell", n_particles],
- base / results["cell", n_particles],
- results["distance", n_particles],
- )
- )
+ # if n_particles > 32000:
+ # results["brute", n_particles] = 0
+ # brute_speedup = 0
+ row = f"{n_particles:<12}"
+ for strategy in strategies:
+ row += f"{results[strategy, n_particles]:<5.2f} x{base / results[strategy, n_particles]:<13.2f} "
+ print(row)
diff --git a/cibuildwheel_support/before_all_linux.sh b/cibuildwheel_support/before_all_linux.sh
deleted file mode 100755
index 12c814b31..000000000
--- a/cibuildwheel_support/before_all_linux.sh
+++ /dev/null
@@ -1,81 +0,0 @@
-#! /bin/bash
-
-set -e
-set -x
-
-# Configure pip to use PyTorch extra-index-url for CPU
-mkdir -p $HOME/.config/pip
-echo "[global]
-extra-index-url = https://download.pytorch.org/whl/cpu" > $HOME/.config/pip/pip.conf
-
-
-if [ "$ACCELERATOR" == "cu118" ]; then
-
- # Install CUDA 11.8:
- dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
-
- dnf install --setopt=obsoletes=0 -y \
- cuda-nvcc-11-8-11.8.89-1 \
- cuda-cudart-devel-11-8-11.8.89-1 \
- libcurand-devel-11-8-10.3.0.86-1 \
- libcudnn9-devel-cuda-11-9.8.0.87-1 \
- libcublas-devel-11-8-11.11.3.6-1 \
- libnccl-devel-2.15.5-1+cuda11.8 \
- libcusparse-devel-11-8-11.7.5.86-1 \
- libcusolver-devel-11-8-11.4.1.48-1 \
- gcc-toolset-11
-
- ln -s cuda-11.8 /usr/local/cuda
- ln -s /opt/rh/gcc-toolset-11/root/usr/bin/gcc /usr/local/cuda/bin/gcc
- ln -s /opt/rh/gcc-toolset-11/root/usr/bin/g++ /usr/local/cuda/bin/g++
-
- export CUDA_HOME="/usr/local/cuda"
-
- # Configure pip to use PyTorch extra-index-url for CUDA 11.8
- mkdir -p $HOME/.config/pip
- echo "[global]
-extra-index-url = https://download.pytorch.org/whl/cu118" > $HOME/.config/pip/pip.conf
-
-elif [ "$ACCELERATOR" == "cu126" ]; then
- # Install CUDA 12.6
- dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
-
- dnf install --setopt=obsoletes=0 -y \
- cuda-compiler-12-6-12.6.3-1 \
- cuda-libraries-12-6-12.6.3-1 \
- cuda-libraries-devel-12-6-12.6.3-1 \
- cuda-toolkit-12-6-12.6.3-1 \
- gcc-toolset-13
-
- ln -s cuda-12.6 /usr/local/cuda
- ln -s /opt/rh/gcc-toolset-13/root/usr/bin/gcc /usr/local/cuda/bin/gcc
- ln -s /opt/rh/gcc-toolset-13/root/usr/bin/g++ /usr/local/cuda/bin/g++
- ln -s /usr/local/cuda/targets/x86_64-linux/lib/stubs/libcuda.so /usr/lib/libcuda.so.1
-
- # Configure pip to use PyTorch extra-index-url for CUDA 12.6
- mkdir -p $HOME/.config/pip
- echo "[global]
-extra-index-url = https://download.pytorch.org/whl/cu126" > $HOME/.config/pip/pip.conf
-
-elif [ "$ACCELERATOR" == "cu128" ]; then
- # Install CUDA 12.8
- dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
-
- dnf install --setopt=obsoletes=0 -y \
- cuda-compiler-12-8-12.8.1-1 \
- cuda-libraries-12-8-12.8.1-1 \
- cuda-libraries-devel-12-8-12.8.1-1 \
- cuda-toolkit-12-8-12.8.1-1 \
- gcc-toolset-13
-
- ln -s cuda-12.8 /usr/local/cuda
- ln -s /opt/rh/gcc-toolset-13/root/usr/bin/gcc /usr/local/cuda/bin/gcc
- ln -s /opt/rh/gcc-toolset-13/root/usr/bin/g++ /usr/local/cuda/bin/g++
- ln -s /usr/local/cuda/targets/x86_64-linux/lib/stubs/libcuda.so /usr/lib/libcuda.so.1
-
- # Configure pip to use PyTorch extra-index-url for CUDA 12.8
- mkdir -p $HOME/.config/pip
- echo "[global]
-extra-index-url = https://download.pytorch.org/whl/cu128" > $HOME/.config/pip/pip.conf
-
-fi
\ No newline at end of file
diff --git a/cibuildwheel_support/before_all_windows.sh b/cibuildwheel_support/before_all_windows.sh
deleted file mode 100755
index 9cf789159..000000000
--- a/cibuildwheel_support/before_all_windows.sh
+++ /dev/null
@@ -1,37 +0,0 @@
-#!/bin/bash
-
-set -e
-set -x
-
-# Create pip directory if it doesn't exist
-mkdir -p "C:\ProgramData\pip"
-
-# Create pip.ini file with PyTorch CPU index
-echo "[global]
-extra-index-url = https://download.pytorch.org/whl/cpu" > "C:\ProgramData\pip\pip.ini"
-
-if [ "$ACCELERATOR" == "cu118" ]; then
- curl --netrc-optional -L -nv -o cuda.exe https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_522.06_windows.exe
- ./cuda.exe -s nvcc_11.8 cudart_11.8 cublas_dev_11.8 curand_dev_11.8 cusparse_dev_11.8 cusolver_dev_11.8 thrust_11.8
- rm cuda.exe
-
- # Create pip.ini file with PyTorch CUDA 11.8 index
- echo "[global]
-extra-index-url = https://download.pytorch.org/whl/cu118" > "C:\ProgramData\pip\pip.ini"
-elif [ "$ACCELERATOR" == "cu126" ]; then
- curl --netrc-optional -L -nv -o cuda.exe https://developer.download.nvidia.com/compute/cuda/12.6.0/local_installers/cuda_12.6.0_560.76_windows.exe
- ./cuda.exe -s nvcc_12.6 cudart_12.6 cublas_dev_12.6 curand_dev_12.6 cusparse_dev_12.6 cusolver_dev_12.6 thrust_12.6
- rm cuda.exe
-
- # Create pip.ini file with PyTorch CUDA 12.6 index
- echo "[global]
-extra-index-url = https://download.pytorch.org/whl/cu126" > "C:\ProgramData\pip\pip.ini"
-elif [ "$ACCELERATOR" == "cu128" ]; then
- curl --netrc-optional -L -nv -o cuda.exe https://developer.download.nvidia.com/compute/cuda/12.8.1/local_installers/cuda_12.8.1_572.61_windows.exe
- ./cuda.exe -s nvcc_12.8 cudart_12.8 cublas_dev_12.8 curand_dev_12.8 cusparse_dev_12.8 cusolver_dev_12.8 thrust_12.8
- rm cuda.exe
-
- # Create pip.ini file with PyTorch CUDA 12.8 index
- echo "[global]
-extra-index-url = https://download.pytorch.org/whl/cu128" > "C:\ProgramData\pip\pip.ini"
-fi
\ No newline at end of file
diff --git a/docs/source/installation.rst b/docs/source/installation.rst
index 769bce14c..6e01ca096 100644
--- a/docs/source/installation.rst
+++ b/docs/source/installation.rst
@@ -1,20 +1,18 @@
Installation
============
-TorchMD-Net is available as a pip installable wheel as well as in `conda-forge `_
+TorchMD-Net is available as a pip package as well as in `conda-forge `_
-TorchMD-Net provides builds for CPU-only, CUDA 11 and CUDA 12.
-CPU versions are only provided as reference, as the performance will be extremely limited.
-Depending on which variant you wish to install, you can install it with one of the following commands:
+As TorchMD-Net depends on PyTorch we need to add additional index URLs to the installation command as per `pytorch `_
.. code-block:: shell
- # The following will install the CUDA 11.8 version
- pip install torchmd-net-cu11 --extra-index-url https://download.pytorch.org/whl/cu118
- # The following will install the CUDA 12.4 version
- pip install torchmd-net-cu12 --extra-index-url https://download.pytorch.org/whl/cu124
- # The following will install the CPU only version (not recommended)
- pip install torchmd-net-cpu --extra-index-url https://download.pytorch.org/whl/cpu
+ # The following will install TorchMD-Net with PyTorch CUDA 11.8 version
+ pip install torchmd-net --extra-index-url https://download.pytorch.org/whl/cu118
+ # The following will install TorchMD-Net with PyTorch CUDA 12.4 version
+ pip install torchmd-net --extra-index-url https://download.pytorch.org/whl/cu124
+ # The following will install TorchMD-Net with PyTorch CPU only version (not recommended)
+ pip install torchmd-net --extra-index-url https://download.pytorch.org/whl/cpu
Alternatively it can be installed with conda or mamba with one of the following commands.
We recommend using `Miniforge `_ instead of anaconda.
@@ -27,70 +25,14 @@ We recommend using `Miniforge `_ inst
Install from source
-------------------
+For development purposes, we recommend using `uv `_ to install the TorchMD-Net in editable mode.
+After installing uv, run the following command to install the TorchMD-Net and its dependencies in editable mode.
-1. Clone the repository:
-
-.. code-block:: shell
-
- git clone https://github.com/torchmd/torchmd-net.git
- cd torchmd-net
-
-2. Install the dependencies in environment.yml.
-
-.. code-block:: shell
-
- conda env create -f environment.yml
- conda activate torchmd-net
-
-3. CUDA enabled installation
-
-You can skip this section if you only need a CPU installation.
-
-You will need the CUDA compiler (nvcc) and the corresponding development libraries to build TorchMD-Net with CUDA support. You can install CUDA from the `official NVIDIA channel `_ or from conda-forge.
-
-The conda-forge channel `changed the way to install CUDA from versions 12 and above `_, thus the following instructions depend on whether you need CUDA < 12. If you have a GPU available, conda-forge probably installed the CUDA runtime (not the developer tools) on your system already, you can check with conda:
-
-.. code-block:: shell
-
- conda list | grep cuda
-
-
-Or by asking pytorch:
-
.. code-block:: shell
-
- python -c "import torch; print(torch.version.cuda)"
-
-
-It is recommended to install the same version as the one used by torch.
-
-.. warning:: At the time of writing there is a `bug in Mamba `_ (v1.5.6) that can cause trouble when installing CUDA on an already created environment. We thus recommend conda for this step.
-
-* CUDA>=12
-
-.. code-block:: shell
-
- conda install -c conda-forge python=3.10 cuda-version=12.6 cuda-nvvm cuda-nvcc cuda-libraries-dev
-
-
-* CUDA<12
-
-The nvidia channel provides the developer tools for CUDA<12.
-
-.. code-block:: shell
-
- conda install -c nvidia "cuda-nvcc<12" "cuda-libraries-dev<12" "cuda-version<12" "gxx<12" pytorch=*=*cuda*
-
-
-4. Install TorchMD-NET into the environment:
-
-.. code-block:: shell
-
- pip install -e .
+ uv sync
-.. note:: Pip installation in CUDA mode requires compiling CUDA source codes, this can take a really long time and the process might appear as if it is "stuck". Run pip with `-vv` to see the compilation process.
+This will install the package alongside PyTorch for CUDA 12.6. To change the CUDA version,
+edit the `pyproject.toml `_ file and change the `torch` dependency to the desired version.
-This will install TorchMD-NET in editable mode, so that changes to the source code are immediately available.
-Besides making all python utilities available environment-wide, this will also install the ``torchmd-train`` command line utility.
diff --git a/environment.yml b/environment.yml
deleted file mode 100644
index a1b851e38..000000000
--- a/environment.yml
+++ /dev/null
@@ -1,19 +0,0 @@
-name: torchmd-net
-channels:
- - conda-forge
-dependencies:
- - h5py
- - nnpops
- - pip
- - pytorch
- - pytorch_geometric
- - lightning
- - torchmetrics
- - tqdm
- # Dev tools
- - flake8
- - pytest
- - psutil
- - gxx
- # optional
- - ase
diff --git a/examples/aceff_examples/ase_aceff.py b/examples/aceff_examples/ase_aceff.py
index 9c9ceaf15..93e57c3cf 100644
--- a/examples/aceff_examples/ase_aceff.py
+++ b/examples/aceff_examples/ase_aceff.py
@@ -17,7 +17,7 @@
# We create the ASE calculator by supplying the path to the model and specifying the device and dtype
calc = TMDNETCalculator(model_file_path, device='cuda')
-atoms = read('caffiene.pdb')
+atoms = read('caffeine.pdb')
print(atoms)
atoms.calc = calc
@@ -81,9 +81,6 @@
atoms.calc = calc
-# Single point calcuation to trigger compile
-atoms.get_potential_energy()
-
# Run more dynamics
t1 = time.perf_counter()
dyn.run(steps=nsteps)
diff --git a/pyproject.toml b/pyproject.toml
index 0092a3cbc..5939d1b95 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,15 +1,27 @@
[project]
-name = "PLACEHOLDER"
+name = "torchmd-net"
description = "TorchMD-NET provides state-of-the-art neural networks potentials for biomolecular systems"
authors = [{ name = "Acellera", email = "info@acellera.com" }]
readme = "README.md"
license = "MIT"
-requires-python = ">=3.8"
-dynamic = ["version", "dependencies"]
+requires-python = ">=3.9"
+dynamic = ["version"]
classifiers = [
"Programming Language :: Python :: 3",
"Operating System :: POSIX :: Linux",
]
+dependencies = [
+ "h5py",
+ # "nnpops",
+ "torch",
+ "torch_geometric",
+ "lightning",
+ "tqdm",
+ "numpy",
+ "triton; sys_platform == 'linux' and platform_machine != 'aarch64'",
+ "triton-windows; sys_platform == 'win32'",
+ "ase",
+]
[project.urls]
"Homepage" = "https://github.com/torchmd/torchmd-net"
@@ -27,40 +39,21 @@ include = ["torchmdnet*"]
"*" = ["*.c", "*.cpp", "*.h", "*.cuh", "*.cu", ".gitignore"]
[build-system]
-requires = ["setuptools>=78", "setuptools-scm>=8", "torch==2.7.1", "numpy"]
+requires = ["setuptools>=78", "setuptools-scm>=8"]
build-backend = "setuptools.build_meta"
+[dependency-groups]
+dev = ["flake8>=7.3.0", "ipython>=8.18.1", "pytest>=8.4.2"]
-[tool.cibuildwheel]
-# Disable builds which can't support CUDA and pytorch
-skip = [
- "cp38-*",
- "cp314-*",
- "cp314t-*",
- "pp*",
- "*win32",
- "*armv7l",
- "*_i686",
- "*_ppc64le",
- "*_s390x",
- "*_universal2",
- "*-musllinux_*",
-]
-test-requires = ["pytest", "pytest-xdist"]
-test-command = "pytest {project}/tests"
-environment-pass = ["ACCELERATOR", "CI"]
-# container-engine = "docker; create_args: --gpus all"
-
-[tool.cibuildwheel.linux]
-before-all = "bash {project}/cibuildwheel_support/before_all_linux.sh"
-repair-wheel-command = [
- "auditwheel repair --exclude libcudart.so.* --exclude libc10.so --exclude libc10_cuda.so --exclude libtorch.so --exclude libtorch_cuda.so --exclude libtorch_cpu.so --exclude libtorch_python.so -w {dest_dir} {wheel}",
+[tool.uv.sources]
+torch = [
+ { index = "pytorch-cu126", marker = "(sys_platform == 'linux' or sys_platform == 'win32') and platform_machine != 'aarch64' and platform_machine != 'arm64'" },
]
-
-[tool.cibuildwheel.macos]
-repair-wheel-command = [
- "delocate-wheel --ignore-missing-dependencies --require-archs {delocate_archs} -w {dest_dir} -v {wheel}",
+torchvision = [
+ { index = "pytorch-cu126", marker = "(sys_platform == 'linux' or sys_platform == 'win32') and platform_machine != 'aarch64' and platform_machine != 'arm64'" },
]
-[tool.cibuildwheel.windows]
-before-all = "bash {project}/cibuildwheel_support/before_all_windows.sh"
+[[tool.uv.index]]
+name = "pytorch-cu126"
+url = "https://download.pytorch.org/whl/cu126"
+explicit = true
diff --git a/setup.py b/setup.py
deleted file mode 100644
index b2ee351e9..000000000
--- a/setup.py
+++ /dev/null
@@ -1,113 +0,0 @@
-# Copyright Universitat Pompeu Fabra 2020-2023 https://www.compscience.org
-# Distributed under the MIT License.
-# (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT)
-
-from setuptools import setup
-import torch
-from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension
-import os
-import platform
-
-
-if os.environ.get("ACCELERATOR", None) is not None:
- use_cuda = os.environ.get("ACCELERATOR", "").startswith("cu")
-else:
- use_cuda = torch.cuda._is_compiled()
-
-
-def _replace_name(name):
- import pathlib
-
- pyproject_path = pathlib.Path(__file__).parent / "pyproject.toml"
- with open(pyproject_path, "r") as f:
- pyproject_text = f.read()
- pyproject_text = pyproject_text.replace("PLACEHOLDER", name)
- with open(pyproject_path, "w") as f:
- f.write(pyproject_text)
-
-
-if os.getenv("ACCELERATOR", "").startswith("cpu"):
- _replace_name("torchmd-net-cpu")
-if os.getenv("ACCELERATOR", "").startswith("cu"):
- cuda_ver = os.getenv("ACCELERATOR", "")[2:4]
- _replace_name(f"torchmd-net-cu{cuda_ver}")
-
-
-def set_torch_cuda_arch_list():
- """Set the CUDA arch list according to the architectures the current torch installation was compiled for.
- This function is a no-op if the environment variable TORCH_CUDA_ARCH_LIST is already set or if torch was not compiled with CUDA support.
- """
- if use_cuda and not os.environ.get("TORCH_CUDA_ARCH_LIST"):
- arch_flags = torch._C._cuda_getArchFlags()
- sm_versions = [x[3:] for x in arch_flags.split() if x.startswith("sm_")]
- formatted_versions = ";".join([f"{y[:-1]}.{y[-1]}" for y in sm_versions])
- formatted_versions += "+PTX"
- os.environ["TORCH_CUDA_ARCH_LIST"] = formatted_versions
-
-
-set_torch_cuda_arch_list()
-
-extension_root = os.path.join("torchmdnet", "extensions")
-neighbor_sources = ["neighbors_cpu.cpp"]
-if use_cuda:
- neighbor_sources.append("neighbors_cuda.cu")
-neighbor_sources = [
- os.path.join(extension_root, "neighbors", source) for source in neighbor_sources
-]
-
-runtime_library_dirs = None
-if platform.system() == "Darwin":
- runtime_library_dirs = [
- "@loader_path/../../torch/lib",
- "@loader_path/../../nvidia/cuda_runtime/lib",
- ]
-elif platform.system() == "Linux":
- runtime_library_dirs = [
- "$ORIGIN/../../torch/lib",
- "$ORIGIN/../../nvidia/cuda_runtime/lib",
- ]
-
-extra_deps = []
-if os.getenv("ACCELERATOR", "").startswith("cu"):
- cuda_ver = os.getenv("ACCELERATOR")[2:4]
- extra_deps = [f"nvidia-cuda-runtime-cu{cuda_ver}"]
-
-ExtensionType = CppExtension if not use_cuda else CUDAExtension
-extensions = ExtensionType(
- name="torchmdnet.extensions.torchmdnet_extensions",
- sources=[os.path.join(extension_root, "torchmdnet_extensions.cpp")]
- + neighbor_sources,
- define_macros=[("WITH_CUDA", 1)] if use_cuda else [],
- runtime_library_dirs=runtime_library_dirs,
-)
-
-kwargs = {}
-if "CI" in os.environ:
- from setuptools_scm import get_version
-
- # Drop the dev version suffix because we modify pyproject.toml
- # We do this only in CI because we need to upload to PyPI
-
- kwargs = {"version": ".".join(get_version().split(".")[:3])}
-
-if __name__ == "__main__":
- setup(
- ext_modules=[extensions],
- cmdclass={
- "build_ext": BuildExtension.with_options(
- no_python_abi_suffix=True, use_ninja=False
- )
- },
- install_requires=[
- "h5py",
- # "nnpops",
- "torch==2.7.1",
- "torch_geometric",
- "lightning",
- "tqdm",
- "numpy",
- "ase",
- ]
- + extra_deps,
- **kwargs,
- )
diff --git a/tests/caffeine.pdb b/tests/caffeine.pdb
new file mode 100644
index 000000000..080ed8f97
--- /dev/null
+++ b/tests/caffeine.pdb
@@ -0,0 +1,51 @@
+MODEL 1
+HETATM 1 O UNL 1 0.470 2.569 0.001 1.00 0.00 O
+HETATM 2 O UNL 2 -3.127 -0.444 -0.000 1.00 0.00 O
+HETATM 3 N UNL 3 -0.969 -1.312 0.000 1.00 0.00 N
+HETATM 4 N UNL 4 2.218 0.141 -0.000 1.00 0.00 N
+HETATM 5 N UNL 5 -1.348 1.080 -0.000 1.00 0.00 N
+HETATM 6 N UNL 6 1.412 -1.937 0.000 1.00 0.00 N
+HETATM 7 C UNL 7 0.858 0.259 -0.001 1.00 0.00 C
+HETATM 8 C UNL 8 0.390 -1.026 -0.000 1.00 0.00 C
+HETATM 9 C UNL 9 0.031 1.422 -0.001 1.00 0.00 C
+HETATM 10 C UNL 10 -1.906 -0.249 -0.000 1.00 0.00 C
+HETATM 11 C UNL 11 2.503 -1.200 0.000 1.00 0.00 C
+HETATM 12 C UNL 12 -1.428 -2.696 0.001 1.00 0.00 C
+HETATM 13 C UNL 13 3.193 1.206 0.000 1.00 0.00 C
+HETATM 14 C UNL 14 -2.297 2.188 0.001 1.00 0.00 C
+HETATM 15 H UNL 15 3.516 -1.579 0.001 1.00 0.00 H
+HETATM 16 H UNL 16 -1.045 -3.197 -0.894 1.00 0.00 H
+HETATM 17 H UNL 17 -2.519 -2.760 0.001 1.00 0.00 H
+HETATM 18 H UNL 18 -1.045 -3.196 0.896 1.00 0.00 H
+HETATM 19 H UNL 19 4.199 0.780 0.000 1.00 0.00 H
+HETATM 20 H UNL 20 3.047 1.809 -0.899 1.00 0.00 H
+HETATM 21 H UNL 21 3.047 1.808 0.900 1.00 0.00 H
+HETATM 22 H UNL 22 -1.809 3.165 -0.000 1.00 0.00 H
+HETATM 23 H UNL 23 -2.932 2.103 0.888 1.00 0.00 H
+HETATM 24 H UNL 24 -2.935 2.102 -0.885 1.00 0.00 H
+CONECT 1 9
+CONECT 2 10
+CONECT 3 8 10 12
+CONECT 4 7 11 13
+CONECT 5 9 10 14
+CONECT 6 8 11
+CONECT 7 4 8 9
+CONECT 8 3 6 7
+CONECT 9 1 5 7
+CONECT 10 2 3 5
+CONECT 11 4 6 15
+CONECT 12 3 16 17 18
+CONECT 13 4 19 20 21
+CONECT 14 5 22 23 24
+CONECT 15 11
+CONECT 16 12
+CONECT 17 12
+CONECT 18 12
+CONECT 19 13
+CONECT 20 13
+CONECT 21 13
+CONECT 22 14
+CONECT 23 14
+CONECT 24 14
+MASTER 0 0 0 0 0 0 0 0 24 0 24 0
+END
diff --git a/tests/example_tensornet.ckpt b/tests/example_tensornet.ckpt
new file mode 100644
index 000000000..2c7f521ac
Binary files /dev/null and b/tests/example_tensornet.ckpt differ
diff --git a/tests/test_calculator.py b/tests/test_calculator.py
index 90d552060..08b53f9c7 100644
--- a/tests/test_calculator.py
+++ b/tests/test_calculator.py
@@ -81,3 +81,76 @@ def test_compare_forward_multiple():
assert_close(e_calc, e_pred)
assert_close(f_calc, f_pred.view(-1, len(z1), 3), rtol=3e-4, atol=1e-5)
+
+
+@pytest.mark.parametrize("device", ["cpu", "cuda"])
+def test_ase_calculator(device):
+ import platform
+
+ if device == "cuda" and not torch.cuda.is_available():
+ pytest.skip("CUDA not available")
+ # Skip on Windows CI for now because we don't have cl.exe installed for torch.compile
+ if platform.system() == "Windows":
+ pytest.skip("Skipping test on Windows")
+
+ from torchmdnet.calculators import TMDNETCalculator
+ from ase.io import read
+ from ase import units
+ from ase.md.langevin import Langevin
+ import os
+ import numpy as np
+
+ ref_forces = np.array(
+ [
+ [-7.76999056e-01, -6.89736724e-01, -9.31625906e-03],
+ [2.16895628e00, 5.29322922e-01, 9.77647374e-04],
+ [-6.50327325e-01, 1.11337602e00, 2.27598846e-03],
+ [1.19350255e00, -1.23314810e00, -7.83117674e-03],
+ [5.11314452e-01, -3.33878160e-01, -5.03035402e-03],
+ [-7.18148768e-01, 7.37230778e-02, 3.08941817e-03],
+ [-1.25317931e-01, -5.19263268e-01, 2.00758013e-03],
+ [-3.05806249e-01, -4.49415118e-01, -8.53991229e-03],
+ [5.10734320e-03, 2.49908626e-01, 2.04431713e-02],
+ [-3.65967184e-01, -1.57078415e-01, -1.55984145e-03],
+ [-6.44133329e-01, 1.16345167e00, 4.54566162e-03],
+ [2.05828249e-02, -2.64510632e-01, -1.38899162e-02],
+ [1.73451304e-02, 3.65104795e-01, 1.13833081e-02],
+ [-8.57830405e-01, -2.25283504e-01, -2.49589253e-02],
+ [-1.56955227e-01, 1.19012646e-01, -1.87584094e-03],
+ [-1.50042176e-02, 1.75106078e-02, 2.51995742e-01],
+ [3.01239967e-01, 3.67318511e-01, 4.64916229e-06],
+ [-9.57870483e-03, 1.21697336e-02, -2.39765823e-01],
+ [-2.48186022e-01, 2.74000764e-02, -1.08634552e-03],
+ [1.26295090e-01, 1.04473650e-01, 2.81187654e-01],
+ [1.28753006e-01, 1.03064716e-01, -2.88918495e-01],
+ [2.80321002e-01, -5.11180341e-01, -1.12308562e-03],
+ [5.06305993e-02, 6.65888190e-02, -2.11322665e-01],
+ [7.02065229e-02, 7.10679889e-02, 2.37307906e-01],
+ ]
+ )
+
+ curr_dir = os.path.dirname(__file__)
+
+ checkpoint = join(curr_dir, "example_tensornet.ckpt")
+ calc = TMDNETCalculator(checkpoint, device=device)
+
+ atoms = read(join(curr_dir, "caffeine.pdb"))
+ atoms.calc = calc
+ # The total molecular charge must be set
+ atoms.info["charge"] = 0
+ assert np.allclose(atoms.get_potential_energy(), -113.6652, atol=1e-4)
+ assert np.allclose(atoms.get_forces(), ref_forces, atol=1e-4)
+
+ # Molecular dynamics
+ temperature_K: float = 300
+ timestep: float = 1.0 * units.fs
+ friction: float = 0.01 / units.fs
+ nsteps: int = 10
+ dyn = Langevin(atoms, timestep, temperature_K=temperature_K, friction=friction)
+ dyn.run(steps=nsteps)
+
+ # Now we can do the same but enabling torch.compile for increased speed
+ calc = TMDNETCalculator(checkpoint, device=device, compile=True)
+ atoms.calc = calc
+ # Run more dynamics
+ dyn.run(steps=nsteps)
diff --git a/tests/test_datasets.py b/tests/test_datasets.py
index 7402f8f31..736bc8a28 100644
--- a/tests/test_datasets.py
+++ b/tests/test_datasets.py
@@ -84,19 +84,17 @@ def test_custom(energy, forces, num_files, preload, tmpdir):
# Assert shapes of whole dataset:
for i in range(len(data)):
n_atoms_i = n_atoms_per_sample[i]
- assert np.array(data[i].z).shape == (
+ assert data[i].z.shape == (
n_atoms_i,
), "Dataset has incorrect atom numbers shape"
- assert np.array(data[i].pos).shape == (
+ assert data[i].pos.shape == (
n_atoms_i,
3,
), "Dataset has incorrect coords shape"
if energy:
- assert np.array(data[i].y).shape == (
- 1,
- ), "Dataset has incorrect energy shape"
+ assert data[i].y.shape == (1,), "Dataset has incorrect energy shape"
if forces:
- assert np.array(data[i].neg_dy).shape == (
+ assert data[i].neg_dy.shape == (
n_atoms_i,
3,
), "Dataset has incorrect forces shape"
@@ -190,19 +188,17 @@ def test_hdf5(preload, energy, forces, num_files, tmpdir):
# Assert shapes of whole dataset:
for i in range(len(data)):
n_atoms_i = n_atoms_per_sample[i]
- assert np.array(data[i].z).shape == (
+ assert data[i].z.shape == (
n_atoms_i,
), "Dataset has incorrect atom numbers shape"
- assert np.array(data[i].pos).shape == (
+ assert data[i].pos.shape == (
n_atoms_i,
3,
), "Dataset has incorrect coords shape"
if energy:
- assert np.array(data[i].y).shape == (
- 1,
- ), "Dataset has incorrect energy shape"
+ assert data[i].y.shape == (1,), "Dataset has incorrect energy shape"
if forces:
- assert np.array(data[i].neg_dy).shape == (
+ assert data[i].neg_dy.shape == (
n_atoms_i,
3,
), "Dataset has incorrect forces shape"
diff --git a/tests/test_model.py b/tests/test_model.py
index 323c0fa67..65ca3863e 100644
--- a/tests/test_model.py
+++ b/tests/test_model.py
@@ -114,6 +114,40 @@ def test_torchscript_dynamic_shapes(model_name, device):
)[0]
+@mark.parametrize("model_name", models.__all_models__)
+@mark.parametrize("device", ["cpu", "cuda"])
+def test_torchscript_then_compile(model_name, device):
+ """Test that a TorchScripted model can be torch.compiled afterwards"""
+ if device == "cuda" and not torch.cuda.is_available():
+ pytest.skip("CUDA not available")
+
+ # Create and TorchScript the model
+ z, pos, batch = create_example_batch()
+ z = z.to(device)
+ pos = pos.to(device).requires_grad_(True)
+ batch = batch.to(device)
+
+ scripted_model = torch.jit.script(
+ create_model(load_example_args(model_name, remove_prior=True, derivative=True))
+ ).to(device=device)
+
+ # Get baseline output from scripted model
+ y_scripted, neg_dy_scripted = scripted_model(z, pos, batch=batch)
+
+ # Now try to torch.compile the scripted model
+ try:
+ compiled_model = torch.compile(scripted_model, backend="inductor")
+ y_compiled, neg_dy_compiled = compiled_model(z, pos, batch=batch)
+
+ # Verify outputs match
+ torch.testing.assert_close(y_scripted, y_compiled, atol=1e-5, rtol=1e-5)
+ torch.testing.assert_close(
+ neg_dy_scripted, neg_dy_compiled, atol=1e-5, rtol=1e-5
+ )
+ except Exception as e:
+ pytest.fail(f"torch.compile failed on TorchScripted {model_name} model: {e}")
+
+
# Currently only tensornet is CUDA graph compatible
@mark.parametrize("model_name", ["tensornet"])
def test_cuda_graph_compatible(model_name):
@@ -161,6 +195,65 @@ def test_cuda_graph_compatible(model_name):
assert torch.allclose(neg_dy, neg_dy2, atol=1e-5, rtol=1e-5)
+@mark.parametrize("model_name", ["tensornet"])
+def test_torchscript_cuda_graph_compatible(model_name):
+ """Test that a TorchScripted model is compatible with CUDA graphs.
+
+ This is important when models are saved as TorchScript
+ and then CUDA graphs are used for performance.
+ """
+ if not torch.cuda.is_available():
+ pytest.skip("CUDA not available")
+ z, pos, batch = create_example_batch()
+ args = {
+ "model": model_name,
+ "embedding_dimension": 128,
+ "num_layers": 2,
+ "num_rbf": 32,
+ "rbf_type": "expnorm",
+ "trainable_rbf": False,
+ "activation": "silu",
+ "cutoff_lower": 0.0,
+ "cutoff_upper": 5.0,
+ "max_z": 100,
+ "max_num_neighbors": 128,
+ "equivariance_invariance_group": "O(3)",
+ "prior_model": None,
+ "atom_filter": -1,
+ "derivative": True,
+ "check_errors": False,
+ "static_shapes": True,
+ "output_model": "Scalar",
+ "reduce_op": "sum",
+ "precision": 32,
+ }
+ # Create the model
+ base_model = create_model(args)
+ # Setup for CUDA graphs before TorchScripting
+ z_cuda = z.to("cuda")
+ pos_cuda = pos.to("cuda").requires_grad_(True)
+ batch_cuda = batch.to("cuda")
+
+ # Now TorchScript the model (like OpenMM-Torch does)
+ model = torch.jit.script(base_model).to(device="cuda")
+ model.eval()
+
+ # Warm up the model
+ with torch.cuda.stream(torch.cuda.Stream()):
+ for _ in range(0, 15):
+ y, neg_dy = model(z_cuda, pos_cuda, batch=batch_cuda)
+ # Capture the model in a CUDA graph
+ g = torch.cuda.CUDAGraph()
+ y2, neg_dy2 = model(z_cuda, pos_cuda, batch=batch_cuda)
+ with torch.cuda.graph(g):
+ y, neg_dy = model(z_cuda, pos_cuda, batch=batch_cuda)
+ y.fill_(0.0)
+ neg_dy.fill_(0.0)
+ g.replay()
+ assert torch.allclose(y, y2)
+ assert torch.allclose(neg_dy, neg_dy2, atol=1e-5, rtol=1e-5)
+
+
@mark.parametrize("model_name", models.__all_models__)
def test_seed(model_name):
args = load_example_args(model_name, remove_prior=True)
diff --git a/tests/test_module.py b/tests/test_module.py
index 021672472..7f55418f2 100644
--- a/tests/test_module.py
+++ b/tests/test_module.py
@@ -43,6 +43,10 @@ def test_train(model_name, use_atomref, precision, tmpdir):
# OSX MPS backend runs out of memory on Github Actions
torch.set_default_device("cpu")
accelerator = "cpu"
+ elif precision == 64 and torch.backends.mps.is_available():
+ # MPS backend doesn't support float64
+ torch.set_default_device("cpu")
+ accelerator = "cpu"
args = load_example_args(
model_name,
@@ -95,6 +99,10 @@ def test_dummy_train(model_name, use_atomref, precision, tmpdir):
# OSX MPS backend runs out of memory on Github Actions
torch.set_default_device("cpu")
accelerator = "cpu"
+ elif precision == 64 and torch.backends.mps.is_available():
+ # MPS backend doesn't support float64
+ torch.set_default_device("cpu")
+ accelerator = "cpu"
extra_args = {}
if model_name != "tensornet":
diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py
index fa3e06046..574d9b611 100644
--- a/tests/test_neighbors.py
+++ b/tests/test_neighbors.py
@@ -73,7 +73,7 @@ def compute_ref_neighbors(pos, batch, loop, include_transpose, cutoff, box_vecto
@pytest.mark.parametrize(
("device", "strategy"),
- [("cpu", "brute"), ("cuda", "brute"), ("cuda", "shared"), ("cuda", "cell")],
+ [("cpu", "brute"), ("cuda", "brute"), ("cuda", "cell")],
)
@pytest.mark.parametrize("n_batches", [1, 2, 3, 4, 128])
@pytest.mark.parametrize("cutoff", [0.1, 1.0, 3.0, 4.9])
@@ -130,6 +130,7 @@ def test_neighbors(
return_vecs=True,
include_transpose=include_transpose,
)
+ nl.to(device)
batch.to(device)
neighbors, distances, distance_vecs = nl(pos, batch)
neighbors = neighbors.cpu().detach().numpy()
@@ -149,7 +150,7 @@ def test_neighbors(
@pytest.mark.parametrize(
("device", "strategy"),
- [("cpu", "brute"), ("cuda", "brute"), ("cuda", "shared"), ("cuda", "cell")],
+ [("cpu", "brute"), ("cuda", "brute"), ("cuda", "cell")],
)
@pytest.mark.parametrize("loop", [True, False])
@pytest.mark.parametrize("include_transpose", [True, False])
@@ -225,6 +226,7 @@ def test_neighbor_grads(
resize_to_fit=True,
box=box,
)
+ nl.to(device)
neighbors, distances, deltas = nl(positions)
# Check neighbor pairs are correct
ref_neighbors_sort, _, _ = sort_neighbors(
@@ -261,7 +263,7 @@ def test_neighbor_grads(
@pytest.mark.parametrize(
("device", "strategy"),
- [("cpu", "brute"), ("cuda", "brute"), ("cuda", "shared"), ("cuda", "cell")],
+ [("cpu", "brute"), ("cuda", "brute"), ("cuda", "cell")],
)
@pytest.mark.parametrize("loop", [True, False])
@pytest.mark.parametrize("include_transpose", [True, False])
@@ -300,6 +302,7 @@ def test_neighbor_autograds(
resize_to_fit=True,
box=box,
)
+ nl.to(device)
positions = 0.25 * lbox * torch.rand(num_atoms, 3, device=device, dtype=dtype)
positions.requires_grad_(True)
batch = torch.zeros((num_atoms,), dtype=torch.long, device=device)
@@ -314,7 +317,7 @@ def test_neighbor_autograds(
)
-@pytest.mark.parametrize("strategy", ["brute", "cell", "shared"])
+@pytest.mark.parametrize("strategy", ["brute", "cell"])
@pytest.mark.parametrize("n_batches", [1, 2, 3, 4])
def test_large_size(strategy, n_batches):
device = "cuda"
@@ -359,6 +362,7 @@ def test_large_size(strategy, n_batches):
include_transpose=True,
resize_to_fit=True,
)
+ nl.to(device)
neighbors, distances, distance_vecs = nl(pos, batch)
neighbors = neighbors.cpu().detach().numpy()
distance_vecs = distance_vecs.cpu().detach().numpy()
@@ -373,7 +377,7 @@ def test_large_size(strategy, n_batches):
@pytest.mark.parametrize(
("device", "strategy"),
- [("cpu", "brute"), ("cuda", "brute"), ("cuda", "shared"), ("cuda", "cell")],
+ [("cpu", "brute"), ("cuda", "brute"), ("cuda", "cell")],
)
@pytest.mark.parametrize("n_batches", [1, 128])
@pytest.mark.parametrize("cutoff", [1.0])
@@ -424,6 +428,7 @@ def test_jit_script_compatible(
return_vecs=True,
include_transpose=include_transpose,
)
+ nl.to(device)
batch.to(device)
nl = torch.jit.script(nl)
@@ -444,7 +449,7 @@ def test_jit_script_compatible(
@pytest.mark.parametrize("device", ["cuda"])
-@pytest.mark.parametrize("strategy", ["brute", "shared", "cell"])
+@pytest.mark.parametrize("strategy", ["brute", "cell"])
@pytest.mark.parametrize("n_batches", [1, 128])
@pytest.mark.parametrize("cutoff", [1.0])
@pytest.mark.parametrize("loop", [True, False])
@@ -494,6 +499,7 @@ def test_cuda_graph_compatible_forward(
check_errors=False,
resize_to_fit=False,
)
+ nl.to(device)
batch.to(device)
graph = torch.cuda.CUDAGraph()
@@ -527,7 +533,7 @@ def test_cuda_graph_compatible_forward(
@pytest.mark.parametrize("device", ["cuda"])
-@pytest.mark.parametrize("strategy", ["brute", "shared", "cell"])
+@pytest.mark.parametrize("strategy", ["brute", "cell"])
@pytest.mark.parametrize("n_batches", [1, 128])
@pytest.mark.parametrize("cutoff", [1.0])
@pytest.mark.parametrize("loop", [True, False])
@@ -580,6 +586,7 @@ def test_cuda_graph_compatible_backward(
check_errors=False,
resize_to_fit=False,
)
+ nl.to(device)
batch.to(device)
graph = torch.cuda.CUDAGraph()
@@ -600,9 +607,7 @@ def test_cuda_graph_compatible_backward(
torch.cuda.synchronize()
-@pytest.mark.parametrize(
- ("device", "strategy"), [("cpu", "brute"), ("cuda", "brute"), ("cuda", "shared")]
-)
+@pytest.mark.parametrize(("device", "strategy"), [("cpu", "brute"), ("cuda", "brute")])
@pytest.mark.parametrize("n_batches", [1, 128])
@pytest.mark.parametrize("use_forward", [True, False])
def test_per_batch_box(device, strategy, n_batches, use_forward):
@@ -646,6 +651,7 @@ def test_per_batch_box(device, strategy, n_batches, use_forward):
return_vecs=True,
include_transpose=include_transpose,
)
+ nl.to(device)
batch.to(device)
neighbors, distances, distance_vecs = nl(
pos, batch, box=box if use_forward else None
@@ -694,8 +700,6 @@ def test_torch_compile(device, dtype, loop, include_transpose):
if sys.version_info >= (3, 12):
pytest.skip("Not available in this version")
- if torch.__version__ < "2.0.0":
- pytest.skip("Not available in this version")
if device == "cuda" and not torch.cuda.is_available():
pytest.skip("CUDA not available")
np.random.seed(123456)
diff --git a/torchmdnet/calculators.py b/torchmdnet/calculators.py
index 662f7b7a4..dbf1e6b27 100644
--- a/torchmdnet/calculators.py
+++ b/torchmdnet/calculators.py
@@ -268,10 +268,12 @@ def calculate(
else:
print("atomic numbers changed, re-compiling...")
-
- self.model.representation_model.setup_for_compile_cudagraphs(batch)
- self.model.output_model.setup_for_compile_cudagraphs(batch)
+ # Warmup pass to set dim_size before compilation
+ # This is needed because torch.compile doesn't support .item() calls
self.model.to(self.device)
+ with torch.no_grad():
+ _ = self.model(numbers, positions, batch=batch, q=total_charge)
+
self.compiled_model = torch.compile(self.model, backend='inductor', dynamic=False, fullgraph=True, mode='reduce-overhead')
self.compiled = True
diff --git a/torchmdnet/extensions/__init__.py b/torchmdnet/extensions/__init__.py
index f474bff89..f163ff7d7 100644
--- a/torchmdnet/extensions/__init__.py
+++ b/torchmdnet/extensions/__init__.py
@@ -1,12 +1,3 @@
# Copyright Universitat Pompeu Fabra 2020-2023 https://www.compscience.org
# Distributed under the MIT License.
# (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT)
-
-# Place here any short extensions to torch that you want to use in your code.
-# The extensions present in extensions.cpp will be automatically compiled in setup.py and loaded here.
-# The extensions will be available under torch.ops.torchmdnet_extensions, but you can add wrappers here to make them more convenient to use.
-# Place here too any meta registrations for your extensions if required.
-
-import torch
-from pathlib import Path
-from . import torchmdnet_extensions, ops
diff --git a/torchmdnet/extensions/neighbors.py b/torchmdnet/extensions/neighbors.py
new file mode 100644
index 000000000..f495f4b4c
--- /dev/null
+++ b/torchmdnet/extensions/neighbors.py
@@ -0,0 +1,174 @@
+from torch import Tensor
+from typing import Tuple
+import torch
+
+
+def _round_nearest(x: Tensor) -> Tensor:
+ # Equivalent to torch.round but works for both float32/float64 and keeps TorchScript happy.
+ return torch.where(x >= 0, torch.floor(x + 0.5), torch.ceil(x - 0.5))
+
+
+def _apply_pbc_torch(deltas: Tensor, box_for_pairs: Tensor) -> Tensor:
+ # box_for_pairs: (num_pairs, 3, 3)
+ scale3 = _round_nearest(deltas[:, 2] / box_for_pairs[:, 2, 2])
+ deltas = deltas - scale3.unsqueeze(1) * box_for_pairs[:, 2]
+ scale2 = _round_nearest(deltas[:, 1] / box_for_pairs[:, 1, 1])
+ deltas = deltas - scale2.unsqueeze(1) * box_for_pairs[:, 1]
+ scale1 = _round_nearest(deltas[:, 0] / box_for_pairs[:, 0, 0])
+ deltas = deltas - scale1.unsqueeze(1) * box_for_pairs[:, 0]
+ return deltas
+
+
+def torch_neighbor_bruteforce(
+ positions: Tensor,
+ batch: Tensor,
+ box_vectors: Tensor,
+ use_periodic: bool,
+ cutoff_lower: float,
+ cutoff_upper: float,
+ max_num_pairs: int,
+ loop: bool,
+ include_transpose: bool,
+) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
+ """Optimized brute-force neighbor list using pure PyTorch.
+
+ This implementation avoids nonzero() to be torch.compile compatible.
+ Uses triangular indexing to reduce memory usage and computation from O(n^2) to O(n^2/2).
+ It uses argsort with infinity for invalid pairs to achieve fixed output shapes.
+ """
+ if positions.dim() != 2 or positions.size(1) != 3:
+ raise ValueError('Expected "positions" to have shape (N, 3)')
+ if batch.dim() != 1 or batch.size(0) != positions.size(0):
+ raise ValueError('Expected "batch" to have shape (N,)')
+ if max_num_pairs <= 0:
+ raise ValueError('Expected "max_num_pairs" to be positive')
+
+ device = positions.device
+ dtype = positions.dtype
+ n_atoms = positions.size(0)
+
+ if use_periodic:
+ if box_vectors.dim() == 2:
+ box_vectors = box_vectors.unsqueeze(0)
+ elif box_vectors.dim() != 3:
+ raise ValueError('Expected "box_vectors" to have shape (n_batch, 3, 3)')
+ box_vectors = box_vectors.to(device=device, dtype=dtype)
+
+ # Generate base pairs
+ if loop:
+ # loop=True: i >= j (lower triangle including diagonal)
+ tril_indices = torch.tril_indices(n_atoms, n_atoms, device=device)
+ i_indices = tril_indices[0]
+ j_indices = tril_indices[1]
+ else:
+ # loop=False: i > j (lower triangle excluding diagonal)
+ tril_indices = torch.tril_indices(n_atoms, n_atoms, offset=-1, device=device)
+ i_indices = tril_indices[0]
+ j_indices = tril_indices[1]
+
+ # If include_transpose, add the flipped pairs (j,i)
+ if include_transpose:
+ if loop:
+ # For loop=True, base pairs are i >= j, so add i < j transposes
+ triu_indices = torch.triu_indices(n_atoms, n_atoms, offset=1, device=device)
+ i_transpose = triu_indices[0]
+ j_transpose = triu_indices[1]
+ else:
+ # For loop=False, base pairs are i > j, so add all transposes (j,i)
+ i_transpose = j_indices
+ j_transpose = i_indices
+ # Combine base and transpose pairs
+ i_indices = torch.cat([i_indices, i_transpose])
+ j_indices = torch.cat([j_indices, j_transpose])
+
+ # Compute deltas for all pairs
+ deltas = positions[i_indices] - positions[j_indices]
+
+ # Apply PBC if needed
+ if use_periodic:
+ batch_i = batch[i_indices]
+ if box_vectors.size(0) == 1:
+ # Single box for all - use the same box
+ box_for_pairs = box_vectors.expand(len(deltas), 3, 3)
+ else:
+ # Per-batch boxes - index by batch of atom i
+ box_for_pairs = box_vectors[batch_i]
+ # Apply PBC to pairs
+ deltas = _apply_pbc_torch(deltas, box_for_pairs)
+
+ # Compute distances for all pairs
+ dist_sq = (deltas * deltas).sum(dim=-1)
+ zero_mask = dist_sq == 0
+ distances = torch.where(
+ zero_mask,
+ torch.zeros_like(dist_sq),
+ torch.sqrt(dist_sq.clamp(min=1e-32)),
+ )
+
+ # Build validity mask for all pairs
+ valid_mask = torch.ones(len(distances), device=device, dtype=torch.bool)
+
+ # Apply batch constraint
+ if batch.numel() > 0:
+ same_batch = batch[i_indices] == batch[j_indices]
+ valid_mask = valid_mask & same_batch
+
+ # Apply cutoff constraints
+ # Self-loops (i == j) are exempt from cutoff_lower since they have distance 0
+ is_self_loop = i_indices == j_indices
+ valid_mask = (
+ valid_mask
+ & (distances < cutoff_upper)
+ & ((distances >= cutoff_lower) | is_self_loop)
+ )
+
+ # Sort key: valid pairs by distance, invalid pairs get infinity (sorted last)
+ sort_key = torch.where(
+ valid_mask,
+ distances,
+ torch.full_like(distances, float("inf")),
+ )
+
+ # Sort and take top max_num_pairs (fixed output shape)
+ # For include_transpose + loop case, we may have more pairs than expected
+ num_candidates = min(sort_key.size(0), max_num_pairs)
+ order = torch.argsort(sort_key)[:num_candidates]
+
+ # Gather results using the sorted indices
+ i_out = i_indices.index_select(0, order)
+ j_out = j_indices.index_select(0, order)
+ deltas_out = deltas.index_select(0, order)
+ distances_out = distances.index_select(0, order)
+ valid_out = valid_mask.index_select(0, order)
+
+ # Pad to max_num_pairs if needed
+ if num_candidates < max_num_pairs:
+ pad_size = max_num_pairs - num_candidates
+ pad_indices = torch.full((pad_size,), -1, device=device, dtype=torch.long)
+ pad_deltas = torch.zeros((pad_size, 3), device=device, dtype=dtype)
+ pad_distances = torch.zeros((pad_size,), device=device, dtype=dtype)
+ pad_valid = torch.zeros((pad_size,), device=device, dtype=torch.bool)
+
+ i_out = torch.cat([i_out, pad_indices])
+ j_out = torch.cat([j_out, pad_indices])
+ deltas_out = torch.cat([deltas_out, pad_deltas])
+ distances_out = torch.cat([distances_out, pad_distances])
+ valid_out = torch.cat([valid_out, pad_valid])
+
+ # Replace invalid entries with -1 (neighbors) and 0 (deltas/distances)
+ # Ensure gradients flow through valid entries
+ neighbors_out = torch.stack(
+ [
+ torch.where(valid_out, i_out, torch.full_like(i_out, -1)),
+ torch.where(valid_out, j_out, torch.full_like(j_out, -1)),
+ ]
+ )
+ # For deltas/distances, use where to preserve gradients for valid entries
+ zero_deltas = deltas_out.detach() * 0
+ zero_distances = distances_out.detach() * 0
+ deltas_out = torch.where(valid_out.unsqueeze(1), deltas_out, zero_deltas)
+ distances_out = torch.where(valid_out, distances_out, zero_distances)
+
+ # Count valid pairs (before slicing) to detect overflow
+ num_pairs_tensor = valid_mask.sum().view(1).to(torch.long)
+ return neighbors_out, deltas_out, distances_out, num_pairs_tensor
diff --git a/torchmdnet/extensions/neighbors/common.cuh b/torchmdnet/extensions/neighbors/common.cuh
deleted file mode 100644
index bf875af5b..000000000
--- a/torchmdnet/extensions/neighbors/common.cuh
+++ /dev/null
@@ -1,206 +0,0 @@
-/* Copyright Universitat Pompeu Fabra 2020-2023 https://www.compscience.org
- * Distributed under the MIT License.
- *(See accompanying file README.md file or copy at http://opensource.org/licenses/MIT)
- * Raul P. Pelaez 2023. Common utilities for the CUDA neighbor operation.
- */
-#ifndef NEIGHBORS_COMMON_CUH
-#define NEIGHBORS_COMMON_CUH
-#include
-#include
-#include
-#include
-#include
-
-using c10::cuda::getCurrentCUDAStream;
-using c10::cuda::CUDAStreamGuard;
-using at::Tensor;
-using at::TensorOptions;
-using at::TensorAccessor;
-using torch::Scalar;
-using torch::empty;
-using torch::full;
-
-template
-using Accessor = torch::PackedTensorAccessor32;
-
-template
-using KernelAccessor = TensorAccessor;
-
-template
-inline Accessor get_accessor(const Tensor& tensor) {
- return tensor.packed_accessor32();
-};
-
-template __device__ __forceinline__ scalar_t sqrt_(scalar_t x){return ::sqrt(x);};
-template <> __device__ __forceinline__ float sqrt_(float x) {
- return ::sqrtf(x);
-};
-template <> __device__ __forceinline__ double sqrt_(double x) {
- return ::sqrt(x);
-};
-
-template struct vec3 {
- using type = void;
-};
-
-template <> struct vec3 {
- using type = float3;
-};
-
-template <> struct vec3 {
- using type = double3;
-};
-
-template using scalar3 = typename vec3::type;
-
-/*
- * @brief Get the position of the i'th particle
- * @param positions The positions tensor
- * @param i The index of the particle
- * @return The position of the i'th particle
- */
-template
-__device__ scalar3 fetchPosition(const Accessor positions, const int i) {
- return {positions[i][0], positions[i][1], positions[i][2]};
-}
-
-struct PairList {
- Tensor i_curr_pair;
- Tensor neighbors;
- Tensor deltas;
- Tensor distances;
- const bool loop, include_transpose, use_periodic;
- PairList(int max_num_pairs, torch::TensorOptions options, bool loop, bool include_transpose,
- bool use_periodic)
- : i_curr_pair(torch::zeros({1}, options.dtype(at::kInt))),
- neighbors(torch::full({2, max_num_pairs}, -1, options.dtype(at::kInt))),
- deltas(torch::full({max_num_pairs, 3}, 0, options)),
- distances(torch::full({max_num_pairs}, 0, options)), loop(loop),
- include_transpose(include_transpose), use_periodic(use_periodic) {
- }
-};
-
-template struct PairListAccessor {
- Accessor i_curr_pair;
- Accessor neighbors;
- Accessor deltas;
- Accessor distances;
- bool loop, include_transpose, use_periodic;
- explicit PairListAccessor(const PairList& pl)
- : i_curr_pair(get_accessor(pl.i_curr_pair)),
- neighbors(get_accessor(pl.neighbors)),
- deltas(get_accessor(pl.deltas)),
- distances(get_accessor(pl.distances)), loop(pl.loop),
- include_transpose(pl.include_transpose), use_periodic(pl.use_periodic) {
- }
-};
-
-template
-__device__ void writeAtomPair(PairListAccessor& list, int i, int j,
- scalar3 delta, scalar_t distance, int i_pair) {
- if (i_pair < list.neighbors.size(1)) {
- list.neighbors[0][i_pair] = i;
- list.neighbors[1][i_pair] = j;
- list.deltas[i_pair][0] = delta.x;
- list.deltas[i_pair][1] = delta.y;
- list.deltas[i_pair][2] = delta.z;
- list.distances[i_pair] = distance;
- }
-}
-
-template
-__device__ void addAtomPairToList(PairListAccessor& list, int i, int j,
- scalar3 delta, scalar_t distance, bool add_transpose) {
- const int32_t i_pair = atomicAdd(&list.i_curr_pair[0], add_transpose ? 2 : 1);
- // Neighbors after the max number of pairs are ignored, although the pair is counted
- writeAtomPair(list, i, j, delta, distance, i_pair);
- if (add_transpose) {
- writeAtomPair(list, j, i, {-delta.x, -delta.y, -delta.z}, distance, i_pair + 1);
- }
-}
-
-static void checkInput(const Tensor& positions, const Tensor& batch) {
- TORCH_CHECK(positions.dim() == 2, "Expected \"positions\" to have two dimensions");
- TORCH_CHECK(positions.size(0) > 0,
- "Expected the 1nd dimension size of \"positions\" to be more than 0");
- TORCH_CHECK(positions.size(1) == 3, "Expected the 2nd dimension size of \"positions\" to be 3");
- TORCH_CHECK(positions.is_contiguous(), "Expected \"positions\" to be contiguous");
-
- TORCH_CHECK(batch.dim() == 1, "Expected \"batch\" to have one dimension");
- TORCH_CHECK(batch.size(0) == positions.size(0),
- "Expected the 1st dimension size of \"batch\" to be the same as the 1st dimension "
- "size of \"positions\"");
- TORCH_CHECK(batch.is_contiguous(), "Expected \"batch\" to be contiguous");
- TORCH_CHECK(batch.dtype() == torch::kInt64, "Expected \"batch\" to be of type at::kLong");
-}
-
-namespace rect {
-
-/*
- * @brief Takes a point to the unit cell in the range [-0.5, 0.5]*box_size using Minimum Image
- * Convention
- * @param p The point position
- * @param box_size The box size
- * @return The point in the unit cell
- */
-template
-__device__ auto apply_pbc(scalar3 p, scalar3 box_size) {
- p.x = p.x - floorf(p.x / box_size.x + scalar_t(0.5)) * box_size.x;
- p.y = p.y - floorf(p.y / box_size.y + scalar_t(0.5)) * box_size.y;
- p.z = p.z - floorf(p.z / box_size.z + scalar_t(0.5)) * box_size.z;
- return p;
-}
-
-template
-__device__ auto compute_distance(scalar3 pos_i, scalar3 pos_j,
- bool use_periodic, scalar3 box_size) {
- scalar3 delta = {pos_i.x - pos_j.x, pos_i.y - pos_j.y, pos_i.z - pos_j.z};
- if (use_periodic) {
- delta = apply_pbc(delta, box_size);
- }
- return delta;
-}
-
-} // namespace rect
-
-namespace triclinic {
-template using BoxAccessor = Accessor;
-template
-BoxAccessor get_box_accessor(const Tensor& box_vectors, bool use_periodic) {
- return get_accessor(box_vectors);
-}
-
-/*
- * @brief Takes a point to the unit cell using Minimum Image
- * Convention
- * @param p The point position
- * @param box_vectors The box vectors (3x3 matrix)
- * @return The point in the unit cell
- */
-template
-__device__ auto apply_pbc(scalar3 delta, const KernelAccessor& box) {
- scalar_t scale3 = round(delta.z / box[2][2]);
- delta.x -= scale3 * box[2][0];
- delta.y -= scale3 * box[2][1];
- delta.z -= scale3 * box[2][2];
- scalar_t scale2 = round(delta.y / box[1][1]);
- delta.x -= scale2 * box[1][0];
- delta.y -= scale2 * box[1][1];
- scalar_t scale1 = round(delta.x / box[0][0]);
- delta.x -= scale1 * box[0][0];
- return delta;
-}
-
- template
-__device__ auto compute_distance(scalar3 pos_i, scalar3 pos_j,
- bool use_periodic, const KernelAccessor& box) {
- scalar3 delta = {pos_i.x - pos_j.x, pos_i.y - pos_j.y, pos_i.z - pos_j.z};
- if (use_periodic) {
- delta = apply_pbc(delta, box);
- }
- return delta;
-}
-
-} // namespace triclinic
-
-#endif
diff --git a/torchmdnet/extensions/neighbors/neighbors_cpu.cpp b/torchmdnet/extensions/neighbors/neighbors_cpu.cpp
deleted file mode 100644
index 350bbd26f..000000000
--- a/torchmdnet/extensions/neighbors/neighbors_cpu.cpp
+++ /dev/null
@@ -1,273 +0,0 @@
-/* Copyright Universitat Pompeu Fabra 2020-2023 https://www.compscience.org
- * Distributed under the MIT License.
- *(See accompanying file README.md file or copy at http://opensource.org/licenses/MIT)
- */
-#include
-#include
-#include
-#include
-#include
-
-using std::tuple;
-using torch::arange;
-using torch::div;
-using torch::linalg_vector_norm;
-using torch::full;
-using torch::hstack;
-using torch::index_select;
-using torch::kInt32;
-using torch::round;
-using torch::Scalar;
-using torch::vstack;
-using torch::autograd::AutogradContext;
-using torch::autograd::Function;
-using torch::autograd::tensor_list;
-using torch::indexing::Slice;
-using at::Tensor;
-
-static tuple
-forward_impl(const std::string& strategy, const Tensor& positions, const Tensor& batch,
- const Tensor& in_box_vectors, bool use_periodic, const Scalar& cutoff_lower,
- const Scalar& cutoff_upper, const Scalar& max_num_pairs, bool loop,
- bool include_transpose) {
- TORCH_CHECK(positions.dim() == 2, "Expected \"positions\" to have two dimensions");
- TORCH_CHECK(positions.size(0) > 0,
- "Expected the 1nd dimension size of \"positions\" to be more than 0");
- TORCH_CHECK(positions.size(1) == 3, "Expected the 2nd dimension size of \"positions\" to be 3");
- TORCH_CHECK(positions.is_contiguous(), "Expected \"positions\" to be contiguous");
- TORCH_CHECK(cutoff_upper.to() > 0, "Expected \"cutoff\" to be positive");
- Tensor box_vectors = in_box_vectors;
- const int n_batch = batch.max().item() + 1;
- if (use_periodic) {
- if (box_vectors.dim() == 2) {
- box_vectors = box_vectors.unsqueeze(0).expand({n_batch, 3, 3});
- }
- TORCH_CHECK(box_vectors.dim() == 3, "Expected \"box_vectors\" to have two dimensions");
- TORCH_CHECK(box_vectors.size(1) == 3 && box_vectors.size(2) == 3,
- "Expected \"box_vectors\" to have shape (n_batch, 3, 3)");
- // Ensure the box first dimension has size max(batch) + 1
- TORCH_CHECK(box_vectors.size(0) == n_batch,
- "Expected \"box_vectors\" to have shape (n_batch, 3, 3)");
- // Check that the box is a valid triclinic box, in the case of a box per sample we only
- // check the first one
- double v[3][3];
- for (int i = 0; i < 3; i++)
- for (int j = 0; j < 3; j++)
- v[i][j] = box_vectors[0][i][j].item();
- double c = cutoff_upper.to();
- TORCH_CHECK(v[0][1] == 0, "Invalid box vectors: box_vectors[0][1] != 0");
- TORCH_CHECK(v[0][2] == 0, "Invalid box vectors: box_vectors[0][2] != 0");
- TORCH_CHECK(v[1][2] == 0, "Invalid box vectors: box_vectors[1][2] != 0");
- TORCH_CHECK(v[0][0] >= 2 * c, "Invalid box vectors: box_vectors[0][0] < 2*cutoff");
- TORCH_CHECK(v[1][1] >= 2 * c, "Invalid box vectors: box_vectors[1][1] < 2*cutoff");
- TORCH_CHECK(v[2][2] >= 2 * c, "Invalid box vectors: box_vectors[2][2] < 2*cutoff");
- TORCH_CHECK(v[0][0] >= 2 * v[1][0],
- "Invalid box vectors: box_vectors[0][0] < 2*box_vectors[1][0]");
- TORCH_CHECK(v[0][0] >= 2 * v[2][0],
- "Invalid box vectors: box_vectors[0][0] < 2*box_vectors[1][0]");
- TORCH_CHECK(v[1][1] >= 2 * v[2][1],
- "Invalid box vectors: box_vectors[1][1] < 2*box_vectors[2][1]");
- }
- TORCH_CHECK(max_num_pairs.toLong() > 0, "Expected \"max_num_neighbors\" to be positive");
- const int n_atoms = positions.size(0);
- Tensor neighbors = torch::empty({0}, positions.options().dtype(kInt32));
- Tensor distances = torch::empty({0}, positions.options());
- Tensor deltas = torch::empty({0}, positions.options());
- neighbors = torch::vstack((torch::tril_indices(n_atoms, n_atoms, -1, neighbors.options())));
- auto mask = index_select(batch, 0, neighbors.index({0, Slice()})) ==
- index_select(batch, 0, neighbors.index({1, Slice()}));
- neighbors = neighbors.index({Slice(), mask}).to(kInt32);
- deltas = index_select(positions, 0, neighbors.index({0, Slice()})) -
- index_select(positions, 0, neighbors.index({1, Slice()}));
- if (use_periodic) {
- const auto pair_batch = batch.index({neighbors.index({0, Slice()})});
- const auto scale3 =
- round(deltas.index({Slice(), 2}) / box_vectors.index({pair_batch, 2, 2}));
- deltas.index_put_({Slice(), 0}, deltas.index({Slice(), 0}) -
- scale3 * box_vectors.index({pair_batch, 2, 0}));
- deltas.index_put_({Slice(), 1}, deltas.index({Slice(), 1}) -
- scale3 * box_vectors.index({pair_batch, 2, 1}));
- deltas.index_put_({Slice(), 2}, deltas.index({Slice(), 2}) -
- scale3 * box_vectors.index({pair_batch, 2, 2}));
- const auto scale2 =
- round(deltas.index({Slice(), 1}) / box_vectors.index({pair_batch, 1, 1}));
- deltas.index_put_({Slice(), 0}, deltas.index({Slice(), 0}) -
- scale2 * box_vectors.index({pair_batch, 1, 0}));
- deltas.index_put_({Slice(), 1}, deltas.index({Slice(), 1}) -
- scale2 * box_vectors.index({pair_batch, 1, 1}));
- const auto scale1 =
- round(deltas.index({Slice(), 0}) / box_vectors.index({pair_batch, 0, 0}));
- deltas.index_put_({Slice(), 0}, deltas.index({Slice(), 0}) -
- scale1 * box_vectors.index({pair_batch, 0, 0}));
- }
- distances = linalg_vector_norm(deltas, 2, 1);
- mask = (distances < cutoff_upper) * (distances >= cutoff_lower);
- neighbors = neighbors.index({Slice(), mask});
- deltas = deltas.index({mask, Slice()});
- distances = distances.index({mask});
- if (include_transpose) {
- neighbors = torch::hstack({neighbors, torch::stack({neighbors[1], neighbors[0]})});
- distances = torch::hstack({distances, distances});
- deltas = torch::vstack({deltas, -deltas});
- }
- if (loop) {
- const Tensor range = torch::arange(0, n_atoms, torch::kInt32);
- neighbors = torch::hstack({neighbors, torch::stack({range, range})});
- distances = torch::hstack({distances, torch::zeros_like(range)});
- deltas = torch::vstack({deltas, torch::zeros({n_atoms, 3}, deltas.options())});
- }
- Tensor num_pairs_found = torch::empty(1, distances.options().dtype(kInt32));
- num_pairs_found[0] = distances.size(0);
- // This seems wasteful, but it allows to enable torch.compile by guaranteeing that the output of
- // this operator has a predictable size Resize to max_num_pairs, add zeros if necessary
- int64_t extension = std::max(max_num_pairs.toLong() - distances.size(0), (int64_t)0);
- if (extension > 0) {
- deltas = torch::vstack({deltas, torch::zeros({extension, 3}, deltas.options())});
- distances = torch::hstack({distances, torch::zeros({extension}, distances.options())});
- // For the neighbors add (-1,-1) pairs to fill the tensor
- neighbors = torch::hstack(
- {neighbors, torch::full({2, extension}, -1, neighbors.options().dtype(kInt32))});
- }
- return {neighbors, deltas, distances, num_pairs_found};
-}
-
-// The backwards operation is implemented fully in pytorch so that it can be differentiated a second
-// time automatically via Autograd.
-static Tensor backward_impl(const Tensor& grad_edge_vec, const Tensor& grad_edge_weight,
- const Tensor& edge_index, const Tensor& edge_vec,
- const Tensor& edge_weight, const int64_t num_atoms) {
- auto zero_mask = edge_weight.eq(0);
- auto zero_mask3 = zero_mask.unsqueeze(-1).expand_as(grad_edge_vec);
- // We need to avoid dividing by 0. Otherwise Autograd fills the gradient with NaNs in the
- // case of a double backwards. This is why we index_select like this.
- auto grad_distances_ = edge_vec / edge_weight.masked_fill(zero_mask, 1).unsqueeze(-1) *
- grad_edge_weight.masked_fill(zero_mask, 0).unsqueeze(-1);
- auto result = grad_edge_vec.masked_fill(zero_mask3, 0) + grad_distances_;
- // Avoid out-of-bounds indices under compiler fusion by mapping masked indices to 0
- // and ensuring their contributions are exactly zero.
- // This removes the need for allocating an extra dummy row (num_atoms + 1).
- auto grad_positions = torch::zeros({num_atoms, 3}, edge_vec.options());
- auto edge_index_ =
- edge_index.masked_fill(zero_mask.unsqueeze(0).expand_as(edge_index), 0);
- grad_positions.index_add_(0, edge_index_[0], result);
- grad_positions.index_add_(0, edge_index_[1], -result);
- return grad_positions;
-}
-
-// This is the autograd function that is called when the user calls get_neighbor_pairs.
-// It dispatches the required strategy for the forward function and implements the backward
-// function. The backward function is written in full pytorch so that it can be differentiated a
-// second time automatically via Autograd.
-class NeighborAutograd : public torch::autograd::Function {
-public:
- static tensor_list forward(AutogradContext* ctx, const std::string& strategy,
- const Tensor& positions, const Tensor& batch,
- const Tensor& box_vectors, bool use_periodic,
- const Scalar& cutoff_lower, const Scalar& cutoff_upper,
- const Scalar& max_num_pairs, bool loop, bool include_transpose) {
- static auto fwd =
- torch::Dispatcher::singleton()
- .findSchemaOrThrow("torchmdnet_extensions::get_neighbor_pairs_fwd", "")
- .typed();
- Tensor neighbors, deltas, distances, i_curr_pair;
- std::tie(neighbors, deltas, distances, i_curr_pair) =
- fwd.call(strategy, positions, batch, box_vectors, use_periodic, cutoff_lower,
- cutoff_upper, max_num_pairs, loop, include_transpose);
- ctx->save_for_backward({neighbors, deltas, distances});
- ctx->saved_data["num_atoms"] = positions.size(0);
- return {neighbors, deltas, distances, i_curr_pair};
- }
-
- using Slice = torch::indexing::Slice;
-
- static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs) {
- auto saved = ctx->get_saved_variables();
- auto edge_index = saved[0];
- auto edge_vec = saved[1];
- auto edge_weight = saved[2];
- auto num_atoms = ctx->saved_data["num_atoms"].toInt();
- auto grad_edge_vec = grad_outputs[1];
- auto grad_edge_weight = grad_outputs[2];
- static auto backward =
- torch::Dispatcher::singleton()
- .findSchemaOrThrow("torchmdnet_extensions::get_neighbor_pairs_bkwd", "")
- .typed();
- auto grad_positions = backward.call(grad_edge_vec, grad_edge_weight, edge_index, edge_vec,
- edge_weight, num_atoms);
- Tensor ignore;
- return {ignore, grad_positions, ignore, ignore, ignore, ignore,
- ignore, ignore, ignore, ignore, ignore};
- }
-};
-
-// By registering as a CompositeImplicitAutograd we can torch.compile this Autograd function.
-// This mode will generate meta registration for NeighborAutograd automatically, in this case using
-// the
-// meta registrations provided for the forward and backward functions python side.
-// We provide meta registrations python side because it is the recommended way to do it.
-TORCH_LIBRARY_IMPL(torchmdnet_extensions, CompositeImplicitAutograd, m) {
- m.impl("get_neighbor_pairs",
- [](const std::string& strategy, const Tensor& positions, const Tensor& batch,
- const Tensor& box_vectors, bool use_periodic, const Scalar& cutoff_lower,
- const Scalar& cutoff_upper, const Scalar& max_num_pairs, bool loop,
- bool include_transpose) {
- auto result = NeighborAutograd::apply(strategy, positions, batch, box_vectors,
- use_periodic, cutoff_lower, cutoff_upper,
- max_num_pairs, loop, include_transpose);
- return std::make_tuple(result[0], result[1], result[2], result[3]);
- });
-}
-
-// // Explicit device backend registrations for PyTorch versions that do not
-// // automatically fall back to CompositeImplicitAutograd for device dispatch.
-// TORCH_LIBRARY_IMPL(torchmdnet_extensions, CPU, m) {
-// m.impl("get_neighbor_pairs",
-// [](const std::string& strategy, const Tensor& positions, const Tensor& batch,
-// const Tensor& box_vectors, bool use_periodic, const Scalar& cutoff_lower,
-// const Scalar& cutoff_upper, const Scalar& max_num_pairs, bool loop,
-// bool include_transpose) {
-// auto result = NeighborAutograd::apply(strategy, positions, batch, box_vectors,
-// use_periodic, cutoff_lower, cutoff_upper,
-// max_num_pairs, loop, include_transpose);
-// return std::make_tuple(result[0], result[1], result[2], result[3]);
-// });
-// }
-
-// TORCH_LIBRARY_IMPL(torchmdnet_extensions, CUDA, m) {
-// m.impl("get_neighbor_pairs",
-// [](const std::string& strategy, const Tensor& positions, const Tensor& batch,
-// const Tensor& box_vectors, bool use_periodic, const Scalar& cutoff_lower,
-// const Scalar& cutoff_upper, const Scalar& max_num_pairs, bool loop,
-// bool include_transpose) {
-// auto result = NeighborAutograd::apply(strategy, positions, batch, box_vectors,
-// use_periodic, cutoff_lower, cutoff_upper,
-// max_num_pairs, loop, include_transpose);
-// return std::make_tuple(result[0], result[1], result[2], result[3]);
-// });
-// }
-
-// The registration for the CUDA version of the forward function is done in a separate .cu file.
-TORCH_LIBRARY_IMPL(torchmdnet_extensions, CPU, m) {
- m.impl("get_neighbor_pairs_fwd", forward_impl);
-}
-
-// Ideally we would register this just once using CompositeImplicitAutograd, but this causes a
-// segfault
-// when trying to torch.compile this function.
-// Doing it this way prints a message about Autograd not being provided a backward function of the
-// backward function. It gets it from the implementation just fine now, but warns that in the future
-// this will be deprecated.
-TORCH_LIBRARY_IMPL(torchmdnet_extensions, CPU, m) {
- m.impl("get_neighbor_pairs_bkwd", backward_impl);
-}
-
-TORCH_LIBRARY_IMPL(torchmdnet_extensions, CUDA, m) {
- m.impl("get_neighbor_pairs_bkwd", backward_impl);
-}
-
-// // Register explicit Autograd fallthroughs to avoid deprecated autograd fallback warnings
-// // when backpropagating through these custom operators on CUDA/CPU.
-// TORCH_LIBRARY_IMPL(torchmdnet_extensions, Autograd, m) {
-// m.impl("get_neighbor_pairs", torch::CppFunction::makeFallthrough());
-// m.impl("get_neighbor_pairs_bkwd", torch::CppFunction::makeFallthrough());
-// }
diff --git a/torchmdnet/extensions/neighbors/neighbors_cuda.cu b/torchmdnet/extensions/neighbors/neighbors_cuda.cu
deleted file mode 100644
index 68f775485..000000000
--- a/torchmdnet/extensions/neighbors/neighbors_cuda.cu
+++ /dev/null
@@ -1,45 +0,0 @@
-/*
- * Copyright Universitat Pompeu Fabra 2020-2023 https://www.compscience.org
- * Distributed under the MIT License.
- * (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT)
- * Raul P. Pelaez 2023
- */
-#include
-#include
-#include
-
-#include
-#include
-#include
-
-#include "neighbors_cuda_brute.cuh"
-#include "neighbors_cuda_cell.cuh"
-#include "neighbors_cuda_shared.cuh"
-
-static std::tuple
-forward_impl_cuda(const std::string& strategy, const at::Tensor& positions, const at::Tensor& batch,
- const at::Tensor& in_box_vectors, bool use_periodic, const at::Scalar& cutoff_lower,
- const at::Scalar& cutoff_upper, const at::Scalar& max_num_pairs, bool loop,
- bool include_transpose) {
- auto kernel = forward_brute;
- if (positions.size(0) >= 32768 && strategy == "brute") {
- kernel = forward_shared;
- }
- if (strategy == "brute") {
- } else if (strategy == "cell") {
- kernel = forward_cell;
- } else if (strategy == "shared") {
- kernel = forward_shared;
- } else {
- throw std::runtime_error("Unknown kernel name");
- }
- return kernel(positions, batch, in_box_vectors, use_periodic, cutoff_lower, cutoff_upper,
- max_num_pairs, loop, include_transpose);
-}
-
-// We only need to register the CUDA version of the forward function here. This way we can avoid
-// compiling this file in CPU-only mode The rest of the registrations take place in
-// neighbors_cpu.cpp
-TORCH_LIBRARY_IMPL(torchmdnet_extensions, CUDA, m) {
- m.impl("get_neighbor_pairs_fwd", forward_impl_cuda);
-}
diff --git a/torchmdnet/extensions/neighbors/neighbors_cuda_brute.cuh b/torchmdnet/extensions/neighbors/neighbors_cuda_brute.cuh
deleted file mode 100644
index 4d33979f3..000000000
--- a/torchmdnet/extensions/neighbors/neighbors_cuda_brute.cuh
+++ /dev/null
@@ -1,122 +0,0 @@
-/* Copyright Universitat Pompeu Fabra 2020-2023 https://www.compscience.org
- * Distributed under the MIT License.
- *(See accompanying file README.md file or copy at http://opensource.org/licenses/MIT)
- *
- * Raul P. Pelaez 2023. Brute force neighbor list construction in CUDA.
- *
- * A brute force approach that assigns a thread per each possible pair of particles in the system.
- * Based on an implementation by Raimondas Galvelis.
- * Works fantastically for small (less than 10K atoms) systems, but cannot handle more than 32K
- * atoms.
- */
-#ifndef NEIGHBORS_BRUTE_CUH
-#define NEIGHBORS_BRUTE_CUH
-#include "common.cuh"
-#include
-
-__device__ uint32_t get_row(uint32_t index) {
- uint32_t row = floor((sqrtf(8 * index + 1) + 1) / 2);
- if (row * (row - 1) > 2 * index)
- row--;
- return row;
-}
-
-template
-__global__ void forward_kernel_brute(uint32_t num_all_pairs, const Accessor positions,
- const Accessor batch, scalar_t cutoff_lower2,
- scalar_t cutoff_upper2, PairListAccessor list,
- triclinic::BoxAccessor box) {
- const uint32_t index = blockIdx.x * blockDim.x + threadIdx.x;
- if (index >= num_all_pairs)
- return;
- const uint32_t row = get_row(index);
- const uint32_t column = (index - row * (row - 1) / 2);
- const auto batch_row = batch[row];
- if (batch_row == batch[column]) {
- const auto pos_i = fetchPosition(positions, row);
- const auto pos_j = fetchPosition(positions, column);
- const auto box_row = box[batch_row];
- const auto delta = triclinic::compute_distance(pos_i, pos_j, list.use_periodic, box_row);
- const scalar_t distance2 = delta.x * delta.x + delta.y * delta.y + delta.z * delta.z;
- if (distance2 < cutoff_upper2 && distance2 >= cutoff_lower2) {
- const scalar_t r2 = sqrt_(distance2);
- addAtomPairToList(list, row, column, delta, r2, list.include_transpose);
- }
- }
-}
-
-template
-__global__ void add_self_kernel(const int num_atoms, Accessor positions,
- PairListAccessor list) {
- const int32_t i_atom = blockIdx.x * blockDim.x + threadIdx.x;
- if (i_atom >= num_atoms)
- return;
- __shared__ int i_pair;
- if (threadIdx.x == 0) { // Each block adds blockDim.x pairs to the list.
- // Handle the last block, so that only num_atoms are added in total
- i_pair =
- atomicAdd(&list.i_curr_pair[0], min(blockDim.x, num_atoms - blockIdx.x * blockDim.x));
- }
- __syncthreads();
- scalar3 delta{};
- scalar_t distance = 0;
- writeAtomPair(list, i_atom, i_atom, delta, distance, i_pair + threadIdx.x);
-}
-
-static std::tuple
-forward_brute(const Tensor& positions, const Tensor& batch, const Tensor& in_box_vectors,
- bool use_periodic, const Scalar& cutoff_lower, const Scalar& cutoff_upper,
- const Scalar& max_num_pairs, bool loop, bool include_transpose) {
- checkInput(positions, batch);
- const auto max_num_pairs_ = max_num_pairs.toLong();
- TORCH_CHECK(max_num_pairs_ > 0, "Expected \"max_num_neighbors\" to be positive");
- auto box_vectors = in_box_vectors.to(positions.device()).clone();
- if (box_vectors.dim() == 2) {
- // If the box is a 3x3 tensor it is assumed every sample has the same box
- if (use_periodic) {
- TORCH_CHECK(box_vectors.size(0) == 3 && box_vectors.size(1) == 3,
- "Expected \"box_vectors\" to have shape (3, 3)");
- }
- // Make the box (None,3,3), expand artificially to positions.size(0)
- box_vectors = box_vectors.unsqueeze(0);
- if (use_periodic) {
- // I use positions.size(0) because the batch dimension is not available here
- box_vectors = box_vectors.expand({positions.size(0), 3, 3});
- }
- }
- if (use_periodic) {
- TORCH_CHECK(box_vectors.dim() == 3, "Expected \"box_vectors\" to have three dimensions");
- TORCH_CHECK(box_vectors.size(1) == 3 && box_vectors.size(2) == 3,
- "Expected \"box_vectors\" to have shape (n_batch, 3, 3)");
- }
- const int num_atoms = positions.size(0);
- TORCH_CHECK(num_atoms < 32768, "The brute strategy fails with \"num_atoms\" larger than 32768");
- const int num_pairs = max_num_pairs_;
- const TensorOptions options = positions.options();
- const auto stream = getCurrentCUDAStream(positions.get_device());
- PairList list(num_pairs, positions.options(), loop, include_transpose, use_periodic);
- const CUDAStreamGuard guard(stream);
- const uint64_t num_all_pairs = num_atoms * (num_atoms - 1UL) / 2UL;
- const uint64_t num_threads = 128;
- const uint64_t num_blocks = std::max(static_cast((num_all_pairs + num_threads - 1UL) / num_threads), static_cast(1UL));
- AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "get_neighbor_pairs_forward", [&]() {
- PairListAccessor list_accessor(list);
- auto box = triclinic::get_box_accessor(box_vectors, use_periodic);
- const scalar_t cutoff_upper_ = cutoff_upper.to();
- const scalar_t cutoff_lower_ = cutoff_lower.to();
- TORCH_CHECK(cutoff_upper_ > 0, "Expected \"cutoff\" to be positive");
- forward_kernel_brute<<>>(
- num_all_pairs, get_accessor(positions), get_accessor(batch),
- cutoff_lower_ * cutoff_lower_, cutoff_upper_ * cutoff_upper_, list_accessor, box);
- if (loop) {
- const uint32_t num_threads_self = 256;
- const uint32_t num_blocks_self =
- std::max((num_atoms + num_threads_self - 1U) / num_threads_self, 1U);
- add_self_kernel<<>>(
- num_atoms, get_accessor(positions), list_accessor);
- }
- });
- return {list.neighbors, list.deltas, list.distances, list.i_curr_pair};
-}
-
-#endif
diff --git a/torchmdnet/extensions/neighbors/neighbors_cuda_cell.cuh b/torchmdnet/extensions/neighbors/neighbors_cuda_cell.cuh
deleted file mode 100644
index b54ab1c13..000000000
--- a/torchmdnet/extensions/neighbors/neighbors_cuda_cell.cuh
+++ /dev/null
@@ -1,390 +0,0 @@
-/* Copyright Universitat Pompeu Fabra 2020-2023 https://www.compscience.org
- * Distributed under the MIT License.
- *(See accompanying file README.md file or copy at http://opensource.org/licenses/MIT)
- * Raul P. Pelaez 2023. Batched cell list neighbor list implementation for CUDA.
- */
-#ifndef NEIGHBOR_CUDA_CELL_H
-#define NEIGHBOR_CUDA_CELL_H
-#include "common.cuh"
-
-/*
- * @brief Calculates the cell dimensions for a given box size and cutoff
- * @param box_size The box size
- * @param cutoff The cutoff
- * @return The cell dimensions
- */
-template
-__host__ __device__ int3 getCellDimensions(scalar3 box_size, scalar_t cutoff) {
- int3 cell_dim = make_int3(box_size.x / cutoff, box_size.y / cutoff, box_size.z / cutoff);
- // Minimum 3 cells in each dimension
- cell_dim.x = max(cell_dim.x, 3);
- cell_dim.y = max(cell_dim.y, 3);
- cell_dim.z = max(cell_dim.z, 3);
-// In the host, throw if there are more than 1024 cells in any dimension
-#ifndef __CUDA_ARCH__
- if (cell_dim.x > 1024 || cell_dim.y > 1024 || cell_dim.z > 1024) {
- throw std::runtime_error("Too many cells in one dimension. Maximum is 1024");
- }
-#endif
- return cell_dim;
-}
-
-/*
- * @brief Get the cell coordinates of a point
- * @param p The point position
- * @param box_size The size of the box in each dimension
- * @param cutoff The cutoff
- * @param cell_dim The number of cells in each dimension
- * @return The cell coordinates
- */
-template
-__device__ int3 getCell(scalar3 p, scalar3 box_size, scalar_t cutoff,
- int3 cell_dim) {
- p = rect::apply_pbc(p, box_size);
- // Take to the [0, box_size] range and divide by cutoff (which is the cell size)
- int cx = floorf((p.x + scalar_t(0.5) * box_size.x) / cutoff);
- int cy = floorf((p.y + scalar_t(0.5) * box_size.y) / cutoff);
- int cz = floorf((p.z + scalar_t(0.5) * box_size.z) / cutoff);
- // Wrap around. If the position of a particle is exactly box_size, it will be in the last cell,
- // which results in an illegal access down the line.
- if (cx == cell_dim.x)
- cx = 0;
- if (cy == cell_dim.y)
- cy = 0;
- if (cz == cell_dim.z)
- cz = 0;
- return make_int3(cx, cy, cz);
-}
-
-/*
- * @brief Get the index of a cell in a 1D array of cells.
- * @param cell The cell coordinates, assumed to be in the range [0, cell_dim].
- * @param cell_dim The number of cells in each dimension
- */
-__device__ int getCellIndex(int3 cell, int3 cell_dim) {
- return cell.x + cell_dim.x * (cell.y + cell_dim.y * cell.z);
-}
-
-/*
- @brief Fold a cell coordinate to the range [0, cell_dim)
- @param cell The cell coordinate
- @param cell_dim The dimensions of the grid
- @return The folded cell coordinate
-*/
-__device__ int3 getPeriodicCell(int3 cell, int3 cell_dim) {
- int3 periodic_cell = cell;
- if (cell.x < 0)
- periodic_cell.x += cell_dim.x;
- if (cell.x >= cell_dim.x)
- periodic_cell.x -= cell_dim.x;
- if (cell.y < 0)
- periodic_cell.y += cell_dim.y;
- if (cell.y >= cell_dim.y)
- periodic_cell.y -= cell_dim.y;
- if (cell.z < 0)
- periodic_cell.z += cell_dim.z;
- if (cell.z >= cell_dim.z)
- periodic_cell.z -= cell_dim.z;
- return periodic_cell;
-}
-
-// Computes and stores the cell index of each atom.
-template
-__global__ void assignCellIndex(const Accessor positions,
- Accessor cell_indices, scalar3 box_size,
- scalar_t cutoff) {
- const int32_t i_atom = blockIdx.x * blockDim.x + threadIdx.x;
- if (i_atom >= positions.size(0))
- return;
- const auto pi = fetchPosition(positions, i_atom);
- const auto cell_dim = getCellDimensions(box_size, cutoff);
- const auto ci = getCell(pi, box_size, cutoff, cell_dim);
- cell_indices[i_atom] = getCellIndex(ci, cell_dim);
-}
-
-/*
- * @brief Sort the positions by cell index
- * @param positions The positions of the atoms
- * @param box_size The box vectors
- * @param cutoff The cutoff
- * @return A tuple of the sorted indices and cell indices
- */
-static auto sortAtomsByCellIndex(const Tensor& positions, const Tensor& box_size,
- const Scalar& cutoff) {
- const int num_atoms = positions.size(0);
- Tensor cell_index = empty({num_atoms}, positions.options().dtype(torch::kInt32));
- const int threads = 128;
- const int blocks = (num_atoms + threads - 1) / threads;
- auto stream = at::cuda::getCurrentCUDAStream();
- AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "assignHash", [&] {
- scalar_t cutoff_ = cutoff.to();
- scalar3 box_size_ = {box_size[0][0].item(),
- box_size[1][1].item(),
- box_size[2][2].item()};
- assignCellIndex<<>>(get_accessor(positions),
- get_accessor(cell_index),
- box_size_, cutoff_);
- });
- // Sort the atom indices by cell index
- Tensor sorted_atom_index;
- Tensor sorted_cell_index;
- std::tie(sorted_cell_index, sorted_atom_index) = torch::sort(cell_index);
- return std::make_tuple(sorted_atom_index.to(torch::kInt32), sorted_cell_index);
-}
-
-__global__ void fillCellOffsetsD(const Accessor sorted_cell_indices,
- Accessor cell_start, Accessor cell_end) {
- // Since positions are sorted by cell, for a given atom, if the previous atom is in a different
- // cell, then the current atom is the first atom in its cell We use this fact to fill the
- // cell_start and cell_end arrays
- const int32_t i_atom = blockIdx.x * blockDim.x + threadIdx.x;
- if (i_atom >= sorted_cell_indices.size(0))
- return;
- const int icell = sorted_cell_indices[i_atom];
- int im1_cell;
- if (i_atom > 0) {
- const int im1 = i_atom - 1;
- im1_cell = sorted_cell_indices[im1];
- } else {
- im1_cell = 0;
- }
- if (icell != im1_cell || i_atom == 0) {
- cell_start[icell] = i_atom;
- if (i_atom > 0) {
- cell_end[im1_cell] = i_atom;
- }
- }
- if (i_atom == sorted_cell_indices.size(0) - 1) {
- cell_end[icell] = i_atom + 1;
- }
-}
-
-/*
- @brief Fills the cell_start and cell_end arrays, identifying the first and last atom in each cell
- @param sorted_cell_indices The cell indices of each position
- @param cell_dim The dimensions of the cell grid
- @return A tuple of cell_start and cell_end arrays
-*/
-static auto fillCellOffsets(const Tensor& sorted_cell_indices, int3 cell_dim) {
- const TensorOptions options = sorted_cell_indices.options();
- const int num_cells = cell_dim.x * cell_dim.y * cell_dim.z;
- const Tensor cell_start = full({num_cells}, -1, options.dtype(torch::kInt));
- const Tensor cell_end = empty({num_cells}, options.dtype(torch::kInt));
- const int threads = 128;
- const int blocks = (sorted_cell_indices.size(0) + threads - 1) / threads;
- auto stream = at::cuda::getCurrentCUDAStream();
- fillCellOffsetsD<<>>(get_accessor(sorted_cell_indices),
- get_accessor(cell_start),
- get_accessor(cell_end));
- return std::make_tuple(cell_start, cell_end);
-}
-
-/*
- @brief Get the cell index of the i'th neighboring cell for a given cell
- @param cell_i The cell coordinates
- @param i The index of the neighboring cell, from 0 to 26
- @param cell_dim The dimensions of the cell grid
- @return The cell index of the i'th neighboring cell
-*/
-__device__ int getNeighborCellIndex(int3 cell_i, int i, int3 cell_dim) {
- auto cell_j = cell_i;
- cell_j.x += i % 3 - 1;
- cell_j.y += (i / 3) % 3 - 1;
- cell_j.z += i / 9 - 1;
- cell_j = getPeriodicCell(cell_j, cell_dim);
- const int icellj = getCellIndex(cell_j, cell_dim);
- return icellj;
-}
-
-template struct Particle {
- int index; // Index in the sorted arrays
- int original_index; // Index in the original arrays
- int batch;
- scalar3 position;
- scalar_t cutoff_upper2, cutoff_lower2;
-};
-
-struct CellList {
- Tensor cell_start, cell_end;
- Tensor sorted_indices;
- Tensor sorted_positions, sorted_batch;
-};
-
-CellList constructCellList(const Tensor& positions, const Tensor& batch, const Tensor& box_size,
- const Scalar& cutoff) {
- // The algorithm for the cell list construction can be summarized in three separate steps:
- // 1. Label the particles according to the cell (bin) they lie in.
- // 2. Sort the particles using the cell index as the ordering label
- // (technically this is known as sorting by key). So that particles with positions
- // lying in the same cell become contiguous in memory.
- // 3. Identify where each cell starts and ends in the sorted particle positions
- // array.
- const TensorOptions options = positions.options();
- CellList cl;
- Tensor sorted_cell_indices;
- // Steps 1 and 2
- std::tie(cl.sorted_indices, sorted_cell_indices) =
- sortAtomsByCellIndex(positions, box_size, cutoff);
- cl.sorted_positions = positions.index_select(0, cl.sorted_indices);
- cl.sorted_batch = batch.index_select(0, cl.sorted_indices);
- // Step 3
- int3 cell_dim;
- AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "computeCellDim", [&] {
- scalar_t cutoff_ = cutoff.to();
- scalar3 box_size_ = {box_size[0][0].item(),
- box_size[1][1].item(),
- box_size[2][2].item()};
- cell_dim = getCellDimensions(box_size_, cutoff_);
- });
- std::tie(cl.cell_start, cl.cell_end) = fillCellOffsets(sorted_cell_indices, cell_dim);
- return cl;
-}
-
-template struct CellListAccessor {
- Accessor cell_start, cell_end;
- Accessor sorted_indices;
- Accessor sorted_positions;
- Accessor sorted_batch;
-
- explicit CellListAccessor(const CellList& cl)
- : cell_start(get_accessor(cl.cell_start)),
- cell_end(get_accessor(cl.cell_end)),
- sorted_indices(get_accessor(cl.sorted_indices)),
- sorted_positions(get_accessor(cl.sorted_positions)),
- sorted_batch(get_accessor(cl.sorted_batch)) {
- }
-};
-
-/*
- * @brief Add a pair of particles to the pair list. If necessary, also add the transpose pair.
- * @param list The pair list
- * @param i The index of the first particle
- * @param j The index of the second particle
- * @param distance2 The squared distance between the particles
- * @param delta The vector between the particles
- */
-template
-__device__ void addNeighborPair(PairListAccessor& list, const int i, const int j,
- scalar_t distance2, scalar3 delta) {
- const bool requires_transpose = list.include_transpose && (j != i);
- const int ni = max(i, j);
- const int nj = min(i, j);
- const scalar_t delta_sign = (ni == i) ? scalar_t(1.0) : scalar_t(-1.0);
- const scalar_t distance = sqrt_(distance2);
- delta = {delta_sign * delta.x, delta_sign * delta.y, delta_sign * delta.z};
- addAtomPairToList(list, ni, nj, delta, distance, requires_transpose);
-}
-
-/*
- * @brief Add to the pair list all neighbors of particle i_atom in cell j_cell
- * @param i_atom The Information of the particle for which we are adding neighbors
- * @param j_cell The index of the cell in which we are looking for neighbors
- * @param cl The cell list
- * @param box_size The box size
- * @param list The pair list
- */
-template
-__device__ void addNeighborsForCell(const Particle& i_atom, int j_cell,
- const CellListAccessor& cl,
- scalar3 box_size, PairListAccessor& list) {
- const auto first_particle = cl.cell_start[j_cell];
- if (first_particle != -1) { // Continue only if there are particles in this cell
- const auto last_particle = cl.cell_end[j_cell];
- for (int cur_j = first_particle; cur_j < last_particle; cur_j++) {
- const auto j_batch = cl.sorted_batch[cur_j];
- if ((j_batch == i_atom.batch) &&
- ((cur_j < i_atom.index) || (list.loop && cur_j == i_atom.index))) {
- const auto position_j = fetchPosition(cl.sorted_positions, cur_j);
- const auto delta = rect::compute_distance(i_atom.position, position_j,
- list.use_periodic, box_size);
- const scalar_t distance2 =
- delta.x * delta.x + delta.y * delta.y + delta.z * delta.z;
- if ( ((distance2 < i_atom.cutoff_upper2 && distance2 >= i_atom.cutoff_lower2)) ||
- (list.loop && cur_j == i_atom.index)) {
- const int orj = cl.sorted_indices[cur_j];
- addNeighborPair(list, i_atom.original_index, orj, distance2, delta);
- } // endif
- } // endif
- } // endfor
- } // endif
-}
-
-// Traverse the cell list for each atom and find the neighbors
-template
-__global__ void traverseCellList(const CellListAccessor cell_list,
- PairListAccessor list, int num_atoms,
- scalar3 box_size, scalar_t cutoff_lower,
- scalar_t cutoff_upper) {
- // Each atom traverses the cells around it and finds the neighbors
- // Atoms for all batches are placed in the same cell list, but other batches are ignored while
- // traversing
- Particle i_atom;
- i_atom.index = blockIdx.x * blockDim.x + threadIdx.x;
- if (i_atom.index >= num_atoms) {
- return;
- }
- i_atom.original_index = cell_list.sorted_indices[i_atom.index];
- i_atom.batch = cell_list.sorted_batch[i_atom.index];
- i_atom.position = fetchPosition(cell_list.sorted_positions, i_atom.index);
- i_atom.cutoff_lower2 = cutoff_lower * cutoff_lower;
- i_atom.cutoff_upper2 = cutoff_upper * cutoff_upper;
- const int3 cell_dim = getCellDimensions(box_size, cutoff_upper);
- const int3 cell_i = getCell(i_atom.position, box_size, cutoff_upper, cell_dim);
- // Loop over the 27 cells around the current cell
- for (int i = 0; i < 27; i++) {
- const int neighbor_cell = getNeighborCellIndex(cell_i, i, cell_dim);
- addNeighborsForCell(i_atom, neighbor_cell, cell_list, box_size, list);
- }
-}
-
-static std::tuple
-forward_cell(const Tensor& positions, const Tensor& batch, const Tensor& in_box_size,
- bool use_periodic, const Scalar& cutoff_lower, const Scalar& cutoff_upper,
- const Scalar& max_num_pairs, bool loop, bool include_transpose) {
- // This module computes the pair list for a given set of particles, which may be in multiple
- // batches. The strategy is to first compute a cell list for all particles, and then
- // traverse the cell list for each particle to construct a pair list.
- checkInput(positions, batch);
- auto box_size = in_box_size.to("cpu");
- // If the box has dimensions (1, 3,3) squeeze it
- if (box_size.dim() == 3) {
- TORCH_CHECK(box_size.size(0) == 1 && box_size.size(1) == 3 && box_size.size(2) == 3,
- "Cell list does not support a box per sample. Expected \"box_size\" to have shape (1, 3, 3) or (3, 3)");
- box_size = box_size.squeeze(0);
- }
-
- TORCH_CHECK(box_size.dim() == 2, "Expected \"box_size\" to have two dimensions");
- TORCH_CHECK(box_size.size(0) == 3 && box_size.size(1) == 3,
- "Expected \"box_size\" to have shape (3, 3)");
- TORCH_CHECK(box_size[0][1].item() == 0 && box_size[0][2].item() == 0 &&
- box_size[1][0].item() == 0 && box_size[1][2].item() == 0 &&
- box_size[2][0].item() == 0 && box_size[2][1].item() == 0,
- "Expected \"box_size\" to be diagonal");
- const auto max_num_pairs_ = max_num_pairs.toInt();
- TORCH_CHECK(max_num_pairs_ > 0, "Expected \"max_num_neighbors\" to be positive");
- const int num_atoms = positions.size(0);
- const auto cell_list = constructCellList(positions, batch, box_size, cutoff_upper);
- PairList list(max_num_pairs_, positions.options(), loop, include_transpose, use_periodic);
- const auto stream = getCurrentCUDAStream(positions.get_device());
- { // Traverse the cell list to find the neighbors
- const CUDAStreamGuard guard(stream);
- AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "forward", [&] {
- const scalar_t cutoff_upper_ = cutoff_upper.to();
- TORCH_CHECK(cutoff_upper_ > 0, "Expected cutoff_upper to be positive");
- const scalar_t cutoff_lower_ = cutoff_lower.to();
- const scalar3 box_size_ = {box_size[0][0].item(),
- box_size[1][1].item(),
- box_size[2][2].item()};
- PairListAccessor list_accessor(list);
- CellListAccessor cell_list_accessor(cell_list);
- const int threads = 128;
- const int blocks = (num_atoms + threads - 1) / threads;
- traverseCellList<<>>(cell_list_accessor, list_accessor,
- num_atoms, box_size_, cutoff_lower_,
- cutoff_upper_);
- });
- }
- return {list.neighbors, list.deltas, list.distances, list.i_curr_pair};
-}
-
-#endif
diff --git a/torchmdnet/extensions/neighbors/neighbors_cuda_shared.cuh b/torchmdnet/extensions/neighbors/neighbors_cuda_shared.cuh
deleted file mode 100644
index 950a3bb6b..000000000
--- a/torchmdnet/extensions/neighbors/neighbors_cuda_shared.cuh
+++ /dev/null
@@ -1,124 +0,0 @@
-/* Copyright Universitat Pompeu Fabra 2020-2023 https://www.compscience.org
- * Distributed under the MIT License.
- *(See accompanying file README.md file or copy at http://opensource.org/licenses/MIT)
- * Raul P. Pelaez 2023. Shared memory neighbor list construction for CUDA.
- * This brute force approach checks all pairs of atoms by collaborativelly loading and processing
- * tiles of atoms into shared memory.
- * This approach is tipically slower than the brute force approach, but can handle an arbitrarily
- * large number of atoms.
- */
-#ifndef NEIGHBORS_SHARED_CUH
-#define NEIGHBORS_SHARED_CUH
-#include "common.cuh"
-#include
-
-template
-__global__ void forward_kernel_shared(uint32_t num_atoms, const Accessor positions,
- const Accessor batch, scalar_t cutoff_lower2,
- scalar_t cutoff_upper2, PairListAccessor list,
- int32_t num_tiles, triclinic::BoxAccessor box) {
- // A thread per atom
- const int id = blockIdx.x * blockDim.x + threadIdx.x;
- // All threads must pass through __syncthreads,
- // but when N is not a multiple of 32 some threads are assigned a particle i>N.
- // This threads cant return, so they are masked to not do any work
- const bool active = id < num_atoms;
- __shared__ scalar3 sh_pos[BLOCKSIZE];
- __shared__ int64_t sh_batch[BLOCKSIZE];
- scalar3 pos_i;
- int64_t batch_i;
- if (active) {
- pos_i = fetchPosition(positions, id);
- batch_i = batch[id];
- }
- // Distribute the N particles in a group of tiles. Storing in each tile blockDim.x values in
- // shared memory. This way all threads are accesing the same memory addresses at the same time
- for (int tile = 0; tile < num_tiles; tile++) {
- // Load this tiles particles values to shared memory
- const int i_load = tile * blockDim.x + threadIdx.x;
- if (i_load < num_atoms) { // Even if im not active, my thread may load a value each tile to
- // shared memory.
- sh_pos[threadIdx.x] = fetchPosition(positions, i_load);
- sh_batch[threadIdx.x] = batch[i_load];
- }
- // Wait for all threads to arrive
- __syncthreads();
- // Go through all the particles in the current tile
-#pragma unroll 8
- for (int counter = 0; counter < blockDim.x; counter++) {
- if (!active)
- break; // An out of bounds thread must be masked
- const int cur_j = tile * blockDim.x + counter;
- const bool testPair = (cur_j < num_atoms) && (cur_j < id || (list.loop && cur_j == id));
- if (testPair) {
- const auto batch_j = sh_batch[counter];
- if (batch_i == batch_j) {
- const auto pos_j = sh_pos[counter];
- const auto box_i = box[batch_i];
- const auto delta =
- triclinic::compute_distance(pos_i, pos_j, list.use_periodic, box_i);
- const scalar_t distance2 =
- delta.x * delta.x + delta.y * delta.y + delta.z * delta.z;
- if (distance2 < cutoff_upper2 && distance2 >= cutoff_lower2) {
- const bool requires_transpose = list.include_transpose && !(cur_j == id);
- const auto distance = sqrt_(distance2);
- addAtomPairToList(list, id, cur_j, delta, distance, requires_transpose);
- }
- }
- }
- }
- __syncthreads();
- }
-}
-
-static std::tuple
-forward_shared(const Tensor& positions, const Tensor& batch, const Tensor& in_box_vectors,
- bool use_periodic, const Scalar& cutoff_lower, const Scalar& cutoff_upper,
- const Scalar& max_num_pairs, bool loop, bool include_transpose) {
- checkInput(positions, batch);
- const auto max_num_pairs_ = max_num_pairs.toLong();
- auto box_vectors = in_box_vectors.to(positions.device());
- if (box_vectors.dim() == 2) {
- // If the box is a 3x3 tensor it is assumed every sample has the same box
- if (use_periodic) {
- TORCH_CHECK(box_vectors.size(0) == 3 && box_vectors.size(1) == 3,
- "Expected \"box_vectors\" to have shape (n_batch, 3, 3)");
- }
- // Make the box (None,3,3), expand artificially to positions.size(0)
- box_vectors = box_vectors.unsqueeze(0);
- if (use_periodic) {
- // I use positions.size(0) because the batch dimension is not available here
- box_vectors = box_vectors.expand({positions.size(0), 3, 3});
- }
- }
- TORCH_CHECK(max_num_pairs_ > 0, "Expected \"max_num_neighbors\" to be positive");
- if (use_periodic) {
- TORCH_CHECK(box_vectors.dim() == 3, "Expected \"box_vectors\" to have three dimensions");
- TORCH_CHECK(box_vectors.size(1) == 3 && box_vectors.size(2) == 3,
- "Expected \"box_vectors\" to have shape (n_batch, 3, 3)");
- }
- const int num_atoms = positions.size(0);
- const int num_pairs = max_num_pairs_;
- const TensorOptions options = positions.options();
- const auto stream = getCurrentCUDAStream(positions.get_device());
- PairList list(num_pairs, positions.options(), loop, include_transpose, use_periodic);
- const CUDAStreamGuard guard(stream);
- AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "get_neighbor_pairs_shared_forward", [&]() {
- const scalar_t cutoff_upper_ = cutoff_upper.to();
- const scalar_t cutoff_lower_ = cutoff_lower.to();
- auto box = triclinic::get_box_accessor(box_vectors, use_periodic);
- TORCH_CHECK(cutoff_upper_ > 0, "Expected \"cutoff\" to be positive");
- constexpr int BLOCKSIZE = 64;
- const int num_blocks = std::max((num_atoms + BLOCKSIZE - 1) / BLOCKSIZE, 1);
- const int num_threads = BLOCKSIZE;
- const int num_tiles = num_blocks;
- PairListAccessor list_accessor(list);
- forward_kernel_shared<<>>(
- num_atoms, get_accessor(positions), get_accessor(batch),
- cutoff_lower_ * cutoff_lower_, cutoff_upper_ * cutoff_upper_, list_accessor, num_tiles,
- box);
- });
- return {list.neighbors, list.deltas, list.distances, list.i_curr_pair};
-}
-
-#endif
diff --git a/torchmdnet/extensions/ops.py b/torchmdnet/extensions/ops.py
index 8a03c87b6..a9ebbce24 100644
--- a/torchmdnet/extensions/ops.py
+++ b/torchmdnet/extensions/ops.py
@@ -10,23 +10,21 @@
import torch
from torch import Tensor
from typing import Tuple
-from torch.library import register_fake
+import logging
+from torchmdnet.extensions.neighbors import torch_neighbor_bruteforce
+try:
+ import triton
+ from torchmdnet.extensions.triton_neighbors import triton_neighbor_pairs
-__all__ = ["is_current_stream_capturing", "get_neighbor_pairs_kernel"]
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
-def is_current_stream_capturing():
- """Returns True if the current CUDA stream is capturing.
+logger = logging.getLogger(__name__)
- Returns False if CUDA is not available or the current stream is not capturing.
-
- This utility is required because the builtin torch function that does this is not scriptable.
- """
- _is_current_stream_capturing = (
- torch.ops.torchmdnet_extensions.is_current_stream_capturing
- )
- return _is_current_stream_capturing()
+__all__ = ["get_neighbor_pairs_kernel"]
def get_neighbor_pairs_kernel(
@@ -40,6 +38,7 @@ def get_neighbor_pairs_kernel(
max_num_pairs: int,
loop: bool,
include_transpose: bool,
+ num_cells: int,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Computes the neighbor pairs for a given set of atomic positions.
The list is generated as a list of pairs (i,j) without any enforced ordering.
@@ -48,7 +47,7 @@ def get_neighbor_pairs_kernel(
Parameters
----------
strategy : str
- Strategy to use for computing the neighbor list. Can be one of :code:`["shared", "brute", "cell"]`.
+ Strategy to use for computing the neighbor list. Can be one of :code:`["brute", "cell"]`.
positions : Tensor
A tensor with shape (N, 3) representing the atomic positions.
batch : Tensor
@@ -67,7 +66,8 @@ def get_neighbor_pairs_kernel(
Whether to include self-interactions.
include_transpose : bool
Whether to include the transpose of the neighbor list (pair i,j and pair j,i).
-
+ num_cells : int
+ The number of cells in the grid if using the cell strategy.
Returns
-------
neighbors : Tensor
@@ -79,53 +79,46 @@ def get_neighbor_pairs_kernel(
num_pairs : Tensor
The number of pairs found.
"""
- return torch.ops.torchmdnet_extensions.get_neighbor_pairs(
+ if torch.jit.is_scripting() or not positions.is_cuda:
+
+ return torch_neighbor_bruteforce(
+ positions,
+ batch=batch,
+ box_vectors=box_vectors,
+ use_periodic=use_periodic,
+ cutoff_lower=cutoff_lower,
+ cutoff_upper=cutoff_upper,
+ max_num_pairs=max_num_pairs,
+ loop=loop,
+ include_transpose=include_transpose,
+ )
+
+ if not HAS_TRITON:
+ logger.warning(
+ "Triton is not available, using torch version of the neighbor pairs kernel."
+ )
+ return torch_neighbor_bruteforce(
+ positions,
+ batch=batch,
+ box_vectors=box_vectors,
+ use_periodic=use_periodic,
+ cutoff_lower=cutoff_lower,
+ cutoff_upper=cutoff_upper,
+ max_num_pairs=max_num_pairs,
+ loop=loop,
+ include_transpose=include_transpose,
+ )
+
+ return triton_neighbor_pairs(
strategy,
- positions,
- batch,
- box_vectors,
- use_periodic,
- cutoff_lower,
- cutoff_upper,
- max_num_pairs,
- loop,
- include_transpose,
+ positions=positions,
+ batch=batch,
+ box_vectors=box_vectors,
+ use_periodic=use_periodic,
+ cutoff_lower=cutoff_lower,
+ cutoff_upper=cutoff_upper,
+ max_num_pairs=max_num_pairs,
+ loop=loop,
+ include_transpose=include_transpose,
+ num_cells=num_cells,
)
-
-
-# Registers a FakeTensor kernel (aka "meta kernel", "abstract impl")
-# that describes what the properties of the output Tensor are given
-# the properties of the input Tensor. The FakeTensor kernel is necessary
-# for the op to work performantly with torch.compile.
-@torch.library.register_fake("torchmdnet_extensions::get_neighbor_pairs_bkwd")
-def _(
- grad_edge_vec: Tensor,
- grad_edge_weight: Tensor,
- edge_index: Tensor,
- edge_vec: Tensor,
- edge_weight: Tensor,
- num_atoms: int,
-):
- return torch.zeros((num_atoms, 3), dtype=edge_vec.dtype, device=edge_vec.device)
-
-
-@torch.library.register_fake("torchmdnet_extensions::get_neighbor_pairs_fwd")
-def _(
- strategy: str,
- positions: Tensor,
- batch: Tensor,
- box_vectors: Tensor,
- use_periodic: bool,
- cutoff_lower: float,
- cutoff_upper: float,
- max_num_pairs: int,
- loop: bool,
- include_transpose: bool,
-) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
- """Returns empty vectors with the correct shape for the output of get_neighbor_pairs_kernel."""
- size = max_num_pairs
- edge_index = torch.empty((2, size), dtype=torch.int32, device=positions.device)
- edge_distance = torch.empty((size,), dtype=positions.dtype, device=positions.device)
- edge_vec = torch.empty((size, 3), dtype=positions.dtype, device=positions.device)
- num_pairs = torch.empty((1,), dtype=torch.int32, device=positions.device)
- return edge_index, edge_vec, edge_distance, num_pairs
diff --git a/torchmdnet/extensions/torchmdnet_extensions.cpp b/torchmdnet/extensions/torchmdnet_extensions.cpp
deleted file mode 100644
index 992a3b178..000000000
--- a/torchmdnet/extensions/torchmdnet_extensions.cpp
+++ /dev/null
@@ -1,73 +0,0 @@
-/*
- * Copyright Universitat Pompeu Fabra 2020-2023 https://www.compscience.org
- * Distributed under the MIT License.
- *(See accompanying file README.md file or copy at http://opensource.org/licenses/MIT)
- * Raul P. Pelaez 2023. Torch extensions to the torchmdnet library.
- * You can expose functions to python here which will be compatible with TorchScript.
- * Add your exports to the TORCH_LIBRARY macro below, see __init__.py to see how to access them from python.
- * The WITH_CUDA macro will be defined when compiling with CUDA support.
- */
-
-
-#include
-#include
-#include
-#include
-
-#if defined(WITH_CUDA)
-#include
-#include
-#endif
-
-
-extern "C" {
- /* Creates a dummy empty torchmdnet_extensions module that can be imported from Python.
- The import from Python will load the .so consisting of this file
- in this extension, so that the TORCH_LIBRARY static initializers
- below are run. */
- PyObject* PyInit_torchmdnet_extensions(void)
- {
- static struct PyModuleDef module_def = {
- PyModuleDef_HEAD_INIT,
- "torchmdnet_extensions", /* name of module */
- NULL, /* module documentation, may be NULL */
- -1, /* size of per-interpreter state of the module,
- or -1 if the module keeps state in global variables. */
- NULL, /* methods */
- };
- return PyModule_Create(&module_def);
- }
-}
-
-/* @brief Returns true if the current torch CUDA stream is capturing.
- * This function is required because the one available in torch is not compatible with TorchScript.
- * @return True if the current torch CUDA stream is capturing.
- */
-bool is_current_stream_capturing() {
-#if defined(WITH_CUDA)
- auto current_stream = at::cuda::getCurrentCUDAStream().stream();
- cudaStreamCaptureStatus capture_status;
- cudaError_t err = cudaStreamGetCaptureInfo(current_stream, &capture_status, nullptr);
- if (err != cudaSuccess) {
- throw std::runtime_error(cudaGetErrorString(err));
- }
- return capture_status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive;
-#else
- return false;
-#endif
-}
-
-TORCH_LIBRARY(torchmdnet_extensions, m) {
- m.def("is_current_stream_capturing", is_current_stream_capturing);
- m.def("get_neighbor_pairs(str strategy, Tensor positions, Tensor batch, Tensor box_vectors, "
- "bool use_periodic, Scalar cutoff_lower, Scalar cutoff_upper, Scalar max_num_pairs, bool "
- "loop, bool include_transpose) -> (Tensor neighbors, Tensor distances, Tensor "
- "distance_vecs, Tensor num_pairs)");
- //The individual fwd and bkwd functions must be exposed in order to register their meta implementations python side.
- m.def("get_neighbor_pairs_fwd(str strategy, Tensor positions, Tensor batch, Tensor box_vectors, "
- "bool use_periodic, Scalar cutoff_lower, Scalar cutoff_upper, Scalar max_num_pairs, bool "
- "loop, bool include_transpose) -> (Tensor neighbors, Tensor distances, Tensor "
- "distance_vecs, Tensor num_pairs)");
- m.def("get_neighbor_pairs_bkwd(Tensor grad_edge_vec, Tensor grad_edge_weight, Tensor edge_index, "
- "Tensor edge_vec, Tensor edge_weight, int num_atoms) -> Tensor");
-}
diff --git a/torchmdnet/extensions/triton_brute.py b/torchmdnet/extensions/triton_brute.py
new file mode 100644
index 000000000..6a6e19019
--- /dev/null
+++ b/torchmdnet/extensions/triton_brute.py
@@ -0,0 +1,255 @@
+# Copyright Universitat Pompeu Fabra 2020-2023 https://www.compscience.org
+# Distributed under the MIT License.
+# (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT)
+import triton
+import triton.language as tl
+from torch import Tensor
+import torch
+from torchmdnet.extensions.triton_neighbors import _tl_round, TritonNeighborAutograd
+from torch.library import triton_op, wrap_triton
+from typing import Tuple
+
+
+@triton.jit
+def _neighbor_brute_kernel(
+ pos_ptr,
+ batch_ptr,
+ box_ptr,
+ neighbors0_ptr,
+ neighbors1_ptr,
+ deltas_ptr,
+ distances_ptr,
+ counter_ptr,
+ box_batch_stride,
+ n_atoms,
+ num_all_pairs,
+ use_periodic: tl.constexpr,
+ include_transpose: tl.constexpr,
+ loop: tl.constexpr,
+ max_pairs,
+ cutoff_lower,
+ cutoff_upper,
+ BLOCK: tl.constexpr,
+):
+ """Brute-force neighbor list kernel using triangular indexing and atomic compaction.
+
+ Uses triangular indexing to iterate over only n*(n-1)/2 pairs (or n*(n+1)/2 with loop),
+ achieving 100% thread utilization while maintaining block-level atomic compaction.
+ """
+ pid = tl.program_id(axis=0)
+ start = pid * BLOCK
+ idx = start + tl.arange(0, BLOCK)
+
+ valid = idx < num_all_pairs
+
+ # Convert linear index to (i, j) using triangular formula (same as CUDA get_row)
+ # Do integer arithmetic first, only convert to float for sqrt
+ if loop:
+ # With self-loops: j <= i, num_pairs = n*(n+1)/2
+ # row = floor((-1 + sqrt(1 + 8k)) / 2)
+ sqrt_arg = (1 + 8 * idx).to(tl.float32)
+ row_f = tl.math.floor((-1.0 + tl.math.sqrt(sqrt_arg)) * 0.5)
+ i = row_f.to(tl.int32)
+ # Handle floating-point edge case: if i*(i+1)/2 > idx, decrement i
+ i = tl.where(i * (i + 1) > 2 * idx, i - 1, i)
+ # col = k - row*(row+1)/2
+ j = idx - (i * (i + 1)) // 2
+ else:
+ # Without self-loops: j < i, num_pairs = n*(n-1)/2
+ # row = floor((1 + sqrt(1 + 8k)) / 2)
+ sqrt_arg = (1 + 8 * idx).to(tl.float32)
+ row_f = tl.math.floor((1.0 + tl.math.sqrt(sqrt_arg)) * 0.5)
+ i = row_f.to(tl.int32)
+ # Handle floating-point edge case (same correction as CUDA)
+ i = tl.where(i * (i - 1) > 2 * idx, i - 1, i)
+ # col = k - row*(row-1)/2
+ j = idx - (i * (i - 1)) // 2
+
+ # Validate indices: check bounds and triangular constraint
+ # Due to float precision, we may get invalid (i, j) pairs
+ valid = valid & (i >= 0) & (i < n_atoms) & (j >= 0) & (j <= i)
+ if not loop:
+ # For non-loop case, also require j < i (no self-loops)
+ valid = valid & (j < i)
+
+ batch_i = tl.load(batch_ptr + i, mask=valid, other=0)
+ batch_j = tl.load(batch_ptr + j, mask=valid, other=0)
+ valid = valid & (batch_i == batch_j)
+
+ pos_ix = tl.load(pos_ptr + i * 3 + 0, mask=valid, other=0.0)
+ pos_iy = tl.load(pos_ptr + i * 3 + 1, mask=valid, other=0.0)
+ pos_iz = tl.load(pos_ptr + i * 3 + 2, mask=valid, other=0.0)
+ pos_jx = tl.load(pos_ptr + j * 3 + 0, mask=valid, other=0.0)
+ pos_jy = tl.load(pos_ptr + j * 3 + 1, mask=valid, other=0.0)
+ pos_jz = tl.load(pos_ptr + j * 3 + 2, mask=valid, other=0.0)
+
+ dx = pos_ix - pos_jx
+ dy = pos_iy - pos_jy
+ dz = pos_iz - pos_jz
+
+ if use_periodic:
+ box_base = box_ptr + batch_i * box_batch_stride
+
+ b20 = tl.load(box_base + 2 * 3 + 0, mask=valid, other=0.0)
+ b21 = tl.load(box_base + 2 * 3 + 1, mask=valid, other=0.0)
+ b22 = tl.load(box_base + 2 * 3 + 2, mask=valid, other=1.0)
+ b10 = tl.load(box_base + 1 * 3 + 0, mask=valid, other=0.0)
+ b11 = tl.load(box_base + 1 * 3 + 1, mask=valid, other=1.0)
+ b00 = tl.load(box_base + 0 * 3 + 0, mask=valid, other=1.0)
+
+ scale3 = _tl_round(dz / b22)
+ dx = dx - scale3 * b20
+ dy = dy - scale3 * b21
+ dz = dz - scale3 * b22
+ scale2 = _tl_round(dy / b11)
+ dx = dx - scale2 * b10
+ dy = dy - scale2 * b11
+ scale1 = _tl_round(dx / b00)
+ dx = dx - scale1 * b00
+
+ dist2 = dx * dx + dy * dy + dz * dz
+ dist = tl.sqrt(dist2)
+ # Self-loops (i == j) are exempt from cutoff_lower since they have distance 0
+ is_self_loop = i == j
+ valid = valid & (dist < cutoff_upper) & ((dist >= cutoff_lower) | is_self_loop)
+
+ valid_int = valid.to(tl.int32)
+ local_idx = tl.cumsum(valid_int, axis=0) - 1
+ total_pairs = tl.sum(valid_int, axis=0)
+
+ # For transpose, don't count self-loops (i == j)
+ if include_transpose:
+ is_not_self_loop = i != j
+ valid_transpose = valid & is_not_self_loop
+ total_transpose = tl.sum(valid_transpose.to(tl.int32), axis=0)
+ total_out = total_pairs + total_transpose
+ else:
+ total_out = total_pairs
+
+ has_work = total_out > 0
+ start_idx = tl.atomic_add(counter_ptr, total_out, mask=has_work)
+ start_idx = tl.where(has_work, start_idx, 0)
+ start_idx = tl.broadcast_to(start_idx, local_idx.shape)
+
+ write_idx = start_idx + local_idx
+ mask_store = valid & (write_idx < max_pairs)
+ tl.store(neighbors0_ptr + write_idx, i, mask=mask_store)
+ tl.store(neighbors1_ptr + write_idx, j, mask=mask_store)
+ tl.store(deltas_ptr + write_idx * 3 + 0, dx, mask=mask_store)
+ tl.store(deltas_ptr + write_idx * 3 + 1, dy, mask=mask_store)
+ tl.store(deltas_ptr + write_idx * 3 + 2, dz, mask=mask_store)
+ tl.store(distances_ptr + write_idx, dist, mask=mask_store)
+
+ if include_transpose:
+ # Don't add transpose for self-loops (i == j)
+ is_not_self_loop = i != j
+ valid_transpose = valid & is_not_self_loop
+ valid_t_int = valid_transpose.to(tl.int32)
+ local_idx_t = tl.cumsum(valid_t_int, axis=0) - 1
+ total_pairs_t = tl.sum(valid_t_int, axis=0)
+
+ write_idx_t = start_idx + total_pairs + local_idx_t
+ mask_store_t = valid_transpose & (write_idx_t < max_pairs)
+ tl.store(neighbors0_ptr + write_idx_t, j, mask=mask_store_t)
+ tl.store(neighbors1_ptr + write_idx_t, i, mask=mask_store_t)
+ tl.store(deltas_ptr + write_idx_t * 3 + 0, -dx, mask=mask_store_t)
+ tl.store(deltas_ptr + write_idx_t * 3 + 1, -dy, mask=mask_store_t)
+ tl.store(deltas_ptr + write_idx_t * 3 + 2, -dz, mask=mask_store_t)
+ tl.store(distances_ptr + write_idx_t, dist, mask=mask_store_t)
+
+
+@triton_op("torchmdnet::triton_neighbor_bruteforce", mutates_args={})
+def triton_neighbor_bruteforce(
+ positions: Tensor,
+ batch: Tensor,
+ box_vectors: Tensor,
+ use_periodic: bool,
+ cutoff_lower: float,
+ cutoff_upper: float,
+ max_num_pairs: int,
+ loop: bool,
+ include_transpose: bool,
+) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
+ device = positions.device
+ dtype = positions.dtype
+ n_atoms = positions.size(0)
+
+ batch = batch.contiguous()
+ positions = positions.contiguous()
+ if use_periodic:
+ if box_vectors.dim() == 2:
+ box_vectors = box_vectors.unsqueeze(0)
+ elif box_vectors.dim() != 3:
+ raise ValueError('Expected "box_vectors" to have shape (n_batch, 3, 3)')
+ box_vectors = box_vectors.to(device=device, dtype=dtype)
+ box_vectors = box_vectors.contiguous()
+ # Use stride 0 to broadcast single box to all batches (avoids CPU sync)
+ box_batch_stride = 0 if box_vectors.size(0) == 1 else 9
+
+ neighbors = torch.full((2, max_num_pairs), -1, device=device, dtype=torch.long)
+ deltas = torch.zeros((max_num_pairs, 3), device=device, dtype=dtype)
+ distances = torch.zeros((max_num_pairs,), device=device, dtype=dtype)
+ num_pairs = torch.zeros((1,), device=device, dtype=torch.int32)
+
+ # Compute triangular pair count: n*(n-1)/2 without self-loops, n*(n+1)/2 with
+ if loop:
+ num_all_pairs = n_atoms * (n_atoms + 1) // 2
+ else:
+ num_all_pairs = n_atoms * (n_atoms - 1) // 2
+
+ # Grid covers only triangular pairs (not n*n)
+ grid = lambda meta: (triton.cdiv(num_all_pairs, meta["BLOCK"]),)
+
+ wrap_triton(_neighbor_brute_kernel)[grid](
+ positions,
+ batch,
+ box_vectors if use_periodic else positions, # dummy pointer if not periodic
+ neighbors[0],
+ neighbors[1],
+ deltas,
+ distances,
+ num_pairs,
+ box_batch_stride if use_periodic else 0,
+ n_atoms,
+ num_all_pairs,
+ use_periodic,
+ include_transpose,
+ loop,
+ max_num_pairs,
+ cutoff_lower,
+ cutoff_upper,
+ BLOCK=256,
+ )
+
+ return neighbors, deltas, distances, num_pairs
+
+
+class TritonBruteNeighborAutograd(TritonNeighborAutograd):
+ @staticmethod
+ def forward( # type: ignore[override]
+ ctx,
+ positions: Tensor,
+ batch: Tensor,
+ box_vectors: Tensor,
+ use_periodic: bool,
+ cutoff_lower: float,
+ cutoff_upper: float,
+ max_num_pairs: int,
+ loop: bool,
+ include_transpose: bool,
+ ):
+ neighbors, deltas, distances, num_pairs = triton_neighbor_bruteforce(
+ positions,
+ batch,
+ box_vectors,
+ use_periodic,
+ cutoff_lower,
+ cutoff_upper,
+ max_num_pairs,
+ loop,
+ include_transpose,
+ )
+
+ ctx.save_for_backward(neighbors, deltas, distances)
+ ctx.num_atoms = positions.size(0)
+ return neighbors, deltas, distances, num_pairs
diff --git a/torchmdnet/extensions/triton_cell.py b/torchmdnet/extensions/triton_cell.py
new file mode 100644
index 000000000..2586fdfcb
--- /dev/null
+++ b/torchmdnet/extensions/triton_cell.py
@@ -0,0 +1,474 @@
+# Copyright Universitat Pompeu Fabra 2020-2023 https://www.compscience.org
+# Distributed under the MIT License.
+# (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT)
+import triton
+import triton.language as tl
+import torch
+from torch import Tensor
+from typing import Tuple
+from torchmdnet.extensions.triton_neighbors import TritonNeighborAutograd
+from torch.library import triton_op, wrap_triton
+
+
+def _get_cell_dimensions(
+ box_x: torch.float32,
+ box_y: torch.float32,
+ box_z: torch.float32,
+ cutoff_upper: torch.float32,
+) -> int:
+ nx = torch.floor(box_x / cutoff_upper).clamp(min=3).long()
+ ny = torch.floor(box_y / cutoff_upper).clamp(min=3).long()
+ nz = torch.floor(box_z / cutoff_upper).clamp(min=3).long()
+ return torch.stack([nx, ny, nz])
+
+
+@triton.jit
+def _tl_round(x):
+ return tl.where(x >= 0, tl.math.floor(x + 0.5), tl.math.ceil(x - 0.5))
+
+
+@triton.jit
+def cell_neighbor_kernel(
+ # Cell data structure (1D sorted approach)
+ SortedIndices, # [n_atoms] - original atom indices, sorted by cell
+ SortedPositions, # [n_atoms, 3] - positions sorted by cell (for coalesced access)
+ SortedBatch, # [n_atoms] - batch indices sorted by cell
+ CellStart, # [num_cells] - start index in sorted arrays for each cell
+ CellEnd, # [num_cells] - end index (exclusive) in sorted arrays for each cell
+ # Box parameters
+ BoxSizes, # [3] - box dimensions
+ CellDims, # [3] - number of cells in each dimension (int32)
+ # Output
+ OutPairs, # [2, max_pairs]
+ OutDeltas, # [max_pairs, 3]
+ OutDists, # [max_pairs]
+ GlobalCounter, # [1]
+ # Scalar parameters
+ max_pairs,
+ cutoff_lower_sq,
+ cutoff_upper_sq,
+ # Flags
+ use_periodic: tl.constexpr,
+ loop: tl.constexpr,
+ include_transpose: tl.constexpr,
+ # Batch size for vectorized processing
+ BATCH_SIZE: tl.constexpr, # e.g., 32 -> processes 32×32=1024 pairs per iteration
+):
+ """
+ Each program processes one cell (the "home" cell).
+ Uses 1D sorted array with cell_start/cell_end pointers.
+
+ Vectorized batched processing:
+ - Loads BATCH_SIZE atoms at a time for both home and neighbor
+ - Computes BATCH_SIZE × BATCH_SIZE distance matrix per iteration
+ - Uses while loops to iterate only over actual atoms
+ - Minimal waste: only last partial batch may have masked elements
+
+ To avoid double-counting:
+ - For half list (include_transpose=False): only emit pairs where home_atom > neighbor_atom
+ - For full list (include_transpose=True): emit both directions
+ """
+ home_cell_id = tl.program_id(0)
+
+ # Load box and cell dimensions
+ box_x = tl.load(BoxSizes + 0)
+ box_y = tl.load(BoxSizes + 1)
+ box_z = tl.load(BoxSizes + 2)
+
+ num_cells_x = tl.load(CellDims + 0)
+ num_cells_y = tl.load(CellDims + 1)
+ num_cells_z = tl.load(CellDims + 2)
+
+ # Decompose home cell ID into 3D coordinates
+ cells_yz = num_cells_y * num_cells_z
+ home_cx = home_cell_id // cells_yz
+ home_cy = (home_cell_id % cells_yz) // num_cells_z
+ home_cz = home_cell_id % num_cells_z
+
+ # Load home cell boundaries
+ home_start = tl.load(CellStart + home_cell_id)
+ home_end = tl.load(CellEnd + home_cell_id)
+
+ # Loop over 27 neighbor cells
+ for neighbor_offset in tl.range(0, 27):
+ # Decompose neighbor_offset into di, dj, dk (each in {-1, 0, 1})
+ di = (neighbor_offset % 3) - 1
+ dj = ((neighbor_offset // 3) % 3) - 1
+ dk = (neighbor_offset // 9) - 1
+
+ # Compute neighbor cell coordinates
+ ni = home_cx + di
+ nj = home_cy + dj
+ nk = home_cz + dk
+
+ # Handle boundary conditions
+ if use_periodic:
+ ni = (ni + num_cells_x) % num_cells_x
+ nj = (nj + num_cells_y) % num_cells_y
+ nk = (nk + num_cells_z) % num_cells_z
+ cell_valid = True
+ else:
+ cell_valid = (
+ (ni >= 0)
+ & (ni < num_cells_x)
+ & (nj >= 0)
+ & (nj < num_cells_y)
+ & (nk >= 0)
+ & (nk < num_cells_z)
+ )
+
+ neighbor_cell_id = ni * cells_yz + nj * num_cells_z + nk
+
+ # Load neighbor cell boundaries
+ neighbor_start = tl.load(CellStart + neighbor_cell_id)
+ neighbor_end = tl.load(CellEnd + neighbor_cell_id)
+
+ # If cell is invalid (non-periodic boundary), make it empty
+ neighbor_start = tl.where(cell_valid, neighbor_start, 0)
+ neighbor_end = tl.where(cell_valid, neighbor_end, 0)
+
+ # Batched iteration over home atoms
+ home_batch_start = home_start
+ while home_batch_start < home_end:
+ # Load BATCH_SIZE home atoms
+ home_offsets = tl.arange(0, BATCH_SIZE)
+ home_global_idx = home_batch_start + home_offsets
+ home_mask = home_global_idx < home_end
+
+ # Load home atom original indices (for output pair indices)
+ home_atoms = tl.load(
+ SortedIndices + home_global_idx, mask=home_mask, other=0
+ )
+
+ # Load home atom positions (sequential access - coalesced!)
+ home_x = tl.load(
+ SortedPositions + home_global_idx * 3 + 0, mask=home_mask, other=0.0
+ )
+ home_y = tl.load(
+ SortedPositions + home_global_idx * 3 + 1, mask=home_mask, other=0.0
+ )
+ home_z = tl.load(
+ SortedPositions + home_global_idx * 3 + 2, mask=home_mask, other=0.0
+ )
+ home_batch = tl.load(
+ SortedBatch + home_global_idx, mask=home_mask, other=-1
+ )
+
+ # Batched iteration over neighbor atoms
+ neighbor_batch_start = neighbor_start
+ while neighbor_batch_start < neighbor_end:
+ # Load BATCH_SIZE neighbor atoms
+ neighbor_offsets = tl.arange(0, BATCH_SIZE)
+ neighbor_global_idx = neighbor_batch_start + neighbor_offsets
+ neighbor_mask = neighbor_global_idx < neighbor_end
+
+ # Load neighbor atom original indices (for output pair indices)
+ neighbor_atoms = tl.load(
+ SortedIndices + neighbor_global_idx, mask=neighbor_mask, other=0
+ )
+
+ # Load neighbor atom positions (sequential access - coalesced!)
+ neighbor_x = tl.load(
+ SortedPositions + neighbor_global_idx * 3 + 0,
+ mask=neighbor_mask,
+ other=0.0,
+ )
+ neighbor_y = tl.load(
+ SortedPositions + neighbor_global_idx * 3 + 1,
+ mask=neighbor_mask,
+ other=0.0,
+ )
+ neighbor_z = tl.load(
+ SortedPositions + neighbor_global_idx * 3 + 2,
+ mask=neighbor_mask,
+ other=0.0,
+ )
+ neighbor_batch_vals = tl.load(
+ SortedBatch + neighbor_global_idx, mask=neighbor_mask, other=-2
+ )
+
+ # Compute pairwise distances: [BATCH_SIZE, BATCH_SIZE]
+ dx = home_x[:, None] - neighbor_x[None, :]
+ dy = home_y[:, None] - neighbor_y[None, :]
+ dz = home_z[:, None] - neighbor_z[None, :]
+
+ # Apply PBC
+ if use_periodic:
+ dx = dx - box_x * _tl_round(dx / box_x)
+ dy = dy - box_y * _tl_round(dy / box_y)
+ dz = dz - box_z * _tl_round(dz / box_z)
+
+ dist_sq = dx * dx + dy * dy + dz * dz
+
+ # Build validity mask
+ # 1. Distance within cutoff
+ cond_dist = (dist_sq < cutoff_upper_sq) & (dist_sq >= cutoff_lower_sq)
+
+ # 2. Same batch
+ cond_batch = home_batch[:, None] == neighbor_batch_vals[None, :]
+
+ # 3. Index ordering to avoid double-counting
+ home_atoms_bc = home_atoms[:, None]
+ neighbor_atoms_bc = neighbor_atoms[None, :]
+
+ if include_transpose:
+ if loop:
+ cond_idx = True
+ else:
+ cond_idx = home_atoms_bc != neighbor_atoms_bc
+ else:
+ if loop:
+ cond_idx = home_atoms_bc >= neighbor_atoms_bc
+ else:
+ cond_idx = home_atoms_bc > neighbor_atoms_bc
+
+ # 4. Both atoms must be valid (within actual cell bounds)
+ cond_valid = home_mask[:, None] & neighbor_mask[None, :]
+
+ # Combined validity
+ valid_mask = cond_dist & cond_batch & cond_idx & cond_valid
+
+ # Count and store valid pairs
+ num_found = tl.sum(valid_mask.to(tl.int32))
+
+ if num_found > 0:
+ # Atomically reserve space in output
+ current_offset = tl.atomic_add(GlobalCounter, num_found)
+
+ if current_offset + num_found <= max_pairs:
+ # Compute storage indices using cumsum
+ flat_mask = tl.ravel(valid_mask)
+ csum = tl.cumsum(flat_mask.to(tl.int32), axis=0)
+ store_idx = current_offset + csum - 1
+
+ # Prepare flattened data
+ flat_home = tl.ravel(
+ tl.broadcast_to(
+ home_atoms[:, None], (BATCH_SIZE, BATCH_SIZE)
+ )
+ )
+ flat_neighbor = tl.ravel(
+ tl.broadcast_to(
+ neighbor_atoms[None, :], (BATCH_SIZE, BATCH_SIZE)
+ )
+ )
+ flat_dx = tl.ravel(dx)
+ flat_dy = tl.ravel(dy)
+ flat_dz = tl.ravel(dz)
+ flat_dist = tl.sqrt(tl.ravel(dist_sq))
+
+ # Store pairs
+ tl.store(
+ OutPairs + 0 * max_pairs + store_idx,
+ flat_home,
+ mask=flat_mask,
+ )
+ tl.store(
+ OutPairs + 1 * max_pairs + store_idx,
+ flat_neighbor,
+ mask=flat_mask,
+ )
+
+ # Store deltas
+ tl.store(OutDeltas + store_idx * 3 + 0, flat_dx, mask=flat_mask)
+ tl.store(OutDeltas + store_idx * 3 + 1, flat_dy, mask=flat_mask)
+ tl.store(OutDeltas + store_idx * 3 + 2, flat_dz, mask=flat_mask)
+
+ # Store distances
+ tl.store(OutDists + store_idx, flat_dist, mask=flat_mask)
+
+ neighbor_batch_start += BATCH_SIZE
+ home_batch_start += BATCH_SIZE
+
+
+def build_cell_list(
+ positions: Tensor,
+ batch: Tensor,
+ box_sizes: Tensor, # [3] diagonal elements
+ use_periodic: bool,
+ cell_dims: Tensor, # [3] number of cells in each dimension
+ num_cells: int, # total number of cells (fixed for CUDA graphs)
+) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
+ """
+ Build the cell list data structure using 1D sorted arrays.
+
+ Args:
+ positions: [N, 3] atom positions
+ batch: [N] batch indices
+ box_sizes: [3] box diagonal elements
+ use_periodic: whether to use periodic boundary conditions
+ cell_dims: [3] number of cells in each dimension (pre-computed)
+ num_cells: total number of cells (pre-computed, fixed for CUDA graphs)
+
+ Returns:
+ sorted_indices: [n_atoms] - original atom indices, sorted by cell
+ sorted_positions: [n_atoms, 3] - positions sorted by cell (for coalesced access)
+ sorted_batch: [n_atoms] - batch indices sorted by cell
+ cell_start: [num_cells] - start index in sorted arrays for each cell
+ cell_end: [num_cells] - end index (exclusive) in sorted arrays for each cell
+ """
+ device = positions.device
+ n_atoms = positions.size(0)
+
+ # Compute cell index for each atom
+ if use_periodic:
+ # Wrap to [0, box)
+ inv_box = 1.0 / box_sizes
+ wrapped = positions - torch.floor(positions * inv_box) * box_sizes
+ else:
+ # Shift by half box (like CUDA implementation)
+ wrapped = positions + 0.5 * box_sizes
+
+ # Cell coordinates
+ cell_size = box_sizes / cell_dims.float()
+ cell_coords = (wrapped / cell_size).long()
+ cell_coords = torch.clamp(
+ cell_coords, min=torch.zeros(3, device=device), max=cell_dims - 1
+ )
+
+ # Flat cell index
+ cell_idx = (
+ cell_coords[:, 0] * (cell_dims[1] * cell_dims[2])
+ + cell_coords[:, 1] * cell_dims[2]
+ + cell_coords[:, 2]
+ ).long()
+
+ # Sort atoms by cell index
+ sorted_cell_idx, sort_order = torch.sort(cell_idx)
+ sorted_indices = sort_order.int() # Original atom indices, now sorted by cell
+
+ # Create sorted positions and batch for coalesced memory access
+ sorted_positions = positions.index_select(0, sort_order).contiguous()
+ sorted_batch = batch.index_select(0, sort_order).contiguous()
+
+ # Count atoms per cell
+ cell_counts = torch.zeros(num_cells, dtype=torch.int32, device=device)
+ cell_counts.scatter_add_(
+ 0, cell_idx, torch.ones(n_atoms, dtype=torch.int32, device=device)
+ )
+
+ # Compute cell_start and cell_end using cumsum
+ cell_end = torch.cumsum(cell_counts, dim=0).int()
+ cell_start = torch.zeros(num_cells, dtype=torch.int32, device=device)
+ cell_start[1:] = cell_end[:-1]
+
+ return sorted_indices, sorted_positions, sorted_batch, cell_start, cell_end
+
+
+@triton_op("torchmdnet::triton_neighbor_cell", mutates_args={})
+def triton_neighbor_cell(
+ positions: Tensor,
+ batch: Tensor,
+ box_vectors: Tensor,
+ use_periodic: bool,
+ cutoff_lower: float,
+ cutoff_upper: float,
+ max_num_pairs: int,
+ loop: bool,
+ include_transpose: bool,
+ num_cells: int,
+) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
+ device = positions.device
+ dtype = positions.dtype
+ n_atoms = positions.size(0)
+
+ # Validate inputs
+ if positions.dim() != 2 or positions.size(1) != 3:
+ raise ValueError('Expected "positions" to have shape (N, 3)')
+ if batch.dim() != 1 or batch.size(0) != n_atoms:
+ raise ValueError('Expected "batch" to have shape (N,)')
+
+ # Extract box diagonal
+ box_vectors = box_vectors.contiguous()
+ if box_vectors.dim() == 3:
+ box_diag = box_vectors[0]
+ else:
+ box_diag = box_vectors
+ box_sizes = torch.stack(
+ [box_diag[0, 0], box_diag[1, 1], box_diag[2, 2]]
+ ).contiguous()
+
+ # Compute cell dimensions using shared utility (stays on GPU)
+ cell_dims = _get_cell_dimensions(
+ box_sizes[0], box_sizes[1], box_sizes[2], cutoff_upper
+ )
+
+ # Build cell list (1D sorted approach with sorted positions for coalesced access)
+ sorted_indices, sorted_positions, sorted_batch, cell_start, cell_end = (
+ build_cell_list(positions, batch, box_sizes, use_periodic, cell_dims, num_cells)
+ )
+
+ # Allocate outputs
+ neighbors = torch.full((2, max_num_pairs), -1, device=device, dtype=torch.long)
+ deltas = torch.zeros((max_num_pairs, 3), device=device, dtype=dtype)
+ distances = torch.zeros((max_num_pairs,), device=device, dtype=dtype)
+ num_pairs = torch.zeros((1,), device=device, dtype=torch.int32)
+
+ # Launch kernel: one program per cell
+ # BATCH_SIZE: process atoms in batches for vectorized compute
+ # 32 is a good balance: 32×32=1024 elements fits in registers, minimal waste on partial batches
+ BATCH_SIZE = 32
+
+ grid = (num_cells,)
+ wrap_triton(cell_neighbor_kernel)[grid](
+ sorted_indices,
+ sorted_positions,
+ sorted_batch,
+ cell_start,
+ cell_end,
+ box_sizes,
+ cell_dims,
+ neighbors,
+ deltas,
+ distances,
+ num_pairs,
+ max_num_pairs,
+ cutoff_lower**2,
+ cutoff_upper**2,
+ use_periodic=use_periodic,
+ loop=loop,
+ include_transpose=include_transpose,
+ BATCH_SIZE=BATCH_SIZE,
+ )
+ return neighbors, deltas, distances, num_pairs
+
+
+class TritonCellNeighborAutograd(TritonNeighborAutograd):
+ @staticmethod
+ def forward(
+ ctx,
+ positions: Tensor,
+ batch: Tensor,
+ box_vectors: Tensor,
+ use_periodic: bool,
+ cutoff_lower: float,
+ cutoff_upper: float,
+ max_num_pairs: int,
+ loop: bool,
+ include_transpose: bool,
+ num_cells: int,
+ ):
+ neighbors, deltas, distances, num_pairs = triton_neighbor_cell(
+ positions,
+ batch,
+ box_vectors,
+ use_periodic,
+ cutoff_lower,
+ cutoff_upper,
+ max_num_pairs,
+ loop,
+ include_transpose,
+ num_cells,
+ )
+
+ ctx.save_for_backward(neighbors, deltas, distances)
+ ctx.num_atoms = positions.size(0)
+ return neighbors, deltas, distances, num_pairs
+
+ @staticmethod
+ def backward(ctx, grad_neighbors, grad_deltas, grad_distances, grad_num_pairs): # type: ignore[override]
+ # Call parent backward (returns 9 values) and add None for num_cells
+ parent_grads = TritonNeighborAutograd.backward(
+ ctx, grad_neighbors, grad_deltas, grad_distances, grad_num_pairs
+ )
+ return (*parent_grads, None)
diff --git a/torchmdnet/extensions/triton_neighbors.py b/torchmdnet/extensions/triton_neighbors.py
new file mode 100644
index 000000000..5d2df81af
--- /dev/null
+++ b/torchmdnet/extensions/triton_neighbors.py
@@ -0,0 +1,108 @@
+# Copyright Universitat Pompeu Fabra 2020-2023 https://www.compscience.org
+# Distributed under the MIT License.
+# (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT)
+
+from typing import Tuple
+import torch
+from torch import Tensor
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _tl_round(x):
+ return tl.where(x >= 0, tl.math.floor(x + 0.5), tl.math.ceil(x - 0.5))
+
+
+class TritonNeighborAutograd(torch.autograd.Function):
+ @staticmethod
+ def backward(ctx, grad_neighbors, grad_deltas, grad_distances, grad_num_pairs): # type: ignore[override]
+ neighbors, edge_vec, edge_weight = ctx.saved_tensors
+ num_atoms = ctx.num_atoms
+
+ if grad_deltas is None:
+ grad_deltas = torch.zeros_like(edge_vec)
+ if grad_distances is None:
+ grad_distances = torch.zeros_like(edge_weight)
+
+ zero_mask = edge_weight.eq(0)
+ zero_mask3 = zero_mask.unsqueeze(-1).expand_as(grad_deltas)
+
+ grad_distances_term = edge_vec / edge_weight.masked_fill(
+ zero_mask, 1
+ ).unsqueeze(-1)
+ grad_distances_term = grad_distances_term * grad_distances.masked_fill(
+ zero_mask, 0
+ ).unsqueeze(-1)
+
+ grad_positions = torch.zeros(
+ (num_atoms, 3), device=edge_vec.device, dtype=edge_vec.dtype
+ )
+ edge_index_safe = neighbors.masked_fill(
+ zero_mask.unsqueeze(0).expand_as(neighbors), 0
+ )
+ grad_vec = grad_deltas.masked_fill(zero_mask3, 0) + grad_distances_term
+ grad_positions.index_add_(0, edge_index_safe[0], grad_vec)
+ grad_positions.index_add_(0, edge_index_safe[1], -grad_vec)
+
+ return (
+ grad_positions,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ )
+
+
+def triton_neighbor_pairs(
+ strategy: str,
+ positions: Tensor,
+ batch: Tensor,
+ box_vectors: Tensor,
+ use_periodic: bool,
+ cutoff_lower: float,
+ cutoff_upper: float,
+ max_num_pairs: int,
+ loop: bool,
+ include_transpose: bool,
+ num_cells: int,
+) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
+ from torchmdnet.extensions.triton_cell import TritonCellNeighborAutograd
+ from torchmdnet.extensions.triton_brute import TritonBruteNeighborAutograd
+
+ if positions.device.type != "cuda":
+ raise RuntimeError("Triton neighbor list requires CUDA tensors")
+ if positions.dtype not in (torch.float32, torch.float64):
+ raise RuntimeError("Unsupported dtype for Triton neighbor list")
+
+ if strategy == "brute":
+ return TritonBruteNeighborAutograd.apply(
+ positions,
+ batch,
+ box_vectors,
+ use_periodic,
+ float(cutoff_lower),
+ float(cutoff_upper),
+ int(max_num_pairs),
+ bool(loop),
+ bool(include_transpose),
+ )
+ elif strategy == "cell":
+ return TritonCellNeighborAutograd.apply(
+ positions,
+ batch,
+ box_vectors,
+ use_periodic,
+ cutoff_lower,
+ cutoff_upper,
+ int(max_num_pairs),
+ bool(loop),
+ bool(include_transpose),
+ num_cells,
+ )
+ else:
+ raise ValueError(f"Unsupported strategy {strategy}")
diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py
index b7c8398cc..942f7932b 100644
--- a/torchmdnet/models/model.py
+++ b/torchmdnet/models/model.py
@@ -127,6 +127,7 @@ def create_model(args, prior_model=None, mean=None, std=None):
activation=args["activation"],
reduce_op=args["reduce_op"],
dtype=dtype,
+ static_shapes=args.get("static_shapes", False),
num_hidden_layers=args.get("output_mlp_num_layers", 0),
)
@@ -251,6 +252,13 @@ def load_model(filepath, args=None, device="cpu", return_std=False, **kwargs):
for p in patterns:
state_dict = {re.sub(p[0], p[1], k): v for k, v in state_dict.items()}
+ # Backward compatibility: box was changed from class attribute to registered buffer
+ # Old checkpoints don't have the box buffer, so add it with default value
+ if "representation_model.distance.box" not in state_dict:
+ state_dict["representation_model.distance.box"] = torch.zeros(
+ (3, 3), device="cpu"
+ )
+
model.load_state_dict(state_dict)
return model.to(device)
diff --git a/torchmdnet/models/output_modules.py b/torchmdnet/models/output_modules.py
index 64f9cebfc..c9f933b9b 100644
--- a/torchmdnet/models/output_modules.py
+++ b/torchmdnet/models/output_modules.py
@@ -6,14 +6,8 @@
from typing import Optional
import torch
from torch import nn
-from torchmdnet.models.utils import (
- act_class_mapping,
- GatedEquivariantBlock,
- scatter,
- MLP,
-)
+from torchmdnet.models.utils import GatedEquivariantBlock, scatter, MLP
from torchmdnet.utils import atomic_masses
-from torchmdnet.extensions.ops import is_current_stream_capturing
from warnings import warn
__all__ = ["Scalar", "DipoleMoment", "ElectronicSpatialExtent"]
@@ -26,27 +20,34 @@ class OutputModel(nn.Module, metaclass=ABCMeta):
As an example, have a look at the :py:mod:`torchmdnet.output_modules.Scalar` output model.
"""
- def __init__(self, allow_prior_model, reduce_op):
+ def __init__(self, allow_prior_model, reduce_op, static_shapes=False):
super(OutputModel, self).__init__()
self.allow_prior_model = allow_prior_model
self.reduce_op = reduce_op
+ self.static_shapes = static_shapes
self.dim_size = 0
- self.setup_for_compile = False
def reset_parameters(self):
pass
- def setup_for_compile_cudagraphs(self, batch):
- self.dim_size = int(batch.max().item() + 1)
- self.setup_for_compile = True
-
@abstractmethod
def pre_reduce(self, x, v, z, pos, batch):
return
def reduce(self, x, batch):
- if not self.setup_for_compile:
- is_capturing = x.is_cuda and is_current_stream_capturing()
+ # torch.compile and torch.export don't support .item() calls during tracing
+ # The model should be warmed up before compilation to set the correct dim_size
+ if torch.compiler.is_compiling():
+ pass
+ elif torch.jit.is_scripting():
+ # TorchScript doesn't support torch.cuda.is_current_stream_capturing()
+ # For CPU, always update dim_size (no CUDA graphs on CPU)
+ # For CUDA with static_shapes, only update once (first call sets dim_size for CUDA graph capture)
+ # For CUDA without static_shapes, always update (dynamic batch sizes)
+ if not x.is_cuda or not self.static_shapes or self.dim_size == 0:
+ self.dim_size = int(batch.max().item() + 1)
+ else:
+ is_capturing = x.is_cuda and torch.cuda.is_current_stream_capturing()
if not x.is_cuda or not is_capturing:
self.dim_size = int(batch.max().item() + 1)
if is_capturing:
@@ -72,10 +73,13 @@ def __init__(
allow_prior_model=True,
reduce_op="sum",
dtype=torch.float,
+ static_shapes=False,
**kwargs,
):
super(Scalar, self).__init__(
- allow_prior_model=allow_prior_model, reduce_op=reduce_op
+ allow_prior_model=allow_prior_model,
+ reduce_op=reduce_op,
+ static_shapes=static_shapes,
)
self.output_network = MLP(
in_channels=hidden_channels,
@@ -102,10 +106,13 @@ def __init__(
allow_prior_model=True,
reduce_op="sum",
dtype=torch.float,
+ static_shapes=False,
**kwargs,
):
super(EquivariantScalar, self).__init__(
- allow_prior_model=allow_prior_model, reduce_op=reduce_op
+ allow_prior_model=allow_prior_model,
+ reduce_op=reduce_op,
+ static_shapes=static_shapes,
)
if kwargs.get("num_layers", 0) > 0:
warn("num_layers is not used in EquivariantScalar")
@@ -144,6 +151,7 @@ def __init__(
activation="silu",
reduce_op="sum",
dtype=torch.float,
+ static_shapes=False,
**kwargs,
):
super(DipoleMoment, self).__init__(
@@ -152,6 +160,7 @@ def __init__(
allow_prior_model=False,
reduce_op=reduce_op,
dtype=dtype,
+ static_shapes=static_shapes,
**kwargs,
)
atomic_mass = torch.from_numpy(atomic_masses).to(dtype)
@@ -177,6 +186,7 @@ def __init__(
activation="silu",
reduce_op="sum",
dtype=torch.float,
+ static_shapes=False,
**kwargs,
):
super(EquivariantDipoleMoment, self).__init__(
@@ -185,6 +195,7 @@ def __init__(
allow_prior_model=False,
reduce_op=reduce_op,
dtype=dtype,
+ static_shapes=static_shapes,
**kwargs,
)
atomic_mass = torch.from_numpy(atomic_masses).to(dtype)
@@ -211,10 +222,11 @@ def __init__(
activation="silu",
reduce_op="sum",
dtype=torch.float,
+ static_shapes=False,
**kwargs,
):
super(ElectronicSpatialExtent, self).__init__(
- allow_prior_model=False, reduce_op=reduce_op
+ allow_prior_model=False, reduce_op=reduce_op, static_shapes=static_shapes
)
self.output_network = MLP(
in_channels=hidden_channels,
@@ -254,6 +266,7 @@ def __init__(
activation="silu",
reduce_op="sum",
dtype=torch.float,
+ static_shapes=False,
**kwargs,
):
super(EquivariantVectorOutput, self).__init__(
@@ -262,6 +275,7 @@ def __init__(
allow_prior_model=False,
reduce_op="sum",
dtype=dtype,
+ static_shapes=static_shapes,
**kwargs,
)
diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py
index 23a3c2683..9aec148ca 100644
--- a/torchmdnet/models/tensornet.py
+++ b/torchmdnet/models/tensornet.py
@@ -218,9 +218,6 @@ def reset_parameters(self):
self.linear.reset_parameters()
self.out_norm.reset_parameters()
- def setup_for_compile_cudagraphs(self, batch):
- self.distance.setup_for_compile_cudagraphs()
-
def forward(
self,
z: Tensor,
diff --git a/torchmdnet/models/utils.py b/torchmdnet/models/utils.py
index 55cd89982..031da9a5c 100644
--- a/torchmdnet/models/utils.py
+++ b/torchmdnet/models/utils.py
@@ -152,7 +152,7 @@ class OptimizedDistance(torch.nn.Module):
If the number of pairs found is larger than this, the pairs are randomly sampled. When check_errors is True, an exception is raised in this case.
If negative, it is interpreted as (minus) the maximum number of neighbors per atom.
strategy : str
- Strategy to use for computing the neighbor list. Can be one of :code:`["shared", "brute", "cell"]`.
+ Strategy to use for computing the neighbor list. Can be one of :code:`["brute", "cell"]`.
1. *Shared*: An O(N^2) algorithm that leverages CUDA shared memory, best for large number of particles.
2. *Brute*: A brute force O(N^2) algorithm, best for small number of particles.
@@ -201,34 +201,46 @@ def __init__(
self.cutoff_lower = cutoff_lower
self.max_num_pairs = max_num_pairs
self.strategy = strategy
- self.box: Optional[Tensor] = box
self.loop = loop
self.return_vecs = return_vecs
self.include_transpose = include_transpose
self.resize_to_fit = resize_to_fit
self.use_periodic = True
- if self.box is None:
+ self.num_cells = 0
+
+ # Use register_buffer for box to make it export-compatible and handle device movement
+ if box is None:
self.use_periodic = False
- self.box = torch.empty((0, 0))
- if self.strategy == "cell":
+ if strategy == "cell":
# Default the box to 3 times the cutoff, really inefficient for the cell list
lbox = cutoff_upper * 3.0
- self.box = torch.tensor(
+ box = torch.tensor(
[[lbox, 0, 0], [0, lbox, 0], [0, 0, lbox]], device="cpu"
)
+ else:
+ # Use a placeholder box instead of empty (0,0) to avoid shape issues in torch.export
+ # This won't be used when use_periodic=False, but prevents export tracing issues
+ box = torch.zeros((3, 3), device="cpu")
+
+ # Register box as a buffer so it moves with the module and is export-compatible
+ self.register_buffer("box", box, persistent=True)
+
if self.strategy == "cell":
- self.box = self.box.cpu()
+ from torchmdnet.extensions.triton_cell import _get_cell_dimensions
+
+ cell_dims = _get_cell_dimensions(
+ self.box[0, 0], self.box[1, 1], self.box[2, 2], cutoff_upper
+ )
+ self.num_cells = int(cell_dims.prod())
+ if self.num_cells > 1024**3:
+ raise RuntimeError(
+ f"Too many cells: {self.num_cells}. Maximum is 1024^3. "
+ f"Reduce box size or increase cutoff."
+ )
+
self.check_errors = check_errors
self.long_edge_index = long_edge_index
- def setup_for_compile_cudagraphs(self):
- # box needs to be a buffer to it moves to correct device, otherwise we can end
- # up with torch.compile failing to use cuda graphs with "skipping cudagraphs due to skipping cudagraphs due to cpu device (primals_3)."
- # we cant just make it a buffer in the constructor above because then old state dicts will fail to load
- _box = self.box
- del self.box
- self.register_buffer('box', _box)
-
def forward(
self, pos: Tensor, batch: Optional[Tensor] = None, box: Optional[Tensor] = None
) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
@@ -263,9 +275,18 @@ def forward(
use_periodic = self.use_periodic
if not use_periodic:
use_periodic = box is not None
- box = self.box if box is None else box
+
+ # Use the registered buffer box if no box is provided
+ # The buffer is automatically moved to the correct device with the module
+ if box is None:
+ box = self.box
+
assert box is not None, "Box must be provided"
- box = box.to(pos.dtype)
+
+ # Ensure box has correct dtype (device is already correct if using self.box)
+ if box.dtype != pos.dtype:
+ box = box.to(dtype=pos.dtype)
+
max_pairs: int = self.max_num_pairs
if self.max_num_pairs < 0:
max_pairs = -self.max_num_pairs * pos.shape[0]
@@ -282,6 +303,7 @@ def forward(
include_transpose=self.include_transpose,
box_vectors=box,
use_periodic=use_periodic,
+ num_cells=self.num_cells,
)
if self.check_errors:
assert (