Skip to content

Commit

Permalink
Simplify calculation of outer product.
Browse files Browse the repository at this point in the history
  • Loading branch information
lohedges committed Dec 6, 2023
1 parent 7f2038a commit 467ce9d
Showing 1 changed file with 6 additions and 32 deletions.
38 changes: 6 additions & 32 deletions emle/emle.py
Original file line number Diff line number Diff line change
Expand Up @@ -1969,7 +1969,12 @@ def _get_r_data(cls, xyz, device):

r_inv1 = r_inv.repeat_interleave(3, dim=1)
r_inv2 = r_inv1.repeat_interleave(3, dim=0)
outer = cls._get_outer(rr_mat, device)

# Get a stacked matrix of outer products over the rr_mat tensors.
outer = torch.einsum("bik,bij->bjik", rr_mat, rr_mat).reshape(
(n_atoms * 3, n_atoms * 3)
)

id2 = torch.tile(
torch.tile(
torch.eye(3, dtype=torch.float32, device=device).T, (1, n_atoms)
Expand All @@ -1984,37 +1989,6 @@ def _get_r_data(cls, xyz, device):

return {"r_mat": r_mat, "T01": t01, "T11": t11, "T21": t21, "T22": t22}

@staticmethod
def _get_outer(a, device):
"""
Internal method, calculates stacked matrix of outer products of a
list of vectors.
Parameters
----------
a: torch.tensor (N_ATOMS, 3)
List of vectors.
device: torch.device
The PyTorch device to use.
Returns
-------
result: torch.tensor (N_ATOMS * 3, N_ATOMS * 3)
"""
n = len(a)
idx = np.triu_indices(n, 1)

result = torch.zeros((n, n, 3, 3), dtype=torch.float32, device=device)
result[idx] = a[idx][:, :, None] @ a[idx][:, None, :]
tmp = result
result = result.swapaxes(0, 1)
result[idx] = tmp[idx]

return result.swapaxes(1, 2).reshape((n * 3, n * 3))

@classmethod
def _get_mesh_data(cls, xyz, xyz_mesh, s):
"""
Expand Down

0 comments on commit 467ce9d

Please sign in to comment.