Skip to content

Commit

Permalink
Remove the rascaline.torch dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Mar 9, 2024
1 parent 0221fe7 commit 53ddd41
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 31 deletions.
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ dependencies = [
"ase",
"torch",
"hydra-core",
"rascaline-torch @ git+https://github.com/luthaf/rascaline#subdirectory=python/rascaline-torch",
"metatensor-core",
"metatensor-operations",
"metatensor-torch",
Expand Down Expand Up @@ -56,7 +55,9 @@ requires = [
build-backend = "setuptools.build_meta"

[project.optional-dependencies]
soap-bpnn = []
soap-bpnn = [
"rascaline-torch @ git+https://github.com/luthaf/rascaline#subdirectory=python/rascaline-torch",
]
alchemical-model = [
"torch_alchemical @ git+https://github.com/abmazitov/torch_alchemical.git@fafb0bd",
]
Expand Down
41 changes: 20 additions & 21 deletions src/metatensor/models/utils/composition.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import List, Tuple
from typing import List, Tuple, Union

import rascaline.torch
import torch
from metatensor.learn.data.dataset import _BaseDataset
from metatensor.torch import Labels, TensorBlock, TensorMap
Expand All @@ -9,7 +8,7 @@


def calculate_composition_weights(
datasets: _BaseDataset, property: str
datasets: Union[_BaseDataset, List[_BaseDataset]], property: str
) -> Tuple[torch.Tensor, List[int]]:
"""Calculate the composition weights for a dataset.
Expand All @@ -19,6 +18,12 @@ def calculate_composition_weights(
:returns: Composition weights for the dataset, as well as the
list of species that the weights correspond to.
"""
if not isinstance(datasets, list):
datasets = [datasets]

species = get_all_species(datasets)
# note that this is sorted, and the composition weights are sorted
# as well, because the species are sorted in the composition features

# Get the target for each system in the dataset
# TODO: the dataset will be iterable once metatensor PR #500 merged.
Expand All @@ -29,24 +34,22 @@ def calculate_composition_weights(
for sample_id in range(len(dataset))
]
)
targets = targets.squeeze(dim=(1, 2)) # remove component and property dimensions

# Get the composition for each system in the dataset
composition_calculator = rascaline.torch.AtomicComposition(per_system=True)
# TODO: the dataset will be iterable once metatensor PR #500 merged.
composition_features = composition_calculator.compute(
[
dataset[sample_id]._asdict()["system"]
for dataset in datasets
for sample_id in range(len(dataset))
]
)
composition_features = composition_features.keys_to_properties("center_type")
composition_features = composition_features.block().values

targets = targets.squeeze(dim=(1, 2)) # remove component and property dimensions
structure_list = [
dataset[sample_id]._asdict()["system"]
for dataset in datasets
for sample_id in range(len(dataset))
]

dtype = structure_list[0].positions.dtype
composition_features = torch.empty((len(structure_list), len(species)), dtype=dtype)
for i, structure in enumerate(structure_list):
for j, s in enumerate(species):
composition_features[i, j] = torch.sum(structure.types == s)

regularizer = 1e-20

while regularizer:
if regularizer > 1e5:
raise RuntimeError(
Expand All @@ -69,10 +72,6 @@ def calculate_composition_weights(
except torch._C._LinAlgError:
regularizer *= 10.0

species = get_all_species(datasets)
# note that this is sorted, and the composition weights are sorted
# as well, because the species are sorted in the composition features

return solution, species


Expand Down
71 changes: 71 additions & 0 deletions tests/utils/test_composition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from pathlib import Path

import torch
from metatensor.learn import Dataset
from metatensor.torch import Labels, TensorBlock, TensorMap
from metatensor.torch.atomistic import System

from metatensor.models.utils.composition import calculate_composition_weights


RESOURCES_PATH = Path(__file__).parent.resolve() / ".." / "resources"


def test_calculate_composition_weights():
"""Test the calculation of composition weights."""

# Here we use three synthetic structures:
# - O atom, with an energy of 1.0
# - H2O molecule, with an energy of 5.0
# - H4O2 molecule, with an energy of 10.0
# The expected composition weights are 2.0 for H and 1.0 for O.

systems = [
System(
positions=torch.tensor([[0.0, 0.0, 0.0]]),
types=torch.tensor([8]),
cell=torch.eye(3),
),
System(
positions=torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]),
types=torch.tensor([1, 1, 8]),
cell=torch.eye(3),
),
System(
positions=torch.tensor(
[
[0.0, 0.0, 0.0],
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
[1.0, 0.0, 1.0],
[0.0, 1.0, 1.0],
]
),
types=torch.tensor([1, 1, 8, 1, 1, 8]),
cell=torch.eye(3),
),
]
energies = [1.0, 5.0, 10.0]
energies = [
TensorMap(
keys=Labels(names=["_"], values=torch.tensor([[0]])),
blocks=[
TensorBlock(
values=torch.tensor([[e]]),
samples=Labels(names=["system"], values=torch.tensor([[i]])),
components=[],
properties=Labels(names=["energy"], values=torch.tensor([[0]])),
)
],
)
for i, e in enumerate(energies)
]
dataset = Dataset(system=systems, energy=energies)

weights, species = calculate_composition_weights(dataset, "energy")

assert len(weights) == len(species)
assert len(weights) == 2
assert species == [1, 8]
assert torch.allclose(weights, torch.tensor([2.0, 1.0]))
10 changes: 6 additions & 4 deletions tests/utils/test_model_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import metatensor.torch
import pytest
import rascaline.torch
import torch
from metatensor.torch.atomistic import ModelCapabilities, ModelOutput

from metatensor.models.experimental import soap_bpnn
Expand Down Expand Up @@ -35,18 +35,20 @@ def test_save_load_checkpoint(monkeypatch, tmp_path):
)

model = soap_bpnn.Model(capabilities)
systems = read_systems(RESOURCES_PATH / "qm9_reduced_100.xyz")
systems = read_systems(
RESOURCES_PATH / "qm9_reduced_100.xyz", dtype=torch.get_default_dtype()
)

output_before_save = model(
rascaline.torch.systems_to_torch(systems),
systems,
{"energy": model.capabilities.outputs["energy"]},
)

save_model(model, "test_model.ckpt")
loaded_model = load_checkpoint("test_model.ckpt")

output_after_load = loaded_model(
rascaline.torch.systems_to_torch(systems),
systems,
{"energy": model.capabilities.outputs["energy"]},
)

Expand Down
24 changes: 20 additions & 4 deletions tests/utils/test_output_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

import metatensor.torch
import pytest
import rascaline.torch
import torch
from metatensor.torch.atomistic import ModelCapabilities, ModelOutput
from metatensor.torch.atomistic import ModelCapabilities, ModelOutput, System

from metatensor.models.experimental import soap_bpnn
from metatensor.models.utils.data import read_systems
Expand Down Expand Up @@ -33,7 +32,14 @@ def test_forces(is_training):
systems = read_systems(
RESOURCES_PATH / "qm9_reduced_100.xyz", dtype=torch.get_default_dtype()
)[:5]
systems = rascaline.torch.systems_to_torch(systems, positions_requires_grad=True)
systems = [
System(
positions=system.positions.requires_grad_(True),
cell=system.cell,
types=system.types,
)
for system in systems
]
output = model(systems, {"energy": model.capabilities.outputs["energy"]})
position_gradients = compute_gradient(
output["energy"].block().values,
Expand All @@ -43,7 +49,17 @@ def test_forces(is_training):
forces = [-position_gradient for position_gradient in position_gradients]

jitted_model = torch.jit.script(model)
systems = rascaline.torch.systems_to_torch(systems, positions_requires_grad=True)
systems = read_systems(
RESOURCES_PATH / "qm9_reduced_100.xyz", dtype=torch.get_default_dtype()
)[:5]
systems = [
System(
positions=system.positions.requires_grad_(True),
cell=system.cell,
types=system.types,
)
for system in systems
]
output = jitted_model(systems, {"energy": model.capabilities.outputs["energy"]})
jitted_position_gradients = compute_gradient(
output["energy"].block().values,
Expand Down

0 comments on commit 53ddd41

Please sign in to comment.