Skip to content

Commit

Permalink
Optimisations to the _get_r_data method.
Browse files Browse the repository at this point in the history
  • Loading branch information
lohedges committed Dec 6, 2023
1 parent f8947a7 commit 7f2038a
Showing 1 changed file with 2 additions and 12 deletions.
14 changes: 2 additions & 12 deletions emle/emle.py
Original file line number Diff line number Diff line change
Expand Up @@ -1964,18 +1964,8 @@ def _get_r_data(cls, xyz, device):
n_atoms = len(xyz)

rr_mat = xyz[:, None, :] - xyz[None, :, :]

r2_mat = torch.sum(rr_mat**2, axis=2)
r_mat = torch.sqrt(torch.where(r2_mat > 0.0, r2_mat, 1.0))

new_diag = torch.zeros_like(
r_mat.diagonal(), dtype=torch.float32, device=device
)
mask = torch.diag(torch.ones_like(new_diag, dtype=torch.float32, device=device))
r_mat = mask * torch.diag(new_diag) + (1.0 - mask) * r_mat

tmp = torch.where(r_mat == 0.0, 1.0, r_mat)
r_inv = torch.where(r_mat == 0.0, 0.0, 1.0 / tmp)
r_mat = torch.cdist(xyz, xyz)
r_inv = torch.where(r_mat == 0.0, 0.0, 1.0 / r_mat)

r_inv1 = r_inv.repeat_interleave(3, dim=1)
r_inv2 = r_inv1.repeat_interleave(3, dim=0)
Expand Down

0 comments on commit 7f2038a

Please sign in to comment.