From 2a5c945304ff6337a5f132223ba3d02924f43a18 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Wed, 13 Mar 2024 18:07:55 +0100 Subject: [PATCH 1/4] Disable rascaline logger during eval as well --- src/metatensor/models/utils/model_io.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/metatensor/models/utils/model_io.py b/src/metatensor/models/utils/model_io.py index 94c38ce1c..de0cadaf5 100644 --- a/src/metatensor/models/utils/model_io.py +++ b/src/metatensor/models/utils/model_io.py @@ -16,7 +16,10 @@ pass try: + import rascaline import rascaline.torch # noqa: F401 + + rascaline.set_logging_callback(lambda x, y: None) except ImportError: pass From 02d8650d7ed5e79f4e21e0720e64f74a282c0586 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Thu, 14 Mar 2024 13:51:52 +0100 Subject: [PATCH 2/4] Only suppress rascaline logging for INFO level --- src/metatensor/models/experimental/soap_bpnn/train.py | 10 ++++++++-- src/metatensor/models/utils/io.py | 8 +++++++- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/metatensor/models/experimental/soap_bpnn/train.py b/src/metatensor/models/experimental/soap_bpnn/train.py index b7f47df83..fc69242dd 100644 --- a/src/metatensor/models/experimental/soap_bpnn/train.py +++ b/src/metatensor/models/experimental/soap_bpnn/train.py @@ -8,6 +8,7 @@ from metatensor.learn.data import DataLoader from metatensor.learn.data.dataset import _BaseDataset from metatensor.torch.atomistic import ModelCapabilities +from rascaline.log import RASCAL_LOG_LEVEL_ERROR, RASCAL_LOG_LEVEL_WARN from ...utils.composition import calculate_composition_weights from ...utils.compute_loss import compute_model_loss @@ -30,8 +31,13 @@ logger = logging.getLogger(__name__) -# disable rascaline logger -rascaline.set_logging_callback(lambda x, y: None) +# disable rascaline logger for info messages +def rascaline_logging(level, message): + if level in [RASCAL_LOG_LEVEL_WARN, RASCAL_LOG_LEVEL_ERROR]: + rascaline.log.default_logging_callback(level, message) + + +rascaline.set_logging_callback(rascaline_logging) # Filter out the second derivative and device warnings from rascaline-torch warnings.filterwarnings("ignore", category=UserWarning, message="second derivative") diff --git a/src/metatensor/models/utils/io.py b/src/metatensor/models/utils/io.py index 5127ea76a..828d01942 100644 --- a/src/metatensor/models/utils/io.py +++ b/src/metatensor/models/utils/io.py @@ -25,8 +25,14 @@ try: import rascaline import rascaline.torch # noqa: F401 + from rascaline import RASCAL_LOG_LEVEL_ERROR, RASCAL_LOG_LEVEL_WARN - rascaline.set_logging_callback(lambda x, y: None) + # disable rascaline logger for info messages + def rascaline_logging(level, message): + if level in [RASCAL_LOG_LEVEL_WARN, RASCAL_LOG_LEVEL_ERROR]: + rascaline.log.default_logging_callback(level, message) + + rascaline.set_logging_callback(rascaline_logging) except ImportError: pass From b67e44af44a9f31c2823bfb1dbba903d6151f8a2 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Fri, 15 Mar 2024 11:46:33 +0100 Subject: [PATCH 3/4] Re-allow rascaline logging --- src/metatensor/models/experimental/soap_bpnn/train.py | 10 ---------- src/metatensor/models/utils/io.py | 9 --------- 2 files changed, 19 deletions(-) diff --git a/src/metatensor/models/experimental/soap_bpnn/train.py b/src/metatensor/models/experimental/soap_bpnn/train.py index fc69242dd..a3bbefffb 100644 --- a/src/metatensor/models/experimental/soap_bpnn/train.py +++ b/src/metatensor/models/experimental/soap_bpnn/train.py @@ -3,12 +3,10 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple, Union -import rascaline import torch from metatensor.learn.data import DataLoader from metatensor.learn.data.dataset import _BaseDataset from metatensor.torch.atomistic import ModelCapabilities -from rascaline.log import RASCAL_LOG_LEVEL_ERROR, RASCAL_LOG_LEVEL_WARN from ...utils.composition import calculate_composition_weights from ...utils.compute_loss import compute_model_loss @@ -31,14 +29,6 @@ logger = logging.getLogger(__name__) -# disable rascaline logger for info messages -def rascaline_logging(level, message): - if level in [RASCAL_LOG_LEVEL_WARN, RASCAL_LOG_LEVEL_ERROR]: - rascaline.log.default_logging_callback(level, message) - - -rascaline.set_logging_callback(rascaline_logging) - # Filter out the second derivative and device warnings from rascaline-torch warnings.filterwarnings("ignore", category=UserWarning, message="second derivative") warnings.filterwarnings( diff --git a/src/metatensor/models/utils/io.py b/src/metatensor/models/utils/io.py index 828d01942..a478f77bd 100644 --- a/src/metatensor/models/utils/io.py +++ b/src/metatensor/models/utils/io.py @@ -23,16 +23,7 @@ pass try: - import rascaline import rascaline.torch # noqa: F401 - from rascaline import RASCAL_LOG_LEVEL_ERROR, RASCAL_LOG_LEVEL_WARN - - # disable rascaline logger for info messages - def rascaline_logging(level, message): - if level in [RASCAL_LOG_LEVEL_WARN, RASCAL_LOG_LEVEL_ERROR]: - rascaline.log.default_logging_callback(level, message) - - rascaline.set_logging_callback(rascaline_logging) except ImportError: pass From 96ba5188775c05b63d7a37c8c199a379548f8d45 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Fri, 15 Mar 2024 14:18:07 +0100 Subject: [PATCH 4/4] Update rascaline pin --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index fd797745b..6d33b0891 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,7 +55,7 @@ build-backend = "setuptools.build_meta" [project.optional-dependencies] soap-bpnn = [ - "rascaline-torch @ git+https://github.com/luthaf/rascaline@ae05064#subdirectory=python/rascaline-torch", + "rascaline-torch @ git+https://github.com/luthaf/rascaline@211511f#subdirectory=python/rascaline-torch", ] alchemical-model = [ "torch_alchemical @ git+https://github.com/abmazitov/torch_alchemical.git@357a01f",