diff --git a/docs/src/dev-docs/utils/data/combine_dataloaders.rst b/docs/src/dev-docs/utils/data/combine_dataloaders.rst index 0aff64b69..29db5ab7f 100644 --- a/docs/src/dev-docs/utils/data/combine_dataloaders.rst +++ b/docs/src/dev-docs/utils/data/combine_dataloaders.rst @@ -5,3 +5,5 @@ Combining dataloaders :members: :undoc-members: :show-inheritance: + :special-members: + :exclude-members: __init__, reset, __iter__, __next__ diff --git a/src/metatensor/models/experimental/alchemical_model/tests/test_regression.py b/src/metatensor/models/experimental/alchemical_model/tests/test_regression.py index 87da84068..e44928ad6 100644 --- a/src/metatensor/models/experimental/alchemical_model/tests/test_regression.py +++ b/src/metatensor/models/experimental/alchemical_model/tests/test_regression.py @@ -134,7 +134,7 @@ def test_regression_train(): ) expected_output = torch.tensor( - [[-118.6454], [-106.1644], [-137.0310], [-164.7832], [-139.8678]] + [[-123.0245], [-109.3167], [-129.6946], [-160.1561], [-138.4090]] ) torch.testing.assert_close( diff --git a/src/metatensor/models/experimental/alchemical_model/train.py b/src/metatensor/models/experimental/alchemical_model/train.py index eac003976..b3364b667 100644 --- a/src/metatensor/models/experimental/alchemical_model/train.py +++ b/src/metatensor/models/experimental/alchemical_model/train.py @@ -10,10 +10,10 @@ from ...utils.composition import calculate_composition_weights from ...utils.data import ( + CombinedDataLoader, DatasetInfo, check_datasets, collate_fn, - combine_dataloaders, get_all_species, get_all_targets, ) @@ -163,7 +163,7 @@ def train( collate_fn=collate_fn, ) ) - train_dataloader = combine_dataloaders(train_dataloaders, shuffle=True) + train_dataloader = CombinedDataLoader(train_dataloaders, shuffle=True) # Create dataloader for the validation datasets: validation_dataloaders = [] @@ -176,20 +176,7 @@ def train( collate_fn=collate_fn, ) ) - validation_dataloader = combine_dataloaders(validation_dataloaders, shuffle=False) - - # Create dataloader for the validation datasets: - validation_dataloaders = [] - for dataset in validation_datasets: - validation_dataloaders.append( - DataLoader( - dataset=dataset, - batch_size=hypers_training["batch_size"], - shuffle=False, - collate_fn=collate_fn, - ) - ) - validation_dataloader = combine_dataloaders(validation_dataloaders, shuffle=False) + validation_dataloader = CombinedDataLoader(validation_dataloaders, shuffle=False) # Extract all the possible outputs and their gradients from the training set: outputs_dict = get_outputs_dict(train_datasets) diff --git a/src/metatensor/models/experimental/soap_bpnn/tests/test_regression.py b/src/metatensor/models/experimental/soap_bpnn/tests/test_regression.py index 35ade310f..9a24f2330 100644 --- a/src/metatensor/models/experimental/soap_bpnn/tests/test_regression.py +++ b/src/metatensor/models/experimental/soap_bpnn/tests/test_regression.py @@ -42,7 +42,9 @@ def test_regression_init(): [systems_to_torch(system) for system in systems], {"U0": soap_bpnn.capabilities.outputs["U0"]}, ) - expected_output = torch.tensor([[0.3964], [0.0813], [0.0491], [0.2726], [0.4292]]) + expected_output = torch.tensor( + [[-0.0840], [0.0352], [0.0389], [-0.3115], [-0.1372]] + ) torch.testing.assert_close( output["U0"].block().values, expected_output, rtol=1e-3, atol=1e-08 @@ -87,7 +89,7 @@ def test_regression_train(): output = soap_bpnn(systems[:5], {"U0": soap_bpnn.capabilities.outputs["U0"]}) expected_output = torch.tensor( - [[-40.4551], [-56.5427], [-76.3641], [-77.3653], [-93.4208]] + [[-40.4913], [-56.5962], [-76.5165], [-77.3447], [-93.4256]] ) torch.testing.assert_close( diff --git a/src/metatensor/models/experimental/soap_bpnn/train.py b/src/metatensor/models/experimental/soap_bpnn/train.py index 02556da4a..ec11d0143 100644 --- a/src/metatensor/models/experimental/soap_bpnn/train.py +++ b/src/metatensor/models/experimental/soap_bpnn/train.py @@ -11,10 +11,10 @@ from ...utils.composition import calculate_composition_weights from ...utils.data import ( + CombinedDataLoader, DatasetInfo, check_datasets, collate_fn, - combine_dataloaders, get_all_species, get_all_targets, ) @@ -173,7 +173,7 @@ def train( collate_fn=collate_fn, ) ) - train_dataloader = combine_dataloaders(train_dataloaders, shuffle=True) + train_dataloader = CombinedDataLoader(train_dataloaders, shuffle=True) # Create dataloader for the validation datasets: validation_dataloaders = [] @@ -186,7 +186,7 @@ def train( collate_fn=collate_fn, ) ) - validation_dataloader = combine_dataloaders(validation_dataloaders, shuffle=False) + validation_dataloader = CombinedDataLoader(validation_dataloaders, shuffle=False) # Extract all the possible outputs and their gradients from the training set: outputs_dict = get_outputs_dict(train_datasets) diff --git a/src/metatensor/models/utils/data/__init__.py b/src/metatensor/models/utils/data/__init__.py index eb883ee5a..aff18af9b 100644 --- a/src/metatensor/models/utils/data/__init__.py +++ b/src/metatensor/models/utils/data/__init__.py @@ -16,5 +16,5 @@ ) from .writers import write_predictions # noqa: F401 -from .combine_dataloaders import combine_dataloaders # noqa: F401 +from .combine_dataloaders import CombinedDataLoader # noqa: F401 from .system_to_ase import system_to_ase # noqa: F401 diff --git a/src/metatensor/models/utils/data/combine_dataloaders.py b/src/metatensor/models/utils/data/combine_dataloaders.py index 88bf3f15a..8e839b039 100644 --- a/src/metatensor/models/utils/data/combine_dataloaders.py +++ b/src/metatensor/models/utils/data/combine_dataloaders.py @@ -1,57 +1,60 @@ -import itertools from typing import List import numpy as np import torch -class CombinedIterableDataset(torch.utils.data.IterableDataset): +class CombinedDataLoader: """ - Combines multiple dataloaders into a single iterable dataset. - This is useful for combining multiple dataloaders into a single - dataloader. The new dataloader can be shuffled or not. + Combines multiple dataloaders into a single dataloader. + + This is useful for learning from multiple datasets at the same time, + each of which may have different batch sizes, properties, etc. :param dataloaders: list of dataloaders to combine - :param shuffle: whether to shuffle the combined dataloader + :param shuffle: whether to shuffle the combined dataloader (this does not + act on the individual batches, but it shuffles the order in which + they are returned) - :return: combined dataloader + :return: the combined dataloader """ - def __init__(self, dataloaders, shuffle): + def __init__(self, dataloaders: List[torch.utils.data.DataLoader], shuffle: bool): self.dataloaders = dataloaders self.shuffle = shuffle # Create the indices: - indices = [ - (i, dl_idx) - for dl_idx, dl in enumerate(self.dataloaders) - for i in range(len(dl)) - ] + self.indices = list(range(len(self))) # Shuffle the indices if requested if self.shuffle: - np.random.shuffle(indices) + np.random.shuffle(self.indices) + + self.reset() - self.indices = indices + def reset(self): + self.current_index = 0 + self.full_list = [batch for dl in self.dataloaders for batch in dl] def __iter__(self): - for idx, dataloader_idx in self.indices: - yield next(itertools.islice(self.dataloaders[dataloader_idx], idx, None)) + return self - def __len__(self): - return len(self.indices) + def __next__(self): + if self.current_index >= len(self.indices): + self.reset() # Reset the index for the next iteration + raise StopIteration + idx = self.indices[self.current_index] + self.current_index += 1 + return self.full_list[idx] -def combine_dataloaders( - dataloaders: List[torch.utils.data.DataLoader], shuffle: bool = True -): - """ - Combines multiple dataloaders into a single dataloader. + def __len__(self): + """Returns the total number of batches in all dataloaders. - :param dataloaders: list of dataloaders to combine - :param shuffle: whether to shuffle the combined dataloader + This returns the total number of batches in all dataloaders + (as opposed to the total number of samples or the number of + individual dataloaders). - :return: combined dataloader - """ - combined_dataset = CombinedIterableDataset(dataloaders, shuffle) - return torch.utils.data.DataLoader(combined_dataset, batch_size=None) + :return: the total number of batches in all dataloaders + """ + return sum(len(dl) for dl in self.dataloaders) diff --git a/tests/utils/data/test_combine_dataloaders.py b/tests/utils/data/test_combine_dataloaders.py index 278bdf243..5676ceb49 100644 --- a/tests/utils/data/test_combine_dataloaders.py +++ b/tests/utils/data/test_combine_dataloaders.py @@ -5,8 +5,8 @@ from omegaconf import OmegaConf from metatensor.models.utils.data import ( + CombinedDataLoader, collate_fn, - combine_dataloaders, read_systems, read_targets, ) @@ -56,7 +56,7 @@ def test_without_shuffling(): dataloader_alchemical = DataLoader(dataset, batch_size=2, collate_fn=collate_fn) # will yield 5 batches of 2 - combined_dataloader = combine_dataloaders( + combined_dataloader = CombinedDataLoader( [dataloader_qm9, dataloader_alchemical], shuffle=False ) @@ -88,7 +88,9 @@ def test_with_shuffling(): } targets = read_targets(OmegaConf.create(conf)) dataset = Dataset(system=systems, U0=targets["U0"]) - dataloader_qm9 = DataLoader(dataset, batch_size=10, collate_fn=collate_fn) + dataloader_qm9 = DataLoader( + dataset, batch_size=10, collate_fn=collate_fn, shuffle=True + ) # will yield 10 batches of 10 systems = read_systems(RESOURCES_PATH / "alchemical_reduced_10.xyz") @@ -106,10 +108,12 @@ def test_with_shuffling(): } targets = read_targets(OmegaConf.create(conf)) dataset = Dataset(system=systems, free_energy=targets["free_energy"]) - dataloader_alchemical = DataLoader(dataset, batch_size=2, collate_fn=collate_fn) + dataloader_alchemical = DataLoader( + dataset, batch_size=2, collate_fn=collate_fn, shuffle=True + ) # will yield 5 batches of 2 - combined_dataloader = combine_dataloaders( + combined_dataloader = CombinedDataLoader( [dataloader_qm9, dataloader_alchemical], shuffle=True ) @@ -119,17 +123,30 @@ def test_with_shuffling(): alchemical_batch_count = 0 original_ordering = ["qm9"] * 10 + ["alchemical"] * 5 actual_ordering = [] + qm9_samples = [] + alchemical_samples = [] for batch in combined_dataloader: if "U0" in batch[1]: qm9_batch_count += 1 assert batch[1]["U0"].block().values.shape == (10, 1) actual_ordering.append("qm9") + qm9_samples.append(batch[1]["U0"].block().samples.column("system")) else: alchemical_batch_count += 1 assert batch[1]["free_energy"].block().values.shape == (2, 1) actual_ordering.append("alchemical") + alchemical_samples.append( + batch[1]["free_energy"].block().samples.column("system") + ) assert qm9_batch_count == 10 assert alchemical_batch_count == 5 assert actual_ordering != original_ordering + + qm9_samples = [int(item) for sublist in qm9_samples for item in sublist] + alchemical_samples = [ + int(item) for sublist in alchemical_samples for item in sublist + ] + assert set(qm9_samples) == set(range(100)) + assert set(alchemical_samples) == set(range(10))