From 5d09d14f848eedf48eb7751de0c3471ec07903ff Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Mon, 1 Apr 2024 20:30:55 +0200 Subject: [PATCH 1/7] Fix combined dataloader bug --- .../experimental/alchemical_model/train.py | 19 +------ .../models/experimental/soap_bpnn/train.py | 6 +- src/metatensor/models/utils/data/__init__.py | 2 +- .../models/utils/data/combine_dataloaders.py | 57 +++++++------------ tests/utils/data/test_combine_dataloaders.py | 27 +++++++-- 5 files changed, 48 insertions(+), 63 deletions(-) diff --git a/src/metatensor/models/experimental/alchemical_model/train.py b/src/metatensor/models/experimental/alchemical_model/train.py index d8d515439..0c529ad39 100644 --- a/src/metatensor/models/experimental/alchemical_model/train.py +++ b/src/metatensor/models/experimental/alchemical_model/train.py @@ -10,10 +10,10 @@ from ...utils.composition import calculate_composition_weights from ...utils.data import ( + CombinedDataLoader, DatasetInfo, check_datasets, collate_fn, - combine_dataloaders, get_all_species, get_all_targets, ) @@ -163,7 +163,7 @@ def train( collate_fn=collate_fn, ) ) - train_dataloader = combine_dataloaders(train_dataloaders, shuffle=True) + train_dataloader = CombinedDataLoader(train_dataloaders, shuffle=True) # Create dataloader for the validation datasets: validation_dataloaders = [] @@ -176,20 +176,7 @@ def train( collate_fn=collate_fn, ) ) - validation_dataloader = combine_dataloaders(validation_dataloaders, shuffle=False) - - # Create dataloader for the validation datasets: - validation_dataloaders = [] - for dataset in validation_datasets: - validation_dataloaders.append( - DataLoader( - dataset=dataset, - batch_size=hypers_training["batch_size"], - shuffle=False, - collate_fn=collate_fn, - ) - ) - validation_dataloader = combine_dataloaders(validation_dataloaders, shuffle=False) + validation_dataloader = CombinedDataLoader(validation_dataloaders, shuffle=False) # Extract all the possible outputs and their gradients from the training set: outputs_dict = get_outputs_dict(train_datasets) diff --git a/src/metatensor/models/experimental/soap_bpnn/train.py b/src/metatensor/models/experimental/soap_bpnn/train.py index 51d6ed935..bdb9a8015 100644 --- a/src/metatensor/models/experimental/soap_bpnn/train.py +++ b/src/metatensor/models/experimental/soap_bpnn/train.py @@ -11,10 +11,10 @@ from ...utils.composition import calculate_composition_weights from ...utils.data import ( + CombinedDataLoader, DatasetInfo, check_datasets, collate_fn, - combine_dataloaders, get_all_species, get_all_targets, ) @@ -173,7 +173,7 @@ def train( collate_fn=collate_fn, ) ) - train_dataloader = combine_dataloaders(train_dataloaders, shuffle=True) + train_dataloader = CombinedDataLoader(train_dataloaders, shuffle=True) # Create dataloader for the validation datasets: validation_dataloaders = [] @@ -186,7 +186,7 @@ def train( collate_fn=collate_fn, ) ) - validation_dataloader = combine_dataloaders(validation_dataloaders, shuffle=False) + validation_dataloader = CombinedDataLoader(validation_dataloaders, shuffle=False) # Extract all the possible outputs and their gradients from the training set: outputs_dict = get_outputs_dict(train_datasets) diff --git a/src/metatensor/models/utils/data/__init__.py b/src/metatensor/models/utils/data/__init__.py index eb883ee5a..aff18af9b 100644 --- a/src/metatensor/models/utils/data/__init__.py +++ b/src/metatensor/models/utils/data/__init__.py @@ -16,5 +16,5 @@ ) from .writers import write_predictions # noqa: F401 -from .combine_dataloaders import combine_dataloaders # noqa: F401 +from .combine_dataloaders import CombinedDataLoader # noqa: F401 from .system_to_ase import system_to_ase # noqa: F401 diff --git a/src/metatensor/models/utils/data/combine_dataloaders.py b/src/metatensor/models/utils/data/combine_dataloaders.py index 88bf3f15a..6259aee0a 100644 --- a/src/metatensor/models/utils/data/combine_dataloaders.py +++ b/src/metatensor/models/utils/data/combine_dataloaders.py @@ -1,57 +1,38 @@ -import itertools from typing import List import numpy as np import torch -class CombinedIterableDataset(torch.utils.data.IterableDataset): - """ - Combines multiple dataloaders into a single iterable dataset. - This is useful for combining multiple dataloaders into a single - dataloader. The new dataloader can be shuffled or not. - - :param dataloaders: list of dataloaders to combine - :param shuffle: whether to shuffle the combined dataloader - - :return: combined dataloader - """ - - def __init__(self, dataloaders, shuffle): +class CombinedDataLoader: + def __init__(self, dataloaders: List[torch.utils.data.DataLoader], shuffle: bool): self.dataloaders = dataloaders self.shuffle = shuffle # Create the indices: - indices = [ - (i, dl_idx) - for dl_idx, dl in enumerate(self.dataloaders) - for i in range(len(dl)) - ] + self.indices = list(range(len(self))) # Shuffle the indices if requested if self.shuffle: - np.random.shuffle(indices) - - self.indices = indices + np.random.shuffle(self.indices) - def __iter__(self): - for idx, dataloader_idx in self.indices: - yield next(itertools.islice(self.dataloaders[dataloader_idx], idx, None)) + self.reset() - def __len__(self): - return len(self.indices) + def reset(self): + self.current_index = 0 + self.full_list = [batch for dl in self.dataloaders for batch in dl] + def __iter__(self): + return self -def combine_dataloaders( - dataloaders: List[torch.utils.data.DataLoader], shuffle: bool = True -): - """ - Combines multiple dataloaders into a single dataloader. + def __next__(self): + if self.current_index >= len(self.indices): + self.reset() # Reset the index for the next iteration + raise StopIteration - :param dataloaders: list of dataloaders to combine - :param shuffle: whether to shuffle the combined dataloader + idx = self.indices[self.current_index] + self.current_index += 1 + return self.full_list[idx] - :return: combined dataloader - """ - combined_dataset = CombinedIterableDataset(dataloaders, shuffle) - return torch.utils.data.DataLoader(combined_dataset, batch_size=None) + def __len__(self): + return sum(len(dl) for dl in self.dataloaders) diff --git a/tests/utils/data/test_combine_dataloaders.py b/tests/utils/data/test_combine_dataloaders.py index 278bdf243..5676ceb49 100644 --- a/tests/utils/data/test_combine_dataloaders.py +++ b/tests/utils/data/test_combine_dataloaders.py @@ -5,8 +5,8 @@ from omegaconf import OmegaConf from metatensor.models.utils.data import ( + CombinedDataLoader, collate_fn, - combine_dataloaders, read_systems, read_targets, ) @@ -56,7 +56,7 @@ def test_without_shuffling(): dataloader_alchemical = DataLoader(dataset, batch_size=2, collate_fn=collate_fn) # will yield 5 batches of 2 - combined_dataloader = combine_dataloaders( + combined_dataloader = CombinedDataLoader( [dataloader_qm9, dataloader_alchemical], shuffle=False ) @@ -88,7 +88,9 @@ def test_with_shuffling(): } targets = read_targets(OmegaConf.create(conf)) dataset = Dataset(system=systems, U0=targets["U0"]) - dataloader_qm9 = DataLoader(dataset, batch_size=10, collate_fn=collate_fn) + dataloader_qm9 = DataLoader( + dataset, batch_size=10, collate_fn=collate_fn, shuffle=True + ) # will yield 10 batches of 10 systems = read_systems(RESOURCES_PATH / "alchemical_reduced_10.xyz") @@ -106,10 +108,12 @@ def test_with_shuffling(): } targets = read_targets(OmegaConf.create(conf)) dataset = Dataset(system=systems, free_energy=targets["free_energy"]) - dataloader_alchemical = DataLoader(dataset, batch_size=2, collate_fn=collate_fn) + dataloader_alchemical = DataLoader( + dataset, batch_size=2, collate_fn=collate_fn, shuffle=True + ) # will yield 5 batches of 2 - combined_dataloader = combine_dataloaders( + combined_dataloader = CombinedDataLoader( [dataloader_qm9, dataloader_alchemical], shuffle=True ) @@ -119,17 +123,30 @@ def test_with_shuffling(): alchemical_batch_count = 0 original_ordering = ["qm9"] * 10 + ["alchemical"] * 5 actual_ordering = [] + qm9_samples = [] + alchemical_samples = [] for batch in combined_dataloader: if "U0" in batch[1]: qm9_batch_count += 1 assert batch[1]["U0"].block().values.shape == (10, 1) actual_ordering.append("qm9") + qm9_samples.append(batch[1]["U0"].block().samples.column("system")) else: alchemical_batch_count += 1 assert batch[1]["free_energy"].block().values.shape == (2, 1) actual_ordering.append("alchemical") + alchemical_samples.append( + batch[1]["free_energy"].block().samples.column("system") + ) assert qm9_batch_count == 10 assert alchemical_batch_count == 5 assert actual_ordering != original_ordering + + qm9_samples = [int(item) for sublist in qm9_samples for item in sublist] + alchemical_samples = [ + int(item) for sublist in alchemical_samples for item in sublist + ] + assert set(qm9_samples) == set(range(100)) + assert set(alchemical_samples) == set(range(10)) From 9feb92cd688ad2c8f58288ce80e11b62dd1ac41f Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Tue, 2 Apr 2024 09:49:14 +0200 Subject: [PATCH 2/7] Add docstring --- .../models/utils/data/combine_dataloaders.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/metatensor/models/utils/data/combine_dataloaders.py b/src/metatensor/models/utils/data/combine_dataloaders.py index 6259aee0a..96082dd6c 100644 --- a/src/metatensor/models/utils/data/combine_dataloaders.py +++ b/src/metatensor/models/utils/data/combine_dataloaders.py @@ -5,7 +5,24 @@ class CombinedDataLoader: + """ + Combines multiple dataloaders into a single dataloader. + + This is useful for learning from multiple datasets at the same time, + each of which may have different batch sizes, properties, etc. + """ + def __init__(self, dataloaders: List[torch.utils.data.DataLoader], shuffle: bool): + """Creates the combined dataloader. + + :param dataloaders: list of dataloaders to combine + :param shuffle: whether to shuffle the combined dataloader (this does not + act on the individual batches, but it shuffles the order in which + they are returned) + + :return: the combined dataloader + """ + self.dataloaders = dataloaders self.shuffle = shuffle From 02b145a0d06a71f4812b35b002a51f2e9706f487 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Tue, 2 Apr 2024 09:51:23 +0200 Subject: [PATCH 3/7] Add comment for __len__ --- src/metatensor/models/utils/data/combine_dataloaders.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/metatensor/models/utils/data/combine_dataloaders.py b/src/metatensor/models/utils/data/combine_dataloaders.py index 96082dd6c..3daa43c85 100644 --- a/src/metatensor/models/utils/data/combine_dataloaders.py +++ b/src/metatensor/models/utils/data/combine_dataloaders.py @@ -52,4 +52,7 @@ def __next__(self): return self.full_list[idx] def __len__(self): + # this returns the total number of batches in all dataloaders + # (as opposed to the total number of samples or the number of + # individual dataloaders) return sum(len(dl) for dl in self.dataloaders) From 5c9bf268081f89cf26b3d1327117f3037f919219 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Thu, 4 Apr 2024 13:39:04 +0200 Subject: [PATCH 4/7] Debug regression test --- .../models/experimental/soap_bpnn/tests/test_regression.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/metatensor/models/experimental/soap_bpnn/tests/test_regression.py b/src/metatensor/models/experimental/soap_bpnn/tests/test_regression.py index 35ade310f..61ceacaf5 100644 --- a/src/metatensor/models/experimental/soap_bpnn/tests/test_regression.py +++ b/src/metatensor/models/experimental/soap_bpnn/tests/test_regression.py @@ -44,6 +44,7 @@ def test_regression_init(): ) expected_output = torch.tensor([[0.3964], [0.0813], [0.0491], [0.2726], [0.4292]]) + print(output["U0"].block().values) torch.testing.assert_close( output["U0"].block().values, expected_output, rtol=1e-3, atol=1e-08 ) @@ -90,6 +91,7 @@ def test_regression_train(): [[-40.4551], [-56.5427], [-76.3641], [-77.3653], [-93.4208]] ) + print(output["U0"].block().values) torch.testing.assert_close( output["U0"].block().values, expected_output, rtol=1e-3, atol=1e-08 ) From d61d2d4e78779fc71052a1088a9176e3a9802ba5 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Thu, 4 Apr 2024 15:44:54 +0200 Subject: [PATCH 5/7] Fix regression tests --- .../alchemical_model/tests/test_regression.py | 3 +-- .../experimental/soap_bpnn/tests/test_regression.py | 8 ++++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/metatensor/models/experimental/alchemical_model/tests/test_regression.py b/src/metatensor/models/experimental/alchemical_model/tests/test_regression.py index 90edd6b6f..e44928ad6 100644 --- a/src/metatensor/models/experimental/alchemical_model/tests/test_regression.py +++ b/src/metatensor/models/experimental/alchemical_model/tests/test_regression.py @@ -134,10 +134,9 @@ def test_regression_train(): ) expected_output = torch.tensor( - [[-118.6454], [-106.1644], [-137.0310], [-164.7832], [-139.8678]] + [[-123.0245], [-109.3167], [-129.6946], [-160.1561], [-138.4090]] ) - print(output["U0"].block().values) torch.testing.assert_close( output["U0"].block().values, expected_output, rtol=1e-05, atol=1e-4 ) diff --git a/src/metatensor/models/experimental/soap_bpnn/tests/test_regression.py b/src/metatensor/models/experimental/soap_bpnn/tests/test_regression.py index 61ceacaf5..9a24f2330 100644 --- a/src/metatensor/models/experimental/soap_bpnn/tests/test_regression.py +++ b/src/metatensor/models/experimental/soap_bpnn/tests/test_regression.py @@ -42,9 +42,10 @@ def test_regression_init(): [systems_to_torch(system) for system in systems], {"U0": soap_bpnn.capabilities.outputs["U0"]}, ) - expected_output = torch.tensor([[0.3964], [0.0813], [0.0491], [0.2726], [0.4292]]) + expected_output = torch.tensor( + [[-0.0840], [0.0352], [0.0389], [-0.3115], [-0.1372]] + ) - print(output["U0"].block().values) torch.testing.assert_close( output["U0"].block().values, expected_output, rtol=1e-3, atol=1e-08 ) @@ -88,10 +89,9 @@ def test_regression_train(): output = soap_bpnn(systems[:5], {"U0": soap_bpnn.capabilities.outputs["U0"]}) expected_output = torch.tensor( - [[-40.4551], [-56.5427], [-76.3641], [-77.3653], [-93.4208]] + [[-40.4913], [-56.5962], [-76.5165], [-77.3447], [-93.4256]] ) - print(output["U0"].block().values) torch.testing.assert_close( output["U0"].block().values, expected_output, rtol=1e-3, atol=1e-08 ) From 4be826e036f54340e454ce3e5bb5dfd5184cb130 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Thu, 4 Apr 2024 16:38:50 +0200 Subject: [PATCH 6/7] Suggestions from code review --- .../models/utils/data/combine_dataloaders.py | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/src/metatensor/models/utils/data/combine_dataloaders.py b/src/metatensor/models/utils/data/combine_dataloaders.py index 3daa43c85..8e839b039 100644 --- a/src/metatensor/models/utils/data/combine_dataloaders.py +++ b/src/metatensor/models/utils/data/combine_dataloaders.py @@ -10,19 +10,16 @@ class CombinedDataLoader: This is useful for learning from multiple datasets at the same time, each of which may have different batch sizes, properties, etc. - """ - - def __init__(self, dataloaders: List[torch.utils.data.DataLoader], shuffle: bool): - """Creates the combined dataloader. - :param dataloaders: list of dataloaders to combine - :param shuffle: whether to shuffle the combined dataloader (this does not - act on the individual batches, but it shuffles the order in which - they are returned) + :param dataloaders: list of dataloaders to combine + :param shuffle: whether to shuffle the combined dataloader (this does not + act on the individual batches, but it shuffles the order in which + they are returned) - :return: the combined dataloader - """ + :return: the combined dataloader + """ + def __init__(self, dataloaders: List[torch.utils.data.DataLoader], shuffle: bool): self.dataloaders = dataloaders self.shuffle = shuffle @@ -52,7 +49,12 @@ def __next__(self): return self.full_list[idx] def __len__(self): - # this returns the total number of batches in all dataloaders - # (as opposed to the total number of samples or the number of - # individual dataloaders) + """Returns the total number of batches in all dataloaders. + + This returns the total number of batches in all dataloaders + (as opposed to the total number of samples or the number of + individual dataloaders). + + :return: the total number of batches in all dataloaders + """ return sum(len(dl) for dl in self.dataloaders) From 2c0340e824a315259babf54b68438cfa9c43cc4e Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Thu, 4 Apr 2024 17:17:01 +0200 Subject: [PATCH 7/7] Show __len__ in docs --- docs/src/dev-docs/utils/data/combine_dataloaders.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/src/dev-docs/utils/data/combine_dataloaders.rst b/docs/src/dev-docs/utils/data/combine_dataloaders.rst index 0aff64b69..29db5ab7f 100644 --- a/docs/src/dev-docs/utils/data/combine_dataloaders.rst +++ b/docs/src/dev-docs/utils/data/combine_dataloaders.rst @@ -5,3 +5,5 @@ Combining dataloaders :members: :undoc-members: :show-inheritance: + :special-members: + :exclude-members: __init__, reset, __iter__, __next__