Skip to content

Commit

Permalink
Add evaluation timings
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Oct 4, 2024
1 parent 161fd56 commit 5a0c4f8
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
27 changes: 27 additions & 0 deletions src/metatrain/cli/eval.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import itertools
import logging
import time
from pathlib import Path
from typing import Dict, List, Optional, Union

Expand Down Expand Up @@ -161,11 +162,14 @@ def _eval_targets(
"""Evaluates an exported model on a dataset and prints the RMSEs for each target.
Optionally, it also returns the predictions of the model.
The total and per-atom timings for the evaluation are also printed.
Wraps around metatrain.cli.evaluate_model.
"""

if len(dataset) == 0:
logger.info("This dataset is empty. No evaluation will be performed.")
return

# Attach neighbor lists to the systems:
# TODO: these might already be present... find a way to avoid recomputing
Expand Down Expand Up @@ -194,6 +198,10 @@ def _eval_targets(
if return_predictions:
all_predictions = []

# Set up timings:
total_time = 0.0
timings_per_atom = []

# Evaluate the model
for batch in dataloader:
systems, batch_targets = batch
Expand All @@ -202,13 +210,21 @@ def _eval_targets(
key: value.to(dtype=dtype, device=device)
for key, value in batch_targets.items()
}

start_time = time.time()

batch_predictions = evaluate_model(
model,
systems,
options,
is_training=False,
check_consistency=check_consistency,
)

if torch.cuda.is_available():
torch.cuda.synchronize()
end_time = time.time()

batch_predictions = average_by_num_atoms(
batch_predictions, systems, per_structure_keys=[]
)
Expand All @@ -219,6 +235,10 @@ def _eval_targets(
if return_predictions:
all_predictions.append(batch_predictions)

time_taken = end_time - start_time
total_time += time_taken
timings_per_atom.append(time_taken / sum(len(system) for system in systems))

# Finalize the RMSEs
rmse_values = rmse_accumulator.finalize(not_per_atom=["positions_gradients"])
# print the RMSEs with MetricLogger
Expand All @@ -229,6 +249,13 @@ def _eval_targets(
)
metric_logger.log(rmse_values)

# Log timings
mean_time_per_atom = sum(timings_per_atom) / len(timings_per_atom)
logger.info(
f"evaluation time: {total_time:.2f} s "
f"[{1000.0*mean_time_per_atom:.2f} ms per atom]"
)

if return_predictions:
# concatenate the TensorMaps
all_predictions_joined = _concatenate_tensormaps(all_predictions)
Expand Down
2 changes: 2 additions & 0 deletions tests/cli/test_eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def test_eval(monkeypatch, tmp_path, caplog, model_name, options):
log = "".join([rec.message for rec in caplog.records])
assert "energy RMSE (per atom)" in log
assert "dataset with index" not in log
assert "evaluation time" in log
assert "ms per atom" in log

# Test file is written predictions
frames = ase.io.read("foo.xyz", ":")
Expand Down

0 comments on commit 5a0c4f8

Please sign in to comment.