diff --git a/src/metatensor/models/cli/eval.py b/src/metatensor/models/cli/eval.py index 90f43dc81..f9a46ea13 100644 --- a/src/metatensor/models/cli/eval.py +++ b/src/metatensor/models/cli/eval.py @@ -67,6 +67,9 @@ def _add_eval_model_parser(subparser: argparse._SubParsersAction) -> None: def _eval_targets(model, dataset: Union[Dataset, torch.utils.data.Subset]) -> None: """Evaluate an exported model on a dataset and print the RMSEs for each target.""" + if len(dataset) == 0: + logger.info("This dataset is empty. No evaluation will be performed.") + return # Attach neighbor lists to the systems: requested_neighbor_lists = model.requested_neighbors_lists() # working around https://github.com/lab-cosmo/metatensor/issues/521 diff --git a/src/metatensor/models/cli/train.py b/src/metatensor/models/cli/train.py index c38a62d4b..1b2875fd5 100644 --- a/src/metatensor/models/cli/train.py +++ b/src/metatensor/models/cli/train.py @@ -255,7 +255,10 @@ def _train_model_hydra(options: DictConfig) -> None: train_size -= test_size if test_size < 0 or test_size >= 1: - raise ValueError("Test set split must be between 0 and 1.") + raise ValueError( + "Test set split must be greater " + "than (or equal to) 0 and lesser than 1." + ) generator = torch.Generator() if options["seed"] is not None: @@ -304,8 +307,10 @@ def _train_model_hydra(options: DictConfig) -> None: validation_size = validation_options train_size -= validation_size - if validation_size < 0 or validation_size >= 1: - raise ValueError("Validation set split must be between 0 and 1.") + if validation_size <= 0 or validation_size >= 1: + raise ValueError( + "Validation set split must be greater " "than 0 and lesser than 1." + ) generator = torch.Generator() if options["seed"] is not None: diff --git a/tests/cli/test_train_model.py b/tests/cli/test_train_model.py index f23b8c31e..342a6d4b7 100644 --- a/tests/cli/test_train_model.py +++ b/tests/cli/test_train_model.py @@ -135,6 +135,51 @@ def test_train_multiple_datasets(monkeypatch, tmp_path, options): train_model(options) +def test_empty_training_set(monkeypatch, tmp_path, options): + """Test that an error is raised if no training set is provided.""" + monkeypatch.chdir(tmp_path) + + shutil.copy(DATASET_PATH, "qm9_reduced_100.xyz") + + options["validation_set"] = 0.6 + options["test_set"] = 0.4 + + with pytest.raises( + ValueError, match="Fraction of the train set is smaller or equal to 0!" + ): + train_model(options) + + +def test_empty_validation_set(monkeypatch, tmp_path, options): + """Test that an error is raised if no validation set is provided.""" + monkeypatch.chdir(tmp_path) + + shutil.copy(DATASET_PATH, "qm9_reduced_100.xyz") + + options["validation_set"] = 0.0 + options["test_set"] = 0.4 + + with pytest.raises(ValueError, match="must be greater than 0"): + train_model(options) + + +def test_empty_test_set(monkeypatch, tmp_path, options): + """Test that no error is raised if no test set is provided.""" + monkeypatch.chdir(tmp_path) + + shutil.copy(DATASET_PATH, "qm9_reduced_100.xyz") + + options["validation_set"] = 0.4 + options["test_set"] = 0.0 + + train_model(options) + + # check if the logging is correct + with open(glob.glob("outputs/*/*/train.log")[0]) as f: + log = f.read() + assert "This dataset is empty. No evaluation" in log + + @pytest.mark.parametrize( "test_set_file, validation_set_file", [(True, False), (False, True)] )