From 5a0c4f8ad79a3661a7b20a419c58efe335191b61 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Fri, 4 Oct 2024 18:42:21 +0200 Subject: [PATCH] Add evaluation timings --- src/metatrain/cli/eval.py | 27 +++++++++++++++++++++++++++ tests/cli/test_eval_model.py | 2 ++ 2 files changed, 29 insertions(+) diff --git a/src/metatrain/cli/eval.py b/src/metatrain/cli/eval.py index 93adb162..c37d9e29 100644 --- a/src/metatrain/cli/eval.py +++ b/src/metatrain/cli/eval.py @@ -1,6 +1,7 @@ import argparse import itertools import logging +import time from pathlib import Path from typing import Dict, List, Optional, Union @@ -161,11 +162,14 @@ def _eval_targets( """Evaluates an exported model on a dataset and prints the RMSEs for each target. Optionally, it also returns the predictions of the model. + The total and per-atom timings for the evaluation are also printed. + Wraps around metatrain.cli.evaluate_model. """ if len(dataset) == 0: logger.info("This dataset is empty. No evaluation will be performed.") + return # Attach neighbor lists to the systems: # TODO: these might already be present... find a way to avoid recomputing @@ -194,6 +198,10 @@ def _eval_targets( if return_predictions: all_predictions = [] + # Set up timings: + total_time = 0.0 + timings_per_atom = [] + # Evaluate the model for batch in dataloader: systems, batch_targets = batch @@ -202,6 +210,9 @@ def _eval_targets( key: value.to(dtype=dtype, device=device) for key, value in batch_targets.items() } + + start_time = time.time() + batch_predictions = evaluate_model( model, systems, @@ -209,6 +220,11 @@ def _eval_targets( is_training=False, check_consistency=check_consistency, ) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + end_time = time.time() + batch_predictions = average_by_num_atoms( batch_predictions, systems, per_structure_keys=[] ) @@ -219,6 +235,10 @@ def _eval_targets( if return_predictions: all_predictions.append(batch_predictions) + time_taken = end_time - start_time + total_time += time_taken + timings_per_atom.append(time_taken / sum(len(system) for system in systems)) + # Finalize the RMSEs rmse_values = rmse_accumulator.finalize(not_per_atom=["positions_gradients"]) # print the RMSEs with MetricLogger @@ -229,6 +249,13 @@ def _eval_targets( ) metric_logger.log(rmse_values) + # Log timings + mean_time_per_atom = sum(timings_per_atom) / len(timings_per_atom) + logger.info( + f"evaluation time: {total_time:.2f} s " + f"[{1000.0*mean_time_per_atom:.2f} ms per atom]" + ) + if return_predictions: # concatenate the TensorMaps all_predictions_joined = _concatenate_tensormaps(all_predictions) diff --git a/tests/cli/test_eval_model.py b/tests/cli/test_eval_model.py index 7aaa2c2b..7eda92af 100644 --- a/tests/cli/test_eval_model.py +++ b/tests/cli/test_eval_model.py @@ -68,6 +68,8 @@ def test_eval(monkeypatch, tmp_path, caplog, model_name, options): log = "".join([rec.message for rec in caplog.records]) assert "energy RMSE (per atom)" in log assert "dataset with index" not in log + assert "evaluation time" in log + assert "ms per atom" in log # Test file is written predictions frames = ase.io.read("foo.xyz", ":")