From 53ddd41b6b0a44dff22c94fda1f326fa44b8726a Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Sat, 9 Mar 2024 05:47:26 +0100 Subject: [PATCH] Remove the rascaline.torch dependency --- pyproject.toml | 5 +- src/metatensor/models/utils/composition.py | 41 ++++++------- tests/utils/test_composition.py | 71 ++++++++++++++++++++++ tests/utils/test_model_io.py | 10 +-- tests/utils/test_output_gradient.py | 24 ++++++-- 5 files changed, 120 insertions(+), 31 deletions(-) create mode 100644 tests/utils/test_composition.py diff --git a/pyproject.toml b/pyproject.toml index cf6081df3..4be3709d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,6 @@ dependencies = [ "ase", "torch", "hydra-core", - "rascaline-torch @ git+https://github.com/luthaf/rascaline#subdirectory=python/rascaline-torch", "metatensor-core", "metatensor-operations", "metatensor-torch", @@ -56,7 +55,9 @@ requires = [ build-backend = "setuptools.build_meta" [project.optional-dependencies] -soap-bpnn = [] +soap-bpnn = [ + "rascaline-torch @ git+https://github.com/luthaf/rascaline#subdirectory=python/rascaline-torch", +] alchemical-model = [ "torch_alchemical @ git+https://github.com/abmazitov/torch_alchemical.git@fafb0bd", ] diff --git a/src/metatensor/models/utils/composition.py b/src/metatensor/models/utils/composition.py index 50682a3c3..fe3d87289 100644 --- a/src/metatensor/models/utils/composition.py +++ b/src/metatensor/models/utils/composition.py @@ -1,6 +1,5 @@ -from typing import List, Tuple +from typing import List, Tuple, Union -import rascaline.torch import torch from metatensor.learn.data.dataset import _BaseDataset from metatensor.torch import Labels, TensorBlock, TensorMap @@ -9,7 +8,7 @@ def calculate_composition_weights( - datasets: _BaseDataset, property: str + datasets: Union[_BaseDataset, List[_BaseDataset]], property: str ) -> Tuple[torch.Tensor, List[int]]: """Calculate the composition weights for a dataset. @@ -19,6 +18,12 @@ def calculate_composition_weights( :returns: Composition weights for the dataset, as well as the list of species that the weights correspond to. """ + if not isinstance(datasets, list): + datasets = [datasets] + + species = get_all_species(datasets) + # note that this is sorted, and the composition weights are sorted + # as well, because the species are sorted in the composition features # Get the target for each system in the dataset # TODO: the dataset will be iterable once metatensor PR #500 merged. @@ -29,24 +34,22 @@ def calculate_composition_weights( for sample_id in range(len(dataset)) ] ) + targets = targets.squeeze(dim=(1, 2)) # remove component and property dimensions - # Get the composition for each system in the dataset - composition_calculator = rascaline.torch.AtomicComposition(per_system=True) # TODO: the dataset will be iterable once metatensor PR #500 merged. - composition_features = composition_calculator.compute( - [ - dataset[sample_id]._asdict()["system"] - for dataset in datasets - for sample_id in range(len(dataset)) - ] - ) - composition_features = composition_features.keys_to_properties("center_type") - composition_features = composition_features.block().values - - targets = targets.squeeze(dim=(1, 2)) # remove component and property dimensions + structure_list = [ + dataset[sample_id]._asdict()["system"] + for dataset in datasets + for sample_id in range(len(dataset)) + ] + + dtype = structure_list[0].positions.dtype + composition_features = torch.empty((len(structure_list), len(species)), dtype=dtype) + for i, structure in enumerate(structure_list): + for j, s in enumerate(species): + composition_features[i, j] = torch.sum(structure.types == s) regularizer = 1e-20 - while regularizer: if regularizer > 1e5: raise RuntimeError( @@ -69,10 +72,6 @@ def calculate_composition_weights( except torch._C._LinAlgError: regularizer *= 10.0 - species = get_all_species(datasets) - # note that this is sorted, and the composition weights are sorted - # as well, because the species are sorted in the composition features - return solution, species diff --git a/tests/utils/test_composition.py b/tests/utils/test_composition.py new file mode 100644 index 000000000..9513ead81 --- /dev/null +++ b/tests/utils/test_composition.py @@ -0,0 +1,71 @@ +from pathlib import Path + +import torch +from metatensor.learn import Dataset +from metatensor.torch import Labels, TensorBlock, TensorMap +from metatensor.torch.atomistic import System + +from metatensor.models.utils.composition import calculate_composition_weights + + +RESOURCES_PATH = Path(__file__).parent.resolve() / ".." / "resources" + + +def test_calculate_composition_weights(): + """Test the calculation of composition weights.""" + + # Here we use three synthetic structures: + # - O atom, with an energy of 1.0 + # - H2O molecule, with an energy of 5.0 + # - H4O2 molecule, with an energy of 10.0 + # The expected composition weights are 2.0 for H and 1.0 for O. + + systems = [ + System( + positions=torch.tensor([[0.0, 0.0, 0.0]]), + types=torch.tensor([8]), + cell=torch.eye(3), + ), + System( + positions=torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]), + types=torch.tensor([1, 1, 8]), + cell=torch.eye(3), + ), + System( + positions=torch.tensor( + [ + [0.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + [1.0, 0.0, 1.0], + [0.0, 1.0, 1.0], + ] + ), + types=torch.tensor([1, 1, 8, 1, 1, 8]), + cell=torch.eye(3), + ), + ] + energies = [1.0, 5.0, 10.0] + energies = [ + TensorMap( + keys=Labels(names=["_"], values=torch.tensor([[0]])), + blocks=[ + TensorBlock( + values=torch.tensor([[e]]), + samples=Labels(names=["system"], values=torch.tensor([[i]])), + components=[], + properties=Labels(names=["energy"], values=torch.tensor([[0]])), + ) + ], + ) + for i, e in enumerate(energies) + ] + dataset = Dataset(system=systems, energy=energies) + + weights, species = calculate_composition_weights(dataset, "energy") + + assert len(weights) == len(species) + assert len(weights) == 2 + assert species == [1, 8] + assert torch.allclose(weights, torch.tensor([2.0, 1.0])) diff --git a/tests/utils/test_model_io.py b/tests/utils/test_model_io.py index 33bfbc853..443c1a59f 100644 --- a/tests/utils/test_model_io.py +++ b/tests/utils/test_model_io.py @@ -3,7 +3,7 @@ import metatensor.torch import pytest -import rascaline.torch +import torch from metatensor.torch.atomistic import ModelCapabilities, ModelOutput from metatensor.models.experimental import soap_bpnn @@ -35,10 +35,12 @@ def test_save_load_checkpoint(monkeypatch, tmp_path): ) model = soap_bpnn.Model(capabilities) - systems = read_systems(RESOURCES_PATH / "qm9_reduced_100.xyz") + systems = read_systems( + RESOURCES_PATH / "qm9_reduced_100.xyz", dtype=torch.get_default_dtype() + ) output_before_save = model( - rascaline.torch.systems_to_torch(systems), + systems, {"energy": model.capabilities.outputs["energy"]}, ) @@ -46,7 +48,7 @@ def test_save_load_checkpoint(monkeypatch, tmp_path): loaded_model = load_checkpoint("test_model.ckpt") output_after_load = loaded_model( - rascaline.torch.systems_to_torch(systems), + systems, {"energy": model.capabilities.outputs["energy"]}, ) diff --git a/tests/utils/test_output_gradient.py b/tests/utils/test_output_gradient.py index 98866a5a6..0159499a5 100644 --- a/tests/utils/test_output_gradient.py +++ b/tests/utils/test_output_gradient.py @@ -2,9 +2,8 @@ import metatensor.torch import pytest -import rascaline.torch import torch -from metatensor.torch.atomistic import ModelCapabilities, ModelOutput +from metatensor.torch.atomistic import ModelCapabilities, ModelOutput, System from metatensor.models.experimental import soap_bpnn from metatensor.models.utils.data import read_systems @@ -33,7 +32,14 @@ def test_forces(is_training): systems = read_systems( RESOURCES_PATH / "qm9_reduced_100.xyz", dtype=torch.get_default_dtype() )[:5] - systems = rascaline.torch.systems_to_torch(systems, positions_requires_grad=True) + systems = [ + System( + positions=system.positions.requires_grad_(True), + cell=system.cell, + types=system.types, + ) + for system in systems + ] output = model(systems, {"energy": model.capabilities.outputs["energy"]}) position_gradients = compute_gradient( output["energy"].block().values, @@ -43,7 +49,17 @@ def test_forces(is_training): forces = [-position_gradient for position_gradient in position_gradients] jitted_model = torch.jit.script(model) - systems = rascaline.torch.systems_to_torch(systems, positions_requires_grad=True) + systems = read_systems( + RESOURCES_PATH / "qm9_reduced_100.xyz", dtype=torch.get_default_dtype() + )[:5] + systems = [ + System( + positions=system.positions.requires_grad_(True), + cell=system.cell, + types=system.types, + ) + for system in systems + ] output = jitted_model(systems, {"energy": model.capabilities.outputs["energy"]}) jitted_position_gradients = compute_gradient( output["energy"].block().values,