Skip to content

Commit

Permalink
finish tests and docs for eval
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri committed Feb 8, 2024
1 parent f7627aa commit 84d8d2a
Show file tree
Hide file tree
Showing 9 changed files with 184 additions and 152 deletions.
15 changes: 13 additions & 2 deletions docs/src/getting-started/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,20 @@ The sub-command to evaluate an already trained model is
metatensor-models eval
Besides the a already trained `model` you will also have to provide a file containing
the structure and possible target values for evaluation. The structure of this
``eval.yaml`` is exactly the same as for a dataset in the ``options.yaml`` file.

.. literalinclude:: ../../static/qm9/eval.yaml
:language: yaml

Note that the ``targets`` section is optional. If the ``targets`` section is present the
function will calculate and report RMSE values of the predictions with respect to the
real values as loaded from the ``targets`` section. You can run an evaluation by typing

.. literalinclude:: ../../../examples/basic_usage/usage.sh
:language: bash
:lines: 9-25
:lines: 9-24


Exporting
Expand All @@ -84,7 +95,7 @@ The sub-command to export a trained model is
.. literalinclude:: ../../../examples/basic_usage/usage.sh
:language: bash
:lines: 25-
:lines: 26-

In the next tutorials we show how adjust the dataset section of ``options.yaml`` file
to use it for your own datasets.
4 changes: 4 additions & 0 deletions docs/static/qm9/eval.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
structures: "qm9_reduced_100.xyz" # file where the positions are stored
targets:
energy:
key: "U0" # name of the target value
1 change: 1 addition & 0 deletions examples/basic_usage/eval.yaml
8 changes: 4 additions & 4 deletions examples/basic_usage/usage.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ metatensor-models train options.yaml
metatensor-models train --help

# We now evaluate the model on the training dataset, where the first arguments specifies
# the model and the second the structure file
# trained model and the second an option file containing the dataset for evaulation.

metatensor-models eval model.pt qm9_reduced_100.xyz
metatensor-models eval model.pt eval.yaml

# The evaluation command predicts the property the model was trained against; here "U0".
# The predictions together with the structures have been written in a file named
# The evaluation command predicts those properties the model was trained against; here
# "U0". The predictions together with the structures have been written in a file named
# ``output.xyz`` in the current directory. The written file starts with the following
# lines

Expand Down
104 changes: 25 additions & 79 deletions src/metatensor/models/cli/eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Dict, Tuple, Union

import torch
from omegaconf import OmegaConf
from omegaconf import DictConfig, OmegaConf

from ..utils.compute_loss import compute_model_loss
from ..utils.data import (
Expand All @@ -13,19 +13,16 @@
read_targets,
write_predictions,
)
from ..utils.data.dataset import _train_test_random_split
from ..utils.extract_targets import get_outputs_dict
from ..utils.info import finalize_aggregated_info, update_aggregated_info
from ..utils.loss import TensorMapDictLoss
from ..utils.model_io import load_model
from ..utils.omegaconf import _has_yaml_suffix, check_units, expand_dataset_config
from ..utils.omegaconf import expand_dataset_config
from .formatter import CustomHelpFormatter


logger = logging.getLogger(__name__)

CHOICES_EVAL_ON = ["train", "validation", "test"]


def _add_eval_model_parser(subparser: argparse._SubParsersAction) -> None:
"""Add the `eval_model` paramaters to an argparse (sub)-parser"""
Expand All @@ -41,21 +38,15 @@ def _add_eval_model_parser(subparser: argparse._SubParsersAction) -> None:
formatter_class=CustomHelpFormatter,
)
parser.set_defaults(callable="eval_model")
parser.add_argument(
"options",
type=_has_yaml_suffix,
help="Options file to define a test dataset taken for the evaluation.",
)
parser.add_argument(
"model",
type=str,
help="saved model to be evaluated",
type=load_model,
help="Saved model to be evaluated.",
)
parser.add_argument(
"eval_on",
type=str,
choices=CHOICES_EVAL_ON,
help="On which part of the dataset should the model be evaluated.",
"options",
type=OmegaConf.load,
help="Eval options file to define a dataset for evaluation.",
)
parser.add_argument(
"-o",
Expand Down Expand Up @@ -139,77 +130,32 @@ def _eval_targets(model, dataset: Union[Dataset, torch.utils.data.Subset]) -> No


def eval_model(
options: str, model: str, eval_on: str = "test", output: str = "output.xyz"
model: torch.nn.Module, options: DictConfig, output: str = "output.xyz"
) -> None:
"""Evaluate a pretrained model on a set.
"""Evaluate a pretrained model on a certain data set.
The test dataset will be selected as defined in the options yaml file. Predicted
values will be written ``output``.
If ``options`` contains a ``targets`` sub-section, RMSE values will be reported. If
this sub-section is missing only a xyz-file with containing the properties the model
was trained against is written.
:param options: Options file path to define a test dataset taken for the evaluation.
:param model: Path to a saved model
:param eval_on: On which part of the dataset should the model be evaluated. Possible
values are 'test', 'train' or 'validation'.
:param model: Saved model to be evaluated.
:param options: DictConfig to define a test dataset taken for the evaluation.
:param output: Path to save the predicted values
"""
logging.basicConfig(level=logging.INFO, format="%(message)s")
logger.info("Setting up evaluation set.")

conf = OmegaConf.load(options)

if eval_on not in CHOICES_EVAL_ON:
raise ValueError(
f"{eval_on!r} is not a possible choice for `eval_on`. Choose from: "
f"{','.join(CHOICES_EVAL_ON)}"
)

logger.info("Setting up {eval_on} set")
train_options = conf["test_set"]
eval_options = conf["{eval_on}_set"]

loaded_model = load_model(model)

options = expand_dataset_config(options)
eval_structures = read_structures(
filename=eval_options["structures"]["read_from"],
fileformat=eval_options["structures"]["file_format"],
filename=options["structures"]["read_from"],
fileformat=options["structures"]["file_format"],
)

# Predict targets
if hasattr(eval_options, "targets"):
if isinstance(eval_options, float):
eval_size = eval_options
train_size = 1 - eval_size

if eval_size < 0 or eval_size >= 1:
raise ValueError(f"{eval_on} set split must be between 0 and 1.")

train_structures = read_structures(
filename=train_options["structures"]["read_from"],
fileformat=train_options["structures"]["file_format"],
)
train_targets = read_targets(train_options["targets"])
train_dataset = Dataset(train_structures, train_targets)

generator = torch.Generator()
if conf["seed"] is not None:
generator.manual_seed(conf["seed"])

_, eval_dataset = _train_test_random_split(
train_dataset=train_dataset,
train_size=train_size,
test_size=eval_size,
generator=generator,
)

# Select eval_structures based on fraction
eval_structures = [eval_structures[index] for index in eval_dataset.indices]

else:
eval_options = expand_dataset_config(eval_options)
eval_targets = read_targets(eval_options["targets"])
eval_dataset = Dataset(eval_structures, eval_targets)
check_units(actual_options=eval_options, desired_options=train_options)

_eval_targets(loaded_model, eval_dataset)
if hasattr(options, "targets"):
eval_targets = read_targets(options["targets"])
eval_dataset = Dataset(eval_structures, eval_targets)
_eval_targets(model, eval_dataset)

# Predict strcutures
predictions = loaded_model(eval_structures, loaded_model.capabilities.outputs)
# Predict structures
predictions = model(eval_structures, model.capabilities.outputs)
write_predictions(output, predictions, eval_structures)
115 changes: 59 additions & 56 deletions src/metatensor/models/utils/omegaconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,65 +103,68 @@ def expand_dataset_config(conf: Union[str, DictConfig]) -> DictConfig:
{"structures": read_from, "targets": {"energy": read_from}}
)

if type(conf["structures"]) is str:
conf["structures"] = _resolve_single_str(conf["structures"])
if hasattr(conf, "structures"):
if type(conf["structures"]) is str:
conf["structures"] = _resolve_single_str(conf["structures"])

conf["structures"] = OmegaConf.merge(CONF_STRUCTURES, conf["structures"])
conf["structures"] = OmegaConf.merge(CONF_STRUCTURES, conf["structures"])

for target_key, target in conf["targets"].items():
if type(target) is str:
target = _resolve_single_str(target)
if hasattr(conf, "targets"):
for target_key, target in conf["targets"].items():
if type(target) is str:
target = _resolve_single_str(target)

# Add default gradients "energy" target section
if target_key == "energy":
# For special case of the "energy" we add the section for force and stress
# gradient by default
target = OmegaConf.merge(CONF_ENERGY, target)
else:
target = OmegaConf.merge(CONF_TARGET, target)

if target["key"] is None:
target["key"] = target_key

# Update DictConfig to allow for config node interpolation
conf["targets"][target_key] = target

# merge and interpolate possibly present gradients with default gradient config
for gradient_key, gradient_conf in conf["targets"][target_key].items():
if gradient_key in KNWON_GRADIENTS:
if gradient_conf is True:
gradient_conf = CONF_GRADIENT.copy()
elif type(gradient_conf) is str:
gradient_conf = _resolve_single_str(gradient_conf)

if isinstance(gradient_conf, DictConfig):
gradient_conf = OmegaConf.merge(CONF_GRADIENT, gradient_conf)

if gradient_conf["key"] is None:
gradient_conf["key"] = gradient_key

conf["targets"][target_key][gradient_key] = gradient_conf

# If user sets the virial gradient and leaves the stress section untouched,
# we disable the by default enabled stress gradient section.
base_stress_gradient_conf = CONF_GRADIENT.copy()
base_stress_gradient_conf["key"] = "stress"

if (
target_key == "energy"
and conf["targets"][target_key]["virial"]
and conf["targets"][target_key]["stress"] == base_stress_gradient_conf
):
conf["targets"][target_key]["stress"] = False

if (
conf["targets"][target_key]["stress"]
and conf["targets"][target_key]["virial"]
):
raise ValueError(
f"Cannot perform training with respect to virials and stress as in "
f"section {target_key}. Set either `virials: off` or `stress: off`."
)
# Add default gradients "energy" target section
if target_key == "energy":
# For special case of the "energy" we add the section for force and
# stress gradient by default
target = OmegaConf.merge(CONF_ENERGY, target)
else:
target = OmegaConf.merge(CONF_TARGET, target)

if target["key"] is None:
target["key"] = target_key

# Update DictConfig to allow for config node interpolation
conf["targets"][target_key] = target

# merge and interpolate possibly present gradients with default gradient
# config
for gradient_key, gradient_conf in conf["targets"][target_key].items():
if gradient_key in KNWON_GRADIENTS:
if gradient_conf is True:
gradient_conf = CONF_GRADIENT.copy()
elif type(gradient_conf) is str:
gradient_conf = _resolve_single_str(gradient_conf)

if isinstance(gradient_conf, DictConfig):
gradient_conf = OmegaConf.merge(CONF_GRADIENT, gradient_conf)

if gradient_conf["key"] is None:
gradient_conf["key"] = gradient_key

conf["targets"][target_key][gradient_key] = gradient_conf

# If user sets the virial gradient and leaves the stress section untouched,
# we disable the by default enabled stress gradient section.
base_stress_gradient_conf = CONF_GRADIENT.copy()
base_stress_gradient_conf["key"] = "stress"

if (
target_key == "energy"
and conf["targets"][target_key]["virial"]
and conf["targets"][target_key]["stress"] == base_stress_gradient_conf
):
conf["targets"][target_key]["stress"] = False

if (
conf["targets"][target_key]["stress"]
and conf["targets"][target_key]["virial"]
):
raise ValueError(
f"Cannot perform training with respect to virials and stress as in "
f"section {target_key}. Set either `virials: off` or `stress: off`."
)

return conf

Expand Down
Loading

0 comments on commit 84d8d2a

Please sign in to comment.