Skip to content

Commit

Permalink
Add MAE logging
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Oct 17, 2024
1 parent d99721a commit ee9c69a
Show file tree
Hide file tree
Showing 11 changed files with 180 additions and 14 deletions.
13 changes: 9 additions & 4 deletions src/metatrain/cli/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from ..utils.errors import ArchitectureError
from ..utils.evaluate_model import evaluate_model
from ..utils.logging import MetricLogger
from ..utils.metrics import RMSEAccumulator
from ..utils.metrics import MAEAccumulator, RMSEAccumulator
from ..utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
Expand Down Expand Up @@ -197,6 +197,7 @@ def _eval_targets(

# Initialize RMSE accumulator:
rmse_accumulator = RMSEAccumulator()
mae_accumulator = MAEAccumulator()

# If we're returning the predictions, we need to store them:
if return_predictions:
Expand Down Expand Up @@ -249,22 +250,26 @@ def _eval_targets(
batch_targets, systems, per_structure_keys=[]
)
rmse_accumulator.update(batch_predictions, batch_targets)
mae_accumulator.update(batch_predictions, batch_targets)
if return_predictions:
all_predictions.append(batch_predictions)

time_taken = end_time - start_time
total_time += time_taken
timings_per_atom.append(time_taken / sum(len(system) for system in systems))

# Finalize the RMSEs
# Finalize the metrics
rmse_values = rmse_accumulator.finalize(not_per_atom=["positions_gradients"])
mae_values = mae_accumulator.finalize(not_per_atom=["positions_gradients"])
metrics = {**rmse_values, **mae_values}

# print the RMSEs with MetricLogger
metric_logger = MetricLogger(
log_obj=logger,
dataset_info=model.capabilities(),
initial_metrics=rmse_values,
initial_metrics=metrics,
)
metric_logger.log(rmse_values)
metric_logger.log(metrics)

# Log timings
timings_per_atom = np.array(timings_per_atom)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ training:
checkpoint_interval: 25
per_structure_targets: []
loss_weights: {}
log_mae: False
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@
}
},
"additionalProperties": false
},
"log_mae": {
"type": "boolean"
}
},
"additionalProperties": false
Expand Down
23 changes: 22 additions & 1 deletion src/metatrain/experimental/alchemical_model/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ...utils.io import check_file_extension
from ...utils.logging import MetricLogger
from ...utils.loss import TensorMapDictLoss
from ...utils.metrics import RMSEAccumulator
from ...utils.metrics import MAEAccumulator, RMSEAccumulator
from ...utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
Expand Down Expand Up @@ -223,6 +223,9 @@ def train(
for epoch in range(start_epoch, start_epoch + self.hypers["num_epochs"]):
train_rmse_calculator = RMSEAccumulator()
val_rmse_calculator = RMSEAccumulator()
if self.hypers["log_mae"]:
train_mae_calculator = MAEAccumulator()
val_mae_calculator = MAEAccumulator()

train_loss = 0.0
for batch in train_dataloader:
Expand Down Expand Up @@ -260,9 +263,18 @@ def train(
train_loss_batch.backward()
optimizer.step()
train_rmse_calculator.update(predictions, targets)
if self.hypers["log_mae"]:
train_mae_calculator.update(predictions, targets)

finalized_train_info = train_rmse_calculator.finalize(
not_per_atom=["positions_gradients"] + per_structure_targets
)
if self.hypers["log_mae"]:
finalized_train_info.update(
train_mae_calculator.finalize(
not_per_atom=["positions_gradients"] + per_structure_targets
)
)

val_loss = 0.0
for batch in val_dataloader:
Expand Down Expand Up @@ -297,9 +309,18 @@ def train(
val_loss_batch = loss_fn(predictions, targets)
val_loss += val_loss_batch.item()
val_rmse_calculator.update(predictions, targets)
if self.hypers["log_mae"]:
val_mae_calculator.update(predictions, targets)

finalized_val_info = val_rmse_calculator.finalize(
not_per_atom=["positions_gradients"] + per_structure_targets
)
if self.hypers["log_mae"]:
finalized_val_info.update(
val_mae_calculator.finalize(
not_per_atom=["positions_gradients"] + per_structure_targets
)
)

lr_scheduler.step(val_loss)

Expand Down
1 change: 1 addition & 0 deletions src/metatrain/experimental/soap_bpnn/default-hypers.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,4 @@ training:
fixed_composition_weights: {}
per_structure_targets: []
loss_weights: {}
log_mae: False
3 changes: 3 additions & 0 deletions src/metatrain/experimental/soap_bpnn/schema-hypers.json
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@
}
},
"additionalProperties": false
},
"log_mae": {
"type": "boolean"
}
},
"additionalProperties": false
Expand Down
32 changes: 27 additions & 5 deletions src/metatrain/experimental/soap_bpnn/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ...utils.io import check_file_extension
from ...utils.logging import MetricLogger
from ...utils.loss import TensorMapDictLoss
from ...utils.metrics import RMSEAccumulator
from ...utils.metrics import MAEAccumulator, RMSEAccumulator
from ...utils.neighbor_lists import (
get_requested_neighbor_lists,
get_system_with_neighbor_lists,
Expand Down Expand Up @@ -244,6 +244,9 @@ def train(

train_rmse_calculator = RMSEAccumulator()
val_rmse_calculator = RMSEAccumulator()
if self.hypers["log_mae"]:
train_mae_calculator = MAEAccumulator()
val_mae_calculator = MAEAccumulator()

train_loss = 0.0
for batch in train_dataloader:
Expand Down Expand Up @@ -285,11 +288,22 @@ def train(
torch.distributed.all_reduce(train_loss_batch)
train_loss += train_loss_batch.item()
train_rmse_calculator.update(predictions, targets)
if self.hypers["log_mae"]:
train_mae_calculator.update(predictions, targets)

finalized_train_info = train_rmse_calculator.finalize(
not_per_atom=["positions_gradients"] + per_structure_targets,
is_distributed=is_distributed,
device=device,
)
if self.hypers["log_mae"]:
finalized_train_info.update(
train_mae_calculator.finalize(
not_per_atom=["positions_gradients"] + per_structure_targets,
is_distributed=is_distributed,
device=device,
)
)

val_loss = 0.0
for batch in val_dataloader:
Expand Down Expand Up @@ -326,20 +340,28 @@ def train(
torch.distributed.all_reduce(val_loss_batch)
val_loss += val_loss_batch.item()
val_rmse_calculator.update(predictions, targets)
if self.hypers["log_mae"]:
val_mae_calculator.update(predictions, targets)

finalized_val_info = val_rmse_calculator.finalize(
not_per_atom=["positions_gradients"] + per_structure_targets,
is_distributed=is_distributed,
device=device,
)
if self.hypers["log_mae"]:
finalized_val_info.update(
val_mae_calculator.finalize(
not_per_atom=["positions_gradients"] + per_structure_targets,
is_distributed=is_distributed,
device=device,
)
)

lr_scheduler.step(val_loss)

# Now we log the information:
finalized_train_info = {"loss": train_loss, **finalized_train_info}
finalized_val_info = {
"loss": val_loss,
**finalized_val_info,
}
finalized_val_info = {"loss": val_loss, **finalized_val_info}

if epoch == start_epoch:
metric_logger = MetricLogger(
Expand Down
10 changes: 8 additions & 2 deletions src/metatrain/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,12 @@ def _sort_metric_names(name_list):
# loss goes first
loss_index = name_list.index("loss")
sorted_name_list.append(name_list.pop(loss_index))
# then alphabetical order
sorted_name_list.extend(sorted(name_list))
# then alphabetical order, except for the MAEs, which should come
# after the corresponding RMSEs
sorted_remaining_name_list = sorted(
name_list,
key=lambda x: x.replace("RMSE", "AAA").replace("MAE", "ZZZ"),
)
# add the rest
sorted_name_list.extend(sorted_remaining_name_list)
return sorted_name_list
87 changes: 86 additions & 1 deletion src/metatrain/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def finalize(
is_distributed: bool = False,
device: torch.device = None,
) -> Dict[str, float]:
"""Finalizes the accumulator and return the RMSE for each key.
"""Finalizes the accumulator and returns the RMSE for each key.
All keys will be returned as "{key} RMSE (per atom)" in the output dictionary,
unless ``key`` contains one or more of the strings in ``not_per_atom``,
Expand Down Expand Up @@ -85,3 +85,88 @@ def finalize(
finalized_info[out_key] = (value[0] / value[1]) ** 0.5

return finalized_info


class MAEAccumulator:
"""Accumulates the MAE between predictions and targets for an arbitrary
number of keys, each corresponding to one target."""

def __init__(self):
"""Initialize the accumulator."""
self.information: Dict[str, Tuple[float, int]] = {}

def update(self, predictions: Dict[str, TensorMap], targets: Dict[str, TensorMap]):
"""Updates the accumulator with new predictions and targets.
:param predictions: A dictionary of predictions, where the keys correspond
to the keys in the targets dictionary, and the values are the predictions.
:param targets: A dictionary of targets, where the keys correspond to the keys
in the predictions dictionary, and the values are the targets.
"""

for key, target in targets.items():
if key not in self.information:
self.information[key] = (0.0, 0)
prediction = predictions[key]

self.information[key] = (
self.information[key][0]
+ (prediction.block().values - target.block().values)
.abs()
.sum()
.item(),
self.information[key][1] + prediction.block().values.numel(),
)

for gradient_name, target_gradient in target.block().gradients():
if f"{key}_{gradient_name}_gradients" not in self.information:
self.information[f"{key}_{gradient_name}_gradients"] = (0.0, 0)
prediction_gradient = prediction.block().gradient(gradient_name)
self.information[f"{key}_{gradient_name}_gradients"] = (
self.information[f"{key}_{gradient_name}_gradients"][0]
+ (prediction_gradient.values - target_gradient.values)
.abs()
.sum()
.item(),
self.information[f"{key}_{gradient_name}_gradients"][1]
+ prediction_gradient.values.numel(),
)

def finalize(
self,
not_per_atom: List[str],
is_distributed: bool = False,
device: torch.device = None,
) -> Dict[str, float]:
"""Finalizes the accumulator and returns the MAE for each key.
All keys will be returned as "{key} MAE (per atom)" in the output dictionary,
unless ``key`` contains one or more of the strings in ``not_per_atom``,
in which case "{key} MAE" will be returned.
:param not_per_atom: a list of strings. If any of these strings are present in
a key, the MAE key will not be labeled as "(per atom)".
:param is_distributed: if true, the MAE will be computed across all ranks
of the distributed system.
:param device: the local device to use for the computation. Only needed if
``is_distributed`` is :obj:`python:True`.
"""

if is_distributed:
for key, value in self.information.items():
sae = torch.tensor(value[0]).to(device)
n_elems = torch.tensor(value[1]).to(device)
torch.distributed.all_reduce(sae)
torch.distributed.all_reduce(n_elems)
self.information[key] = (sae.item(), n_elems.item()) # type: ignore

Check warning on line 162 in src/metatrain/utils/metrics.py

View check run for this annotation

Codecov / codecov/patch

src/metatrain/utils/metrics.py#L157-L162

Added lines #L157 - L162 were not covered by tests

finalized_info = {}
for key, value in self.information.items():
if any([s in key for s in not_per_atom]):
out_key = f"{key} MAE"
else:
out_key = f"{key} MAE (per atom)"
finalized_info[out_key] = value[0] / value[1]

return finalized_info
1 change: 1 addition & 0 deletions tests/cli/test_eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def test_eval(monkeypatch, tmp_path, caplog, model_name, options):
# Test target predictions
log = "".join([rec.message for rec in caplog.records])
assert "energy RMSE (per atom)" in log
assert "energy MAE (per atom)" in log
assert "dataset with index" not in log
assert "evaluation time" in log
assert "ms per atom" in log
Expand Down
20 changes: 19 additions & 1 deletion tests/utils/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
from metatensor.torch import Labels, TensorBlock, TensorMap

from metatrain.utils.metrics import RMSEAccumulator
from metatrain.utils.metrics import MAEAccumulator, RMSEAccumulator


@pytest.fixture
Expand Down Expand Up @@ -63,3 +63,21 @@ def test_rmse_accumulator(tensor_map_with_grad_1, tensor_map_with_grad_2):

assert "energy RMSE (per atom)" in rmses
assert "energy_gradient_gradients RMSE" in rmses


def test_mae_accumulator(tensor_map_with_grad_1, tensor_map_with_grad_2):
"""Tests the MAEAccumulator class."""

mae_accumulator = MAEAccumulator()
for _ in range(10):
mae_accumulator.update(
{"energy": tensor_map_with_grad_1}, {"energy": tensor_map_with_grad_2}
)

assert mae_accumulator.information["energy"][1] == 30
assert mae_accumulator.information["energy_gradient_gradients"][1] == 30

maes = mae_accumulator.finalize(not_per_atom=["gradient_gradients"])

assert "energy MAE (per atom)" in maes
assert "energy_gradient_gradients MAE" in maes

0 comments on commit ee9c69a

Please sign in to comment.