diff --git a/src/metatrain/experimental/alchemical_model/tests/__init__.py b/src/metatrain/experimental/alchemical_model/tests/__init__.py index 5fe9b1e1..4fcc1d3f 100644 --- a/src/metatrain/experimental/alchemical_model/tests/__init__.py +++ b/src/metatrain/experimental/alchemical_model/tests/__init__.py @@ -4,8 +4,8 @@ DATASET_PATH = str(Path(__file__).parents[5] / "tests/resources/qm9_reduced_100.xyz") -CARBON_DATASET_PATH = str( - Path(__file__).parents[5] / "tests/resources/carbon_reduced_100.xyz" +QM9_DATASET_PATH = str( + Path(__file__).parents[5] / "tests/resources/qm9_reduced_100.xyz" ) DEFAULT_HYPERS = get_default_hypers("experimental.alchemical_model") diff --git a/src/metatrain/experimental/alchemical_model/tests/test_torch_alchemical_compatibility.py b/src/metatrain/experimental/alchemical_model/tests/test_torch_alchemical_compatibility.py index c6c186fc..03b7ef1d 100644 --- a/src/metatrain/experimental/alchemical_model/tests/test_torch_alchemical_compatibility.py +++ b/src/metatrain/experimental/alchemical_model/tests/test_torch_alchemical_compatibility.py @@ -17,14 +17,14 @@ from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict, read_systems from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists -from . import ALCHEMICAL_DATASET_PATH, MODEL_HYPERS +from . import MODEL_HYPERS, QM9_DATASET_PATH random.seed(0) np.random.seed(0) torch.manual_seed(0) -systems = read_systems(ALCHEMICAL_DATASET_PATH) +systems = read_systems(QM9_DATASET_PATH) systems = [system.to(torch.float32) for system in systems] nl_options = NeighborListOptions( cutoff=5.0, @@ -32,7 +32,7 @@ ) systems = [get_system_with_neighbor_lists(system, [nl_options]) for system in systems] -frames = read(ALCHEMICAL_DATASET_PATH, ":") +frames = read(QM9_DATASET_PATH, ":") dataset = AtomisticDataset( frames, target_properties=["energies", "forces"], diff --git a/src/metatrain/utils/logging.py b/src/metatrain/utils/logging.py index 83b0382b..2f2b17ba 100644 --- a/src/metatrain/utils/logging.py +++ b/src/metatrain/utils/logging.py @@ -110,7 +110,8 @@ def log( logging_string = f"Epoch {epoch:4}" for name, metrics_dict in zip(self.names, metrics): - for key, value in metrics_dict.items(): + for key in sorted(metrics_dict.keys()): + value = metrics_dict[key] new_key = key if key != "loss": # special case: not a metric associated with a target diff --git a/tests/cli/test_train_model.py b/tests/cli/test_train_model.py index f810fc7a..a6382a8a 100644 --- a/tests/cli/test_train_model.py +++ b/tests/cli/test_train_model.py @@ -14,9 +14,9 @@ from metatrain import RANDOM_SEED from metatrain.cli.train import train_model from metatrain.utils.errors import ArchitectureError -from metatrain.utils.logging import setup_logging from . import ( + DATASET_PATH_CARBON, DATASET_PATH_ETHANOL, DATASET_PATH_QM9, MODEL_PATH_64_BIT, @@ -513,19 +513,45 @@ def test_train_issue_290(monkeypatch, tmp_path): train_model(options) -# def test_train_log_order(caplog, monkeypatch, tmp_path, options): -# """Tests that the log is always printed in the same order for forces -# and virials.""" +def test_train_log_order(caplog, monkeypatch, tmp_path, options): + """Tests that the log is always printed in the same order for forces + and virials.""" -# caplog.set_level(logging.INFO) -# logger = logging.getLogger() + monkeypatch.chdir(tmp_path) + shutil.copy(DATASET_PATH_CARBON, "carbon_reduced_100.xyz") + + options["architecture"]["training"]["num_epochs"] = 5 + options["architecture"]["training"]["log_interval"] = 1 + + options["training_set"]["systems"]["read_from"] = str(DATASET_PATH_CARBON) + options["training_set"]["targets"]["energy"]["read_from"] = str(DATASET_PATH_CARBON) + options["training_set"]["targets"]["energy"]["key"] = "energy" + options["training_set"]["targets"]["energy"]["forces"] = { + "key": "force", + } + options["training_set"]["targets"]["energy"]["virial"] = True + + caplog.set_level(logging.INFO) + + train_model(options) + + log_test = caplog.text + + print(log_test) -# with setup_logging(logger, level=logging.INFO): -# logger.info("foo") -# logger.debug("A debug message") + # find all the lines that have "Epoch" in them; these are the lines that + # contain the training metrics + epoch_lines = [line for line in log_test.split("\n") if "Epoch" in line] -# stdout_log = capsys.readouterr().out + # check that "training forces RMSE" comes before "training virial RMSE" + # in every line + for line in epoch_lines: + force_index = line.index("training forces RMSE") + virial_index = line.index("training virial RMSE") + assert force_index < virial_index -# assert "Logging to file is disabled." not in caplog.text # DEBUG message -# assert_log_entry(stdout_log, loglevel="INFO", message="foo") -# assert "A debug message" not in stdout_log + # same for validation + for line in epoch_lines: + force_index = line.index("validation forces RMSE") + virial_index = line.index("validation virial RMSE") + assert force_index < virial_index