From 67a4186ae7ded8446d361d8a8fcc28bab886b251 Mon Sep 17 00:00:00 2001 From: Sergey Pozdnyakov Date: Thu, 29 Feb 2024 21:20:33 +0100 Subject: [PATCH] Add an interface to PET (#68) --------- Co-authored-by: frostedoyster Co-authored-by: Filippo Bigi <98903385+frostedoyster@users.noreply.github.com> Co-authored-by: Arslan Mazitov --- .github/workflows/pet-tests.yml | 36 ++ docs/src/dev-docs/utils/data/index.rst | 1 + .../dev-docs/utils/data/systems_to_ase.rst | 11 + pyproject.toml | 3 + .../conf/architecture/experimental.pet.yaml | 59 ++++ .../models/experimental/pet/__init__.py | 2 + .../models/experimental/pet/model.py | 101 ++++++ .../models/experimental/pet/tests/__init__.py | 6 + .../pet/tests/test_functionality.py | 47 +++ .../pet/tests/test_torchscript.py | 41 +++ .../models/experimental/pet/train.py | 140 ++++++++ .../models/experimental/pet/utils/__init__.py | 5 + .../pet/utils/systems_to_batch_dict.py | 308 ++++++++++++++++++ src/metatensor/models/utils/data/__init__.py | 1 + .../models/utils/data/system_to_ase.py | 26 ++ .../models/utils/neighbors_lists.py | 13 +- tests/utils/data/test_system_to_ase.py | 27 ++ tox.ini | 10 + 18 files changed, 827 insertions(+), 10 deletions(-) create mode 100644 .github/workflows/pet-tests.yml create mode 100644 docs/src/dev-docs/utils/data/systems_to_ase.rst create mode 100644 src/metatensor/models/cli/conf/architecture/experimental.pet.yaml create mode 100644 src/metatensor/models/experimental/pet/__init__.py create mode 100644 src/metatensor/models/experimental/pet/model.py create mode 100644 src/metatensor/models/experimental/pet/tests/__init__.py create mode 100644 src/metatensor/models/experimental/pet/tests/test_functionality.py create mode 100644 src/metatensor/models/experimental/pet/tests/test_torchscript.py create mode 100644 src/metatensor/models/experimental/pet/train.py create mode 100644 src/metatensor/models/experimental/pet/utils/__init__.py create mode 100644 src/metatensor/models/experimental/pet/utils/systems_to_batch_dict.py create mode 100644 src/metatensor/models/utils/data/system_to_ase.py create mode 100644 tests/utils/data/test_system_to_ase.py diff --git a/.github/workflows/pet-tests.yml b/.github/workflows/pet-tests.yml new file mode 100644 index 000000000..87dfd2c13 --- /dev/null +++ b/.github/workflows/pet-tests.yml @@ -0,0 +1,36 @@ +name: PET tests + +on: + push: + branches: [main] + pull_request: + # Check all PR + +jobs: + tests: + runs-on: ${{ matrix.os }} + strategy: + matrix: + include: + - os: ubuntu-22.04 + python-version: "3.12" + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - run: pip install tox + + - name: run PET tests + run: tox -e pet-tests + env: + # Use the CPU only version of torch when building/running the code + PIP_EXTRA_INDEX_URL: https://download.pytorch.org/whl/cpu + + - name: Upload codecoverage + uses: codecov/codecov-action@v4 + with: + files: ./tests/coverage.xml diff --git a/docs/src/dev-docs/utils/data/index.rst b/docs/src/dev-docs/utils/data/index.rst index 1cfa5bc97..9f8612825 100644 --- a/docs/src/dev-docs/utils/data/index.rst +++ b/docs/src/dev-docs/utils/data/index.rst @@ -10,3 +10,4 @@ API for handling data in ``metatensor-models``. dataset readers/index writers + systems_to_ase diff --git a/docs/src/dev-docs/utils/data/systems_to_ase.rst b/docs/src/dev-docs/utils/data/systems_to_ase.rst new file mode 100644 index 000000000..ea44b4580 --- /dev/null +++ b/docs/src/dev-docs/utils/data/systems_to_ase.rst @@ -0,0 +1,11 @@ +Converting Systems to ASE +######################### + +Some machine learning models might train on ``ase.Atoms`` objects. +This module provides a function to convert a ``metatensor.torch.atomistic.System`` +object to an ``ase.Atoms`` object. + +.. automodule:: metatensor.models.utils.data.system_to_ase + :members: + :undoc-members: + :show-inheritance: diff --git a/pyproject.toml b/pyproject.toml index eaf93099b..cf6081df3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,9 @@ soap-bpnn = [] alchemical-model = [ "torch_alchemical @ git+https://github.com/abmazitov/torch_alchemical.git@fafb0bd", ] +pet = [ + "pet @ git+https://github.com/serfg/pet.git@5668bda", +] [tool.setuptools.packages.find] where = ["src"] diff --git a/src/metatensor/models/cli/conf/architecture/experimental.pet.yaml b/src/metatensor/models/cli/conf/architecture/experimental.pet.yaml new file mode 100644 index 000000000..eadfd4c01 --- /dev/null +++ b/src/metatensor/models/cli/conf/architecture/experimental.pet.yaml @@ -0,0 +1,59 @@ +ARCHITECTURAL_HYPERS: + CUTOFF_DELTA: 0.2 + AVERAGE_POOLING: False + TRANSFORMERS_CENTRAL_SPECIFIC: False + HEADS_CENTRAL_SPECIFIC: False + ADD_TOKEN_FIRST: True + ADD_TOKEN_SECOND: True + N_GNN_LAYERS: 3 + TRANSFORMER_D_MODEL: 128 + TRANSFORMER_N_HEAD: 4 + TRANSFORMER_DIM_FEEDFORWARD: 512 + HEAD_N_NEURONS: 128 + N_TRANS_LAYERS: 3 + ACTIVATION: silu + USE_LENGTH: True + USE_ONLY_LENGTH: False + R_CUT: 5.0 + R_EMBEDDING_ACTIVATION: False + COMPRESS_MODE: mlp + BLEND_NEIGHBOR_SPECIES: False + AVERAGE_BOND_ENERGIES: False + USE_BOND_ENERGIES: True + USE_ADDITIONAL_SCALAR_ATTRIBUTES: False + SCALAR_ATTRIBUTES_SIZE: null + TRANSFORMER_TYPE: PostLN # PostLN or PreLN + USE_LONG_RANGE: False + K_CUT: null # should be float; only used when USE_LONG_RANGE is True + + +FITTING_SCHEME: + INITIAL_LR: 1e-4 + EPOCH_NUM_ATOMIC: 1000000000000000000 + SCHEDULER_STEP_SIZE_ATOMIC: 500000000 + EPOCHS_WARMUP_ATOMIC: 250000000 + GLOBAL_AUG: True + SLIDING_FACTOR: 0.7 + ATOMIC_BATCH_SIZE: 850 + MAX_TIME: 234000 + ENERGY_WEIGHT: 0.1 # only used when fitting MLIP + MULTI_GPU: False + RANDOM_SEED: 0 + CUDA_DETERMINISTIC: False + MODEL_TO_START_WITH: null + SUPPORT_MISSING_VALUES: False + USE_WEIGHT_DECAY: False + WEIGHT_DECAY: 0.0 + DO_GRADIENT_CLIPPING: False + GRADIENT_CLIPPING_MAX_NORM: null # must be overwritten if DO_GRADIENT_CLIPPING is True + USE_SHIFT_AGNOSTIC_LOSS: False # only used when fitting general target. Primary use case: EDOS + ENERGIES_LOSS: per_structure # per_structure or per_atom + +MLIP_SETTINGS: # only used when fitting MLIP + ENERGY_KEY: energy + FORCES_KEY: forces + USE_ENERGIES: True + USE_FORCES: True + +UTILITY_FLAGS: #for internal usage; do not change/overwrite + CALCULATION_TYPE: null diff --git a/src/metatensor/models/experimental/pet/__init__.py b/src/metatensor/models/experimental/pet/__init__.py new file mode 100644 index 000000000..ff9a77daf --- /dev/null +++ b/src/metatensor/models/experimental/pet/__init__.py @@ -0,0 +1,2 @@ +from .model import Model, DEFAULT_HYPERS # noqa: F401 +from .train import train # noqa: F401 diff --git a/src/metatensor/models/experimental/pet/model.py b/src/metatensor/models/experimental/pet/model.py new file mode 100644 index 000000000..00c86da3d --- /dev/null +++ b/src/metatensor/models/experimental/pet/model.py @@ -0,0 +1,101 @@ +from typing import Dict, List, Optional + +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap +from metatensor.torch.atomistic import ( + ModelCapabilities, + ModelOutput, + NeighborsListOptions, + System, +) +from omegaconf import OmegaConf +from pet.hypers import Hypers +from pet.pet import PET + +from ... import ARCHITECTURE_CONFIG_PATH +from .utils import systems_to_batch_dict + + +DEFAULT_HYPERS = OmegaConf.to_container( + OmegaConf.load(ARCHITECTURE_CONFIG_PATH / "experimental.pet.yaml") +) + +DEFAULT_MODEL_HYPERS = DEFAULT_HYPERS["ARCHITECTURAL_HYPERS"] + +# We hardcode some of the hypers to make PET work as a MLIP. +DEFAULT_MODEL_HYPERS.update( + {"D_OUTPUT": 1, "TARGET_TYPE": "structural", "TARGET_AGGREGATION": "sum"} +) + +ARCHITECTURE_NAME = "experimental.pet" + + +class Model(torch.nn.Module): + def __init__( + self, capabilities: ModelCapabilities, hypers: Dict = DEFAULT_MODEL_HYPERS + ) -> None: + super().__init__() + self.name = ARCHITECTURE_NAME + self.hypers = Hypers(hypers) if isinstance(hypers, dict) else hypers + self.cutoff = ( + self.hypers["R_CUT"] if isinstance(self.hypers, dict) else self.hypers.R_CUT + ) + self.all_species: List[int] = capabilities.species + self.capabilities = capabilities + self.pet = PET(self.hypers, 0.0, len(self.all_species)) + + def set_trained_model(self, trained_model: torch.nn.Module) -> None: + self.pet = trained_model + + def requested_neighbors_lists( + self, + ) -> List[NeighborsListOptions]: + return [ + NeighborsListOptions( + model_cutoff=self.cutoff, + full_list=True, + ) + ] + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + if selected_atoms is not None: + raise NotImplementedError("PET does not support selected atoms.") + options = self.requested_neighbors_lists()[0] + batch = systems_to_batch_dict(systems, options, self.all_species) + predictions = self.pet(batch) + total_energies: Dict[str, TensorMap] = {} + for output_name in outputs: + total_energies[output_name] = TensorMap( + keys=Labels( + names=["_"], + values=torch.tensor( + [[0]], + device=predictions.device, + ), + ), + blocks=[ + TensorBlock( + samples=Labels( + names=["structure"], + values=torch.arange( + len(predictions), + device=predictions.device, + ).view(-1, 1), + ), + components=[], + properties=Labels( + names=["_"], + values=torch.zeros( + (1, 1), dtype=torch.int32, device=predictions.device + ), + ), + values=predictions, + ) + ], + ) + return total_energies diff --git a/src/metatensor/models/experimental/pet/tests/__init__.py b/src/metatensor/models/experimental/pet/tests/__init__.py new file mode 100644 index 000000000..b6aa045b3 --- /dev/null +++ b/src/metatensor/models/experimental/pet/tests/__init__.py @@ -0,0 +1,6 @@ +from pathlib import Path + +DATASET_PATH = str( + Path(__file__).parent.resolve() + / "../../../../../../tests/resources/qm9_reduced_100.xyz" +) diff --git a/src/metatensor/models/experimental/pet/tests/test_functionality.py b/src/metatensor/models/experimental/pet/tests/test_functionality.py new file mode 100644 index 000000000..ae9a0a38d --- /dev/null +++ b/src/metatensor/models/experimental/pet/tests/test_functionality.py @@ -0,0 +1,47 @@ +import ase +import rascaline.torch +import torch +from metatensor.torch.atomistic import ( + MetatensorAtomisticModel, + ModelCapabilities, + ModelEvaluationOptions, + ModelOutput, +) + +from metatensor.models.experimental.pet import DEFAULT_HYPERS, Model +from metatensor.models.utils.neighbors_lists import get_system_with_neighbors_lists + + +def test_prediction_subset(): + """Tests that the model can predict on a subset + of the elements it was trained on.""" + + capabilities = ModelCapabilities( + length_unit="Angstrom", + species=[1, 6, 7, 8], + outputs={ + "energy": ModelOutput( + quantity="energy", + unit="eV", + ) + }, + ) + + model = Model(capabilities, DEFAULT_HYPERS["ARCHITECTURAL_HYPERS"]).to( + torch.float64 + ) + structure = ase.Atoms("O2", positions=[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]) + system = rascaline.torch.systems_to_torch(structure) + system = get_system_with_neighbors_lists(system, model.requested_neighbors_lists()) + + evaluation_options = ModelEvaluationOptions( + length_unit=capabilities.length_unit, + outputs=capabilities.outputs, + ) + + model = MetatensorAtomisticModel(model.eval(), model.capabilities) + model( + [system], + evaluation_options, + check_consistency=True, + ) diff --git a/src/metatensor/models/experimental/pet/tests/test_torchscript.py b/src/metatensor/models/experimental/pet/tests/test_torchscript.py new file mode 100644 index 000000000..dd88c8c3e --- /dev/null +++ b/src/metatensor/models/experimental/pet/tests/test_torchscript.py @@ -0,0 +1,41 @@ +import torch +from metatensor.torch.atomistic import ModelCapabilities, ModelOutput + +from metatensor.models.experimental.pet import DEFAULT_HYPERS, Model + + +def test_torchscript(): + """Tests that the model can be jitted.""" + + capabilities = ModelCapabilities( + length_unit="Angstrom", + species=[1, 6, 7, 8], + outputs={ + "energy": ModelOutput( + quantity="energy", + unit="eV", + ) + }, + ) + pet = Model(capabilities, DEFAULT_HYPERS["ARCHITECTURAL_HYPERS"]) + torch.jit.script(pet) + + +def test_torchscript_save(): + """Tests that the model can be jitted and saved.""" + + capabilities = ModelCapabilities( + length_unit="Angstrom", + species=[1, 6, 7, 8], + outputs={ + "energy": ModelOutput( + quantity="energy", + unit="eV", + ) + }, + ) + pet = Model(capabilities, DEFAULT_HYPERS["ARCHITECTURAL_HYPERS"]) + torch.jit.save( + torch.jit.script(pet), + "pet.pt", + ) diff --git a/src/metatensor/models/experimental/pet/train.py b/src/metatensor/models/experimental/pet/train.py new file mode 100644 index 000000000..5345cd55d --- /dev/null +++ b/src/metatensor/models/experimental/pet/train.py @@ -0,0 +1,140 @@ +import logging +from pathlib import Path +from typing import Dict, List, Optional, Union + +import torch +from metatensor.learn.data import DataLoader +from metatensor.learn.data.dataset import _BaseDataset +from metatensor.torch.atomistic import ModelCapabilities +from pet.hypers import Hypers +from pet.pet import PET +from pet.train_model import fit_pet + +from ...utils.data import collate_fn +from ...utils.data.system_to_ase import system_to_ase +from .model import DEFAULT_HYPERS, Model + + +logger = logging.getLogger(__name__) + + +def train( + train_datasets: List[Union[_BaseDataset, torch.utils.data.Subset]], + validation_datasets: List[Union[_BaseDataset, torch.utils.data.Subset]], + requested_capabilities: ModelCapabilities, + hypers: Dict = DEFAULT_HYPERS, + continue_from: Optional[str] = None, + output_dir: str = ".", + device_str: str = "cpu", +): + if torch.get_default_dtype() != torch.float32: + raise ValueError("PET only supports float32") + if device_str != "cuda" and device_str != "gpu": + raise ValueError("PET only supports cuda (gpu) training") + if len(requested_capabilities.outputs) != 1: + raise ValueError("PET only supports a single output") + target_name = next(iter(requested_capabilities.outputs.keys())) + if requested_capabilities.outputs[target_name].quantity != "energy": + raise ValueError("PET only supports energies as output") + if requested_capabilities.outputs[target_name].per_atom: + raise ValueError("PET does not support per-atom energies") + if len(train_datasets) != 1: + raise ValueError("PET only supports a single training dataset") + if len(validation_datasets) != 1: + raise ValueError("PET only supports a single validation dataset") + + if device_str == "gpu": + device_str = "cuda" + + if continue_from is not None: + hypers["FITTING_SCHEME"]["MODEL_TO_START_WITH"] = continue_from + + train_dataset = train_datasets[0] + validation_dataset = validation_datasets[0] + + # dummy dataloaders due to https://github.com/lab-cosmo/metatensor/issues/521 + train_dataloader = DataLoader( + train_dataset, + batch_size=1, + shuffle=False, + collate_fn=collate_fn, + ) + validation_dataloader = DataLoader( + validation_dataset, + batch_size=1, + shuffle=False, + collate_fn=collate_fn, + ) + + # only energies or energies and forces? + do_forces = next(iter(train_dataset))[1].block().has_gradient("positions") + all_species = requested_capabilities.species + if not do_forces: + hypers["MLIP_SETTINGS"]["USE_FORCES"] = False + + ase_train_dataset = [] + for (system,), targets in train_dataloader: + ase_atoms = system_to_ase(system) + ase_atoms.info["energy"] = float( + targets[target_name].block().values.squeeze(-1).detach().cpu().numpy() + ) + if do_forces: + ase_atoms.arrays["forces"] = ( + -targets[target_name] + .block() + .gradient("positions") + .values.squeeze(-1) + .detach() + .cpu() + .numpy() + ) + ase_train_dataset.append(ase_atoms) + + ase_validation_dataset = [] + for (system,), targets in validation_dataloader: + ase_atoms = system_to_ase(system) + ase_atoms.info["energy"] = float( + targets[target_name].block().values.squeeze(-1).detach().cpu().numpy() + ) + if do_forces: + ase_atoms.arrays["forces"] = ( + -targets[target_name] + .block() + .gradient("positions") + .values.squeeze(-1) + .detach() + .cpu() + .numpy() + ) + ase_validation_dataset.append(ase_atoms) + + fit_pet( + ase_train_dataset, ase_validation_dataset, hypers, "pet", device_str, output_dir + ) + + if do_forces: + load_path = Path(output_dir) / "pet" / "best_val_rmse_forces_model_state_dict" + else: + load_path = Path(output_dir) / "pet" / "best_val_rmse_energies_model_state_dict" + + state_dict = torch.load(load_path) + + ARCHITECTURAL_HYPERS = Hypers(hypers["ARCHITECTURAL_HYPERS"]) + ARCHITECTURAL_HYPERS.D_OUTPUT = 1 # energy is a single scalar + ARCHITECTURAL_HYPERS.TARGET_TYPE = "structural" # energy is structural property + ARCHITECTURAL_HYPERS.TARGET_AGGREGATION = "sum" # sum of atomic energies + + raw_pet = PET(ARCHITECTURAL_HYPERS, 0.0, len(all_species)) + + new_state_dict = {} + for name, value in state_dict.items(): + name = name.replace("model.pet_model.", "") + new_state_dict[name] = value + + raw_pet.load_state_dict(new_state_dict) + + model = Model(requested_capabilities, ARCHITECTURAL_HYPERS) + + model.set_trained_model(raw_pet) + + return model diff --git a/src/metatensor/models/experimental/pet/utils/__init__.py b/src/metatensor/models/experimental/pet/utils/__init__.py new file mode 100644 index 000000000..62a3238c4 --- /dev/null +++ b/src/metatensor/models/experimental/pet/utils/__init__.py @@ -0,0 +1,5 @@ +from .systems_to_batch_dict import systems_to_batch_dict + +__all__ = [ + "systems_to_batch_dict", +] diff --git a/src/metatensor/models/experimental/pet/utils/systems_to_batch_dict.py b/src/metatensor/models/experimental/pet/utils/systems_to_batch_dict.py new file mode 100644 index 000000000..500ebb401 --- /dev/null +++ b/src/metatensor/models/experimental/pet/utils/systems_to_batch_dict.py @@ -0,0 +1,308 @@ +from typing import Dict, List, Tuple + +import torch +from metatensor.torch.atomistic import NeighborsListOptions, System + + +class NeighborIndexConstructor: + """From a canonical neighbor list, this function constructs neighbor + indices that are needed for internal usage in the PET model.""" + + def __init__( + self, + i_list: List[int], + j_list: List[int], + S_list: List[torch.Tensor], + species: List[int], + ) -> None: + n_atoms: int = len(species) + + self.neighbors_index: List[List[int]] = [] + for _ in range(n_atoms): + neighbors_index_now: List[int] = [] + self.neighbors_index.append(neighbors_index_now) + + self.neighbors_shift: List[List[torch.Tensor]] = [] + for _ in range(n_atoms): + neighbors_shift_now: List[torch.Tensor] = [] + self.neighbors_shift.append(neighbors_shift_now) + + for i, j, _, S in zip(i_list, j_list, range(len(i_list)), S_list): + self.neighbors_index[i].append(j) + self.neighbors_shift[i].append(S) + + self.relative_positions_raw: List[List[torch.Tensor]] = [ + [] for i in range(n_atoms) + ] + self.neighbor_species: List[List[int]] = [] + for _ in range(n_atoms): + now: List[int] = [] + self.neighbor_species.append(now) + + self.neighbors_pos: List[List[torch.Tensor]] = [[] for i in range(n_atoms)] + + for i, j, index, S in zip(i_list, j_list, range(len(i_list)), S_list): + self.relative_positions_raw[i].append(torch.LongTensor([index])) + self.neighbor_species[i].append(species[j]) + for k in range(len(self.neighbors_index[j])): + if (self.neighbors_index[j][k] == i) and torch.equal( + self.neighbors_shift[j][k], -S + ): + self.neighbors_pos[i].append(torch.LongTensor([k])) + + self.relative_positions = [ + torch.cat(chunk, dim=0) for chunk in self.relative_positions_raw + ] + + def get_max_num(self) -> int: + maximum: int = -1 + for chunk in self.relative_positions: + if chunk.shape[0] > maximum: + maximum = chunk.shape[0] + return maximum + + def get_neighbor_index(self, max_num: int, all_species: torch.Tensor) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ]: + nums_raw: List[int] = [] + mask_list: List[torch.Tensor] = [] + relative_positions: torch.Tensor = torch.zeros( + [len(self.relative_positions), max_num], dtype=torch.long + ) + neighbors_pos: torch.Tensor = torch.zeros( + [len(self.relative_positions), max_num], dtype=torch.long + ) + neighbors_index: torch.Tensor = torch.zeros( + [len(self.relative_positions), max_num], dtype=torch.long + ) + + for i in range(len(self.relative_positions)): + + now: torch.Tensor = self.relative_positions[i] + + if len(now) > 0: + relative_positions[i, : len(now)] = now + neighbors_pos[i, : len(now)] = torch.cat(self.neighbors_pos[i], dim=0) + neighbors_index[i, : len(now)] = torch.LongTensor( + self.neighbors_index[i] + ) + + nums_raw.append(len(self.relative_positions[i])) + current_mask: torch.Tensor = torch.zeros([max_num], dtype=torch.bool) + current_mask[len(self.relative_positions[i]) :] = True + mask_list.append(current_mask[None, :]) + + mask: torch.Tensor = torch.cat(mask_list, dim=0).to(dtype=torch.bool) + + nums: torch.Tensor = torch.LongTensor(nums_raw) + + neighbor_species: torch.Tensor = all_species.shape[0] * torch.ones( + [len(self.neighbor_species), max_num], dtype=torch.long + ) + for i in range(len(self.neighbor_species)): + species_now: List[int] = self.neighbor_species[i] + values_now: List[int] = [ + int(torch.where(all_species == specie)[0][0].item()) + for specie in species_now + ] + values_now_torch: torch.Tensor = torch.LongTensor(values_now) + neighbor_species[i, : len(values_now_torch)] = values_now_torch + + return ( + neighbors_pos, + neighbors_index, + nums, + mask, + neighbor_species, + relative_positions, + ) + + +def collate_graph_dicts( + graph_dicts: List[Dict[str, torch.Tensor]], device: str +) -> Dict[str, torch.Tensor]: + """ + Collates a list of graphs into a single graph. + + :param graph_dicts: A list of graphs to be collated. + + :return: The collated grap (batch). + """ + + simple_concatenate_keys: List[str] = [ + "central_species", + "x", + "neighbor_species", + "neighbors_pos", + "nums", + "mask", + ] + + cumulative_adjust_keys: List[str] = ["neighbors_index"] + + result: Dict[str, List[torch.Tensor]] = {} + + n_nodes_cumulative: int = 0 + + number_of_graphs: int = int(len(graph_dicts)) + + for index in range(number_of_graphs): + graph: Dict[str, torch.Tensor] = graph_dicts[index] + + for key in simple_concatenate_keys: + if key not in result: + result[key] = [graph[key]] + else: + result[key].append(graph[key]) + + for key in cumulative_adjust_keys: + if key not in result: + graph_key: torch.Tensor = graph[key] + + now: List[torch.Tensor] = [graph_key + n_nodes_cumulative] + result[key] = now + + else: + graph_key_2: torch.Tensor = graph[key] + + now_2: torch.Tensor = graph_key_2 + n_nodes_cumulative + result[key].append(now_2) + + n_atoms: int = graph["central_species"].shape[0] + + index_repeated: torch.Tensor = torch.LongTensor([index for _ in range(n_atoms)]) + if "batch" not in result.keys(): + result["batch"] = [index_repeated] + else: + result["batch"].append(index_repeated) + + n_nodes_cumulative += n_atoms + + result_final: Dict[str, torch.Tensor] = {} + for key in simple_concatenate_keys + cumulative_adjust_keys: + now_3: List[torch.Tensor] = [] + for el in result[key]: + now_3.append(el) + + result_final[key] = torch.cat(now_3, dim=0) + + result_final["batch"] = torch.cat(result["batch"], dim=0).to(device) + + return result_final + + +def systems_to_batch_dict( + systems: List[System], options: NeighborsListOptions, all_species_list: List[int] +) -> Dict[str, torch.Tensor]: + """ + Converts a standatd input data format of `metatensor-models` to a + PyTorch Geometric `Batch` object, compatible with `PET` model. + + :param systems: The list of systems in `metatensor.torch.atomistic.System` + format, that needs to be converted. + :param options: A `NeighborsListOptions` objects specifying the parameters + for a neighbor list, which will be used during the convertation. + :param all_species: A `torch.Tensor` with all the species present in the + systems. + + :return: Batch compatible with PET. + """ + + all_species: torch.Tensor = torch.LongTensor(all_species_list) + neighbor_index_constructors: List[NeighborIndexConstructor] = [] + + for system in systems: + known_neighbors_lists = system.known_neighbors_lists() + if not torch.any( + torch.tensor([known == options for known in known_neighbors_lists]) + ): + raise ValueError( + f"System does not have the neighbor list with the options {options}" + ) + + neighbors = system.get_neighbors_list(options) + + i_list: torch.Tensor = neighbors.samples.column("first_atom") + j_list: torch.Tensor = neighbors.samples.column("second_atom") + + S_list_raw: List[torch.Tensor] = [ + neighbors.samples.column("cell_shift_a")[None], + neighbors.samples.column("cell_shift_b")[None], + neighbors.samples.column("cell_shift_c")[None], + ] + + S_list: torch.Tensor = torch.cat(S_list_raw) + S_list = S_list.transpose(0, 1) + + species: torch.Tensor = system.species + + i_list = i_list.cpu() + j_list = j_list.cpu() + S_list = S_list.cpu() + species = species.cpu() + + i_list_proper: List[int] = [int(el.item()) for el in i_list] + j_list_proper: List[int] = [int(el.item()) for el in j_list] + S_list_proper: List[torch.Tensor] = [el.to(dtype=torch.long) for el in S_list] + species_proper: List[int] = [int(el.item()) for el in species] + + neighbor_index_constructor: NeighborIndexConstructor = NeighborIndexConstructor( + i_list_proper, j_list_proper, S_list_proper, species_proper + ) + neighbor_index_constructors.append(neighbor_index_constructor) + + max_nums: List[int] = [ + neighbor_index_constructor.get_max_num() + for neighbor_index_constructor in neighbor_index_constructors + ] + max_num: int = max(max_nums) + + graphs: List[Dict[str, torch.Tensor]] = [] + device = "cpu" # initial value to make torch script happy; to be overwritten + + for neighbor_index_constructor, system in zip(neighbor_index_constructors, systems): + ( + neighbors_pos, + neighbors_index, + nums, + mask, + neighbor_species, + relative_positions_index, + ) = neighbor_index_constructor.get_neighbor_index(max_num, all_species) + + neighbors = system.get_neighbors_list(options) + displacement_vectors = neighbors.values[:, :, 0] + + device = str(displacement_vectors.device) + neighbors_pos = neighbors_pos.to(device) + neighbors_index = neighbors_index.to(device) + nums = nums.to(device) + mask = mask.to(device) + neighbor_species = neighbor_species.to(device) + relative_positions_index = relative_positions_index.to(device) + + relative_positions = displacement_vectors[relative_positions_index] + central_species = [ + int(torch.where(all_species == specie)[0][0].item()) + for specie in system.species + ] + + central_species = torch.LongTensor(central_species).to(device) + + graph_now = { + "central_species": central_species, + "x": relative_positions, + "neighbor_species": neighbor_species, + "neighbors_pos": neighbors_pos, + "neighbors_index": neighbors_index, + "nums": nums, + "mask": mask, + } + graphs.append(graph_now) + + return collate_graph_dicts(graphs, device) diff --git a/src/metatensor/models/utils/data/__init__.py b/src/metatensor/models/utils/data/__init__.py index 3ec5c324e..715bbe5ca 100644 --- a/src/metatensor/models/utils/data/__init__.py +++ b/src/metatensor/models/utils/data/__init__.py @@ -15,3 +15,4 @@ from .writers import write_predictions # noqa: F401 from .combine_dataloaders import combine_dataloaders # noqa: F401 +from .system_to_ase import system_to_ase # noqa: F401 diff --git a/src/metatensor/models/utils/data/system_to_ase.py b/src/metatensor/models/utils/data/system_to_ase.py new file mode 100644 index 000000000..ffd5c7821 --- /dev/null +++ b/src/metatensor/models/utils/data/system_to_ase.py @@ -0,0 +1,26 @@ +import ase +from metatensor.torch.atomistic import System + + +def system_to_ase(system: System) -> ase.Atoms: + """Converts a ``metatensor.torch.atomistic.System`` to an ``ase.Atoms`` object. + This will discard any neighbor lists attached to the ``System``. + + :param system: The system to convert. + + :return: The system as an ``ase.Atoms`` object. + """ + + # Convert the system to an ASE atoms object + positions = system.positions.detach().cpu().numpy() + numbers = system.species.detach().cpu().numpy() + cell = system.cell.detach().cpu().numpy() + pbc = list(cell.any(axis=1)) + atoms = ase.Atoms( + numbers=numbers, + positions=positions, + cell=cell, + pbc=pbc, + ) + + return atoms diff --git a/src/metatensor/models/utils/neighbors_lists.py b/src/metatensor/models/utils/neighbors_lists.py index 65c671618..a8fdb3731 100644 --- a/src/metatensor/models/utils/neighbors_lists.py +++ b/src/metatensor/models/utils/neighbors_lists.py @@ -9,6 +9,8 @@ register_autograd_neighbors, ) +from .data.system_to_ase import system_to_ase + def get_system_with_neighbors_lists( system: System, neighbor_lists: List[NeighborsListOptions] @@ -22,16 +24,7 @@ def get_system_with_neighbors_lists( :return: The `System` object with the neighbor lists added. """ # Convert the system to an ASE atoms object - positions = system.positions.detach().cpu().numpy() - numbers = system.species.detach().cpu().numpy() - cell = system.cell.detach().cpu().numpy() - pbc = list(cell.any(axis=1)) - atoms = ase.Atoms( - numbers=numbers, - positions=positions, - cell=cell, - pbc=pbc, - ) + atoms = system_to_ase(system) # Compute the neighbor lists for options in neighbor_lists: diff --git a/tests/utils/data/test_system_to_ase.py b/tests/utils/data/test_system_to_ase.py new file mode 100644 index 000000000..ab53c8ee4 --- /dev/null +++ b/tests/utils/data/test_system_to_ase.py @@ -0,0 +1,27 @@ +import torch +from metatensor.torch.atomistic import System + +from metatensor.models.utils.data import system_to_ase + + +def test_system_to_ase(): + """Tests the conversion of a System to an ASE atoms object.""" + # Create a system + system = System( + positions=torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]]), + species=torch.tensor([1, 8]), + cell=torch.tensor([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]), + ) + + # Convert the system to an ASE atoms object + atoms = system_to_ase(system) + + # Check the positions + assert atoms.positions.tolist() == system.positions.tolist() + + # Check the species + assert atoms.numbers.tolist() == system.species.tolist() + + # Check the cell + assert atoms.cell.tolist() == system.cell.tolist() + assert atoms.pbc.tolist() == [True, True, True] diff --git a/tox.ini b/tox.ini index 166546773..8de85355f 100644 --- a/tox.ini +++ b/tox.ini @@ -103,6 +103,16 @@ changedir = src/metatensor/models/experimental/alchemical_model/tests/ commands = pytest {[testenv]warning_options} {posargs} +[testenv:pet-tests] +description = Run PET tests with pytest +passenv = * +deps = + pytest + pet @ git+https://github.com/serfg/pet.git@5668bda +changedir = src/metatensor/models/experimental/pet/tests/ +commands = + pytest {[testenv]warning_options} {posargs} + [testenv:docs] description = builds the documentation with sphinx deps =