-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove set_default_dtype
and get_default_dtype
#155
Conversation
set_default_dtype
and get_default_dtype
af72695
to
a7f21d7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the changes
base_precision: 64 | ||
base_precision: 32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Uhhh are we sure we want this to be the default behavior? People have had issues with float32
MD
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just changed this to be consistent with torch. Which people had issues? The examples seems to be good and usually problems in MD appear if the forces are not summed up in double precision but as a base precision I think a normal float should be fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, we can discuss this together with other people. I still think float64 is a safer default
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am using the power of omegaconf's custom resolvers to obtain the preferred dtype for each architecture. When we are happy with this I will do the same for the device (but in a different PR)
185e660
to
0e18391
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
Fixes #148
This PR removes all calls of
set_default_dtype
andget_default_dtype
. For now I am using the.to()
method to move the model to the correct dtype. But is there also an option to initialize the model in the first place in the correct dtype?The reveiled some issues that are also fixed within this PR
dtype
of the model is set to the one of the dataset.eval
. There we adjust the loading dtype to the targets to the dtype the exported modelAdditional changes
assert torch.allclose
by the more verbosetorch.testing.assert_close
TODO
eval
test for 64-bit and 32-bit exported models__architecture_capabilities__
to each architecture indicating their supported typesdtype
check tocheck_datasets
function📚 Documentation preview 📚: https://metatensor-models--155.org.readthedocs.build/en/155/