From 005176c7b018f7843d4e0e750fd9a14aaa3bdd46 Mon Sep 17 00:00:00 2001 From: Filippo Bigi <98903385+frostedoyster@users.noreply.github.com> Date: Wed, 12 Jun 2024 08:18:20 +0200 Subject: [PATCH] Output loss in scientific notation (#227) --- src/metatrain/utils/logging.py | 13 +++++++++--- tests/utils/test_logging.py | 36 ++++++++++++++++++++++++++++++++-- 2 files changed, 44 insertions(+), 5 deletions(-) diff --git a/src/metatrain/utils/logging.py b/src/metatrain/utils/logging.py index acf833ac3..05a5e5e71 100644 --- a/src/metatrain/utils/logging.py +++ b/src/metatrain/utils/logging.py @@ -45,11 +45,15 @@ def __init__( self.names = names - # Since the quantities are supposed to decrease, we want to store the number of - # digits at the start of the training, so that we can align the output later: + # Since the quantities are supposed to decrease, we want to store the + # number of digits at the start of the training, so that we can align + # the output later: self.digits = {} for name, metrics_dict in zip(names, initial_metrics): for key, value in metrics_dict.items(): + if "loss" in key: + # losses will be printed in scientific notation + continue self.digits[f"{name}_{key}"] = _get_digits(value) # Save the model outputs. This will be useful to know @@ -96,7 +100,10 @@ def log( logging_string += f", {new_key}: " else: logging_string += f", {name} {new_key}: " - logging_string += f"{value:{self.digits[f'{name}_{key}'][0]}.{self.digits[f'{name}_{key}'][1]}f}" # noqa: E501 + if "loss" in key: # print losses with scientific notation + logging_string += f"{value:.3e}" + else: + logging_string += f"{value:{self.digits[f'{name}_{key}'][0]}.{self.digits[f'{name}_{key}'][1]}f}" # noqa: E501 # If there is no epoch, the string will start with a comma. Remove it: if logging_string.startswith(", "): diff --git a/tests/utils/test_logging.py b/tests/utils/test_logging.py index 39b5f2d97..a13f4797a 100644 --- a/tests/utils/test_logging.py +++ b/tests/utils/test_logging.py @@ -1,7 +1,9 @@ import logging import re -from metatrain.utils.logging import setup_logging +from metatensor.torch.atomistic import ModelOutput + +from metatrain.utils.logging import MetricLogger, setup_logging def assert_log_entry(logtext: str, loglevel: str, message: str) -> None: @@ -78,4 +80,34 @@ def test_debug_log(caplog, monkeypatch, tmp_path, capsys): assert "foo" in logtext assert "A debug message" in logtext # Test that debug information is in output - assert "test_logging.py:test_debug_log:67" in logtext + assert "test_logging.py:test_debug_log:68" in logtext + + +def test_metric_logger(caplog, capsys): + """Tests the MetricLogger class.""" + caplog.set_level(logging.INFO) + logger = logging.getLogger("test") + + outputs = { + "foo": ModelOutput(), + "bar": ModelOutput(), + } + + names = ["train"] + + initial_metrics = [ + { + "loss": 0.1, + "foo RMSE": 1.0, + "bar RMSE": 0.1, + } + ] + + with setup_logging(logger, logfile="logfile.log", level=logging.INFO): + metric_logger = MetricLogger(logger, outputs, initial_metrics, names) + metric_logger.log(initial_metrics) + + stdout_log = capsys.readouterr().out + assert "train loss: 1.000e-01" in stdout_log + assert "train foo RMSE: 1.0000" in stdout_log + assert "train bar RMSE: 0.1000" in stdout_log