Skip to content

Commit

Permalink
Fix/silence warnings in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Luthaf committed Sep 4, 2024
1 parent d3e286e commit f1d0c3c
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 18 deletions.
2 changes: 1 addition & 1 deletion src/metatrain/experimental/alchemical_model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def forward(
def load_checkpoint(cls, path: Union[str, Path]) -> "AlchemicalModel":

# Load the checkpoint
checkpoint = torch.load(path)
checkpoint = torch.load(path, weights_only=False)
model_hypers = checkpoint["model_hypers"]
model_state_dict = checkpoint["model_state_dict"]

Expand Down
2 changes: 1 addition & 1 deletion src/metatrain/experimental/alchemical_model/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def save_checkpoint(self, model, path: Union[str, Path]):
def load_checkpoint(cls, path: Union[str, Path], train_hypers) -> "Trainer":

# Load the checkpoint
checkpoint = torch.load(path)
checkpoint = torch.load(path, weights_only=False)
model_hypers = checkpoint["model_hypers"]
model_state_dict = checkpoint["model_state_dict"]
epoch = checkpoint["epoch"]
Expand Down
4 changes: 2 additions & 2 deletions src/metatrain/experimental/pet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,14 @@ def forward(
@classmethod
def load_checkpoint(cls, path: Union[str, Path]) -> "PET":

checkpoint = torch.load(path)
checkpoint = torch.load(path, weights_only=False)
hypers = checkpoint["hypers"]
dataset_info = checkpoint["dataset_info"]
model = cls(
model_hypers=hypers["ARCHITECTURAL_HYPERS"], dataset_info=dataset_info
)

checkpoint = torch.load(path)
checkpoint = torch.load(path, weights_only=False)
state_dict = checkpoint["checkpoint"]["model_state_dict"]

ARCHITECTURAL_HYPERS = Hypers(model.hypers)
Expand Down
6 changes: 3 additions & 3 deletions src/metatrain/experimental/pet/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def train(
else:
load_path = self.pet_dir / "best_val_rmse_energies_model_state_dict"

state_dict = torch.load(load_path)
state_dict = torch.load(load_path, weights_only=False)

ARCHITECTURAL_HYPERS = Hypers(model.hypers)
raw_pet = PET(ARCHITECTURAL_HYPERS, 0.0, len(model.atomic_types))
Expand All @@ -186,7 +186,7 @@ def save_checkpoint(self, model, path: Union[str, Path]):
# together with the hypers inside a file that will act as a metatrain
# checkpoint
checkpoint_path = self.pet_dir / "checkpoint" # type: ignore
checkpoint = torch.load(checkpoint_path)
checkpoint = torch.load(checkpoint_path, weights_only=False)
torch.save(
{
"checkpoint": checkpoint,
Expand All @@ -204,7 +204,7 @@ def load_checkpoint(cls, path: Union[str, Path], train_hypers) -> "Trainer":
# This function loads a metatrain PET checkpoint and returns a Trainer
# instance with the hypers, while also saving the checkpoint in the
# class
checkpoint = torch.load(path)
checkpoint = torch.load(path, weights_only=False)
trainer = cls(train_hypers)
trainer.pet_checkpoint = checkpoint["checkpoint"]
return trainer
2 changes: 1 addition & 1 deletion src/metatrain/experimental/soap_bpnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def forward(
def load_checkpoint(cls, path: Union[str, Path]) -> "SoapBpnn":

# Load the checkpoint
checkpoint = torch.load(path)
checkpoint = torch.load(path, weights_only=False)
model_hypers = checkpoint["model_hypers"]
model_state_dict = checkpoint["model_state_dict"]

Expand Down
21 changes: 13 additions & 8 deletions src/metatrain/experimental/soap_bpnn/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,6 @@
logger = logging.getLogger(__name__)


# Filter out the second derivative and device warnings from rascaline-torch
warnings.filterwarnings("ignore", category=UserWarning, message="second derivative")
warnings.filterwarnings(
"ignore", category=UserWarning, message="Systems data is on device"
)


class Trainer:
def __init__(self, train_hypers):
self.hypers = train_hypers
Expand All @@ -54,6 +47,17 @@ def train(
val_datasets: List[Union[Dataset, torch.utils.data.Subset]],
checkpoint_dir: str,
):
# Filter out the second derivative and device warnings from rascaline
warnings.filterwarnings(action="ignore", message="Systems data is on device")
warnings.filterwarnings(
action="ignore",
message="second derivatives with respect to positions are not implemented",
)
warnings.filterwarnings(
action="ignore",
message="second derivatives with respect to cell matrix",
)

assert dtype in SoapBpnn.__supported_dtypes__

is_distributed = self.hypers["distributed"]
Expand Down Expand Up @@ -290,6 +294,7 @@ def train(
targets = average_by_num_atoms(targets, systems, per_structure_targets)

train_loss_batch = loss_fn(predictions, targets)

train_loss_batch.backward()
optimizer.step()

Expand Down Expand Up @@ -409,7 +414,7 @@ def save_checkpoint(self, model, path: Union[str, Path]):
def load_checkpoint(cls, path: Union[str, Path], train_hypers) -> "Trainer":

# Load the checkpoint
checkpoint = torch.load(path)
checkpoint = torch.load(path, weights_only=False)
model_hypers = checkpoint["model_hypers"]
model_state_dict = checkpoint["model_state_dict"]
epoch = checkpoint["epoch"]
Expand Down
4 changes: 2 additions & 2 deletions tests/cli/test_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,8 +449,8 @@ def test_model_consistency_with_seed(options, monkeypatch, tmp_path, seed):

train_model(options, output="model2.pt")

m1 = torch.load("model1.ckpt")
m2 = torch.load("model2.ckpt")
m1 = torch.load("model1.ckpt", weights_only=False)
m2 = torch.load("model2.ckpt", weights_only=False)

for i in m1["model_state_dict"]:
tensor1 = m1["model_state_dict"][i]
Expand Down

0 comments on commit f1d0c3c

Please sign in to comment.