-
Notifications
You must be signed in to change notification settings - Fork 94
Triton neighbor list implementation #373
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Ok I think this PR is done now. If someone wants to give it a quick look again feel free. @sef43 maybe take a look as well |
|
I have checked using it in a full tensornet model (aceff1.0) - It looks to be correct, and speed is as reported, i.e. fast enough to not be the bottleneck when compared to the rest of the tensornet model. Before we merge I have two points that I want to check some more/should be discussed:
|
| The number of pairs found. | ||
| """ | ||
| return torch.ops.torchmdnet_extensions.get_neighbor_pairs( | ||
| if torch.jit.is_scripting() or not positions.is_cuda: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this means the torch script model will use the pure python code
|
Very good point about the recompilation! I had not taken that into consideration. The kernels get recompiled whenever a Here the performance of the brute kernel after removing the constants related to number of atoms and cutoffs. Seems like it's negligible which is great! Before the change (infinite kernels) After the change (8 kernels only) |
|
About the other point you raised, @sef43 @peastman would it be possible to change Although now that I'm looking at it, this requires also some more changes to make the model torch.exportable. |
|
Okay I'll try working a bit more on it locally to see if I can get it to torch.export correctly with cell lists / triton code without spamming the PR. I'll let you know when I have an update. |
|
The next version of OpenMM will introduce a mechanism for calling back into Python to compute forces. That's how I plan to support most ML potentials in the future. There are already a lot of potentials that can't be compiled to TorchScript, and since it's deprecated, I expect the number to keep increasing. |
|
I did two benchmarks using a TensorNet model (aceff-1.0).
|
|
Training speed also seem to be the same which is good! (if I checkout the code before 9179ad7 then training speed was impacted by a factor of 2 due to the repeated kernel recompilations whenever the number of atoms changed) |
|
I think we could merge without the torch.export for now and keep the current implementation of the torch-scripted model falling back to the pure pytorch version given that for TensorNet models using openmm-torch it is fast enough for typical atom sizes. |
Currently installation, maintenance and development of torchmd-net is hampered by the neighbor list CUDA and C++ extensions (see also: #372 (comment)). It's very complicated to create and make available python wheels which support multiple different pytorch, cuda and python versions.
Triton uses
nvrtcto compile kernels at runtime theoretically to comparable performance as CUDA. This converts this project into a pure python project which can be installed directly from a source distribution instead of platform and dependency-specific wheels.The Triton reimplementation here is comparable to the CUDA one, it actually beats CUDA in the scenarios we most care about (low batches). See benchmarks in: #372 (comment)
What are your opinions @RaulPPelaez ? By the way many many thanks for the wonderful and extensive test suite! Without it this reimplementation would have been near impossible