diff --git a/src/metatensor/models/cli/eval.py b/src/metatensor/models/cli/eval.py index 3e02bb027..db43b1ba1 100644 --- a/src/metatensor/models/cli/eval.py +++ b/src/metatensor/models/cli/eval.py @@ -67,6 +67,9 @@ def _add_eval_model_parser(subparser: argparse._SubParsersAction) -> None: def _eval_targets(model, dataset: Union[_BaseDataset, torch.utils.data.Subset]) -> None: """Evaluate an exported model on a dataset and print the RMSEs for each target.""" + if len(dataset) == 0: + logger.info("This dataset is empty") + return # Attach neighbor lists to the systems: requested_neighbor_lists = model.requested_neighbors_lists() # working around https://github.com/lab-cosmo/metatensor/issues/521 diff --git a/src/metatensor/models/cli/train.py b/src/metatensor/models/cli/train.py index c5c87bb9a..72b82a857 100644 --- a/src/metatensor/models/cli/train.py +++ b/src/metatensor/models/cli/train.py @@ -351,6 +351,16 @@ def _train_model_hydra(options: DictConfig) -> None: ) validation_datasets.append(validation_dataset) + if ( + sum([len(validation_dataset) for validation_dataset in validation_datasets]) + == 0 + ): + raise ValueError( + "The validation set is empty. Please provide a validation set, " + "either by setting a fraction of the training set or by providing it " + "explicitly." + ) + # Save fully expanded config OmegaConf.save(config=options, f=Path(output_dir) / "options.yaml") diff --git a/tests/cli/test_train_model.py b/tests/cli/test_train_model.py index f23b8c31e..f7c4ee2d8 100644 --- a/tests/cli/test_train_model.py +++ b/tests/cli/test_train_model.py @@ -135,6 +135,51 @@ def test_train_multiple_datasets(monkeypatch, tmp_path, options): train_model(options) +def test_empty_training_set(monkeypatch, tmp_path, options): + """Test that an error is raised if no training set is provided.""" + monkeypatch.chdir(tmp_path) + + shutil.copy(DATASET_PATH, "qm9_reduced_100.xyz") + + options["validation_set"] = 0.6 + options["test_set"] = 0.4 + + with pytest.raises( + ValueError, match="Fraction of the train set is smaller or equal to 0!" + ): + train_model(options) + + +def test_empty_validation_set(monkeypatch, tmp_path, options): + """Test that an error is raised if no validation set is provided.""" + monkeypatch.chdir(tmp_path) + + shutil.copy(DATASET_PATH, "qm9_reduced_100.xyz") + + options["validation_set"] = 0.0 + options["test_set"] = 0.4 + + with pytest.raises(ValueError, match="The validation set is empty. Please provide"): + train_model(options) + + +def test_empty_test_set(monkeypatch, tmp_path, options): + """Test that no error is raised if no test set is provided.""" + monkeypatch.chdir(tmp_path) + + shutil.copy(DATASET_PATH, "qm9_reduced_100.xyz") + + options["validation_set"] = 0.4 + options["test_set"] = 0.0 + + train_model(options) + + # check if the logging is correct + with open(glob.glob("outputs/*/*/train.log")[0]) as f: + log = f.read() + assert "This dataset is empty" in log + + @pytest.mark.parametrize( "test_set_file, validation_set_file", [(True, False), (False, True)] )