Skip to content

Commit

Permalink
Repair LAMMPS-GAP interface (#377)
Browse files Browse the repository at this point in the history
* Added systems converter with the selected_atoms extraction
  • Loading branch information
abmazitov authored Nov 4, 2024
1 parent b8f283e commit 4b1d1c3
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/metatrain/experimental/gap/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from ...utils.additive import ZBL, CompositionModel
from ...utils.export import export
from .utils import extract_selected_atoms


class GAP(torch.nn.Module):
Expand Down Expand Up @@ -147,6 +148,7 @@ def forward(
outputs: Dict[str, ModelOutput],
selected_atoms: Optional[TorchLabels] = None,
) -> Dict[str, TorchTensorMap]:
systems = extract_selected_atoms(systems, selected_atoms)
soap_features = self._soap_torch_calculator(
systems, selected_samples=selected_atoms
)
Expand Down Expand Up @@ -213,7 +215,6 @@ def forward(
output_key = list(outputs.keys())[0]
energies = self._subset_of_regressors_torch(soap_features)
return_dict = {output_key: energies}

if not self.training:
# at evaluation, we also add the additive contributions
for additive_model in self.additive_models:
Expand Down Expand Up @@ -243,7 +244,7 @@ def export(self) -> MetatensorAtomisticModel:
atomic_types=sorted(self.dataset_info.atomic_types),
interaction_range=interaction_range,
length_unit=self.dataset_info.length_unit,
supported_devices=["cuda", "cpu"],
supported_devices=["cpu"],
dtype="float64",
)

Expand Down
5 changes: 5 additions & 0 deletions src/metatrain/experimental/gap/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .extract_selected_atoms import extract_selected_atoms

__all__ = [
"extract_selected_atoms",
]
41 changes: 41 additions & 0 deletions src/metatrain/experimental/gap/utils/extract_selected_atoms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import List, Optional

from metatensor.torch import Labels
from metatensor.torch.atomistic import System


def extract_selected_atoms(
systems: List[System], selected_atoms: Optional[Labels] = None
):
"""
Preprocesses the systems by selecting only the atoms in selected_atoms.
This is particularly important for LAMMPS interface, which returns both
real and ghost atoms as a part of the system. This happens when the
length of the `system` is greater than the length of the `selected_atoms`.
:param systems: List of systems to preprocess.
:param selected_atoms: The atoms to select from the systems.
:return: The preprocessed systems.
"""
if selected_atoms is None:
return systems
processed_systems: List[System] = []
for i, system in enumerate(systems):
selected_atoms_index = selected_atoms.values[:, 1][
selected_atoms.values[:, 0] == i
]
if len(system) > len(selected_atoms_index):
positions = system.positions[selected_atoms_index]
types = system.types[selected_atoms_index]
cell = system.cell
pbc = system.pbc
processed_system = System(
positions=positions, types=types, cell=cell, pbc=pbc
)
for nl_option in system.known_neighbor_lists():
nl = system.get_neighbor_list(nl_option)
processed_system.add_neighbor_list(nl_option, nl)
else:
processed_system = system
processed_systems.append(processed_system)
return processed_systems

0 comments on commit 4b1d1c3

Please sign in to comment.