-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement get_stats
for Dataset
and print it before training
#251
Conversation
TODO: gradients (edit: done) @PicoCentauri the main issue here is that often we don't use our |
__repr__
for Dataset
and print datasets before trainingget_stats
for Dataset
and print it before training
It's a bit cumbersome at the moment because having it as a method forces us to inherit from |
class Dataset: | ||
"""A version of the `metatensor.learn.Dataset` class that allows for | ||
the use of `mtm::` prefixes in the keys of the dictionary. See | ||
https://github.com/lab-cosmo/metatensor/issues/621. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps not for this PR, but I think this is merged now so should allow for these prefixes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not yet released, but we should do a patch release with this
src/metatrain/utils/data/dataset.py
Outdated
|
||
def get_stats(self, dataset_info: DatasetInfo) -> str: | ||
if hasattr(self, "_cached_stats"): | ||
return self._cached_stats # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To please codecov you should call the get_stats
twice. But, is this really necessary to cache this? Should not take super long to compute?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, I over-optimized. It will be gone
:param dict: A dictionary with the data to be stored in the dataset. | ||
""" | ||
|
||
def __init__(self, dict: Dict): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we keep a __repr__
that is also useful without the DatasetInfo
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see the point if we don't use it...
if dataset_len == 0: | ||
return stats | ||
|
||
target_names = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought the target_names
are in the datasetInfo?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes but these are different. They also include the gradients. The variable name can be changed if you think that's a good idea. Something like target_names_with_gradients
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Name is fine but maybe add a comment that they are different.
src/metatrain/utils/data/dataset.py
Outdated
means = {key: sums[key] / n_elements[key] for key in target_names} | ||
|
||
sum_of_squared_residuals = {key: 0.0 for key in target_names} | ||
for sample in dataset: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you sum twice over the dataset?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two iterations: one for the mean, one for the std (for the std you already need to know the mean)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, you don't. You save a sum and sum of squares and do the mean and the standard deviation afterwords.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right
Co-authored-by: Philip Loche <[email protected]>
9fe0df6
to
e209a4e
Compare
This implements a
get_stats()
forDataset
. Closes #205.Contributor (creator of pull-request) checklist
📚 Documentation preview 📚: https://metatrain--251.org.readthedocs.build/en/251/