-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement get_stats
for Dataset
and print it before training
#251
Changes from 5 commits
8a7ad37
d61aaf4
9413119
d3e66cd
aa57dfc
0a84729
d43977b
60fa652
f06c424
4856088
79cc462
e209a4e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,6 @@ | ||
import itertools | ||
import math | ||
import warnings | ||
from collections import UserDict | ||
from dataclasses import dataclass, field | ||
from typing import Any, Dict, List, Optional, Set, Tuple, Union | ||
|
@@ -6,45 +9,8 @@ | |
import torch | ||
from metatensor.torch import TensorMap | ||
from torch import Generator, default_generator | ||
from torch.utils.data import Subset, random_split | ||
|
||
|
||
class Dataset: | ||
"""A version of the `metatensor.learn.Dataset` class that allows for | ||
the use of `mtm::` prefixes in the keys of the dictionary. See | ||
https://github.com/lab-cosmo/metatensor/issues/621. | ||
|
||
It is important to note that, instead of named tuples, this class | ||
accepts and returns dictionaries. | ||
|
||
:param dict: A dictionary with the data to be stored in the dataset. | ||
""" | ||
|
||
def __init__(self, dict: Dict): | ||
|
||
new_dict = {} | ||
for key, value in dict.items(): | ||
key = key.replace("mtm::", "mtm_") | ||
new_dict[key] = value | ||
|
||
self.mts_learn_dataset = metatensor.learn.Dataset(**new_dict) | ||
|
||
def __getitem__(self, idx: int) -> Dict: | ||
|
||
mts_dataset_item = self.mts_learn_dataset[idx]._asdict() | ||
new_dict = {} | ||
for key, value in mts_dataset_item.items(): | ||
key = key.replace("mtm_", "mtm::") | ||
new_dict[key] = value | ||
|
||
return new_dict | ||
|
||
def __len__(self) -> int: | ||
return len(self.mts_learn_dataset) | ||
|
||
def __iter__(self): | ||
for i in range(len(self)): | ||
yield self[i] | ||
from ..external_naming import to_external_name | ||
|
||
|
||
@dataclass | ||
|
@@ -230,6 +196,133 @@ | |
return new | ||
|
||
|
||
class Dataset: | ||
"""A version of the `metatensor.learn.Dataset` class that allows for | ||
the use of `mtm::` prefixes in the keys of the dictionary. See | ||
https://github.com/lab-cosmo/metatensor/issues/621. | ||
|
||
It is important to note that, instead of named tuples, this class | ||
accepts and returns dictionaries. | ||
|
||
:param dict: A dictionary with the data to be stored in the dataset. | ||
""" | ||
|
||
def __init__(self, dict: Dict): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we keep a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see the point if we don't use it... |
||
|
||
new_dict = {} | ||
for key, value in dict.items(): | ||
key = key.replace("mtm::", "mtm_") | ||
new_dict[key] = value | ||
|
||
self.mts_learn_dataset = metatensor.learn.Dataset(**new_dict) | ||
|
||
def __getitem__(self, idx: int) -> Dict: | ||
|
||
mts_dataset_item = self.mts_learn_dataset[idx]._asdict() | ||
new_dict = {} | ||
for key, value in mts_dataset_item.items(): | ||
key = key.replace("mtm_", "mtm::") | ||
new_dict[key] = value | ||
|
||
return new_dict | ||
|
||
def __len__(self) -> int: | ||
return len(self.mts_learn_dataset) | ||
|
||
def __iter__(self): | ||
for i in range(len(self)): | ||
yield self[i] | ||
|
||
def get_stats(self, dataset_info: DatasetInfo) -> str: | ||
if hasattr(self, "_cached_stats"): | ||
return self._cached_stats # type: ignore | ||
stats = _get_dataset_stats(self, dataset_info) | ||
self._cached_stats = stats | ||
return stats | ||
|
||
|
||
class Subset(torch.utils.data.Subset): | ||
""" | ||
A version of `torch.utils.data.Subset` containing a `get_stats` method | ||
allowing us to print information about atomistic datasets. | ||
""" | ||
|
||
def get_stats(self, dataset_info: DatasetInfo) -> str: | ||
if hasattr(self, "_cached_stats"): | ||
return self._cached_stats # type: ignore | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To please codecov you should call the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're right, I over-optimized. It will be gone |
||
stats = _get_dataset_stats(self, dataset_info) | ||
self._cached_stats = stats | ||
return stats | ||
|
||
|
||
def _get_dataset_stats(dataset: Union[Dataset, Subset], dataset_info: DatasetInfo): | ||
frostedoyster marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Returns the statistics of a dataset or subset as a string.""" | ||
|
||
dataset_len = len(dataset) | ||
stats = f"Dataset of size {dataset_len}" | ||
if dataset_len == 0: | ||
return stats | ||
|
||
target_names = [] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes but these are different. They also include the gradients. The variable name can be changed if you think that's a good idea. Something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Name is fine but maybe add a comment that they are different. |
||
for key, tensor_map in dataset[0].items(): | ||
if key == "system": | ||
continue | ||
target_names.append(key) | ||
gradients_list = tensor_map.block(0).gradients_list() | ||
for gradient in gradients_list: | ||
target_names.append(f"{key}_{gradient}_gradients") | ||
|
||
sums = {key: 0.0 for key in target_names} | ||
n_elements = {key: 0 for key in target_names} | ||
for sample in dataset: | ||
for key in target_names: | ||
if "_gradients" not in key: # not a gradient | ||
tensors = [block.values for block in sample[key].blocks()] | ||
else: | ||
original_key = key.split("_")[0] | ||
gradient_name = key.replace(f"{original_key}_", "").replace( | ||
"_gradients", "" | ||
) | ||
tensors = [ | ||
block.gradient(gradient_name).values | ||
for block in sample[original_key].blocks() | ||
] | ||
sums[key] += sum(tensor.sum() for tensor in tensors) | ||
n_elements[key] += sum(tensor.numel() for tensor in tensors) | ||
means = {key: sums[key] / n_elements[key] for key in target_names} | ||
|
||
sum_of_squared_residuals = {key: 0.0 for key in target_names} | ||
for sample in dataset: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do you sum twice over the dataset? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Two iterations: one for the mean, one for the std (for the std you already need to know the mean) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, you don't. You save a sum and sum of squares and do the mean and the standard deviation afterwords. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're right |
||
for key in target_names: | ||
if "_gradients" not in key: # not a gradient | ||
tensors = [block.values for block in sample[key].blocks()] | ||
else: | ||
original_key = key.split("_")[0] | ||
gradient_name = key.replace(f"{original_key}_", "").replace( | ||
"_gradients", "" | ||
) | ||
tensors = [ | ||
block.gradient(gradient_name).values | ||
for block in sample[original_key].blocks() | ||
] | ||
sum_of_squared_residuals[key] += sum( | ||
((tensor - means[key]) ** 2).sum() for tensor in tensors | ||
) | ||
stds = { | ||
key: (sum_of_squared_residuals[key] / n_elements[key]) ** 0.5 | ||
for key in target_names | ||
} | ||
|
||
stats += "\n Mean and standard deviation of targets:" | ||
for key in target_names: | ||
stats += ( | ||
f"\n - {to_external_name(key, dataset_info.targets)}: " # type: ignore | ||
f"mean={means[key]:.3e}, std={stds[key]:.3e}" | ||
) | ||
|
||
return stats | ||
|
||
|
||
def get_atomic_types(datasets: Union[Dataset, List[Dataset]]) -> Set[int]: | ||
"""List of all atomic types present in a dataset or list of datasets. | ||
|
||
|
@@ -344,15 +437,49 @@ | |
train_size: float, | ||
test_size: float, | ||
generator: Optional[Generator] = default_generator, | ||
) -> List[Subset]: | ||
) -> List[Dataset]: | ||
if train_size <= 0: | ||
raise ValueError("Fraction of the train set is smaller or equal to 0!") | ||
|
||
# normalize fractions | ||
lengths = torch.tensor([train_size, test_size]) | ||
lengths /= lengths.sum() | ||
|
||
return random_split(dataset=train_dataset, lengths=lengths, generator=generator) | ||
if math.isclose(sum(lengths), 1) and sum(lengths) <= 1: | ||
subset_lengths: List[int] = [] | ||
for i, frac in enumerate(lengths): | ||
if frac < 0 or frac > 1: | ||
raise ValueError(f"Fraction at index {i} is not between 0 and 1") | ||
n_items_in_split = int( | ||
math.floor(len(train_dataset) * frac) # type: ignore[arg-type] | ||
) | ||
subset_lengths.append(n_items_in_split) | ||
remainder = len(train_dataset) - sum(subset_lengths) # type: ignore[arg-type] | ||
# add 1 to all the lengths in round-robin fashion until the remainder is 0 | ||
for i in range(remainder): | ||
idx_to_add_at = i % len(subset_lengths) | ||
subset_lengths[idx_to_add_at] += 1 | ||
lengths = subset_lengths | ||
for i, length in enumerate(lengths): | ||
if length == 0: | ||
warnings.warn( | ||
f"Length of split at index {i} is 0. " | ||
f"This might result in an empty dataset.", | ||
UserWarning, | ||
stacklevel=2, | ||
) | ||
|
||
# Cannot verify that train_dataset is Sized | ||
if sum(lengths) != len(train_dataset): # type: ignore[arg-type] | ||
raise ValueError( | ||
"Sum of input lengths does not equal the length of the input dataset!" | ||
) | ||
|
||
indices = torch.randperm(sum(lengths), generator=generator).tolist() | ||
return [ | ||
Subset(train_dataset, indices[offset - length : offset]) | ||
for offset, length in zip(itertools.accumulate(lengths), lengths) | ||
] | ||
|
||
|
||
def group_and_join( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps not for this PR, but I think this is merged now so should allow for these prefixes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not yet released, but we should do a patch release with this