Skip to content

Commit

Permalink
Account for empty validation and test sets
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Mar 15, 2024
1 parent 555c01c commit 6ce6233
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/metatensor/models/cli/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ def _add_eval_model_parser(subparser: argparse._SubParsersAction) -> None:

def _eval_targets(model, dataset: Union[_BaseDataset, torch.utils.data.Subset]) -> None:
"""Evaluate an exported model on a dataset and print the RMSEs for each target."""
if len(dataset) == 0:
logger.info("This dataset is empty")
return
# Attach neighbor lists to the systems:
requested_neighbor_lists = model.requested_neighbors_lists()
# working around https://github.com/lab-cosmo/metatensor/issues/521
Expand Down
10 changes: 10 additions & 0 deletions src/metatensor/models/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,16 @@ def _train_model_hydra(options: DictConfig) -> None:
)
validation_datasets.append(validation_dataset)

if (
sum([len(validation_dataset) for validation_dataset in validation_datasets])
== 0
):
raise ValueError(
"The validation set is empty. Please provide a validation set, "
"either by setting a fraction of the training set or by providing it "
"explicitly."
)

# Save fully expanded config
OmegaConf.save(config=options, f=Path(output_dir) / "options.yaml")

Expand Down
45 changes: 45 additions & 0 deletions tests/cli/test_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,51 @@ def test_train_multiple_datasets(monkeypatch, tmp_path, options):
train_model(options)


def test_empty_training_set(monkeypatch, tmp_path, options):
"""Test that an error is raised if no training set is provided."""
monkeypatch.chdir(tmp_path)

shutil.copy(DATASET_PATH, "qm9_reduced_100.xyz")

options["validation_set"] = 0.6
options["test_set"] = 0.4

with pytest.raises(
ValueError, match="Fraction of the train set is smaller or equal to 0!"
):
train_model(options)


def test_empty_validation_set(monkeypatch, tmp_path, options):
"""Test that an error is raised if no validation set is provided."""
monkeypatch.chdir(tmp_path)

shutil.copy(DATASET_PATH, "qm9_reduced_100.xyz")

options["validation_set"] = 0.0
options["test_set"] = 0.4

with pytest.raises(ValueError, match="The validation set is empty. Please provide"):
train_model(options)


def test_empty_test_set(monkeypatch, tmp_path, options):
"""Test that no error is raised if no test set is provided."""
monkeypatch.chdir(tmp_path)

shutil.copy(DATASET_PATH, "qm9_reduced_100.xyz")

options["validation_set"] = 0.4
options["test_set"] = 0.0

train_model(options)

# check if the logging is correct
with open(glob.glob("outputs/*/*/train.log")[0]) as f:
log = f.read()
assert "This dataset is empty" in log


@pytest.mark.parametrize(
"test_set_file, validation_set_file", [(True, False), (False, True)]
)
Expand Down

0 comments on commit 6ce6233

Please sign in to comment.