Skip to content

Conversation

@stefdoerr
Copy link
Collaborator

@stefdoerr stefdoerr commented Dec 16, 2025

Currently installation, maintenance and development of torchmd-net is hampered by the neighbor list CUDA and C++ extensions (see also: #372 (comment)). It's very complicated to create and make available python wheels which support multiple different pytorch, cuda and python versions.

Triton uses nvrtc to compile kernels at runtime theoretically to comparable performance as CUDA. This converts this project into a pure python project which can be installed directly from a source distribution instead of platform and dependency-specific wheels.

The Triton reimplementation here is comparable to the CUDA one, it actually beats CUDA in the scenarios we most care about (low batches). See benchmarks in: #372 (comment)

What are your opinions @RaulPPelaez ? By the way many many thanks for the wonderful and extensive test suite! Without it this reimplementation would have been near impossible

@stefdoerr
Copy link
Collaborator Author

stefdoerr commented Dec 17, 2025

Ok I think this PR is done now. If someone wants to give it a quick look again feel free. @sef43 maybe take a look as well

@sef43 sef43 self-requested a review December 17, 2025 17:00
@sef43
Copy link
Collaborator

sef43 commented Dec 18, 2025

I have checked using it in a full tensornet model (aceff1.0) - It looks to be correct, and speed is as reported, i.e. fast enough to not be the bottleneck when compared to the rest of the tensornet model.

Before we merge I have two points that I want to check some more/should be discussed:

  1. when using it in a torch-scripted model (i.e in openmm-torch) it is only using the pure python code path, not the triton code path. (for small molecules this seems fast enough). I think the triton will error out as it cannot be torch-scripted.

  2. I think the triton code will be recompiled each time the max_num_pairs (number of atoms) changes, this could significantly reduce speed when used as a single point calculator when looping over different molecules, and maybe also training with batch sizes with different sized molecules.

The number of pairs found.
"""
return torch.ops.torchmdnet_extensions.get_neighbor_pairs(
if torch.jit.is_scripting() or not positions.is_cuda:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this means the torch script model will use the pure python code

@stefdoerr
Copy link
Collaborator Author

stefdoerr commented Dec 18, 2025

Very good point about the recompilation! I had not taken that into consideration. The kernels get recompiled whenever a tl.constexpr changes which indeed for brute was also number of atoms and cutoffs. I removed now the tl.constexpr from all those variables and kept it just for use_periodic, loop, include_transpose which means we will get max 8 kernels for all the combinations of those parameters.

Here the performance of the brute kernel after removing the constants related to number of atoms and cutoffs. Seems like it's negligible which is great!

Before the change (infinite kernels)

Benchmarking neighbor list generation for 1 batches with 64 neighbors on average
Summary
-------
N Particles Distance(ms)         Brute(ms)            Cell(ms)             
---------------------------------------------------------------------------
256         0.55  x1.00          0.01  x76.01         0.12  x4.52          
512         0.58  x1.00          0.01  x67.55         0.16  x3.73          
1024        0.65  x1.00          0.01  x48.34         0.13  x5.13          
2048        0.78  x1.00          0.03  x25.79         0.13  x5.97          
4096        1.04  x1.00          0.09  x11.48         0.14  x7.20          
8192        1.60  x1.00          0.33  x4.91          0.24  x6.58          
16384       2.75  x1.00          1.33  x2.06          0.33  x8.47          
32768       5.26  x1.00          4.99  x1.05          0.56  x9.44          
65536       11.36 x1.00          18.94 x0.60          1.11  x10.27         
131072      32.46 x1.00          75.16 x0.43          2.07  x15.71   

After the change (8 kernels only)

Benchmarking neighbor list generation for 1 batches with 64 neighbors on average
Summary
-------
N Particles Distance(ms)         Brute(ms)            Cell(ms)             
---------------------------------------------------------------------------
256         0.56  x1.00          0.01  x70.62         0.12  x4.65          
512         0.59  x1.00          0.01  x63.71         0.15  x3.89          
1024        0.66  x1.00          0.01  x45.11         0.13  x5.27          
2048        0.78  x1.00          0.03  x25.06         0.13  x6.09          
4096        1.05  x1.00          0.09  x11.11         0.15  x7.19          
8192        1.60  x1.00          0.33  x4.83          0.24  x6.58          
16384       2.76  x1.00          1.31  x2.11          0.34  x8.02          
32768       5.27  x1.00          4.94  x1.07          0.61  x8.65          
65536       11.38 x1.00          19.18 x0.59          1.22  x9.36          
131072      32.55 x1.00          75.71 x0.43          2.32  x14.06         

@stefdoerr
Copy link
Collaborator Author

stefdoerr commented Dec 18, 2025

About the other point you raised, @sef43 @peastman would it be possible to change openmm-torch to use torch-compiled models (or torch.export) instead of torch-script? Seems like torch-script is being deprecated since pytorch 2.0 so we will hit that wall soon anyway.

Although now that I'm looking at it, this requires also some more changes to make the model torch.exportable.

@stefdoerr
Copy link
Collaborator Author

Okay I'll try working a bit more on it locally to see if I can get it to torch.export correctly with cell lists / triton code without spamming the PR. I'll let you know when I have an update.

@peastman
Copy link
Collaborator

The next version of OpenMM will introduce a mechanism for calling back into Python to compute forces. That's how I plan to support most ML potentials in the future. There are already a lot of potentials that can't be compiled to TorchScript, and since it's deprecated, I expect the number to keep increasing.

@sef43
Copy link
Collaborator

sef43 commented Dec 18, 2025

I did two benchmarks using a TensorNet model (aceff-1.0).

  1. For the first I used OpenMM torch, with the torchscript models, with cudagraphs enabled. I compare the old code version with the cuda code to the version in this PR that uses the pure pytorch neighbor list (triton cannot be torchscripted). For the atom ranges we care about there is a slight speed decrease, but almost negligible. I.e. for less than 1000 atoms a simple all-to-all distance calculation is fast enough compare to the rest of the model.
N atoms old code, time per step (ms) this PR, time per step (ms)
9 0.712 0.726
30 1.078 1.115
60 2.039 2.117
150 5.113 5.224
300 11.426 11.207
450 16.821 17.002
600 22.743 22.755
750 28.363 29.051
900 34.113 34.167
1200 45.612 46.187
1500 57.642 58.097
1800 69.003 69.968
2400 91.243 93.412
  1. For the second I use the ASE calcuator with torch.compile(..) with the old code and this PR with the triton code (both cases use brute). Speed is essentially the same.
N atoms old code, time per step (ms) this PR, time per step (ms)
9 2.511 2.71
30 2.61 2.713
60 3.324 3.292
150 4.809 4.837
300 7.912 7.802
450 10.342 10.346
600 13.382 13.452
750 15.569 15.571
900 19.171 18.967
1200 24.097 24.748
1500 31.099 30.967
1800 37.36 36.895
2400 61.184 62.755

@sef43
Copy link
Collaborator

sef43 commented Dec 19, 2025

Training speed also seem to be the same which is good!

(if I checkout the code before 9179ad7 then training speed was impacted by a factor of 2 due to the repeated kernel recompilations whenever the number of atoms changed)

@sef43
Copy link
Collaborator

sef43 commented Dec 19, 2025

I think we could merge without the torch.export for now and keep the current implementation of the torch-scripted model falling back to the pure pytorch version given that for TensorNet models using openmm-torch it is fast enough for typical atom sizes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants