Skip to content

Commit

Permalink
Remove rascaline as a dependency of the main package (#139)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Philip Loche <[email protected]>
  • Loading branch information
frostedoyster and PicoCentauri authored Mar 12, 2024
1 parent e42bc29 commit 5909f46
Show file tree
Hide file tree
Showing 17 changed files with 150 additions and 84 deletions.
2 changes: 1 addition & 1 deletion .readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,5 @@ sphinx:
python:
install:
- method: pip
path: .
path: .[soap-bpnn]
- requirements: docs/requirements.txt
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ dependencies = [
"ase",
"torch",
"hydra-core",
"rascaline-torch @ git+https://github.com/luthaf/rascaline#subdirectory=python/rascaline-torch",
"metatensor-operations==0.2.1",
"metatensor-torch==0.3.0",
"metatensor-learn==0.2.1",
Expand Down Expand Up @@ -55,7 +54,9 @@ requires = [
build-backend = "setuptools.build_meta"

[project.optional-dependencies]
soap-bpnn = []
soap-bpnn = [
"rascaline-torch @ git+https://github.com/luthaf/rascaline@ae05064#subdirectory=python/rascaline-torch",
]
alchemical-model = [
"torch_alchemical @ git+https://github.com/abmazitov/torch_alchemical.git@357a01f",
]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import ase
import rascaline.torch
import torch
from metatensor.torch.atomistic import (
MetatensorAtomisticModel,
ModelCapabilities,
ModelEvaluationOptions,
ModelMetadata,
ModelOutput,
systems_to_torch,
)

from metatensor.models.experimental.alchemical_model import DEFAULT_HYPERS, Model
Expand All @@ -31,7 +31,7 @@ def test_prediction_subset():

alchemical_model = Model(capabilities, DEFAULT_HYPERS["model"])
system = ase.Atoms("O2", positions=[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]])
system = rascaline.torch.systems_to_torch(system).to(torch.get_default_dtype())
system = systems_to_torch(system, dtype=torch.get_default_dtype())
system = get_system_with_neighbors_lists(
system, alchemical_model.requested_neighbors_lists()
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import copy

import ase.io
import rascaline.torch
import torch
from metatensor.torch.atomistic import (
MetatensorAtomisticModel,
ModelCapabilities,
ModelEvaluationOptions,
ModelMetadata,
ModelOutput,
systems_to_torch,
)

from metatensor.models.experimental.alchemical_model import DEFAULT_HYPERS, Model
Expand All @@ -35,13 +35,11 @@ def test_rotational_invariance():
system = ase.io.read(DATASET_PATH)
original_system = copy.deepcopy(system)
system.rotate(48, "y")
original_system = rascaline.torch.systems_to_torch(original_system).to(
torch.get_default_dtype()
)
original_system = systems_to_torch(original_system, dtype=torch.get_default_dtype())
original_system = get_system_with_neighbors_lists(
original_system, alchemical_model.requested_neighbors_lists()
)
system = rascaline.torch.systems_to_torch(system).to(torch.get_default_dtype())
system = systems_to_torch(system, dtype=torch.get_default_dtype())
system = get_system_with_neighbors_lists(
system, alchemical_model.requested_neighbors_lists()
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import ase.io
import numpy as np
import rascaline.torch
import torch
from metatensor.learn.data import Dataset
from metatensor.torch.atomistic import (
Expand All @@ -11,6 +10,7 @@
ModelEvaluationOptions,
ModelMetadata,
ModelOutput,
systems_to_torch,
)
from omegaconf import OmegaConf

Expand Down Expand Up @@ -46,8 +46,7 @@ def test_regression_init():
# Predict on the first five systems
systems = ase.io.read(DATASET_PATH, ":5")
systems = [
rascaline.torch.systems_to_torch(system).to(torch.get_default_dtype())
for system in systems
systems_to_torch(system, dtype=torch.get_default_dtype()) for system in systems
]
systems = [
get_system_with_neighbors_lists(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_torchscript():
)
},
)
alchemical_model = Model(capabilities, DEFAULT_HYPERS["model"]).to(torch.float64)
alchemical_model = Model(capabilities, DEFAULT_HYPERS["model"])
torch.jit.script(
alchemical_model, {"energy": alchemical_model.capabilities.outputs["energy"]}
)
Expand All @@ -39,7 +39,7 @@ def test_torchscript_save():
)
},
)
alchemical_model = Model(capabilities, DEFAULT_HYPERS["model"]).to(torch.float64)
alchemical_model = Model(capabilities, DEFAULT_HYPERS["model"])
torch.jit.save(
torch.jit.script(
alchemical_model,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import ase
import rascaline.torch
import torch
from metatensor.torch.atomistic import (
MetatensorAtomisticModel,
ModelCapabilities,
ModelEvaluationOptions,
ModelMetadata,
ModelOutput,
systems_to_torch,
)

from metatensor.models.experimental.pet import DEFAULT_HYPERS, Model
Expand All @@ -29,11 +28,9 @@ def test_prediction_subset():
supported_devices=["cpu"],
)

model = Model(capabilities, DEFAULT_HYPERS["ARCHITECTURAL_HYPERS"]).to(
torch.float64
)
model = Model(capabilities, DEFAULT_HYPERS["ARCHITECTURAL_HYPERS"])
structure = ase.Atoms("O2", positions=[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]])
system = rascaline.torch.systems_to_torch(structure)
system = systems_to_torch(structure)
system = get_system_with_neighbors_lists(system, model.requested_neighbors_lists())

evaluation_options = ModelEvaluationOptions(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import ase
import metatensor.torch
import rascaline.torch
import torch
from metatensor.torch.atomistic import ModelCapabilities, ModelOutput
from metatensor.torch.atomistic import ModelCapabilities, ModelOutput, systems_to_torch

from metatensor.models.experimental.soap_bpnn import DEFAULT_HYPERS, Model

Expand All @@ -26,7 +25,7 @@ def test_prediction_subset_elements():

system = ase.Atoms("O2", positions=[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]])
soap_bpnn(
[rascaline.torch.systems_to_torch(system).to(torch.get_default_dtype())],
[systems_to_torch(system, dtype=torch.get_default_dtype())],
{"energy": soap_bpnn.capabilities.outputs["energy"]},
)

Expand Down Expand Up @@ -56,11 +55,7 @@ def test_prediction_subset_atoms():
)

energy_monomer = soap_bpnn(
[
rascaline.torch.systems_to_torch(system_monomer).to(
torch.get_default_dtype()
)
],
[systems_to_torch(system_monomer).to(torch.get_default_dtype())],
{"energy": soap_bpnn.capabilities.outputs["energy"]},
)

Expand All @@ -82,20 +77,12 @@ def test_prediction_subset_atoms():
)

energy_dimer = soap_bpnn(
[
rascaline.torch.systems_to_torch(system_far_away_dimer).to(
torch.get_default_dtype()
)
],
[systems_to_torch(system_far_away_dimer).to(torch.get_default_dtype())],
{"energy": soap_bpnn.capabilities.outputs["energy"]},
)

energy_monomer_in_dimer = soap_bpnn(
[
rascaline.torch.systems_to_torch(system_far_away_dimer).to(
torch.get_default_dtype()
)
],
[systems_to_torch(system_far_away_dimer).to(torch.get_default_dtype())],
{"energy": soap_bpnn.capabilities.outputs["energy"]},
selected_atoms=selection_labels,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import copy

import ase.io
import rascaline.torch
import torch
from metatensor.torch.atomistic import ModelCapabilities, ModelOutput
from metatensor.torch.atomistic import ModelCapabilities, ModelOutput, systems_to_torch

from metatensor.models.experimental.soap_bpnn import DEFAULT_HYPERS, Model

Expand All @@ -23,18 +22,18 @@ def test_rotational_invariance():
)
},
)
soap_bpnn = Model(capabilities, DEFAULT_HYPERS["model"]).to(torch.float64)
soap_bpnn = Model(capabilities, DEFAULT_HYPERS["model"])

system = ase.io.read(DATASET_PATH)
original_system = copy.deepcopy(system)
system.rotate(48, "y")

original_output = soap_bpnn(
[rascaline.torch.systems_to_torch(original_system)],
[systems_to_torch(original_system)],
{"energy": soap_bpnn.capabilities.outputs["energy"]},
)
rotated_output = soap_bpnn(
[rascaline.torch.systems_to_torch(system)],
[systems_to_torch(system)],
{"energy": soap_bpnn.capabilities.outputs["energy"]},
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@

import ase.io
import numpy as np
import rascaline.torch
import torch
from metatensor.learn.data import Dataset
from metatensor.torch.atomistic import ModelCapabilities, ModelOutput
from metatensor.torch.atomistic import ModelCapabilities, ModelOutput, systems_to_torch
from omegaconf import OmegaConf

from metatensor.models.experimental.soap_bpnn import DEFAULT_HYPERS, Model, train
Expand Down Expand Up @@ -41,7 +40,7 @@ def test_regression_init():

output = soap_bpnn(
[
rascaline.torch.systems_to_torch(system).to(torch.get_default_dtype())
systems_to_torch(system, dtype=torch.get_default_dtype())
for system in systems
],
{"U0": soap_bpnn.capabilities.outputs["U0"]},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_torchscript():
)
},
)
soap_bpnn = Model(capabilities, DEFAULT_HYPERS["model"]).to(torch.float64)
soap_bpnn = Model(capabilities, DEFAULT_HYPERS["model"])
torch.jit.script(soap_bpnn, {"energy": soap_bpnn.capabilities.outputs["energy"]})


Expand All @@ -34,7 +34,7 @@ def test_torchscript_save():
)
},
)
soap_bpnn = Model(capabilities, DEFAULT_HYPERS["model"]).to(torch.float64)
soap_bpnn = Model(capabilities, DEFAULT_HYPERS["model"])
torch.jit.save(
torch.jit.script(
soap_bpnn, {"energy": soap_bpnn.capabilities.outputs["energy"]}
Expand Down
44 changes: 19 additions & 25 deletions src/metatensor/models/utils/composition.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import List, Tuple
from typing import List, Tuple, Union

import rascaline.torch
import torch
from metatensor.learn.data.dataset import _BaseDataset
from metatensor.torch import Labels, TensorBlock, TensorMap
Expand All @@ -9,7 +8,7 @@


def calculate_composition_weights(
datasets: _BaseDataset, property: str
datasets: Union[_BaseDataset, List[_BaseDataset]], property: str
) -> Tuple[torch.Tensor, List[int]]:
"""Calculate the composition weights for a dataset.
Expand All @@ -19,34 +18,33 @@ def calculate_composition_weights(
:returns: Composition weights for the dataset, as well as the
list of species that the weights correspond to.
"""
if not isinstance(datasets, list):
datasets = [datasets]

species = get_all_species(datasets)
# note that this is sorted, and the composition weights are sorted
# as well, because the species are sorted in the composition features

# Get the target for each system in the dataset
# TODO: the dataset will be iterable once metatensor PR #500 merged.
targets = torch.stack(
[
dataset[sample_id]._asdict()[property].block().values
sample._asdict()[property].block().values
for dataset in datasets
for sample_id in range(len(dataset))
for sample in dataset
]
)
targets = targets.squeeze(dim=(1, 2)) # remove component and property dimensions

# Get the composition for each system in the dataset
composition_calculator = rascaline.torch.AtomicComposition(per_system=True)
# TODO: the dataset will be iterable once metatensor PR #500 merged.
composition_features = composition_calculator.compute(
[
dataset[sample_id]._asdict()["system"]
for dataset in datasets
for sample_id in range(len(dataset))
]
)
composition_features = composition_features.keys_to_properties("center_type")
composition_features = composition_features.block().values
structure_list = [
sample._asdict()["system"] for dataset in datasets for sample in dataset
]

targets = targets.squeeze(dim=(1, 2)) # remove component and property dimensions
dtype = structure_list[0].positions.dtype
composition_features = torch.empty((len(structure_list), len(species)), dtype=dtype)
for i, structure in enumerate(structure_list):
for j, s in enumerate(species):
composition_features[i, j] = torch.sum(structure.types == s)

regularizer = 1e-20

while regularizer:
if regularizer > 1e5:
raise RuntimeError(
Expand All @@ -69,10 +67,6 @@ def calculate_composition_weights(
except torch._C._LinAlgError:
regularizer *= 10.0

species = get_all_species(datasets)
# note that this is sorted, and the composition weights are sorted
# as well, because the species are sorted in the composition features

return solution, species


Expand Down
2 changes: 1 addition & 1 deletion src/metatensor/models/utils/neighbors_lists.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List

import ase
import ase.neighborlist
import torch
from metatensor.torch import Labels, TensorBlock
from metatensor.torch.atomistic import (
Expand Down
Loading

0 comments on commit 5909f46

Please sign in to comment.