From c7d9aae1ce200209a70e0cf225ed31e6bbad079d Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Wed, 17 Jul 2024 21:35:19 +0200 Subject: [PATCH] Fix edge case in dataset split --- src/metatrain/cli/train.py | 92 +++++++++++++---------------- src/metatrain/utils/data/dataset.py | 62 ++++++++----------- tests/cli/test_train_model.py | 21 ++++++- 3 files changed, 85 insertions(+), 90 deletions(-) diff --git a/src/metatrain/cli/train.py b/src/metatrain/cli/train.py index d5ba05e6b..11026e265 100644 --- a/src/metatrain/cli/train.py +++ b/src/metatrain/cli/train.py @@ -176,9 +176,9 @@ def train_model( torch.cuda.manual_seed(options["seed"]) torch.cuda.manual_seed_all(options["seed"]) - ########################### - # SETUP TRAINING SET ###### - ########################### + ############################ + # SET UP TRAINING SET ###### + ############################ logger.info("Setting up training set") options["training_set"] = expand_dataset_config(options["training_set"]) @@ -192,91 +192,81 @@ def train_model( train_size = 1.0 - ########################### - # SETUP TEST SET ########## - ########################### - - logger.info("Setting up test set") - test_datasets = [] - if isinstance(options["test_set"], float): - test_size = options["test_set"] - train_size -= test_size + ############################ + # SET UP VALIDATION SET #### + ############################ - generator = torch.Generator() - if options["seed"] is not None: - generator.manual_seed(options["seed"]) + logger.info("Setting up validation set") + val_datasets = [] + if isinstance(options["validation_set"], float): + val_size = options["validation_set"] + train_size -= val_size for i_dataset, train_dataset in enumerate(train_datasets): - train_dataset_new, test_dataset = _train_test_random_split( + train_dataset_new, val_dataset = _train_test_random_split( train_dataset=train_dataset, train_size=train_size, - test_size=test_size, - generator=generator, + test_size=val_size, ) train_datasets[i_dataset] = train_dataset_new - test_datasets.append(test_dataset) + val_datasets.append(val_dataset) else: - options["test_set"] = expand_dataset_config(options["test_set"]) + options["validation_set"] = expand_dataset_config(options["validation_set"]) - if len(options["test_set"]) != len(options["training_set"]): + if len(options["validation_set"]) != len(options["training_set"]): raise ValueError( - f"Test dataset with length {len(options['test_set'])} has a different " - f"size than the training datatset with length " + f"Validation dataset with length {len(options['validation_set'])} has " + "a different size than the training datatset with length " f"{len(options['training_set'])}." ) check_units( - actual_options=options["test_set"], + actual_options=options["validation_set"], desired_options=options["training_set"], ) - for test_options in options["test_set"]: - dataset, _ = get_dataset(test_options) - test_datasets.append(dataset) - - ########################### - # SETUP VALIDATION SET #### - ########################### + for valid_options in options["validation_set"]: + dataset, _ = get_dataset(valid_options) + val_datasets.append(dataset) - logger.info("Setting up validation set") - val_datasets = [] - if isinstance(options["validation_set"], float): - val_size = options["validation_set"] - train_size -= val_size + ############################ + # SET UP TEST SET ########## + ############################ - generator = torch.Generator() - if options["seed"] is not None: - generator.manual_seed(options["seed"]) + logger.info("Setting up test set") + test_datasets = [] + if isinstance(options["test_set"], float): + test_size = options["test_set"] + train_size -= test_size for i_dataset, train_dataset in enumerate(train_datasets): - train_dataset_new, val_dataset = _train_test_random_split( + train_dataset_new, test_dataset = _train_test_random_split( train_dataset=train_dataset, train_size=train_size, - test_size=val_size, - generator=generator, + test_size=test_size, ) train_datasets[i_dataset] = train_dataset_new - val_datasets.append(val_dataset) + test_datasets.append(test_dataset) else: - options["validation_set"] = expand_dataset_config(options["validation_set"]) + options["test_set"] = expand_dataset_config(options["test_set"]) - if len(options["validation_set"]) != len(options["training_set"]): + if len(options["test_set"]) != len(options["training_set"]): raise ValueError( - f"Validation dataset with length {len(options['validation_set'])} has " - "a different size than the training datatset with length " + f"Test dataset with length {len(options['test_set'])} has a different " + f"size than the training datatset with length " f"{len(options['training_set'])}." ) check_units( - actual_options=options["validation_set"], + actual_options=options["test_set"], desired_options=options["training_set"], ) - for valid_options in options["validation_set"]: - dataset, _ = get_dataset(valid_options) - val_datasets.append(dataset) + for test_options in options["test_set"]: + dataset, _ = get_dataset(test_options) + test_datasets.append(dataset) ########################### # CREATE DATASET_INFO ##### diff --git a/src/metatrain/utils/data/dataset.py b/src/metatrain/utils/data/dataset.py index 189bbe42d..f5744a1a5 100644 --- a/src/metatrain/utils/data/dataset.py +++ b/src/metatrain/utils/data/dataset.py @@ -1,13 +1,12 @@ -import itertools import math import warnings from collections import UserDict from typing import Any, Dict, List, Optional, Tuple, Union import metatensor.learn +import numpy as np import torch from metatensor.torch import TensorMap -from torch import Generator, default_generator from ..external_naming import to_external_name from ..units import get_gradient_units @@ -485,49 +484,36 @@ def _train_test_random_split( train_dataset: Dataset, train_size: float, test_size: float, - generator: Optional[Generator] = default_generator, ) -> List[Dataset]: if train_size <= 0: raise ValueError("Fraction of the train set is smaller or equal to 0!") - # normalize fractions - lengths = torch.tensor([train_size, test_size]) - lengths /= lengths.sum() - - if math.isclose(sum(lengths), 1) and sum(lengths) <= 1: - subset_lengths: List[int] = [] - for i, frac in enumerate(lengths): - if frac < 0 or frac > 1: - raise ValueError(f"Fraction at index {i} is not between 0 and 1") - n_items_in_split = int( - math.floor(len(train_dataset) * frac) # type: ignore[arg-type] - ) - subset_lengths.append(n_items_in_split) - remainder = len(train_dataset) - sum(subset_lengths) # type: ignore[arg-type] - # add 1 to all the lengths in round-robin fashion until the remainder is 0 - for i in range(remainder): - idx_to_add_at = i % len(subset_lengths) - subset_lengths[idx_to_add_at] += 1 - lengths = subset_lengths - for i, length in enumerate(lengths): - if length == 0: - warnings.warn( - f"Length of split at index {i} is 0. " - f"This might result in an empty dataset.", - UserWarning, - stacklevel=2, - ) - - # Cannot verify that train_dataset is Sized - if sum(lengths) != len(train_dataset): # type: ignore[arg-type] - raise ValueError( - "Sum of input lengths does not equal the length of the input dataset!" + # normalize the sizes + size_sum = train_size + test_size + train_size /= size_sum + test_size /= size_sum + + # find number of samples in the train and test sets + test_len = math.floor(len(train_dataset) * test_size) + if test_len == 0: + warnings.warn( + "Requested dataset of zero length. This dataset will be empty.", + UserWarning, + stacklevel=2, ) + train_len = len(train_dataset) - test_len + if train_len == 0: + raise ValueError("No samples left in the training set.") + + # find train, test indices + indices = list(range(len(train_dataset))) + np.random.shuffle(indices) + train_indices = indices[:train_len] + test_indices = indices[train_len:] - indices = torch.randperm(sum(lengths), generator=generator).tolist() return [ - Subset(train_dataset, indices[offset - length : offset]) - for offset, length in zip(itertools.accumulate(lengths), lengths) + Subset(train_dataset, train_indices), + Subset(train_dataset, test_indices), ] diff --git a/tests/cli/test_train_model.py b/tests/cli/test_train_model.py index 82124ed11..294a49936 100644 --- a/tests/cli/test_train_model.py +++ b/tests/cli/test_train_model.py @@ -304,7 +304,7 @@ def test_empty_test_set(caplog, monkeypatch, tmp_path, options): options["validation_set"] = 0.4 options["test_set"] = 0.0 - match = "Length of split at index 1 is 0. This might result in an empty dataset." + match = "Requested dataset of zero length. This dataset will be empty." with pytest.warns(UserWarning, match=match): train_model(options) @@ -491,3 +491,22 @@ def test_architecture_error(options, monkeypatch, tmp_path): with pytest.raises(ArchitectureError, match="originates from an architecture"): train_model(options) + + +def test_train_issue_290(monkeypatch, tmp_path): + """Test the potential problem from issue #290.""" + monkeypatch.chdir(tmp_path) + shutil.copy(DATASET_PATH_ETHANOL, "ethanol_reduced_100.xyz") + + structures = ase.io.read("ethanol_reduced_100.xyz", ":") + more_structures = structures * 15 + [structures[0]] + ase.io.write("ethanol_1501.xyz", more_structures) + + # run training with original options + options = OmegaConf.load(OPTIONS_PATH) + options["training_set"]["systems"]["read_from"] = "ethanol_1501.xyz" + options["training_set"]["targets"]["energy"]["key"] = "energy" + options["validation_set"] = 0.01 + options["test_set"] = 0.85 + + train_model(options)