Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add evaluation timings #353

Merged
merged 2 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions src/metatrain/cli/eval.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import argparse
import itertools
import logging
import time
from pathlib import Path
from typing import Dict, List, Optional, Union

import metatensor.torch
import numpy as np
import torch
from metatensor.torch import Labels, TensorBlock, TensorMap
from metatensor.torch.atomistic import MetatensorAtomisticModel
Expand Down Expand Up @@ -161,11 +163,14 @@
"""Evaluates an exported model on a dataset and prints the RMSEs for each target.
Optionally, it also returns the predictions of the model.

The total and per-atom timings for the evaluation are also printed.

Wraps around metatrain.cli.evaluate_model.
"""

if len(dataset) == 0:
logger.info("This dataset is empty. No evaluation will be performed.")
return None

# Attach neighbor lists to the systems:
# TODO: these might already be present... find a way to avoid recomputing
Expand Down Expand Up @@ -194,6 +199,10 @@
if return_predictions:
all_predictions = []

# Set up timings:
total_time = 0.0
timings_per_atom = []

# Evaluate the model
for batch in dataloader:
systems, batch_targets = batch
Expand All @@ -202,13 +211,21 @@
key: value.to(dtype=dtype, device=device)
for key, value in batch_targets.items()
}

start_time = time.time()

batch_predictions = evaluate_model(
model,
systems,
options,
is_training=False,
check_consistency=check_consistency,
)

if torch.cuda.is_available():
torch.cuda.synchronize()

Check warning on line 226 in src/metatrain/cli/eval.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/cli/eval.py#L226

Added line #L226 was not covered by tests
end_time = time.time()

batch_predictions = average_by_num_atoms(
batch_predictions, systems, per_structure_keys=[]
)
Expand All @@ -219,6 +236,10 @@
if return_predictions:
all_predictions.append(batch_predictions)

time_taken = end_time - start_time
total_time += time_taken
timings_per_atom.append(time_taken / sum(len(system) for system in systems))

# Finalize the RMSEs
rmse_values = rmse_accumulator.finalize(not_per_atom=["positions_gradients"])
# print the RMSEs with MetricLogger
Expand All @@ -229,6 +250,16 @@
)
metric_logger.log(rmse_values)

# Log timings
timings_per_atom = np.array(timings_per_atom)
mean_per_atom = np.mean(timings_per_atom)
std_per_atom = np.std(timings_per_atom)
logger.info(
f"evaluation time: {total_time:.2f} s "
f"[{1000.0*mean_per_atom:.2f} ± "
f"{1000.0*std_per_atom:.2f} ms per atom]"
)

if return_predictions:
# concatenate the TensorMaps
all_predictions_joined = _concatenate_tensormaps(all_predictions)
Expand Down
2 changes: 2 additions & 0 deletions tests/cli/test_eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def test_eval(monkeypatch, tmp_path, caplog, model_name, options):
log = "".join([rec.message for rec in caplog.records])
assert "energy RMSE (per atom)" in log
assert "dataset with index" not in log
assert "evaluation time" in log
assert "ms per atom" in log

# Test file is written predictions
frames = ase.io.read("foo.xyz", ":")
Expand Down