Skip to content

Commit

Permalink
Merge branch 'main' into dataset-repr
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster authored Jun 12, 2024
2 parents d61aaf4 + 005176c commit 9413119
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 5 deletions.
13 changes: 10 additions & 3 deletions src/metatrain/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,15 @@ def __init__(

self.names = names

# Since the quantities are supposed to decrease, we want to store the number of
# digits at the start of the training, so that we can align the output later:
# Since the quantities are supposed to decrease, we want to store the
# number of digits at the start of the training, so that we can align
# the output later:
self.digits = {}
for name, metrics_dict in zip(names, initial_metrics):
for key, value in metrics_dict.items():
if "loss" in key:
# losses will be printed in scientific notation
continue
self.digits[f"{name}_{key}"] = _get_digits(value)

# Save the model outputs. This will be useful to know
Expand Down Expand Up @@ -96,7 +100,10 @@ def log(
logging_string += f", {new_key}: "
else:
logging_string += f", {name} {new_key}: "
logging_string += f"{value:{self.digits[f'{name}_{key}'][0]}.{self.digits[f'{name}_{key}'][1]}f}" # noqa: E501
if "loss" in key: # print losses with scientific notation
logging_string += f"{value:.3e}"
else:
logging_string += f"{value:{self.digits[f'{name}_{key}'][0]}.{self.digits[f'{name}_{key}'][1]}f}" # noqa: E501

# If there is no epoch, the string will start with a comma. Remove it:
if logging_string.startswith(", "):
Expand Down
36 changes: 34 additions & 2 deletions tests/utils/test_logging.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import logging
import re

from metatrain.utils.logging import setup_logging
from metatensor.torch.atomistic import ModelOutput

from metatrain.utils.logging import MetricLogger, setup_logging


def assert_log_entry(logtext: str, loglevel: str, message: str) -> None:
Expand Down Expand Up @@ -78,4 +80,34 @@ def test_debug_log(caplog, monkeypatch, tmp_path, capsys):
assert "foo" in logtext
assert "A debug message" in logtext
# Test that debug information is in output
assert "test_logging.py:test_debug_log:67" in logtext
assert "test_logging.py:test_debug_log:68" in logtext


def test_metric_logger(caplog, capsys):
"""Tests the MetricLogger class."""
caplog.set_level(logging.INFO)
logger = logging.getLogger("test")

outputs = {
"foo": ModelOutput(),
"bar": ModelOutput(),
}

names = ["train"]

initial_metrics = [
{
"loss": 0.1,
"foo RMSE": 1.0,
"bar RMSE": 0.1,
}
]

with setup_logging(logger, logfile="logfile.log", level=logging.INFO):
metric_logger = MetricLogger(logger, outputs, initial_metrics, names)
metric_logger.log(initial_metrics)

stdout_log = capsys.readouterr().out
assert "train loss: 1.000e-01" in stdout_log
assert "train foo RMSE: 1.0000" in stdout_log
assert "train bar RMSE: 0.1000" in stdout_log

0 comments on commit 9413119

Please sign in to comment.