Skip to content

Commit

Permalink
fixed neighbor computation for small systems in any dimension
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanljones committed Jan 27, 2022
1 parent eb980b0 commit 5c4e517
Showing 1 changed file with 7 additions and 15 deletions.
22 changes: 7 additions & 15 deletions lattpy/lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -1583,24 +1583,23 @@ def _filter_neighbors(self, indices, neighbors, distances, x_ind=None):
j = np.argsort(distances, axis=1)
distances = distances[i, j]
neighbors = neighbors[i, j]
num_sites = len(neighbors)

all_valid = np.any(distances != np.inf, axis=0)
if all(all_valid):
# No Invalid entries found. This usually happens for (2, 2, ..) systems
if not np.any(distances == np.inf):
# No invalid entries found. This usually happens for (2, 2, ..) systems
# which results in a bug for the neighbor data:
# Assume a 1D chain of 3 atoms. The outer two atoms each have only one
# neighbor (the center atom), whereas the center atom has two.
# In a model of two atoms no atom has two neighbors, which prevents
# the algorithm to create an array of size (N, 2).
# To prevent this a column of invalid values is appended to the data.
# Since the system size is small this doesn't create any memory issues.
col = np.full(shape=(len(neighbors), 1), fill_value=invalid_ind)
neighbors = np.append(neighbors, col, axis=1)

col = np.full(shape=(len(neighbors), 1), fill_value=np.inf)
distances = np.append(distances, col, axis=1)
shape = num_sites, 1
neighbors = np.append(neighbors, np.full(shape, invalid_ind), axis=1)
distances = np.append(distances, np.full(shape, np.inf), axis=1)
else:
# Remove columns containing only invalid data
all_valid = np.any(distances != np.inf, axis=0)
distances = distances[:, all_valid]
neighbors = neighbors[:, all_valid]

Expand Down Expand Up @@ -1678,15 +1677,8 @@ def compute_neighbors(self, indices: ArrayLike, positions: ArrayLike,
# Query and filter neighbors
neighbors, distances = tree.query(num_jobs=num_jobs,
decimals=self.DIST_DECIMALS)
print(neighbors)
neighbors, distances = self._filter_neighbors(indices, neighbors, distances)

# # Fix bug for two sites:
# # Add invalid indices and distances to data
# if len(indices) == 2:
# neighbors = np.array([[neighbors[0, 0], 2], [neighbors[1, 0], 2]])
# distances = np.array([[distances[0, 0], np.inf], [distances[1, 0], np.inf]])

return neighbors, distances

# ==================================================================================
Expand Down

0 comments on commit 5c4e517

Please sign in to comment.