diff --git a/src/metatrain/experimental/alchemical_model/model.py b/src/metatrain/experimental/alchemical_model/model.py index 21b622a2..52b9af5f 100644 --- a/src/metatrain/experimental/alchemical_model/model.py +++ b/src/metatrain/experimental/alchemical_model/model.py @@ -129,7 +129,7 @@ def forward( def load_checkpoint(cls, path: Union[str, Path]) -> "AlchemicalModel": # Load the checkpoint - checkpoint = torch.load(path) + checkpoint = torch.load(path, weights_only=False) model_hypers = checkpoint["model_hypers"] model_state_dict = checkpoint["model_state_dict"] diff --git a/src/metatrain/experimental/alchemical_model/trainer.py b/src/metatrain/experimental/alchemical_model/trainer.py index 9ff96599..2ff0693d 100644 --- a/src/metatrain/experimental/alchemical_model/trainer.py +++ b/src/metatrain/experimental/alchemical_model/trainer.py @@ -349,7 +349,7 @@ def save_checkpoint(self, model, path: Union[str, Path]): def load_checkpoint(cls, path: Union[str, Path], train_hypers) -> "Trainer": # Load the checkpoint - checkpoint = torch.load(path) + checkpoint = torch.load(path, weights_only=False) model_hypers = checkpoint["model_hypers"] model_state_dict = checkpoint["model_state_dict"] epoch = checkpoint["epoch"] diff --git a/src/metatrain/experimental/pet/model.py b/src/metatrain/experimental/pet/model.py index 0ff1ad0e..ee5202bd 100644 --- a/src/metatrain/experimental/pet/model.py +++ b/src/metatrain/experimental/pet/model.py @@ -114,14 +114,14 @@ def forward( @classmethod def load_checkpoint(cls, path: Union[str, Path]) -> "PET": - checkpoint = torch.load(path) + checkpoint = torch.load(path, weights_only=False) hypers = checkpoint["hypers"] dataset_info = checkpoint["dataset_info"] model = cls( model_hypers=hypers["ARCHITECTURAL_HYPERS"], dataset_info=dataset_info ) - checkpoint = torch.load(path) + checkpoint = torch.load(path, weights_only=False) state_dict = checkpoint["checkpoint"]["model_state_dict"] ARCHITECTURAL_HYPERS = Hypers(model.hypers) diff --git a/src/metatrain/experimental/pet/trainer.py b/src/metatrain/experimental/pet/trainer.py index b36551aa..f9c6ef67 100644 --- a/src/metatrain/experimental/pet/trainer.py +++ b/src/metatrain/experimental/pet/trainer.py @@ -163,7 +163,7 @@ def train( else: load_path = self.pet_dir / "best_val_rmse_energies_model_state_dict" - state_dict = torch.load(load_path) + state_dict = torch.load(load_path, weights_only=False) ARCHITECTURAL_HYPERS = Hypers(model.hypers) raw_pet = PET(ARCHITECTURAL_HYPERS, 0.0, len(model.atomic_types)) @@ -186,7 +186,7 @@ def save_checkpoint(self, model, path: Union[str, Path]): # together with the hypers inside a file that will act as a metatrain # checkpoint checkpoint_path = self.pet_dir / "checkpoint" # type: ignore - checkpoint = torch.load(checkpoint_path) + checkpoint = torch.load(checkpoint_path, weights_only=False) torch.save( { "checkpoint": checkpoint, @@ -204,7 +204,7 @@ def load_checkpoint(cls, path: Union[str, Path], train_hypers) -> "Trainer": # This function loads a metatrain PET checkpoint and returns a Trainer # instance with the hypers, while also saving the checkpoint in the # class - checkpoint = torch.load(path) + checkpoint = torch.load(path, weights_only=False) trainer = cls(train_hypers) trainer.pet_checkpoint = checkpoint["checkpoint"] return trainer diff --git a/src/metatrain/experimental/soap_bpnn/model.py b/src/metatrain/experimental/soap_bpnn/model.py index a2be10d1..7f144d27 100644 --- a/src/metatrain/experimental/soap_bpnn/model.py +++ b/src/metatrain/experimental/soap_bpnn/model.py @@ -287,7 +287,7 @@ def forward( def load_checkpoint(cls, path: Union[str, Path]) -> "SoapBpnn": # Load the checkpoint - checkpoint = torch.load(path) + checkpoint = torch.load(path, weights_only=False) model_hypers = checkpoint["model_hypers"] model_state_dict = checkpoint["model_state_dict"] diff --git a/src/metatrain/experimental/soap_bpnn/trainer.py b/src/metatrain/experimental/soap_bpnn/trainer.py index 2a2967fa..4a342b15 100644 --- a/src/metatrain/experimental/soap_bpnn/trainer.py +++ b/src/metatrain/experimental/soap_bpnn/trainer.py @@ -31,13 +31,6 @@ logger = logging.getLogger(__name__) -# Filter out the second derivative and device warnings from rascaline-torch -warnings.filterwarnings("ignore", category=UserWarning, message="second derivative") -warnings.filterwarnings( - "ignore", category=UserWarning, message="Systems data is on device" -) - - class Trainer: def __init__(self, train_hypers): self.hypers = train_hypers @@ -54,6 +47,17 @@ def train( val_datasets: List[Union[Dataset, torch.utils.data.Subset]], checkpoint_dir: str, ): + # Filter out the second derivative and device warnings from rascaline + warnings.filterwarnings(action="ignore", message="Systems data is on device") + warnings.filterwarnings( + action="ignore", + message="second derivatives with respect to positions are not implemented", + ) + warnings.filterwarnings( + action="ignore", + message="second derivatives with respect to cell matrix", + ) + assert dtype in SoapBpnn.__supported_dtypes__ is_distributed = self.hypers["distributed"] @@ -290,6 +294,7 @@ def train( targets = average_by_num_atoms(targets, systems, per_structure_targets) train_loss_batch = loss_fn(predictions, targets) + train_loss_batch.backward() optimizer.step() @@ -409,7 +414,7 @@ def save_checkpoint(self, model, path: Union[str, Path]): def load_checkpoint(cls, path: Union[str, Path], train_hypers) -> "Trainer": # Load the checkpoint - checkpoint = torch.load(path) + checkpoint = torch.load(path, weights_only=False) model_hypers = checkpoint["model_hypers"] model_state_dict = checkpoint["model_state_dict"] epoch = checkpoint["epoch"] diff --git a/tests/cli/test_train_model.py b/tests/cli/test_train_model.py index 85455933..acecceea 100644 --- a/tests/cli/test_train_model.py +++ b/tests/cli/test_train_model.py @@ -449,8 +449,8 @@ def test_model_consistency_with_seed(options, monkeypatch, tmp_path, seed): train_model(options, output="model2.pt") - m1 = torch.load("model1.ckpt") - m2 = torch.load("model2.ckpt") + m1 = torch.load("model1.ckpt", weights_only=False) + m2 = torch.load("model2.ckpt", weights_only=False) for i in m1["model_state_dict"]: tensor1 = m1["model_state_dict"][i]