Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Erase dtype and device #166

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft

Erase dtype and device #166

wants to merge 4 commits into from

Conversation

E-Rum
Copy link
Contributor

@E-Rum E-Rum commented Feb 3, 2025

A couple of PRs ago, we decided to include dtype and device as explicit and obligatory parameters for both calculators and potentials.

Unfortunately, after thorough consideration of how typical pipelines are built, I concluded that we should abandon this design choice.

The main reason is that, in most cases, when working with an NN model, the preferred strategy is to first initialize the model and then move it to the desired device using model.to(device).

Since torch-pme is designed to be an internal part of the model, this creates a conflict. We initialize dtype and device once, but when we later move the model to a different device, it undermines our prior device-checking logic.

Luckily, since our entire pipeline is either a torch.nn.Module or its subclass, we can integrate it smoothly with models that change their device and dtype. The key idea is to thoroughly rewrite the pipeline so that all newly created tensors during calculations are registered as buffers using self.register_buffer.

This PR aims to achieve exactly that.


📚 Documentation preview 📚: https://torch-pme--166.org.readthedocs.build/en/166/

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant