Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix combined dataloader bug #163

Merged
merged 10 commits into from
Apr 4, 2024
2 changes: 2 additions & 0 deletions docs/src/dev-docs/utils/data/combine_dataloaders.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ Combining dataloaders
:members:
:undoc-members:
:show-inheritance:
:special-members:
:exclude-members: __init__, reset, __iter__, __next__
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def test_regression_train():
)

expected_output = torch.tensor(
[[-118.6454], [-106.1644], [-137.0310], [-164.7832], [-139.8678]]
[[-123.0245], [-109.3167], [-129.6946], [-160.1561], [-138.4090]]
)

torch.testing.assert_close(
Expand Down
19 changes: 3 additions & 16 deletions src/metatensor/models/experimental/alchemical_model/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@

from ...utils.composition import calculate_composition_weights
from ...utils.data import (
CombinedDataLoader,
DatasetInfo,
check_datasets,
collate_fn,
combine_dataloaders,
get_all_species,
get_all_targets,
)
Expand Down Expand Up @@ -163,7 +163,7 @@ def train(
collate_fn=collate_fn,
)
)
train_dataloader = combine_dataloaders(train_dataloaders, shuffle=True)
train_dataloader = CombinedDataLoader(train_dataloaders, shuffle=True)

# Create dataloader for the validation datasets:
validation_dataloaders = []
Expand All @@ -176,20 +176,7 @@ def train(
collate_fn=collate_fn,
)
)
validation_dataloader = combine_dataloaders(validation_dataloaders, shuffle=False)

# Create dataloader for the validation datasets:
validation_dataloaders = []
for dataset in validation_datasets:
validation_dataloaders.append(
DataLoader(
dataset=dataset,
batch_size=hypers_training["batch_size"],
shuffle=False,
collate_fn=collate_fn,
)
)
validation_dataloader = combine_dataloaders(validation_dataloaders, shuffle=False)
validation_dataloader = CombinedDataLoader(validation_dataloaders, shuffle=False)

# Extract all the possible outputs and their gradients from the training set:
outputs_dict = get_outputs_dict(train_datasets)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def test_regression_init():
[systems_to_torch(system) for system in systems],
{"U0": soap_bpnn.capabilities.outputs["U0"]},
)
expected_output = torch.tensor([[0.3964], [0.0813], [0.0491], [0.2726], [0.4292]])
expected_output = torch.tensor(
[[-0.0840], [0.0352], [0.0389], [-0.3115], [-0.1372]]
)

torch.testing.assert_close(
output["U0"].block().values, expected_output, rtol=1e-3, atol=1e-08
Expand Down Expand Up @@ -87,7 +89,7 @@ def test_regression_train():
output = soap_bpnn(systems[:5], {"U0": soap_bpnn.capabilities.outputs["U0"]})

expected_output = torch.tensor(
[[-40.4551], [-56.5427], [-76.3641], [-77.3653], [-93.4208]]
[[-40.4913], [-56.5962], [-76.5165], [-77.3447], [-93.4256]]
)

torch.testing.assert_close(
Expand Down
6 changes: 3 additions & 3 deletions src/metatensor/models/experimental/soap_bpnn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

from ...utils.composition import calculate_composition_weights
from ...utils.data import (
CombinedDataLoader,
DatasetInfo,
check_datasets,
collate_fn,
combine_dataloaders,
get_all_species,
get_all_targets,
)
Expand Down Expand Up @@ -173,7 +173,7 @@ def train(
collate_fn=collate_fn,
)
)
train_dataloader = combine_dataloaders(train_dataloaders, shuffle=True)
train_dataloader = CombinedDataLoader(train_dataloaders, shuffle=True)

# Create dataloader for the validation datasets:
validation_dataloaders = []
Expand All @@ -186,7 +186,7 @@ def train(
collate_fn=collate_fn,
)
)
validation_dataloader = combine_dataloaders(validation_dataloaders, shuffle=False)
validation_dataloader = CombinedDataLoader(validation_dataloaders, shuffle=False)

# Extract all the possible outputs and their gradients from the training set:
outputs_dict = get_outputs_dict(train_datasets)
Expand Down
2 changes: 1 addition & 1 deletion src/metatensor/models/utils/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@
)

from .writers import write_predictions # noqa: F401
from .combine_dataloaders import combine_dataloaders # noqa: F401
from .combine_dataloaders import CombinedDataLoader # noqa: F401
from .system_to_ase import system_to_ase # noqa: F401
63 changes: 33 additions & 30 deletions src/metatensor/models/utils/data/combine_dataloaders.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,60 @@
import itertools
from typing import List

import numpy as np
import torch


class CombinedIterableDataset(torch.utils.data.IterableDataset):
class CombinedDataLoader:
"""
Combines multiple dataloaders into a single iterable dataset.
This is useful for combining multiple dataloaders into a single
dataloader. The new dataloader can be shuffled or not.
Combines multiple dataloaders into a single dataloader.

This is useful for learning from multiple datasets at the same time,
each of which may have different batch sizes, properties, etc.

:param dataloaders: list of dataloaders to combine
:param shuffle: whether to shuffle the combined dataloader
:param shuffle: whether to shuffle the combined dataloader (this does not
act on the individual batches, but it shuffles the order in which
they are returned)

:return: combined dataloader
:return: the combined dataloader
"""

def __init__(self, dataloaders, shuffle):
def __init__(self, dataloaders: List[torch.utils.data.DataLoader], shuffle: bool):
self.dataloaders = dataloaders
self.shuffle = shuffle

# Create the indices:
indices = [
(i, dl_idx)
for dl_idx, dl in enumerate(self.dataloaders)
for i in range(len(dl))
]
self.indices = list(range(len(self)))

# Shuffle the indices if requested
if self.shuffle:
np.random.shuffle(indices)
np.random.shuffle(self.indices)

self.reset()

self.indices = indices
def reset(self):
self.current_index = 0
self.full_list = [batch for dl in self.dataloaders for batch in dl]

def __iter__(self):
for idx, dataloader_idx in self.indices:
yield next(itertools.islice(self.dataloaders[dataloader_idx], idx, None))
return self
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this working? Shouldn't you return the next dataloader here instead of the full instance?

Copy link
Collaborator Author

@frostedoyster frostedoyster Apr 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does work. It comes from ChatGPT, but intuitively it makes sense:
By iterating, you effectively call iterable = iter(dataloader) and then next(iterable) (a bunch of times). Next is defined in the class, so it makes sense to return self

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay then please add a test for iterator.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's there already from the old function + class


def __len__(self):
return len(self.indices)
def __next__(self):
if self.current_index >= len(self.indices):
self.reset() # Reset the index for the next iteration
raise StopIteration

idx = self.indices[self.current_index]
self.current_index += 1
return self.full_list[idx]

def combine_dataloaders(
dataloaders: List[torch.utils.data.DataLoader], shuffle: bool = True
):
"""
Combines multiple dataloaders into a single dataloader.
def __len__(self):
"""Returns the total number of batches in all dataloaders.

:param dataloaders: list of dataloaders to combine
:param shuffle: whether to shuffle the combined dataloader
This returns the total number of batches in all dataloaders
(as opposed to the total number of samples or the number of
individual dataloaders).

:return: combined dataloader
"""
combined_dataset = CombinedIterableDataset(dataloaders, shuffle)
return torch.utils.data.DataLoader(combined_dataset, batch_size=None)
:return: the total number of batches in all dataloaders
"""
return sum(len(dl) for dl in self.dataloaders)
27 changes: 22 additions & 5 deletions tests/utils/data/test_combine_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from omegaconf import OmegaConf

from metatensor.models.utils.data import (
CombinedDataLoader,
collate_fn,
combine_dataloaders,
read_systems,
read_targets,
)
Expand Down Expand Up @@ -56,7 +56,7 @@ def test_without_shuffling():
dataloader_alchemical = DataLoader(dataset, batch_size=2, collate_fn=collate_fn)
# will yield 5 batches of 2

combined_dataloader = combine_dataloaders(
combined_dataloader = CombinedDataLoader(
[dataloader_qm9, dataloader_alchemical], shuffle=False
)

Expand Down Expand Up @@ -88,7 +88,9 @@ def test_with_shuffling():
}
targets = read_targets(OmegaConf.create(conf))
dataset = Dataset(system=systems, U0=targets["U0"])
dataloader_qm9 = DataLoader(dataset, batch_size=10, collate_fn=collate_fn)
dataloader_qm9 = DataLoader(
dataset, batch_size=10, collate_fn=collate_fn, shuffle=True
)
# will yield 10 batches of 10

systems = read_systems(RESOURCES_PATH / "alchemical_reduced_10.xyz")
Expand All @@ -106,10 +108,12 @@ def test_with_shuffling():
}
targets = read_targets(OmegaConf.create(conf))
dataset = Dataset(system=systems, free_energy=targets["free_energy"])
dataloader_alchemical = DataLoader(dataset, batch_size=2, collate_fn=collate_fn)
dataloader_alchemical = DataLoader(
dataset, batch_size=2, collate_fn=collate_fn, shuffle=True
)
# will yield 5 batches of 2

combined_dataloader = combine_dataloaders(
combined_dataloader = CombinedDataLoader(
[dataloader_qm9, dataloader_alchemical], shuffle=True
)

Expand All @@ -119,17 +123,30 @@ def test_with_shuffling():
alchemical_batch_count = 0
original_ordering = ["qm9"] * 10 + ["alchemical"] * 5
actual_ordering = []
qm9_samples = []
alchemical_samples = []

for batch in combined_dataloader:
if "U0" in batch[1]:
qm9_batch_count += 1
assert batch[1]["U0"].block().values.shape == (10, 1)
actual_ordering.append("qm9")
qm9_samples.append(batch[1]["U0"].block().samples.column("system"))
else:
alchemical_batch_count += 1
assert batch[1]["free_energy"].block().values.shape == (2, 1)
actual_ordering.append("alchemical")
alchemical_samples.append(
batch[1]["free_energy"].block().samples.column("system")
)

assert qm9_batch_count == 10
assert alchemical_batch_count == 5
assert actual_ordering != original_ordering

qm9_samples = [int(item) for sublist in qm9_samples for item in sublist]
alchemical_samples = [
int(item) for sublist in alchemical_samples for item in sublist
]
assert set(qm9_samples) == set(range(100))
assert set(alchemical_samples) == set(range(10))