Skip to content

Commit

Permalink
Fix edge case in dataset split
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Jul 17, 2024
1 parent bedc9bb commit c7d9aae
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 90 deletions.
92 changes: 41 additions & 51 deletions src/metatrain/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,9 @@ def train_model(
torch.cuda.manual_seed(options["seed"])
torch.cuda.manual_seed_all(options["seed"])

###########################
# SETUP TRAINING SET ######
###########################
############################
# SET UP TRAINING SET ######
############################

logger.info("Setting up training set")
options["training_set"] = expand_dataset_config(options["training_set"])
Expand All @@ -192,91 +192,81 @@ def train_model(

train_size = 1.0

###########################
# SETUP TEST SET ##########
###########################

logger.info("Setting up test set")
test_datasets = []
if isinstance(options["test_set"], float):
test_size = options["test_set"]
train_size -= test_size
############################
# SET UP VALIDATION SET ####
############################

generator = torch.Generator()
if options["seed"] is not None:
generator.manual_seed(options["seed"])
logger.info("Setting up validation set")
val_datasets = []
if isinstance(options["validation_set"], float):
val_size = options["validation_set"]
train_size -= val_size

for i_dataset, train_dataset in enumerate(train_datasets):
train_dataset_new, test_dataset = _train_test_random_split(
train_dataset_new, val_dataset = _train_test_random_split(
train_dataset=train_dataset,
train_size=train_size,
test_size=test_size,
generator=generator,
test_size=val_size,
)

train_datasets[i_dataset] = train_dataset_new
test_datasets.append(test_dataset)
val_datasets.append(val_dataset)
else:
options["test_set"] = expand_dataset_config(options["test_set"])
options["validation_set"] = expand_dataset_config(options["validation_set"])

if len(options["test_set"]) != len(options["training_set"]):
if len(options["validation_set"]) != len(options["training_set"]):
raise ValueError(
f"Test dataset with length {len(options['test_set'])} has a different "
f"size than the training datatset with length "
f"Validation dataset with length {len(options['validation_set'])} has "
"a different size than the training datatset with length "
f"{len(options['training_set'])}."
)

check_units(
actual_options=options["test_set"],
actual_options=options["validation_set"],
desired_options=options["training_set"],
)

for test_options in options["test_set"]:
dataset, _ = get_dataset(test_options)
test_datasets.append(dataset)

###########################
# SETUP VALIDATION SET ####
###########################
for valid_options in options["validation_set"]:
dataset, _ = get_dataset(valid_options)
val_datasets.append(dataset)

logger.info("Setting up validation set")
val_datasets = []
if isinstance(options["validation_set"], float):
val_size = options["validation_set"]
train_size -= val_size
############################
# SET UP TEST SET ##########
############################

generator = torch.Generator()
if options["seed"] is not None:
generator.manual_seed(options["seed"])
logger.info("Setting up test set")
test_datasets = []
if isinstance(options["test_set"], float):
test_size = options["test_set"]
train_size -= test_size

for i_dataset, train_dataset in enumerate(train_datasets):
train_dataset_new, val_dataset = _train_test_random_split(
train_dataset_new, test_dataset = _train_test_random_split(
train_dataset=train_dataset,
train_size=train_size,
test_size=val_size,
generator=generator,
test_size=test_size,
)

train_datasets[i_dataset] = train_dataset_new
val_datasets.append(val_dataset)
test_datasets.append(test_dataset)
else:
options["validation_set"] = expand_dataset_config(options["validation_set"])
options["test_set"] = expand_dataset_config(options["test_set"])

if len(options["validation_set"]) != len(options["training_set"]):
if len(options["test_set"]) != len(options["training_set"]):
raise ValueError(
f"Validation dataset with length {len(options['validation_set'])} has "
"a different size than the training datatset with length "
f"Test dataset with length {len(options['test_set'])} has a different "
f"size than the training datatset with length "
f"{len(options['training_set'])}."
)

check_units(
actual_options=options["validation_set"],
actual_options=options["test_set"],
desired_options=options["training_set"],
)

for valid_options in options["validation_set"]:
dataset, _ = get_dataset(valid_options)
val_datasets.append(dataset)
for test_options in options["test_set"]:
dataset, _ = get_dataset(test_options)
test_datasets.append(dataset)

###########################
# CREATE DATASET_INFO #####
Expand Down
62 changes: 24 additions & 38 deletions src/metatrain/utils/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import itertools
import math
import warnings
from collections import UserDict
from typing import Any, Dict, List, Optional, Tuple, Union

import metatensor.learn
import numpy as np
import torch
from metatensor.torch import TensorMap
from torch import Generator, default_generator

from ..external_naming import to_external_name
from ..units import get_gradient_units
Expand Down Expand Up @@ -485,49 +484,36 @@ def _train_test_random_split(
train_dataset: Dataset,
train_size: float,
test_size: float,
generator: Optional[Generator] = default_generator,
) -> List[Dataset]:
if train_size <= 0:
raise ValueError("Fraction of the train set is smaller or equal to 0!")

# normalize fractions
lengths = torch.tensor([train_size, test_size])
lengths /= lengths.sum()

if math.isclose(sum(lengths), 1) and sum(lengths) <= 1:
subset_lengths: List[int] = []
for i, frac in enumerate(lengths):
if frac < 0 or frac > 1:
raise ValueError(f"Fraction at index {i} is not between 0 and 1")
n_items_in_split = int(
math.floor(len(train_dataset) * frac) # type: ignore[arg-type]
)
subset_lengths.append(n_items_in_split)
remainder = len(train_dataset) - sum(subset_lengths) # type: ignore[arg-type]
# add 1 to all the lengths in round-robin fashion until the remainder is 0
for i in range(remainder):
idx_to_add_at = i % len(subset_lengths)
subset_lengths[idx_to_add_at] += 1
lengths = subset_lengths
for i, length in enumerate(lengths):
if length == 0:
warnings.warn(
f"Length of split at index {i} is 0. "
f"This might result in an empty dataset.",
UserWarning,
stacklevel=2,
)

# Cannot verify that train_dataset is Sized
if sum(lengths) != len(train_dataset): # type: ignore[arg-type]
raise ValueError(
"Sum of input lengths does not equal the length of the input dataset!"
# normalize the sizes
size_sum = train_size + test_size
train_size /= size_sum
test_size /= size_sum

# find number of samples in the train and test sets
test_len = math.floor(len(train_dataset) * test_size)
if test_len == 0:
warnings.warn(
"Requested dataset of zero length. This dataset will be empty.",
UserWarning,
stacklevel=2,
)
train_len = len(train_dataset) - test_len
if train_len == 0:
raise ValueError("No samples left in the training set.")

Check warning on line 506 in src/metatrain/utils/data/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/data/dataset.py#L506

Added line #L506 was not covered by tests

# find train, test indices
indices = list(range(len(train_dataset)))
np.random.shuffle(indices)
train_indices = indices[:train_len]
test_indices = indices[train_len:]

indices = torch.randperm(sum(lengths), generator=generator).tolist()
return [
Subset(train_dataset, indices[offset - length : offset])
for offset, length in zip(itertools.accumulate(lengths), lengths)
Subset(train_dataset, train_indices),
Subset(train_dataset, test_indices),
]


Expand Down
21 changes: 20 additions & 1 deletion tests/cli/test_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def test_empty_test_set(caplog, monkeypatch, tmp_path, options):
options["validation_set"] = 0.4
options["test_set"] = 0.0

match = "Length of split at index 1 is 0. This might result in an empty dataset."
match = "Requested dataset of zero length. This dataset will be empty."
with pytest.warns(UserWarning, match=match):
train_model(options)

Expand Down Expand Up @@ -491,3 +491,22 @@ def test_architecture_error(options, monkeypatch, tmp_path):

with pytest.raises(ArchitectureError, match="originates from an architecture"):
train_model(options)


def test_train_issue_290(monkeypatch, tmp_path):
"""Test the potential problem from issue #290."""
monkeypatch.chdir(tmp_path)
shutil.copy(DATASET_PATH_ETHANOL, "ethanol_reduced_100.xyz")

structures = ase.io.read("ethanol_reduced_100.xyz", ":")
more_structures = structures * 15 + [structures[0]]
ase.io.write("ethanol_1501.xyz", more_structures)

# run training with original options
options = OmegaConf.load(OPTIONS_PATH)
options["training_set"]["systems"]["read_from"] = "ethanol_1501.xyz"
options["training_set"]["targets"]["energy"]["key"] = "energy"
options["validation_set"] = 0.01
options["test_set"] = 0.85

train_model(options)

0 comments on commit c7d9aae

Please sign in to comment.