Skip to content

Commit

Permalink
move hyper consts to unified places
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri committed Dec 10, 2023
1 parent 07d1f23 commit dc8bb6f
Show file tree
Hide file tree
Showing 8 changed files with 32 additions and 38 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dependencies = [
"ase",
"torch",
"hydra-core",
"rascaline-torch @ git+https://github.com/luthaf/rascaline#subdirectory=python/rascaline-torch",
#"rascaline-torch @ git+https://github.com/luthaf/rascaline#subdirectory=python/rascaline-torch",
"metatensor-core",
"metatensor-operations",
"metatensor-torch",
Expand Down
14 changes: 2 additions & 12 deletions src/metatensor/models/soap_bpnn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,2 @@
from .model import Model # noqa: F401
from .train import train # noqa: F401

from metatensor.models import ARCHITECTURE_CONFIG_PATH
from omegaconf import OmegaConf

DEAFAULT_HYPERS = OmegaConf.to_container(
OmegaConf.load(ARCHITECTURE_CONFIG_PATH / "soap_bpnn.yaml")
)

DEFAULT_MODEL_HYPERS = DEAFAULT_HYPERS["model"]
DEFAULT_TRAIN_HYPERS = DEAFAULT_HYPERS["train"]
from .model import Model, DEFAULT_MODEL_HYPERS # noqa: F401
from .train import train, DEFAULT_TRAINING_HYPERS # noqa: F401
9 changes: 8 additions & 1 deletion src/metatensor/models/soap_bpnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,18 @@
import rascaline.torch
import torch
from metatensor.torch import Labels, TensorBlock, TensorMap
from omegaconf import OmegaConf

from . import DEFAULT_MODEL_HYPERS
from .. import ARCHITECTURE_CONFIG_PATH
from ..utils.composition import apply_composition_contribution


DEAFAULT_HYPERS = OmegaConf.to_container(
OmegaConf.load(ARCHITECTURE_CONFIG_PATH / "soap_bpnn.yaml")
)

DEFAULT_MODEL_HYPERS = DEAFAULT_HYPERS["model"]

ARCHITECTURE_NAME = "soap_bpnn"


Expand Down
6 changes: 2 additions & 4 deletions src/metatensor/models/soap_bpnn/tests/test_functionality.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,15 @@
import rascaline.torch
import torch

from metatensor.models.soap_bpnn import Model

from . import DEAFAULT_HYPERS
from metatensor.models.soap_bpnn import DEFAULT_MODEL_HYPERS, Model


def test_prediction_subset():
"""Tests that the model can predict on a subset
of the elements it was trained on."""

all_species = [1, 6, 7, 8]
soap_bpnn = Model(all_species, DEAFAULT_HYPERS["model"]).to(torch.float64)
soap_bpnn = Model(all_species, DEFAULT_MODEL_HYPERS).to(torch.float64)

structure = ase.Atoms("O2", positions=[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]])
soap_bpnn([rascaline.torch.systems_to_torch(structure)])
6 changes: 3 additions & 3 deletions src/metatensor/models/soap_bpnn/tests/test_invariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
import rascaline.torch
import torch

from metatensor.models.soap_bpnn import Model
from metatensor.models.soap_bpnn import DEFAULT_MODEL_HYPERS, Model

from . import DATASET_PATH, DEAFAULT_HYPERS
from . import DATASET_PATH


def test_rotational_invariance():
"""Tests that the model is rotationally invariant."""

all_species = [1, 6, 7, 8]
soap_bpnn = Model(all_species, DEAFAULT_HYPERS["model"]).to(torch.float64)
soap_bpnn = Model(all_species, DEFAULT_MODEL_HYPERS).to(torch.float64)

structure = ase.io.read(DATASET_PATH)
original_structure = copy.deepcopy(structure)
Expand Down
19 changes: 12 additions & 7 deletions src/metatensor/models/soap_bpnn/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@
import rascaline.torch
import torch

from metatensor.models.soap_bpnn import Model, train
from metatensor.models.soap_bpnn import (
DEFAULT_MODEL_HYPERS,
DEFAULT_TRAINING_HYPERS,
Model,
train,
)
from metatensor.models.utils.data import Dataset
from metatensor.models.utils.data.readers import read_structures, read_targets

from . import DATASET_PATH, DEAFAULT_HYPERS
from . import DATASET_PATH


torch.manual_seed(0)
Expand All @@ -16,7 +21,7 @@ def test_regression_init():
"""Perform a regression test on the model at initialization"""

all_species = [1, 6, 7, 8]
soap_bpnn = Model(all_species, DEAFAULT_HYPERS["model"]).to(torch.float64)
soap_bpnn = Model(all_species, DEFAULT_MODEL_HYPERS).to(torch.float64)

# Predict on the first fivestructures
structures = ase.io.read(DATASET_PATH, ":5")
Expand All @@ -42,11 +47,11 @@ def test_regression_train():
targets = read_targets(DATASET_PATH, "U0")

dataset = Dataset(structures, targets)
soap_bpnn = Model(dataset.all_species, DEAFAULT_HYPERS["model"]).to(torch.float64)
soap_bpnn = Model(dataset.all_species, DEFAULT_MODEL_HYPERS).to(torch.float64)

hypers = DEAFAULT_HYPERS.copy()
hypers["training"]["num_epochs"] = 2
train(soap_bpnn, dataset, hypers["training"])
hypers = DEFAULT_TRAINING_HYPERS.copy()
hypers["num_epochs"] = 2
train(soap_bpnn, dataset, hypers)

# Predict on the first five structures
output = soap_bpnn(structures[:5])
Expand Down
6 changes: 2 additions & 4 deletions src/metatensor/models/soap_bpnn/tests/test_torchscript.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import torch

from metatensor.models.soap_bpnn import Model

from . import DEAFAULT_HYPERS
from metatensor.models.soap_bpnn import DEFAULT_MODEL_HYPERS, Model


def test_torchscript():
"""Tests that the model can be jitted."""

all_species = [1, 6, 7, 8]
soap_bpnn = Model(all_species, DEAFAULT_HYPERS["model"]).to(torch.float64)
soap_bpnn = Model(all_species, DEFAULT_MODEL_HYPERS).to(torch.float64)
torch.jit.script(soap_bpnn)
8 changes: 2 additions & 6 deletions src/metatensor/models/soap_bpnn/train.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
import logging

import torch
from omegaconf import OmegaConf

from metatensor.models import ARCHITECTURE_CONFIG_PATH

from ..utils.composition import calculate_composition_weights
from ..utils.data import collate_fn
from ..utils.model_io import save_model
from .model import DEAFAULT_HYPERS


DEFAULT_TRAINING_HYPERS = OmegaConf.to_container(
OmegaConf.load(ARCHITECTURE_CONFIG_PATH / "soap_bpnn.yaml")
)["training"]
DEFAULT_TRAINING_HYPERS = DEAFAULT_HYPERS["training"]

logger = logging.getLogger(__name__)

Expand Down

0 comments on commit dc8bb6f

Please sign in to comment.