diff --git a/src/metatrain/utils/neighbor_lists.py b/src/metatrain/utils/neighbor_lists.py index 91f9d964..dc8e4c8b 100644 --- a/src/metatrain/utils/neighbor_lists.py +++ b/src/metatrain/utils/neighbor_lists.py @@ -1,4 +1,3 @@ -import random from typing import List import ase.neighborlist @@ -110,51 +109,34 @@ def _compute_single_neighbor_list( cutoff=options.cutoff, ) - # Check the vesin NL against the ASE NL (5% of the time) - if random.random() < 0.05: - nl_i_ase, nl_j_ase, nl_S_ase, nl_D_ase = ase.neighborlist.neighbor_list( - "ijSD", - atoms, - cutoff=options.cutoff, - ) - assert len(nl_i) == len(nl_i_ase) - assert len(nl_j) == len(nl_j_ase) - assert len(nl_S) == len(nl_S_ase) - assert len(nl_D) == len(nl_D_ase) - nl_ijS = np.concatenate( - (nl_i.reshape(-1, 1), nl_j.reshape(-1, 1), nl_S), axis=1 - ) - nl_ijS_ase = np.concatenate( - (nl_i_ase.reshape(-1, 1), nl_j_ase.reshape(-1, 1), nl_S_ase), axis=1 - ) - sort_indices = np.lexsort(nl_ijS.T) - sort_indices_ase = np.lexsort(nl_ijS_ase.T) - assert np.array_equal(nl_ijS[sort_indices], nl_ijS_ase[sort_indices_ase]) - assert np.allclose(nl_D[sort_indices], nl_D_ase[sort_indices_ase]) - - selected = [] - for pair_i, (i, j, S) in enumerate(zip(nl_i, nl_j, nl_S)): + # The pair selection code here below avoids a relatively slow loop over + # all pairs to improve performance + reject_condition = ( # we want a half neighbor list, so drop all duplicated neighbors - if j < i: - continue - elif i == j: - if S[0] == 0 and S[1] == 0 and S[2] == 0: + (nl_j < nl_i) + | ( + (nl_i == nl_j) + & ( # only create pairs with the same atom twice if the pair spans more # than one unit cell - continue - elif S[0] + S[1] + S[2] < 0 or ( - (S[0] + S[1] + S[2] == 0) and (S[2] < 0 or (S[2] == 0 and S[1] < 0)) - ): - # When creating pairs between an atom and one of its periodic - # images, the code generate multiple redundant pairs (e.g. with - # shifts 0 1 1 and 0 -1 -1); and we want to only keep one of these. - # We keep the pair in the positive half plane of shifts. - continue - - selected.append(pair_i) - - selected = np.array(selected, dtype=np.int32) - n_pairs = len(selected) + ((nl_S[:, 0] == 0) & (nl_S[:, 1] == 0) & (nl_S[:, 2] == 0)) + | + # When creating pairs between an atom and one of its periodic images, + # the code generates multiple redundant pairs + # (e.g. with shifts 0 1 1 and 0 -1 -1); and we want to only keep one of + # these. We keep the pair in the positive half plane of shifts. + ( + (nl_S.sum(axis=1) < 0) + | ( + (nl_S.sum(axis=1) == 0) + & ((nl_S[:, 2] < 0) | ((nl_S[:, 2] == 0) & (nl_S[:, 1] < 0))) + ) + ) + ) + ) + ) + selected = np.logical_not(reject_condition) + n_pairs = np.sum(selected) if options.full_list: distances = np.empty((2 * n_pairs, 3), dtype=np.float64)