-
Notifications
You must be signed in to change notification settings - Fork 735
Description
Is your feature request related to a problem? Please describe.
I've been observing the work being done for the v2 release of pytorch-forecasting, including the implementation of the TimeXer model, taken from https://github.com/thuml/Time-Series-Library.
Obviously, we all know that the TSLib is great for when it comes to releasing SOTA models to the public, but imo a lot of code in that original library is more or less outdated and un-optimized given what's possible in newer versions of torch. Why I'm highlighting this is that in TSLib torch is pinned at 1.7.1 (see here), while in pytorch-forecasting minimum allowed version is at torch==2.0.0.
Given this, imo there's no need to actually carry over the TSLib-s torch.einsum()-based attention computation to this project (see l66 of layers/_attention/_full_attention.py here and l67 of /models/timexer/sub_modules.py here ).
The .einsum() computation is much more memory hungry and slower than for the torch.nn.functional.scaled_dot_product_attention now natively available in torch + the latter can use flash attn kernel if the user's GPU supports it. So I don't really see a reason to be using the slower implementation especially as you don't need it, because it seems you're not exposing the attention scores (see output_attention=False here in the TimeXer class) to the users anyways.
Besides the original TSLib modules have a lot of unused args kept in there for some specific models (e.g. the tau and delta in the FullAttention module are meant for the Non-Stationary Transformer specifically, and none of the more recent TSLib models use it). There's also a lot of other memory-related optimizations possible to the TSLib implementations anyways, such as using a single fused layer for Q,K,V projection and chunking them instead of separate projections and so on.
All of these things as of now seem to be carried over into your implementations of the models, so I think a lot could be improved ahead of the v2 release so that this package can offer users a major perf improvement over what's originally possible with TSLib models.
Describe the solution you'd like
I'm not really requesting a solution from your side. If this is something that you'd allow me to pick up as an issue to refactor the TimeXer class accordingly, I'd be very happy. I've worked with TimeXer for my own use cases extensively and found major speedup and memory savings after refactoring the original outdated TSLib attention implementations.
Describe alternatives you've considered
Another alternative possible: If the reasoning for keeping torch.einsum() based attn implementation is that you want to have the attention scores accessible, we could also refactor the Attention classes to have two backends, one could be .einsum()-based, able to return the attention scores (as you have now) and another would be the torch.nn.functional.scaled_dot_product_attention-based, offering optimized perf but w/o the ability to return the attention scores.
Additional context
I've done some benchmarking tests between the two attention methods discussed here for my purposes and can share the results with you as well. If I can take this issue on, would include some test modules (+ maybe some kind of benchmarking notebook) to demonstrate the speedup to you as well.
Big fan of your work in general, and I'd be happy if I can contribute 🫡