Skip to content

Commit

Permalink
Faster ASE NL filtering and remove vesin checks
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Oct 29, 2024
1 parent 2c0a748 commit 7868391
Showing 1 changed file with 25 additions and 43 deletions.
68 changes: 25 additions & 43 deletions src/metatrain/utils/neighbor_lists.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import random
from typing import List

import ase.neighborlist
Expand Down Expand Up @@ -110,51 +109,34 @@ def _compute_single_neighbor_list(
cutoff=options.cutoff,
)

# Check the vesin NL against the ASE NL (5% of the time)
if random.random() < 0.05:
nl_i_ase, nl_j_ase, nl_S_ase, nl_D_ase = ase.neighborlist.neighbor_list(
"ijSD",
atoms,
cutoff=options.cutoff,
)
assert len(nl_i) == len(nl_i_ase)
assert len(nl_j) == len(nl_j_ase)
assert len(nl_S) == len(nl_S_ase)
assert len(nl_D) == len(nl_D_ase)
nl_ijS = np.concatenate(
(nl_i.reshape(-1, 1), nl_j.reshape(-1, 1), nl_S), axis=1
)
nl_ijS_ase = np.concatenate(
(nl_i_ase.reshape(-1, 1), nl_j_ase.reshape(-1, 1), nl_S_ase), axis=1
)
sort_indices = np.lexsort(nl_ijS.T)
sort_indices_ase = np.lexsort(nl_ijS_ase.T)
assert np.array_equal(nl_ijS[sort_indices], nl_ijS_ase[sort_indices_ase])
assert np.allclose(nl_D[sort_indices], nl_D_ase[sort_indices_ase])

selected = []
for pair_i, (i, j, S) in enumerate(zip(nl_i, nl_j, nl_S)):
# The pair selection code here below avoids a relatively slow loop over
# all pairs to improve performance
reject_condition = (
# we want a half neighbor list, so drop all duplicated neighbors
if j < i:
continue
elif i == j:
if S[0] == 0 and S[1] == 0 and S[2] == 0:
(nl_j < nl_i)
| (
(nl_i == nl_j)
& (
# only create pairs with the same atom twice if the pair spans more
# than one unit cell
continue
elif S[0] + S[1] + S[2] < 0 or (
(S[0] + S[1] + S[2] == 0) and (S[2] < 0 or (S[2] == 0 and S[1] < 0))
):
# When creating pairs between an atom and one of its periodic
# images, the code generate multiple redundant pairs (e.g. with
# shifts 0 1 1 and 0 -1 -1); and we want to only keep one of these.
# We keep the pair in the positive half plane of shifts.
continue

selected.append(pair_i)

selected = np.array(selected, dtype=np.int32)
n_pairs = len(selected)
((nl_S[:, 0] == 0) & (nl_S[:, 1] == 0) & (nl_S[:, 2] == 0))
|
# When creating pairs between an atom and one of its periodic images,
# the code generates multiple redundant pairs
# (e.g. with shifts 0 1 1 and 0 -1 -1); and we want to only keep one of
# these. We keep the pair in the positive half plane of shifts.
(
(nl_S.sum(axis=1) < 0)
| (
(nl_S.sum(axis=1) == 0)
& ((nl_S[:, 2] < 0) | ((nl_S[:, 2] == 0) & (nl_S[:, 1] < 0)))
)
)
)
)
)
selected = np.logical_not(reject_condition)
n_pairs = np.sum(selected)

if options.full_list:
distances = np.empty((2 * n_pairs, 3), dtype=np.float64)
Expand Down

0 comments on commit 7868391

Please sign in to comment.