Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement get_stats for Dataset and print it before training #251

Merged
merged 12 commits into from
Jun 13, 2024
39 changes: 22 additions & 17 deletions src/metatrain/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,20 @@ def train_model(
)
validation_datasets.append(validation_dataset)

###########################
# CREATE DATASET_INFO #####
###########################

atomic_types = get_atomic_types(
train_datasets + train_datasets + validation_datasets
)

dataset_info = DatasetInfo(
length_unit=train_options_list[0]["systems"]["length_unit"],
atomic_types=atomic_types,
targets=target_infos,
)

###########################
# PRINT DATASET STATS #####
###########################
Expand All @@ -317,21 +331,26 @@ def train_model(
index = ""
else:
index = f" {i}"
logger.info(f"Training dataset{index} has size {len(train_dataset)}")
logger.info(
f"Training dataset{index}:\n {train_dataset.get_stats(dataset_info)}"
)

for i, validation_dataset in enumerate(validation_datasets):
if len(validation_datasets) == 1:
index = ""
else:
index = f" {i}"
logger.info(f"Validation dataset{index} has size {len(validation_dataset)}")
logger.info(
f"Validation dataset{index}:\n "
f"{validation_dataset.get_stats(dataset_info)}"
)

for i, test_dataset in enumerate(test_datasets):
if len(test_datasets) == 1:
index = ""
else:
index = f" {i}"
logger.info(f"Test dataset{index} has size {len(test_dataset)}")
logger.info(f"Test dataset{index}:\n {test_dataset.get_stats(dataset_info)}")

###########################
# SAVE EXPANDED OPTIONS ###
Expand All @@ -341,20 +360,6 @@ def train_model(
config=options, f=Path(checkpoint_dir) / "options_restart.yaml", resolve=True
)

###########################
# CREATE DATASET_INFO #####
###########################

atomic_types = get_atomic_types(
train_datasets + train_datasets + validation_datasets
)

dataset_info = DatasetInfo(
length_unit=train_options_list[0]["systems"]["length_unit"],
atomic_types=atomic_types,
targets=target_infos,
)

###########################
# SETTING UP MODEL ########
###########################
Expand Down
207 changes: 167 additions & 40 deletions src/metatrain/utils/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import itertools
import math
import warnings
from collections import UserDict
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Set, Tuple, Union
Expand All @@ -6,45 +9,8 @@
import torch
from metatensor.torch import TensorMap
from torch import Generator, default_generator
from torch.utils.data import Subset, random_split


class Dataset:
"""A version of the `metatensor.learn.Dataset` class that allows for
the use of `mtm::` prefixes in the keys of the dictionary. See
https://github.com/lab-cosmo/metatensor/issues/621.

It is important to note that, instead of named tuples, this class
accepts and returns dictionaries.

:param dict: A dictionary with the data to be stored in the dataset.
"""

def __init__(self, dict: Dict):

new_dict = {}
for key, value in dict.items():
key = key.replace("mtm::", "mtm_")
new_dict[key] = value

self.mts_learn_dataset = metatensor.learn.Dataset(**new_dict)

def __getitem__(self, idx: int) -> Dict:

mts_dataset_item = self.mts_learn_dataset[idx]._asdict()
new_dict = {}
for key, value in mts_dataset_item.items():
key = key.replace("mtm_", "mtm::")
new_dict[key] = value

return new_dict

def __len__(self) -> int:
return len(self.mts_learn_dataset)

def __iter__(self):
for i in range(len(self)):
yield self[i]
from ..external_naming import to_external_name


@dataclass
Expand Down Expand Up @@ -230,6 +196,133 @@
return new


class Dataset:
"""A version of the `metatensor.learn.Dataset` class that allows for
the use of `mtm::` prefixes in the keys of the dictionary. See
https://github.com/lab-cosmo/metatensor/issues/621.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps not for this PR, but I think this is merged now so should allow for these prefixes

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not yet released, but we should do a patch release with this


It is important to note that, instead of named tuples, this class
accepts and returns dictionaries.

:param dict: A dictionary with the data to be stored in the dataset.
"""

def __init__(self, dict: Dict):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we keep a __repr__ that is also useful without the DatasetInfo?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see the point if we don't use it...


new_dict = {}
for key, value in dict.items():
key = key.replace("mtm::", "mtm_")
new_dict[key] = value

self.mts_learn_dataset = metatensor.learn.Dataset(**new_dict)

def __getitem__(self, idx: int) -> Dict:

mts_dataset_item = self.mts_learn_dataset[idx]._asdict()
new_dict = {}
for key, value in mts_dataset_item.items():
key = key.replace("mtm_", "mtm::")
new_dict[key] = value

return new_dict

def __len__(self) -> int:
return len(self.mts_learn_dataset)

def __iter__(self):
for i in range(len(self)):
yield self[i]

def get_stats(self, dataset_info: DatasetInfo) -> str:
if hasattr(self, "_cached_stats"):
return self._cached_stats # type: ignore

Check warning on line 238 in src/metatrain/utils/data/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/data/dataset.py#L238

Added line #L238 was not covered by tests
stats = _get_dataset_stats(self, dataset_info)
self._cached_stats = stats
return stats


class Subset(torch.utils.data.Subset):
"""
A version of `torch.utils.data.Subset` containing a `get_stats` method
allowing us to print information about atomistic datasets.
"""

def get_stats(self, dataset_info: DatasetInfo) -> str:
if hasattr(self, "_cached_stats"):
return self._cached_stats # type: ignore

Check warning on line 252 in src/metatrain/utils/data/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/data/dataset.py#L252

Added line #L252 was not covered by tests
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To please codecov you should call the get_stats twice. But, is this really necessary to cache this? Should not take super long to compute?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, I over-optimized. It will be gone

stats = _get_dataset_stats(self, dataset_info)
self._cached_stats = stats
return stats


def _get_dataset_stats(dataset: Union[Dataset, Subset], dataset_info: DatasetInfo):
frostedoyster marked this conversation as resolved.
Show resolved Hide resolved
"""Returns the statistics of a dataset or subset as a string."""

dataset_len = len(dataset)
stats = f"Dataset of size {dataset_len}"
if dataset_len == 0:
return stats

target_names = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought the target_names are in the datasetInfo?

Copy link
Collaborator Author

@frostedoyster frostedoyster Jun 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes but these are different. They also include the gradients. The variable name can be changed if you think that's a good idea. Something like target_names_with_gradients

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Name is fine but maybe add a comment that they are different.

for key, tensor_map in dataset[0].items():
if key == "system":
continue
target_names.append(key)
gradients_list = tensor_map.block(0).gradients_list()
for gradient in gradients_list:
target_names.append(f"{key}_{gradient}_gradients")

sums = {key: 0.0 for key in target_names}
n_elements = {key: 0 for key in target_names}
for sample in dataset:
for key in target_names:
if "_gradients" not in key: # not a gradient
tensors = [block.values for block in sample[key].blocks()]
else:
original_key = key.split("_")[0]
gradient_name = key.replace(f"{original_key}_", "").replace(
"_gradients", ""
)
tensors = [
block.gradient(gradient_name).values
for block in sample[original_key].blocks()
]
sums[key] += sum(tensor.sum() for tensor in tensors)
n_elements[key] += sum(tensor.numel() for tensor in tensors)
means = {key: sums[key] / n_elements[key] for key in target_names}

sum_of_squared_residuals = {key: 0.0 for key in target_names}
for sample in dataset:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you sum twice over the dataset?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two iterations: one for the mean, one for the std (for the std you already need to know the mean)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, you don't. You save a sum and sum of squares and do the mean and the standard deviation afterwords.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right

for key in target_names:
if "_gradients" not in key: # not a gradient
tensors = [block.values for block in sample[key].blocks()]
else:
original_key = key.split("_")[0]
gradient_name = key.replace(f"{original_key}_", "").replace(
"_gradients", ""
)
tensors = [
block.gradient(gradient_name).values
for block in sample[original_key].blocks()
]
sum_of_squared_residuals[key] += sum(
((tensor - means[key]) ** 2).sum() for tensor in tensors
)
stds = {
key: (sum_of_squared_residuals[key] / n_elements[key]) ** 0.5
for key in target_names
}

stats += "\n Mean and standard deviation of targets:"
for key in target_names:
stats += (
f"\n - {to_external_name(key, dataset_info.targets)}: " # type: ignore
f"mean={means[key]:.3e}, std={stds[key]:.3e}"
)

return stats


def get_atomic_types(datasets: Union[Dataset, List[Dataset]]) -> Set[int]:
"""List of all atomic types present in a dataset or list of datasets.

Expand Down Expand Up @@ -344,15 +437,49 @@
train_size: float,
test_size: float,
generator: Optional[Generator] = default_generator,
) -> List[Subset]:
) -> List[Dataset]:
if train_size <= 0:
raise ValueError("Fraction of the train set is smaller or equal to 0!")

# normalize fractions
lengths = torch.tensor([train_size, test_size])
lengths /= lengths.sum()

return random_split(dataset=train_dataset, lengths=lengths, generator=generator)
if math.isclose(sum(lengths), 1) and sum(lengths) <= 1:
subset_lengths: List[int] = []
for i, frac in enumerate(lengths):
if frac < 0 or frac > 1:
raise ValueError(f"Fraction at index {i} is not between 0 and 1")

Check warning on line 452 in src/metatrain/utils/data/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/data/dataset.py#L452

Added line #L452 was not covered by tests
n_items_in_split = int(
math.floor(len(train_dataset) * frac) # type: ignore[arg-type]
)
subset_lengths.append(n_items_in_split)
remainder = len(train_dataset) - sum(subset_lengths) # type: ignore[arg-type]
# add 1 to all the lengths in round-robin fashion until the remainder is 0
for i in range(remainder):
idx_to_add_at = i % len(subset_lengths)
subset_lengths[idx_to_add_at] += 1

Check warning on line 461 in src/metatrain/utils/data/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/data/dataset.py#L460-L461

Added lines #L460 - L461 were not covered by tests
lengths = subset_lengths
for i, length in enumerate(lengths):
if length == 0:
warnings.warn(
f"Length of split at index {i} is 0. "
f"This might result in an empty dataset.",
UserWarning,
stacklevel=2,
)

# Cannot verify that train_dataset is Sized
if sum(lengths) != len(train_dataset): # type: ignore[arg-type]
raise ValueError(

Check warning on line 474 in src/metatrain/utils/data/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/data/dataset.py#L474

Added line #L474 was not covered by tests
"Sum of input lengths does not equal the length of the input dataset!"
)

indices = torch.randperm(sum(lengths), generator=generator).tolist()
return [
Subset(train_dataset, indices[offset - length : offset])
for offset, length in zip(itertools.accumulate(lengths), lengths)
]


def group_and_join(
Expand Down
9 changes: 6 additions & 3 deletions tests/cli/test_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,12 @@ def test_train(capfd, monkeypatch, tmp_path, output):
for logtext in [stdout_log, file_log]:
assert "This log is also available"
assert re.search(r"Random seed of this run is [1-9]\d*", logtext)
assert "Training dataset has size"
assert "Validation dataset has size"
assert "Test dataset has size"
assert "Training dataset:" in logtext
assert "Validation dataset:" in logtext
assert "Test dataset:" in logtext
assert "size 50" in logtext
assert "mean=" in logtext
assert "std=" in logtext
assert "[INFO]" in logtext
assert "Epoch" in logtext
assert "loss" in logtext
Expand Down
52 changes: 52 additions & 0 deletions tests/utils/data/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,3 +507,55 @@ def test_collate_fn():
assert isinstance(batch[0], tuple)
assert len(batch[0]) == 3
assert isinstance(batch[1], dict)


def test_get_stats():
"""Tests the get_stats method of Dataset and Subset."""

systems = read_systems(RESOURCES_PATH / "qm9_reduced_100.xyz")
conf = {
"mtm::U0": {
"quantity": "energy",
"read_from": str(RESOURCES_PATH / "qm9_reduced_100.xyz"),
"file_format": ".xyz",
"key": "U0",
"unit": "eV",
"forces": False,
"stress": False,
"virial": False,
}
}
systems_2 = read_systems(RESOURCES_PATH / "ethanol_reduced_100.xyz")
conf_2 = {
"energy": {
"quantity": "energy",
"read_from": str(RESOURCES_PATH / "ethanol_reduced_100.xyz"),
"file_format": ".xyz",
"key": "energy",
"unit": "eV",
"forces": False,
"stress": False,
"virial": False,
}
}
targets, _ = read_targets(OmegaConf.create(conf))
targets_2, _ = read_targets(OmegaConf.create(conf_2))
dataset = Dataset({"system": systems, **targets})
dataset_2 = Dataset({"system": systems_2, **targets_2})

dataset_info = DatasetInfo(
length_unit="angstrom",
atomic_types={1, 6},
targets={
"mtm::U0": TargetInfo(quantity="energy", unit="eV"),
"energy": TargetInfo(quantity="energy", unit="eV"),
},
)

stats = dataset.get_stats(dataset_info)
stats_2 = dataset_2.get_stats(dataset_info)

assert "size 100" in stats
assert "mtm::U0" in stats
assert "energy" in stats_2
assert "stress" not in stats_2
Loading