diff --git a/pyproject.toml b/pyproject.toml index fd797745b..6d33b0891 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,7 +55,7 @@ build-backend = "setuptools.build_meta" [project.optional-dependencies] soap-bpnn = [ - "rascaline-torch @ git+https://github.com/luthaf/rascaline@ae05064#subdirectory=python/rascaline-torch", + "rascaline-torch @ git+https://github.com/luthaf/rascaline@211511f#subdirectory=python/rascaline-torch", ] alchemical-model = [ "torch_alchemical @ git+https://github.com/abmazitov/torch_alchemical.git@357a01f", diff --git a/src/metatensor/models/experimental/soap_bpnn/train.py b/src/metatensor/models/experimental/soap_bpnn/train.py index b7f47df83..a3bbefffb 100644 --- a/src/metatensor/models/experimental/soap_bpnn/train.py +++ b/src/metatensor/models/experimental/soap_bpnn/train.py @@ -3,7 +3,6 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple, Union -import rascaline import torch from metatensor.learn.data import DataLoader from metatensor.learn.data.dataset import _BaseDataset @@ -30,9 +29,6 @@ logger = logging.getLogger(__name__) -# disable rascaline logger -rascaline.set_logging_callback(lambda x, y: None) - # Filter out the second derivative and device warnings from rascaline-torch warnings.filterwarnings("ignore", category=UserWarning, message="second derivative") warnings.filterwarnings( diff --git a/src/metatensor/models/utils/io.py b/src/metatensor/models/utils/io.py index 830709230..a478f77bd 100644 --- a/src/metatensor/models/utils/io.py +++ b/src/metatensor/models/utils/io.py @@ -4,6 +4,7 @@ from pathlib import Path from typing import Any, Union +import metatensor.torch import torch from metatensor.torch.atomistic import ( MetatensorAtomisticModel, @@ -11,8 +12,6 @@ ModelMetadata, ) -import metatensor - # This import is necessary to avoid errors when loading an # exported alchemical model, which depends on sphericart-torch.