Skip to content

Commit

Permalink
Change how eval function works (#65)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: frostedoyster <[email protected]>
  • Loading branch information
PicoCentauri and frostedoyster authored Feb 9, 2024
1 parent 14bc244 commit cdb74f9
Show file tree
Hide file tree
Showing 13 changed files with 384 additions and 177 deletions.
25 changes: 18 additions & 7 deletions docs/src/getting-started/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,25 +66,36 @@ The sub-command to evaluate an already trained model is
metatensor-models eval
Besides the trained `model`, you will also have to provide a file containing the
structure and possible target values for evaluation. The structure of this ``eval.yaml``
is exactly the same as for a dataset in the ``options.yaml`` file.

.. literalinclude:: ../../static/qm9/eval.yaml
:language: yaml

Note that the ``targets`` section is optional. If the ``targets`` section is present,
the function will calculate and report RMSE values of the predictions with respect to
the real values as loaded from the ``targets`` section. You can run an evaluation by
typing

.. literalinclude:: ../../../examples/basic_usage/usage.sh
:language: bash
:lines: 9-25
:lines: 9-24


Exporting
#########

Exporting a model is very useful if you want to use it in other frameworks,
especially in molecular dynamics simulations.
The sub-command to export a trained model is
Exporting a model is very useful if you want to use it in other frameworks, especially
in molecular dynamics simulations. The sub-command to export a trained model is

.. code-block:: bash
metatensor-models export
.. literalinclude:: ../../../examples/basic_usage/usage.sh
:language: bash
:lines: 25-
:lines: 26-

In the next tutorials we show how adjust the dataset section of ``options.yaml`` file
to use it for your own datasets.
In the next tutorials we show how adjust the dataset section of ``options.yaml`` file to
use it for your own datasets.
4 changes: 4 additions & 0 deletions docs/static/qm9/eval.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
structures: "qm9_reduced_100.xyz" # file where the positions are stored
targets:
energy:
key: "U0" # name of the target value
1 change: 1 addition & 0 deletions examples/basic_usage/eval.yaml
8 changes: 4 additions & 4 deletions examples/basic_usage/usage.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ metatensor-models train options.yaml
metatensor-models train --help

# We now evaluate the model on the training dataset, where the first arguments specifies
# the model and the second the structure file
# trained model and the second an option file containing the path of the dataset for evaulation.

metatensor-models eval model.pt qm9_reduced_100.xyz
metatensor-models eval model.pt eval.yaml

# The evaluation command predicts the property the model was trained against; here "U0".
# The predictions together with the structures have been written in a file named
# The evaluation command predicts those properties the model was trained against; here
# "U0". The predictions together with the structures have been written in a file named
# ``output.xyz`` in the current directory. The written file starts with the following
# lines

Expand Down
140 changes: 119 additions & 21 deletions src/metatensor/models/cli/eval_model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,29 @@
import argparse
import logging
from typing import Dict, Tuple, Union

from ..utils.data.readers import read_structures
from ..utils.data.writers import write_predictions
import torch
from omegaconf import DictConfig, OmegaConf

from ..utils.compute_loss import compute_model_loss
from ..utils.data import (
Dataset,
collate_fn,
read_structures,
read_targets,
write_predictions,
)
from ..utils.extract_targets import get_outputs_dict
from ..utils.info import finalize_aggregated_info, update_aggregated_info
from ..utils.loss import TensorMapDictLoss
from ..utils.model_io import load_model
from ..utils.omegaconf import expand_dataset_config
from .formatter import CustomHelpFormatter


logger = logging.getLogger(__name__)


def _add_eval_model_parser(subparser: argparse._SubParsersAction) -> None:
"""Add the `eval_model` paramaters to an argparse (sub)-parser"""

Expand All @@ -20,16 +38,15 @@ def _add_eval_model_parser(subparser: argparse._SubParsersAction) -> None:
formatter_class=CustomHelpFormatter,
)
parser.set_defaults(callable="eval_model")

parser.add_argument(
"model",
type=str,
help="saved model to be evaluated",
type=load_model,
help="Saved model to be evaluated.",
)
parser.add_argument(
"structures",
type=str,
help="Structure file which should be considered for the evaluation.",
"options",
type=OmegaConf.load,
help="Eval options file to define a dataset for evaluation.",
)
parser.add_argument(
"-o",
Expand All @@ -42,22 +59,103 @@ def _add_eval_model_parser(subparser: argparse._SubParsersAction) -> None:
)


def eval_model(model: str, structures: str, output: str = "output.xyz") -> None:
"""Evaluate a pretrained model.
def _eval_targets(model, dataset: Union[Dataset, torch.utils.data.Subset]) -> None:
"""Evaluate a model on a dataset and print the RMSEs for each target."""

# Extract all the possible outputs and their gradients from the dataset:
outputs_dict = get_outputs_dict([dataset])
for output_name in outputs_dict.keys():
if output_name not in model.capabilities.outputs:
raise ValueError(
f"Output {output_name} is not in the model's capabilities."
)

# Create the loss function:
loss_weights_dict = {}
for output_name, value_or_gradient_list in outputs_dict.items():
loss_weights_dict[output_name] = {
value_or_gradient: 0.0 for value_or_gradient in value_or_gradient_list
}
loss_fn = TensorMapDictLoss(loss_weights_dict)

# Create a dataloader:
dataloader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=4, # Choose small value to not crash the system at evaluation
shuffle=True,
collate_fn=collate_fn,
)

# Compute the RMSEs:
aggregated_info: Dict[str, Tuple[float, int]] = {}
for batch in dataloader:
structures, targets = batch
_, info = compute_model_loss(loss_fn, model, structures, targets)
aggregated_info = update_aggregated_info(aggregated_info, info)
finalized_info = finalize_aggregated_info(aggregated_info)

energy_counter = 0
for output in model.capabilities.outputs.values():
if output.quantity == "energy":
energy_counter += 1
if energy_counter == 1:
only_one_energy = True
else:
only_one_energy = False

log_output = []
for key, value in finalized_info.items():
new_key = key
if key.endswith("_positions_gradients"):
# check if this is a force
target_name = key[: -len("_positions_gradients")]
if model.capabilities.outputs[target_name].quantity == "energy":
# if this is a force, replace the ugly name with "force"
if only_one_energy:
new_key = "force"
else:
new_key = f"force[{target_name}]"
elif key.endswith("_displacement_gradients"):
# check if this is a virial/stress
target_name = key[: -len("_displacement_gradients")]
if model.capabilities.outputs[target_name].quantity == "energy":
# if this is a virial/stress,
# replace the ugly name with "virial/stress"
if only_one_energy:
new_key = "virial/stress"
else:
new_key = f"virial/stress[{target_name}]"
log_output.append(f"{new_key} RMSE: {value}")
logger.info(", ".join(log_output))

``target_property`` will be predicted on a provided set of structures. Predicted
values will be written ``output``.

:param model: Path to a saved model
:param structure: Path to a structure file which should be considered for the
evaluation.
def eval_model(
model: torch.nn.Module, options: DictConfig, output: str = "output.xyz"
) -> None:
"""Evaluate a pretrained model on a given data set.
If ``options`` contains a ``targets`` sub-section, RMSE values will be reported. If
this sub-section is missing, only a xyz-file with containing the properties the
model was trained against is written.
:param model: Saved model to be evaluated.
:param options: DictConfig to define a test dataset taken for the evaluation.
:param output: Path to save the predicted values
"""
logging.basicConfig(level=logging.INFO, format="%(message)s")
logger.info("Setting up evaluation set.")

loaded_model = load_model(model)
structure_list = read_structures(structures)

# this calculates all the properties that the model is capable of predicting:
predictions = loaded_model(structure_list, loaded_model.capabilities.outputs)
options = expand_dataset_config(options)
eval_structures = read_structures(
filename=options["structures"]["read_from"],
fileformat=options["structures"]["file_format"],
)
# Predict targets
if hasattr(options, "targets"):
eval_targets = read_targets(options["targets"])
eval_dataset = Dataset(eval_structures, eval_targets)
_eval_targets(model, eval_dataset)

write_predictions(output, predictions, structure_list)
# Predict structures
predictions = model(eval_structures, model.capabilities.outputs)
write_predictions(output, predictions, eval_structures)
Loading

0 comments on commit cdb74f9

Please sign in to comment.