-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Extract and speed up device transfers #357
Conversation
@Luthaf do you know what's happening in the docs? |
How does this speed it up? This is faster by using torchscript? Also the gains you show above are negligible, do you also have timings for a more realistic training set?
This looks like trying to compile to torchscript the documentation classes, which only contain documentation. Maybe disabling TorchScript for the docs would help? There is an environment variable for this. |
The speed-up is nice, but I think the major advantage is to extract a function that would be essentially used everywhere |
We're already using the environment variable, right? |
The docs are already using one env variable to switch between the TorchScript version of the classes and the documentation version of the classes. There is a separate env variable to disable TorchScript compilation altogether as well. |
Moving systems to the device can be optimized as it requires moving many tensors sequentially (cell, positions, types, NLs, times the number of systems in a batch).
no compilation: 4.7 ms
script: 3.9 ms
torch.compile
: 4.0 ms📚 Documentation preview 📚: https://metatrain--357.org.readthedocs.build/en/357/