Skip to content

Commit

Permalink
Add test for consistent logging order
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Jul 23, 2024
1 parent 3c30e09 commit 2ff7d2b
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 19 deletions.
4 changes: 2 additions & 2 deletions src/metatrain/experimental/alchemical_model/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

DATASET_PATH = str(Path(__file__).parents[5] / "tests/resources/qm9_reduced_100.xyz")

CARBON_DATASET_PATH = str(
Path(__file__).parents[5] / "tests/resources/carbon_reduced_100.xyz"
QM9_DATASET_PATH = str(
Path(__file__).parents[5] / "tests/resources/qm9_reduced_100.xyz"
)

DEFAULT_HYPERS = get_default_hypers("experimental.alchemical_model")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,22 @@
from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict, read_systems
from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists

from . import ALCHEMICAL_DATASET_PATH, MODEL_HYPERS
from . import MODEL_HYPERS, QM9_DATASET_PATH


random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

systems = read_systems(ALCHEMICAL_DATASET_PATH)
systems = read_systems(QM9_DATASET_PATH)
systems = [system.to(torch.float32) for system in systems]
nl_options = NeighborListOptions(
cutoff=5.0,
full_list=True,
)
systems = [get_system_with_neighbor_lists(system, [nl_options]) for system in systems]

frames = read(ALCHEMICAL_DATASET_PATH, ":")
frames = read(QM9_DATASET_PATH, ":")
dataset = AtomisticDataset(
frames,
target_properties=["energies", "forces"],
Expand Down
3 changes: 2 additions & 1 deletion src/metatrain/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ def log(
logging_string = f"Epoch {epoch:4}"

for name, metrics_dict in zip(self.names, metrics):
for key, value in metrics_dict.items():
for key in sorted(metrics_dict.keys()):
value = metrics_dict[key]

new_key = key
if key != "loss": # special case: not a metric associated with a target
Expand Down
48 changes: 35 additions & 13 deletions tests/cli/test_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from metatrain import RANDOM_SEED
from metatrain.cli.train import train_model
from metatrain.utils.errors import ArchitectureError
from metatrain.utils.logging import setup_logging

from . import (
DATASET_PATH_CARBON,
DATASET_PATH_ETHANOL,
DATASET_PATH_QM9,
MODEL_PATH_64_BIT,
Expand Down Expand Up @@ -513,19 +513,41 @@ def test_train_issue_290(monkeypatch, tmp_path):
train_model(options)


# def test_train_log_order(caplog, monkeypatch, tmp_path, options):
# """Tests that the log is always printed in the same order for forces
# and virials."""
def test_train_log_order(caplog, monkeypatch, tmp_path, options):
"""Tests that the log is always printed in the same order for forces
and virials."""

# caplog.set_level(logging.INFO)
# logger = logging.getLogger()
monkeypatch.chdir(tmp_path)
shutil.copy(DATASET_PATH_CARBON, "carbon_reduced_100.xyz")

# with setup_logging(logger, level=logging.INFO):
# logger.info("foo")
# logger.debug("A debug message")
options["architecture"]["training"]["num_epochs"] = 5
options["architecture"]["training"]["log_interval"] = 1

# stdout_log = capsys.readouterr().out
options["training_set"]["systems"]["read_from"] = str(DATASET_PATH_CARBON)
options["training_set"]["targets"]["energy"]["read_from"] = str(DATASET_PATH_CARBON)
options["training_set"]["targets"]["energy"]["key"] = "energy"
options["training_set"]["targets"]["energy"]["forces"] = {
"key": "force",
}
options["training_set"]["targets"]["energy"]["virial"] = True

# assert "Logging to file is disabled." not in caplog.text # DEBUG message
# assert_log_entry(stdout_log, loglevel="INFO", message="foo")
# assert "A debug message" not in stdout_log
caplog.set_level(logging.INFO)
train_model(options)
log_test = caplog.text

# find all the lines that have "Epoch" in them; these are the lines that
# contain the training metrics
epoch_lines = [line for line in log_test.split("\n") if "Epoch" in line]

# check that "training forces RMSE" comes before "training virial RMSE"
# in every line
for line in epoch_lines:
force_index = line.index("training forces RMSE")
virial_index = line.index("training virial RMSE")
assert force_index < virial_index

# same for validation
for line in epoch_lines:
force_index = line.index("validation forces RMSE")
virial_index = line.index("validation virial RMSE")
assert force_index < virial_index

0 comments on commit 2ff7d2b

Please sign in to comment.