Skip to content

Commit

Permalink
Merge branch 'main' into composition-model
Browse files Browse the repository at this point in the history
  • Loading branch information
Luthaf committed Sep 4, 2024
2 parents 7a406d1 + e413a90 commit ab41dc2
Show file tree
Hide file tree
Showing 9 changed files with 26 additions and 21 deletions.
2 changes: 1 addition & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
graft src

include LICENSE
include README.md
include README.rst

prune developer
prune docs
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "metatrain"
dynamic = ["version"]
requires-python = ">=3.9"

readme = "README.md"
readme = "README.rst"
license = {text = "BSD-3-Clause"}
description = "Training and evaluating machine learning models for atomistic systems."
authors = [{name = "metatrain developers"}]
Expand All @@ -12,7 +12,7 @@ dependencies = [
"ase < 3.23.0",
"metatensor-learn==0.2.3",
"metatensor-operations==0.2.3",
"metatensor-torch==0.5.4",
"metatensor-torch==0.5.5",
"jsonschema",
"omegaconf",
"python-hostlist",
Expand Down
2 changes: 1 addition & 1 deletion src/metatrain/experimental/alchemical_model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def forward(
def load_checkpoint(cls, path: Union[str, Path]) -> "AlchemicalModel":

# Load the checkpoint
checkpoint = torch.load(path)
checkpoint = torch.load(path, weights_only=False)
model_hypers = checkpoint["model_hypers"]
model_state_dict = checkpoint["model_state_dict"]

Expand Down
2 changes: 1 addition & 1 deletion src/metatrain/experimental/alchemical_model/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def save_checkpoint(self, model, path: Union[str, Path]):
def load_checkpoint(cls, path: Union[str, Path], train_hypers) -> "Trainer":

# Load the checkpoint
checkpoint = torch.load(path)
checkpoint = torch.load(path, weights_only=False)
model_hypers = checkpoint["model_hypers"]
model_state_dict = checkpoint["model_state_dict"]
epoch = checkpoint["epoch"]
Expand Down
4 changes: 2 additions & 2 deletions src/metatrain/experimental/pet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,14 @@ def forward(
@classmethod
def load_checkpoint(cls, path: Union[str, Path]) -> "PET":

checkpoint = torch.load(path)
checkpoint = torch.load(path, weights_only=False)
hypers = checkpoint["hypers"]
dataset_info = checkpoint["dataset_info"]
model = cls(
model_hypers=hypers["ARCHITECTURAL_HYPERS"], dataset_info=dataset_info
)

checkpoint = torch.load(path)
checkpoint = torch.load(path, weights_only=False)
state_dict = checkpoint["checkpoint"]["model_state_dict"]

ARCHITECTURAL_HYPERS = Hypers(model.hypers)
Expand Down
6 changes: 3 additions & 3 deletions src/metatrain/experimental/pet/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def train(
else:
load_path = self.pet_dir / "best_val_rmse_energies_model_state_dict"

state_dict = torch.load(load_path)
state_dict = torch.load(load_path, weights_only=False)

ARCHITECTURAL_HYPERS = Hypers(model.hypers)
raw_pet = PET(ARCHITECTURAL_HYPERS, 0.0, len(model.atomic_types))
Expand All @@ -186,7 +186,7 @@ def save_checkpoint(self, model, path: Union[str, Path]):
# together with the hypers inside a file that will act as a metatrain
# checkpoint
checkpoint_path = self.pet_dir / "checkpoint" # type: ignore
checkpoint = torch.load(checkpoint_path)
checkpoint = torch.load(checkpoint_path, weights_only=False)
torch.save(
{
"checkpoint": checkpoint,
Expand All @@ -204,7 +204,7 @@ def load_checkpoint(cls, path: Union[str, Path], train_hypers) -> "Trainer":
# This function loads a metatrain PET checkpoint and returns a Trainer
# instance with the hypers, while also saving the checkpoint in the
# class
checkpoint = torch.load(path)
checkpoint = torch.load(path, weights_only=False)
trainer = cls(train_hypers)
trainer.pet_checkpoint = checkpoint["checkpoint"]
return trainer
2 changes: 1 addition & 1 deletion src/metatrain/experimental/soap_bpnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def forward(
def load_checkpoint(cls, path: Union[str, Path]) -> "SoapBpnn":

# Load the checkpoint
checkpoint = torch.load(path)
checkpoint = torch.load(path, weights_only=False)
model_hypers = checkpoint["model_hypers"]
model_state_dict = checkpoint["model_state_dict"]

Expand Down
21 changes: 13 additions & 8 deletions src/metatrain/experimental/soap_bpnn/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,6 @@
logger = logging.getLogger(__name__)


# Filter out the second derivative and device warnings from rascaline-torch
warnings.filterwarnings("ignore", category=UserWarning, message="second derivative")
warnings.filterwarnings(
"ignore", category=UserWarning, message="Systems data is on device"
)


class Trainer:
def __init__(self, train_hypers):
self.hypers = train_hypers
Expand All @@ -48,6 +41,17 @@ def train(
val_datasets: List[Union[Dataset, torch.utils.data.Subset]],
checkpoint_dir: str,
):
# Filter out the second derivative and device warnings from rascaline
warnings.filterwarnings(action="ignore", message="Systems data is on device")
warnings.filterwarnings(
action="ignore",
message="second derivatives with respect to positions are not implemented",
)
warnings.filterwarnings(
action="ignore",
message="second derivatives with respect to cell matrix",
)

assert dtype in SoapBpnn.__supported_dtypes__

is_distributed = self.hypers["distributed"]
Expand Down Expand Up @@ -248,6 +252,7 @@ def train(
targets = average_by_num_atoms(targets, systems, per_structure_targets)

train_loss_batch = loss_fn(predictions, targets)

train_loss_batch.backward()
optimizer.step()

Expand Down Expand Up @@ -368,7 +373,7 @@ def save_checkpoint(self, model, path: Union[str, Path]):
def load_checkpoint(cls, path: Union[str, Path], train_hypers) -> "Trainer":

# Load the checkpoint
checkpoint = torch.load(path)
checkpoint = torch.load(path, weights_only=False)
model_hypers = checkpoint["model_hypers"]
model_state_dict = checkpoint["model_state_dict"]
epoch = checkpoint["epoch"]
Expand Down
4 changes: 2 additions & 2 deletions tests/cli/test_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,8 +449,8 @@ def test_model_consistency_with_seed(options, monkeypatch, tmp_path, seed):

train_model(options, output="model2.pt")

m1 = torch.load("model1.ckpt")
m2 = torch.load("model2.ckpt")
m1 = torch.load("model1.ckpt", weights_only=False)
m2 = torch.load("model2.ckpt", weights_only=False)

for i in m1["model_state_dict"]:
tensor1 = m1["model_state_dict"][i]
Expand Down

0 comments on commit ab41dc2

Please sign in to comment.