diff --git a/docs/src/getting-started/usage.rst b/docs/src/getting-started/usage.rst index 98fc3da51..afcfcc871 100644 --- a/docs/src/getting-started/usage.rst +++ b/docs/src/getting-started/usage.rst @@ -66,9 +66,20 @@ The sub-command to evaluate an already trained model is metatensor-models eval +Besides the a already trained `model` you will also have to provide a file containing +the structure and possible target values for evaluation. The structure of this +``eval.yaml`` is exactly the same as for a dataset in the ``options.yaml`` file. + +.. literalinclude:: ../../static/qm9/eval.yaml + :language: yaml + +Note that the ``targets`` section is optional. If the ``targets`` section is present the +function will calculate and report RMSE values of the predictions with respect to the +real values as loaded from the ``targets`` section. You can run an evaluation by typing + .. literalinclude:: ../../../examples/basic_usage/usage.sh :language: bash - :lines: 9-25 + :lines: 9-24 Exporting @@ -84,7 +95,7 @@ The sub-command to export a trained model is .. literalinclude:: ../../../examples/basic_usage/usage.sh :language: bash - :lines: 25- + :lines: 26- In the next tutorials we show how adjust the dataset section of ``options.yaml`` file to use it for your own datasets. diff --git a/docs/static/qm9/eval.yaml b/docs/static/qm9/eval.yaml new file mode 100644 index 000000000..ed9fedde3 --- /dev/null +++ b/docs/static/qm9/eval.yaml @@ -0,0 +1,4 @@ +structures: "qm9_reduced_100.xyz" # file where the positions are stored +targets: + energy: + key: "U0" # name of the target value diff --git a/examples/basic_usage/eval.yaml b/examples/basic_usage/eval.yaml new file mode 120000 index 000000000..f5a6d431b --- /dev/null +++ b/examples/basic_usage/eval.yaml @@ -0,0 +1 @@ +../../docs/static/qm9/eval.yaml \ No newline at end of file diff --git a/examples/basic_usage/usage.sh b/examples/basic_usage/usage.sh index 901ccd721..62b539ed0 100755 --- a/examples/basic_usage/usage.sh +++ b/examples/basic_usage/usage.sh @@ -8,12 +8,12 @@ metatensor-models train options.yaml metatensor-models train --help # We now evaluate the model on the training dataset, where the first arguments specifies -# the model and the second the structure file +# trained model and the second an option file containing the dataset for evaulation. -metatensor-models eval model.pt qm9_reduced_100.xyz +metatensor-models eval model.pt eval.yaml -# The evaluation command predicts the property the model was trained against; here "U0". -# The predictions together with the structures have been written in a file named +# The evaluation command predicts those properties the model was trained against; here +# "U0". The predictions together with the structures have been written in a file named # ``output.xyz`` in the current directory. The written file starts with the following # lines diff --git a/src/metatensor/models/cli/eval_model.py b/src/metatensor/models/cli/eval_model.py index fc8b1a476..963f327e5 100644 --- a/src/metatensor/models/cli/eval_model.py +++ b/src/metatensor/models/cli/eval_model.py @@ -3,7 +3,7 @@ from typing import Dict, Tuple, Union import torch -from omegaconf import OmegaConf +from omegaconf import DictConfig, OmegaConf from ..utils.compute_loss import compute_model_loss from ..utils.data import ( @@ -13,19 +13,16 @@ read_targets, write_predictions, ) -from ..utils.data.dataset import _train_test_random_split from ..utils.extract_targets import get_outputs_dict from ..utils.info import finalize_aggregated_info, update_aggregated_info from ..utils.loss import TensorMapDictLoss from ..utils.model_io import load_model -from ..utils.omegaconf import _has_yaml_suffix, check_units, expand_dataset_config +from ..utils.omegaconf import expand_dataset_config from .formatter import CustomHelpFormatter logger = logging.getLogger(__name__) -CHOICES_EVAL_ON = ["train", "validation", "test"] - def _add_eval_model_parser(subparser: argparse._SubParsersAction) -> None: """Add the `eval_model` paramaters to an argparse (sub)-parser""" @@ -41,21 +38,15 @@ def _add_eval_model_parser(subparser: argparse._SubParsersAction) -> None: formatter_class=CustomHelpFormatter, ) parser.set_defaults(callable="eval_model") - parser.add_argument( - "options", - type=_has_yaml_suffix, - help="Options file to define a test dataset taken for the evaluation.", - ) parser.add_argument( "model", - type=str, - help="saved model to be evaluated", + type=load_model, + help="Saved model to be evaluated.", ) parser.add_argument( - "eval_on", - type=str, - choices=CHOICES_EVAL_ON, - help="On which part of the dataset should the model be evaluated.", + "options", + type=OmegaConf.load, + help="Eval options file to define a dataset for evaluation.", ) parser.add_argument( "-o", @@ -139,77 +130,32 @@ def _eval_targets(model, dataset: Union[Dataset, torch.utils.data.Subset]) -> No def eval_model( - options: str, model: str, eval_on: str = "test", output: str = "output.xyz" + model: torch.nn.Module, options: DictConfig, output: str = "output.xyz" ) -> None: - """Evaluate a pretrained model on a set. + """Evaluate a pretrained model on a certain data set. - The test dataset will be selected as defined in the options yaml file. Predicted - values will be written ``output``. + If ``options`` contains a ``targets`` sub-section, RMSE values will be reported. If + this sub-section is missing only a xyz-file with containing the properties the model + was trained against is written. - :param options: Options file path to define a test dataset taken for the evaluation. - :param model: Path to a saved model - :param eval_on: On which part of the dataset should the model be evaluated. Possible - values are 'test', 'train' or 'validation'. + :param model: Saved model to be evaluated. + :param options: DictConfig to define a test dataset taken for the evaluation. :param output: Path to save the predicted values """ + logging.basicConfig(level=logging.INFO, format="%(message)s") + logger.info("Setting up evaluation set.") - conf = OmegaConf.load(options) - - if eval_on not in CHOICES_EVAL_ON: - raise ValueError( - f"{eval_on!r} is not a possible choice for `eval_on`. Choose from: " - f"{','.join(CHOICES_EVAL_ON)}" - ) - - logger.info("Setting up {eval_on} set") - train_options = conf["test_set"] - eval_options = conf["{eval_on}_set"] - - loaded_model = load_model(model) - + options = expand_dataset_config(options) eval_structures = read_structures( - filename=eval_options["structures"]["read_from"], - fileformat=eval_options["structures"]["file_format"], + filename=options["structures"]["read_from"], + fileformat=options["structures"]["file_format"], ) - # Predict targets - if hasattr(eval_options, "targets"): - if isinstance(eval_options, float): - eval_size = eval_options - train_size = 1 - eval_size - - if eval_size < 0 or eval_size >= 1: - raise ValueError(f"{eval_on} set split must be between 0 and 1.") - - train_structures = read_structures( - filename=train_options["structures"]["read_from"], - fileformat=train_options["structures"]["file_format"], - ) - train_targets = read_targets(train_options["targets"]) - train_dataset = Dataset(train_structures, train_targets) - - generator = torch.Generator() - if conf["seed"] is not None: - generator.manual_seed(conf["seed"]) - - _, eval_dataset = _train_test_random_split( - train_dataset=train_dataset, - train_size=train_size, - test_size=eval_size, - generator=generator, - ) - - # Select eval_structures based on fraction - eval_structures = [eval_structures[index] for index in eval_dataset.indices] - - else: - eval_options = expand_dataset_config(eval_options) - eval_targets = read_targets(eval_options["targets"]) - eval_dataset = Dataset(eval_structures, eval_targets) - check_units(actual_options=eval_options, desired_options=train_options) - - _eval_targets(loaded_model, eval_dataset) + if hasattr(options, "targets"): + eval_targets = read_targets(options["targets"]) + eval_dataset = Dataset(eval_structures, eval_targets) + _eval_targets(model, eval_dataset) - # Predict strcutures - predictions = loaded_model(eval_structures, loaded_model.capabilities.outputs) + # Predict structures + predictions = model(eval_structures, model.capabilities.outputs) write_predictions(output, predictions, eval_structures) diff --git a/src/metatensor/models/utils/omegaconf.py b/src/metatensor/models/utils/omegaconf.py index e469c7d6d..c88a4c524 100644 --- a/src/metatensor/models/utils/omegaconf.py +++ b/src/metatensor/models/utils/omegaconf.py @@ -103,65 +103,68 @@ def expand_dataset_config(conf: Union[str, DictConfig]) -> DictConfig: {"structures": read_from, "targets": {"energy": read_from}} ) - if type(conf["structures"]) is str: - conf["structures"] = _resolve_single_str(conf["structures"]) + if hasattr(conf, "structures"): + if type(conf["structures"]) is str: + conf["structures"] = _resolve_single_str(conf["structures"]) - conf["structures"] = OmegaConf.merge(CONF_STRUCTURES, conf["structures"]) + conf["structures"] = OmegaConf.merge(CONF_STRUCTURES, conf["structures"]) - for target_key, target in conf["targets"].items(): - if type(target) is str: - target = _resolve_single_str(target) + if hasattr(conf, "targets"): + for target_key, target in conf["targets"].items(): + if type(target) is str: + target = _resolve_single_str(target) - # Add default gradients "energy" target section - if target_key == "energy": - # For special case of the "energy" we add the section for force and stress - # gradient by default - target = OmegaConf.merge(CONF_ENERGY, target) - else: - target = OmegaConf.merge(CONF_TARGET, target) - - if target["key"] is None: - target["key"] = target_key - - # Update DictConfig to allow for config node interpolation - conf["targets"][target_key] = target - - # merge and interpolate possibly present gradients with default gradient config - for gradient_key, gradient_conf in conf["targets"][target_key].items(): - if gradient_key in KNWON_GRADIENTS: - if gradient_conf is True: - gradient_conf = CONF_GRADIENT.copy() - elif type(gradient_conf) is str: - gradient_conf = _resolve_single_str(gradient_conf) - - if isinstance(gradient_conf, DictConfig): - gradient_conf = OmegaConf.merge(CONF_GRADIENT, gradient_conf) - - if gradient_conf["key"] is None: - gradient_conf["key"] = gradient_key - - conf["targets"][target_key][gradient_key] = gradient_conf - - # If user sets the virial gradient and leaves the stress section untouched, - # we disable the by default enabled stress gradient section. - base_stress_gradient_conf = CONF_GRADIENT.copy() - base_stress_gradient_conf["key"] = "stress" - - if ( - target_key == "energy" - and conf["targets"][target_key]["virial"] - and conf["targets"][target_key]["stress"] == base_stress_gradient_conf - ): - conf["targets"][target_key]["stress"] = False - - if ( - conf["targets"][target_key]["stress"] - and conf["targets"][target_key]["virial"] - ): - raise ValueError( - f"Cannot perform training with respect to virials and stress as in " - f"section {target_key}. Set either `virials: off` or `stress: off`." - ) + # Add default gradients "energy" target section + if target_key == "energy": + # For special case of the "energy" we add the section for force and + # stress gradient by default + target = OmegaConf.merge(CONF_ENERGY, target) + else: + target = OmegaConf.merge(CONF_TARGET, target) + + if target["key"] is None: + target["key"] = target_key + + # Update DictConfig to allow for config node interpolation + conf["targets"][target_key] = target + + # merge and interpolate possibly present gradients with default gradient + # config + for gradient_key, gradient_conf in conf["targets"][target_key].items(): + if gradient_key in KNWON_GRADIENTS: + if gradient_conf is True: + gradient_conf = CONF_GRADIENT.copy() + elif type(gradient_conf) is str: + gradient_conf = _resolve_single_str(gradient_conf) + + if isinstance(gradient_conf, DictConfig): + gradient_conf = OmegaConf.merge(CONF_GRADIENT, gradient_conf) + + if gradient_conf["key"] is None: + gradient_conf["key"] = gradient_key + + conf["targets"][target_key][gradient_key] = gradient_conf + + # If user sets the virial gradient and leaves the stress section untouched, + # we disable the by default enabled stress gradient section. + base_stress_gradient_conf = CONF_GRADIENT.copy() + base_stress_gradient_conf["key"] = "stress" + + if ( + target_key == "energy" + and conf["targets"][target_key]["virial"] + and conf["targets"][target_key]["stress"] == base_stress_gradient_conf + ): + conf["targets"][target_key]["stress"] = False + + if ( + conf["targets"][target_key]["stress"] + and conf["targets"][target_key]["virial"] + ): + raise ValueError( + f"Cannot perform training with respect to virials and stress as in " + f"section {target_key}. Set either `virials: off` or `stress: off`." + ) return conf diff --git a/tests/cli/test_eval_model.py b/tests/cli/test_eval_model.py index 183ab24f3..e28eb61a3 100644 --- a/tests/cli/test_eval_model.py +++ b/tests/cli/test_eval_model.py @@ -1,29 +1,79 @@ +import logging import shutil import subprocess from pathlib import Path import ase.io import pytest +from omegaconf import OmegaConf + +from metatensor.models.cli import eval_model +from metatensor.models.utils.model_io import load_model RESOURCES_PATH = Path(__file__).parent.resolve() / ".." / "resources" +MODEL_PATH = RESOURCES_PATH / "bpnn-model.pt" +OPTIONS_PATH = RESOURCES_PATH / "eval.yaml" + + +@pytest.fixture +def model(): + return load_model(MODEL_PATH) -@pytest.mark.parametrize("output", [None, "foo.xyz"]) -def test_eval(output, monkeypatch, tmp_path): - """Test that training via the training cli runs without an error raise.""" +@pytest.fixture +def options(): + return OmegaConf.load(OPTIONS_PATH) + + +def test_eval_cli(monkeypatch, tmp_path): + """Test succesful run of the eval script via the CLI with default arguments""" monkeypatch.chdir(tmp_path) shutil.copy(RESOURCES_PATH / "qm9_reduced_100.xyz", "qm9_reduced_100.xyz") - shutil.copy(RESOURCES_PATH / "bpnn-model.pt", "bpnn-model.pt") - - command = ["metatensor-models", "eval", "bpnn-model.pt", "qm9_reduced_100.xyz"] - if output is not None: - command += ["-o", output] - else: - output = "output.xyz" + command = [ + "metatensor-models", + "eval", + str(MODEL_PATH), + str(OPTIONS_PATH), + ] subprocess.check_call(command) - frames = ase.io.read(output, ":") + assert Path("output.xyz").is_file() + + +def test_eval(monkeypatch, tmp_path, caplog, model, options): + """Test that eval via python API runs without an error raise.""" + monkeypatch.chdir(tmp_path) + caplog.set_level(logging.INFO) + + shutil.copy(RESOURCES_PATH / "qm9_reduced_100.xyz", "qm9_reduced_100.xyz") + + eval_model( + model=model, + options=options, + output="foo.xyz", + ) + + # Test target predictions + assert "energy RMSE" in "".join([rec.message for rec in caplog.records]) + + # Test file is written predictions + frames = ase.io.read("foo.xyz", ":") frames[0].info["energy"] + + +def test_eval_no_targets(monkeypatch, tmp_path, model, options): + monkeypatch.chdir(tmp_path) + + shutil.copy(RESOURCES_PATH / "qm9_reduced_100.xyz", "qm9_reduced_100.xyz") + + options.pop("targets") + + eval_model( + model=model, + options=options, + ) + + assert Path("output.xyz").is_file() diff --git a/tests/resources/eval.yaml b/tests/resources/eval.yaml new file mode 120000 index 000000000..f5a6d431b --- /dev/null +++ b/tests/resources/eval.yaml @@ -0,0 +1 @@ +../../docs/static/qm9/eval.yaml \ No newline at end of file diff --git a/tests/utils/test_omegaconf.py b/tests/utils/test_omegaconf.py index 63253988e..15bffd3f4 100644 --- a/tests/utils/test_omegaconf.py +++ b/tests/utils/test_omegaconf.py @@ -249,3 +249,19 @@ def test_check_units(): ), ): check_units(actual_options=test_options3, desired_options=train_options) + + +def test_missing_targets_section(): + conf = {"structures": "foo.xyz"} + conf_expanded = expand_dataset_config(OmegaConf.create(conf)) + + assert conf_expanded["structures"]["read_from"] == "foo.xyz" + assert conf_expanded["structures"]["file_format"] == ".xyz" + + +def test_missing_strcutures_section(): + conf = {"targets": {"energies": "foo.xyz"}} + conf_expanded = expand_dataset_config(OmegaConf.create(conf)) + + assert conf_expanded["targets"]["energies"]["read_from"] == "foo.xyz" + assert conf_expanded["targets"]["energies"]["file_format"] == ".xyz"