From 5c4e517017fe042023c1840e21bae4860b6fb915 Mon Sep 17 00:00:00 2001 From: dylanljones Date: Thu, 27 Jan 2022 15:14:30 +0100 Subject: [PATCH] fixed neighbor computation for small systems in any dimension --- lattpy/lattice.py | 22 +++++++--------------- 1 file changed, 7 insertions(+), 15 deletions(-) diff --git a/lattpy/lattice.py b/lattpy/lattice.py index b9db3a5..ceb4971 100644 --- a/lattpy/lattice.py +++ b/lattpy/lattice.py @@ -1583,10 +1583,10 @@ def _filter_neighbors(self, indices, neighbors, distances, x_ind=None): j = np.argsort(distances, axis=1) distances = distances[i, j] neighbors = neighbors[i, j] + num_sites = len(neighbors) - all_valid = np.any(distances != np.inf, axis=0) - if all(all_valid): - # No Invalid entries found. This usually happens for (2, 2, ..) systems + if not np.any(distances == np.inf): + # No invalid entries found. This usually happens for (2, 2, ..) systems # which results in a bug for the neighbor data: # Assume a 1D chain of 3 atoms. The outer two atoms each have only one # neighbor (the center atom), whereas the center atom has two. @@ -1594,13 +1594,12 @@ def _filter_neighbors(self, indices, neighbors, distances, x_ind=None): # the algorithm to create an array of size (N, 2). # To prevent this a column of invalid values is appended to the data. # Since the system size is small this doesn't create any memory issues. - col = np.full(shape=(len(neighbors), 1), fill_value=invalid_ind) - neighbors = np.append(neighbors, col, axis=1) - - col = np.full(shape=(len(neighbors), 1), fill_value=np.inf) - distances = np.append(distances, col, axis=1) + shape = num_sites, 1 + neighbors = np.append(neighbors, np.full(shape, invalid_ind), axis=1) + distances = np.append(distances, np.full(shape, np.inf), axis=1) else: # Remove columns containing only invalid data + all_valid = np.any(distances != np.inf, axis=0) distances = distances[:, all_valid] neighbors = neighbors[:, all_valid] @@ -1678,15 +1677,8 @@ def compute_neighbors(self, indices: ArrayLike, positions: ArrayLike, # Query and filter neighbors neighbors, distances = tree.query(num_jobs=num_jobs, decimals=self.DIST_DECIMALS) - print(neighbors) neighbors, distances = self._filter_neighbors(indices, neighbors, distances) - # # Fix bug for two sites: - # # Add invalid indices and distances to data - # if len(indices) == 2: - # neighbors = np.array([[neighbors[0, 0], 2], [neighbors[1, 0], 2]]) - # distances = np.array([[distances[0, 0], np.inf], [distances[1, 0], np.inf]]) - return neighbors, distances # ==================================================================================