Skip to content

Commit

Permalink
Interface with mops (#43)
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster authored Mar 28, 2024
1 parent d587a0f commit d235144
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 34 deletions.
2 changes: 1 addition & 1 deletion examples/alchemical_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __init__(self, hypers, all_species, do_forces) -> None:
super().__init__()
self.all_species = all_species
self.spherical_expansion_calculator = SphericalExpansion(hypers, all_species)
n_max = self.spherical_expansion_calculator.vector_expansion_calculator.radial_basis_calculator.n_max_l
n_max = self.spherical_expansion_calculator.radial_basis_calculator.n_max_l
print("Radial basis:", n_max)
l_max = len(n_max) - 1
n_feat = sum([n_max[l]**2 * n_pseudo**2 for l in range(l_max+1)])
Expand Down
2 changes: 1 addition & 1 deletion tests/data/computing_ref_coeffs-artificial.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
spherical_expansion_calculator = SphericalExpansion(hypers, all_species)
# some random combination matrix, it is only important that we use the same one in the tests
with torch.no_grad():
spherical_expansion_calculator.vector_expansion_calculator.radial_basis_calculator.combination_matrix.weight.copy_(
spherical_expansion_calculator.radial_basis_calculator.combination_matrix.weight.copy_(
torch.tensor(
[[-0.00432252, 0.30971584, -0.47518533],
[-0.4248946 , -0.22236897, 0.15482073]],
Expand Down
4 changes: 2 additions & 2 deletions tests/test_spherical_expansions.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def test_spherical_expansion_coeffs_alchemical(self):
with torch.no_grad():
# wtf? suggested way by torch developers
# https://discuss.pytorch.org/t/initialize-nn-linear-with-specific-weights/29005/4
spherical_expansion_calculator.vector_expansion_calculator.radial_basis_calculator.combination_matrix.weight.copy_(torch.tensor(
spherical_expansion_calculator.radial_basis_calculator.combination_matrix.weight.copy_(torch.tensor(
[[-0.00432252, 0.30971584, -0.47518533],
[-0.4248946 , -0.22236897, 0.15482073]],
device=self.device, dtype=self.dtype))
Expand Down Expand Up @@ -186,7 +186,7 @@ def test_spherical_expansion_coeffs_artificial(self):
tm_ref = tm_ref.to(device=self.device, dtype=self.dtype)
spherical_expansion_calculator = SphericalExpansion(hypers, self.all_species).to(self.device, self.dtype)
with torch.no_grad():
spherical_expansion_calculator.vector_expansion_calculator.radial_basis_calculator.combination_matrix.weight.copy_(
spherical_expansion_calculator.radial_basis_calculator.combination_matrix.weight.copy_(
torch.tensor(
[[-0.00432252, 0.30971584, -0.47518533],
[-0.4248946 , -0.22236897, 0.15482073]],
Expand Down
1 change: 1 addition & 0 deletions torch_spex/operations/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .opsa import outer_product_scatter_add
17 changes: 17 additions & 0 deletions torch_spex/operations/opsa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import torch

try:
import mops.torch
HAS_MOPS = True
except ImportError:
HAS_MOPS = False


if HAS_MOPS:
outer_product_scatter_add = mops.torch.outer_product_scatter_add
else:
def outer_product_scatter_add(A, B, idx, n_out: int):
outer = A.unsqueeze(2) * B.unsqueeze(1)
out = torch.zeros(n_out, A.shape[1], B.shape[1], device=A.device, dtype=A.dtype)
out.index_add_(0, idx, outer)
return out
67 changes: 37 additions & 30 deletions torch_spex/spherical_expansions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import torch
from metatensor.torch import TensorMap, Labels, TensorBlock
import sphericart.torch
from .operations import outer_product_scatter_add

from .radial_basis import RadialBasis
from typing import Dict, List, Optional
from typing import Dict, List


class SphericalExpansion(torch.nn.Module):
Expand Down Expand Up @@ -90,14 +91,23 @@ def __init__(self, hypers: Dict, all_species: List[int]) -> None:
self.normalization_factor = 1.0 # dummy for torchscript
self.normalization_factor_0 = 1.0 # dummy for torchscript
self.all_species = all_species
self.vector_expansion_calculator = VectorExpansion(hypers, self.all_species)

# radial basis needs to know cutoff so we pass it, as well as whether to normalize or not
hypers_radial_basis = copy.deepcopy(hypers["radial basis"])
hypers_radial_basis["r_cut"] = hypers["cutoff radius"]
hypers_radial_basis["normalize"] = self.normalize
if "alchemical" in self.hypers:
self.is_alchemical = True
self.n_pseudo_species = self.hypers["alchemical"]
hypers_radial_basis["alchemical"] = self.hypers["alchemical"]
else:
self.is_alchemical = False
self.n_pseudo_species = 0 # dummy for torchscript
self.is_alchemical = False
self.radial_basis_calculator = RadialBasis(hypers_radial_basis, all_species)
self.l_max = self.radial_basis_calculator.l_max
self.n_max_l = self.radial_basis_calculator.n_max_l
self.spherical_harmonics_calculator = sphericart.torch.SphericalHarmonics(self.l_max, normalized=True)
self.spherical_harmonics_split_list = [(2*l+1) for l in range(self.l_max+1)]

def forward(self,
positions: torch.Tensor,
Expand All @@ -124,7 +134,7 @@ def forward(self,
the original cell expressed with the cell basis.
:param centers: [n_atoms] tensor of integers with the atom indices
for all centers over all structures
:param centers: [n_pairs, 2] tensor of integers with the atom indices
:param pairs: [n_pairs, 2] tensor of integers with the atom indices
for all center and neighbor pairs over all structures
:param structure_centers: [n_atoms] tensor of integers with the indices of the
corresponding structure for each central atom
Expand All @@ -138,59 +148,53 @@ def forward(self,
:math:`c^{a_il}_{Ai, m, a_jn}`
"""

expanded_vectors = self.vector_expansion_calculator(
positions, cells, species, cell_shifts, centers, pairs, structure_centers, structure_pairs, structure_offsets)
cartesian_vectors = get_cartesian_vectors(positions, cells, species, cell_shifts, centers, pairs, structure_centers, structure_pairs, structure_offsets)

bare_cartesian_vectors = cartesian_vectors.values.squeeze(dim=-1)
r = torch.sqrt(
(bare_cartesian_vectors**2)
.sum(dim=-1)
)
samples_metadata = cartesian_vectors.samples # This can be needed by the radial basis to do alchemical contractions
radial_basis = self.radial_basis_calculator(r, samples_metadata)

samples_metadata = expanded_vectors.block({"o3_lambda": 0}).samples
spherical_harmonics = self.spherical_harmonics_calculator.compute(bare_cartesian_vectors) # Get the spherical harmonics
if self.normalize: spherical_harmonics *= (4*torch.pi)**(0.5) # normalize them
spherical_harmonics = torch.split(spherical_harmonics, self.spherical_harmonics_split_list, dim=1) # Split them into l chunks

n_species = len(self.all_species)
species_to_index = {atomic_number : i_species for i_species, atomic_number in enumerate(self.all_species)}

unique_s_i_indices = torch.stack((structure_centers, centers), dim=1)
s_i_metadata_to_unique = structure_offsets[structure_pairs] + pairs[:, 0]

l_max = self.vector_expansion_calculator.l_max
n_centers = len(centers) # total number of atoms in this batch of structures

densities = []
if self.is_alchemical:
density_indices = s_i_metadata_to_unique
for l in range(l_max+1):
expanded_vectors_l = expanded_vectors.block({"o3_lambda": l}).values
densities_l = torch.zeros(
(n_centers, expanded_vectors_l.shape[1], expanded_vectors_l.shape[2]),
dtype = expanded_vectors_l.dtype,
device = expanded_vectors_l.device
)
densities_l.index_add_(dim=0, index=density_indices, source=expanded_vectors_l)
densities_l = densities_l.reshape((n_centers, 2*l+1, -1))
for l in range(self.l_max+1):
# in the case of an alchemical model, the radial basis has an extra dimension (alpha_j)
densities_l = outer_product_scatter_add(spherical_harmonics[l], radial_basis[l].reshape(len(pairs), -1), density_indices.to(torch.int32), n_centers)
densities.append(densities_l)
unique_species = -torch.arange(self.n_pseudo_species, dtype=torch.int64, device=density_indices.device)
else:
aj_metadata = samples_metadata.column("neighbor_type")
aj_shifts = torch.tensor([species_to_index[int(aj_index)] for aj_index in aj_metadata], dtype=torch.int64, device=aj_metadata.device)
density_indices = s_i_metadata_to_unique*n_species+aj_shifts

for l in range(l_max+1):
expanded_vectors_l = expanded_vectors.block({"o3_lambda": l}).values
densities_l = torch.zeros(
(n_centers*n_species, expanded_vectors_l.shape[1], expanded_vectors_l.shape[2]),
dtype = expanded_vectors_l.dtype,
device = expanded_vectors_l.device
)
densities_l.index_add_(dim=0, index=density_indices, source=expanded_vectors_l)
for l in range(self.l_max+1):
densities_l = outer_product_scatter_add(spherical_harmonics[l], radial_basis[l], density_indices.to(torch.int32), n_centers*n_species)
densities_l = densities_l.reshape((n_centers, n_species, 2*l+1, -1)).swapaxes(1, 2).reshape((n_centers, 2*l+1, -1)) # need to swap n, a indices which are in the wrong order
densities.append(densities_l)
unique_species = torch.tensor(self.all_species, dtype=torch.int, device=species.device)

# constructs the TensorMap object
labels : List[List[int]] = []
blocks : List[TensorBlock] = []
for l in range(l_max+1):
for l in range(self.l_max+1):
densities_l = densities[l]
vectors_l_block = expanded_vectors.block({"o3_lambda": l})
vectors_l_block_components = vectors_l_block.components
vectors_l_block_n = torch.arange(len(torch.unique(vectors_l_block.properties.column("n"))), dtype=torch.int64, device=species.device) # Need to be smarter to optimize
vectors_l_block_n = torch.arange(self.n_max_l[l], dtype=torch.int32, device=species.device)
for a_i in self.all_species:
where_ai = torch.where(species == a_i)[0]
densities_ai_l = torch.index_select(densities_l, 0, where_ai)
Expand All @@ -208,7 +212,10 @@ def forward(self,
names = ["structure", "atom"],
values = unique_s_i_indices[where_ai]
),
components = vectors_l_block_components,
components = [Labels(
names = ("o3_mu",),
values = torch.arange(start=-l, end=l+1, dtype=torch.int32, device=spherical_harmonics[0].device).reshape(2*l+1, 1)
)],
properties = Labels(
names = ["neighbor_type", "n"],
values = torch.stack(
Expand Down

0 comments on commit d235144

Please sign in to comment.