Skip to content

Commit

Permalink
update License and refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanljones committed Jan 19, 2022
1 parent 75217ff commit 119e42e
Show file tree
Hide file tree
Showing 13 changed files with 179 additions and 132 deletions.
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
MIT License

Copyright (c) 2021, Dylan Jones
Copyright (c) 2022, Dylan Jones

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
2 changes: 1 addition & 1 deletion lattpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# This code is part of lattpy.
#
# Copyright (c) 2021, Dylan Jones
# Copyright (c) 2022, Dylan Jones
#
# This code is licensed under the MIT License. The copyright notice in the
# LICENSE file in the root directory and this permission notice shall
Expand Down
105 changes: 53 additions & 52 deletions lattpy/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# This code is part of lattpy.
#
# Copyright (c) 2021, Dylan Jones
# Copyright (c) 2022, Dylan Jones
#
# This code is licensed under the MIT License. The copyright notice in the
# LICENSE file in the root directory and this permission notice shall
Expand Down Expand Up @@ -72,9 +72,8 @@ def onsite(self, alpha: Optional[int] = None) -> np.ndarray:
Parameters
----------
alpha : int, optional
Index of the atom in the unitcell. If `None`a mask for all atoms is returned.
The default is `None`.
Index of the atom in the unitcell. If `None` a mask for all atoms
is returned. The default is `None`.
Returns
-------
mask : np.ndarray
Expand All @@ -90,7 +89,8 @@ def hopping(self, distidx: Optional[int] = None) -> np.ndarray:
----------
distidx : int, optional
Index of distance to neighboring sites, default is 0 (nearest neighbors).
If `None` a mask for neighbor-connections is returned. The default is `None`.
If `None` a mask for neighbor-connections is returned. The default is
`None`.
Returns
-------
mask : np.ndarray
Expand All @@ -114,10 +114,9 @@ def fill(self, array: np.ndarray, hop: ArrayLike,
eps : array_like, optional
The onsite values used for the lattice sites. If there are multiple atoms
in the unitcell the length of the values must match. The default is 0.
Returns
-------
filled: np.ndarray
filled : np.ndarray
"""
eps = np.atleast_1d(eps)
hop = np.atleast_1d(hop)
Expand Down Expand Up @@ -206,13 +205,13 @@ def set(self, indices: Sequence[Iterable[int]],
Parameters
----------
indices: array_like of iterable of int
indices : array_like of iterable of int
The lattice indices of the sites.
positions: array_like of iterable of int
positions : array_like of iterable of int
The positions of the sites.
neighbors: iterable of iterable of of int
neighbors : iterable of iterable of of int
The neighbors of the sites.
distances: iterabe of iterable of int
distances : iterabe of iterable of int
The distances of the neighbors.
"""
logger.debug("Setting data")
Expand All @@ -234,10 +233,11 @@ def get_limits(self) -> np.ndarray:
Returns
-------
limits: np.ndarray
limits : np.ndarray
The minimum and maximum value for each axis of the position data.
"""
return np.array([np.min(self.positions, axis=0), np.max(self.positions, axis=0)])
return np.array([np.min(self.positions, axis=0),
np.max(self.positions, axis=0)])

def get_index_limits(self) -> np.ndarray:
"""Computes the geometric limits of the lattice indices of the stored sites.
Expand All @@ -254,7 +254,7 @@ def get_translation_limits(self) -> np.ndarray:
Returns
-------
limits: np.ndarray
limits : np.ndarray
The minimum and maximum value for each axis of the lattice indices.
"""
return self.get_index_limits()[:, :-1]
Expand All @@ -266,21 +266,21 @@ def neighbor_mask(self, site: int, distidx: Optional[int] = None,
Parameters
----------
site: int
site : int
The index of the site.
distidx: int, optional
The index of the distance. If ``None`` the data for all distances is returned.
distidx : int, optional
The index of the distance. If `None` the data for all distances is returned.
The default is `None` (all neighbors).
periodic: bool, optional
Periodic neighbor flag. If ``None`` the data for all neighbors is returned.
If a bool is passed either the periodic or non-periodic neighbors are masked.
The default is ``None`` (all neighbors).
unique: bool, optional
If 'True', each unique pair is only return once. The defualt is ``False``.
periodic : bool, optional
Periodic neighbor flag. If `None` the data for all neighbors is returned.
If a bool is passed either the periodic or non-periodic neighbors
are masked. The default is `None` (all neighbors).
unique : bool, optional
If 'True', each unique pair is only return once. The defualt is `False`.
Returns
-------
mask: np.ndarray
mask : np.ndarray
"""
if distidx is None:
mask = self.distances[site] < self.invalid_distidx
Expand All @@ -297,15 +297,15 @@ def neighbor_mask(self, site: int, distidx: Optional[int] = None,
return mask

def set_periodic(self, indices: dict, distances: dict, axes: dict) -> None:
""" Adds periodic neighbors to the invalid slots of the neighbor data
"""Adds periodic neighbors to the invalid slots of the neighbor data
Parameters
----------
indices: dict
indices : dict
Indices of the periodic neighbors.
distances: dict
distances : dict
The distances of the periodic neighbors.
axes: dict
axes : dict
Index of the translation axis of the periodic neighbors.
"""
for i, pidx in indices.items():
Expand Down Expand Up @@ -397,12 +397,12 @@ def append(self, *args, copy=False):
cols = max(cols1, cols2)
if cols1 < cols:
widths = ((0, 0), (0, cols - cols1))
neighbors1 = np.pad(neighbors1, pad_width=widths, constant_values=invalid_idx)
distances1 = np.pad(distances1, pad_width=widths, constant_values=np.inf)
neighbors1 = np.pad(neighbors1, widths, constant_values=invalid_idx)
distances1 = np.pad(distances1, widths, constant_values=np.inf)
if cols2 < cols:
widths = ((0, 0), (0, cols - cols2))
neighbors2 = np.pad(neighbors2, pad_width=widths, constant_values=invalid_idx)
distances2 = np.pad(distances2, pad_width=widths, constant_values=np.inf)
neighbors2 = np.pad(neighbors2, widths, constant_values=invalid_idx)
distances2 = np.pad(distances2, widths, constant_values=np.inf)

# Join data
indices = np.append(self.indices, indices2, axis=0)
Expand All @@ -427,13 +427,13 @@ def get_positions(self, alpha):
def get_neighbors(self, site: int, distidx: Optional[int] = None,
periodic: Optional[bool] = None,
unique: Optional[bool] = False) -> np.ndarray:
"""Returns all neighbors or the neighbors for a certain distance of a lattice site.
"""Returns the neighbors of a lattice site.
See the `neighbor_mask`-method for more information on parameters
Returns
-------
neighbors: np.ndarray
neighbors : np.ndarray
The indices of the neighbors.
"""
mask = self.neighbor_mask(site, distidx, periodic, unique)
Expand All @@ -448,7 +448,7 @@ def get_neighbor_pos(self, site: int, distidx: Optional[int] = None,
Returns
-------
neighbor_positions: np.ndarray
neighbor_positions : np.ndarray
The positions of the neighbors.
"""
ind = self.get_neighbors(site, distidx, periodic, unique)
Expand All @@ -463,19 +463,19 @@ def iter_neighbors(self, site: int, unique: Optional[bool] = False) -> np.ndarra
Yields
-------
distidx: int
neighbors: np.ndarray
distidx : int
neighbors : np.ndarray
"""
for distidx in np.unique(self.distances[site]):
if distidx != self.invalid_distidx:
yield distidx, self.get_neighbors(site, distidx, unique=unique)

def map(self) -> DataMap:
""" Builds a map containing the atom-indices, site-pairs and corresponding distances.
"""Builds a map containing the atom-indices, site-pairs and distances.
Returns
-------
datamap: DataMap
datamap : DataMap
"""
if self._dmap is None:
alphas = self.indices[:, -1].astype(np.int8)
Expand All @@ -499,16 +499,16 @@ def site_mask(self, mins: Optional[Sequence[Union[float, None]]] = None,
Parameters
----------
mins: sequence or float or None, optional
mins : sequence or float or None, optional
Optional lower bound for the positions. The default is no lower bound.
maxs: sequence or float or None, optional
maxs : sequence or float or None, optional
Optional upper bound for the positions. The default is no upper bound.
invert: bool, optional
invert : bool, optional
If `True`, the mask is inverted. The default is `False`.
Returns
-------
mask: np.ndarray
mask : np.ndarray
The mask containing a boolean value for each site.
"""
if mins is None:
Expand All @@ -535,11 +535,11 @@ def find_sites(self, mins: Optional[Sequence[Union[float, None]]] = None,
Parameters
----------
mins: sequence or float or None, optional
mins : sequence or float or None, optional
Optional lower bound for the positions. The default is no lower bound.
maxs: sequence or float or None, optional
maxs : sequence or float or None, optional
Optional upper bound for the positions. The default is no upper bound.
invert: bool, optional
invert : bool, optional
If `True`, the mask is inverted and the positions outside of the bounds
will be returned. The default is `False`.
Expand All @@ -556,14 +556,14 @@ def find_outer_sites(self, ax: int, offset: int) -> np.ndarray:
Parameters
----------
ax: int
ax : int
The geometrical axis.
offset: int
offset : int
The width of the outer slices.
Returns
-------
indices: np.ndarray
indices : np.ndarray
The indices of the masked sites.
"""
limits = self.get_limits()
Expand All @@ -578,15 +578,16 @@ def __bool__(self) -> bool:
return bool(len(self.indices))

def __str__(self) -> str:
widths = 9, 15, 10
w = 9, 15, 10
delim = " | "
headers = "Indices", "Positions", "Neighbours"
lines = list()
s = f"{headers[0]:<{widths[0]}}{delim}{headers[1]:<{widths[1]}}{delim}{headers[2]}"
s = f"{headers[0]:<{w[0]}}{delim}{headers[1]:<{w[1]}}{delim}{headers[2]}"
lines.append(s)
for site in range(self.num_sites):
pos = "[" + ", ".join(f"{x:.1f}" for x in self.positions[site]) + "]"
idx = str(self.indices[site])
neighbors = str(self.neighbors[site])
lines.append(f"{idx:<{widths[0]}}{delim}{pos:<{widths[1]}}{delim}{neighbors}")
s = f"{idx:<{w[0]}}{delim}{pos:<{w[1]}}{delim}{neighbors}"
lines.append(s)
return "\n".join(lines)
16 changes: 9 additions & 7 deletions lattpy/disptools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# This code is part of lattpy.
#
# Copyright (c) 2021, Dylan Jones
# Copyright (c) 2022, Dylan Jones
#
# This code is licensed under the MIT License. The copyright notice in the
# LICENSE file in the root directory and this permission notice shall
Expand Down Expand Up @@ -37,7 +37,7 @@ def _color_list(color, num_bands):
def _scale_xaxis(num_points, disp, scales=None):
sect_size = len(disp) / (num_points - 1)
scales = np.ones(num_points - 1) if scales is None else scales
k0, k, ticks = 0, list(), [0]
k0, k, ticks = 0, [], [0]
for scale in scales:
k.extend(k0 + np.arange(sect_size) * scale)
k0 = k[-1]
Expand Down Expand Up @@ -79,8 +79,9 @@ def _draw_dispersion(ax, k, disp, color=None, fill=False, alpha=0.2, lw=1.0):
ax.fill_between(x, min(band), max(band), color=col, alpha=alpha)


def plot_dispersion(disp, labels, xlabel="$k$", ylabel="$E(k)$", grid="both", color=None,
alpha=0.2, lw=1.0, scales=None, fill=False, ax=None, show=True):
def plot_dispersion(disp, labels, xlabel="$k$", ylabel="$E(k)$", grid="both",
color=None, alpha=0.2, lw=1.0, scales=None, fill=False,
ax=None, show=True):
num_points = len(labels)
k, ticks = _scale_xaxis(num_points, disp, scales)
if ax is None:
Expand Down Expand Up @@ -132,7 +133,8 @@ def plot_disp_dos(disp, dos_data, labels, xlabel="k", ylabel="E(k)", doslabel="n
num_points = len(labels)
k, ticks = _scale_xaxis(num_points, disp, scales)
if axs is None:
fig, axs = disp_dos_subplots(ticks, labels, xlabel, ylabel, doslabel, wratio, grid)
fig, axs = disp_dos_subplots(ticks, labels, xlabel, ylabel, doslabel, wratio,
grid)
ax1, ax2 = axs
else:
ax1, ax2 = axs
Expand Down Expand Up @@ -163,8 +165,8 @@ def plot_disp_dos(disp, dos_data, labels, xlabel="k", ylabel="E(k)", doslabel="n
return axs


def plot_bands(kgrid, bands, k_label="k", disp_label="E(k)", grid="both", contour_grid=False,
bz=None, pi_ticks=True, ax=None, show=True):
def plot_bands(kgrid, bands, k_label="k", disp_label="E(k)", grid="both",
contour_grid=False, bz=None, pi_ticks=True, ax=None, show=True):
if ax is None:
fig, ax = plt.subplots()
else:
Expand Down
15 changes: 10 additions & 5 deletions lattpy/lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# This code is part of lattpy.
#
# Copyright (c) 2021, Dylan Jones
# Copyright (c) 2022, Dylan Jones
#
# This code is licensed under the MIT License. The copyright notice in the
# LICENSE file in the root directory and this permission notice shall
Expand Down Expand Up @@ -96,7 +96,8 @@ def __init__(self, vectors: Union[float, Sequence[float], Sequence[Sequence[floa
self.data = LatticeData()
self.shape = None
self.periodic_axes = list()
logger.debug("Lattice initialized (D=%i)\n vectors:\n%s", self.dim, self._vectors)
logger.debug("Lattice initialized (D=%i)\nvectors:\n%s",
self.dim, self._vectors)

@classmethod
def chain(cls, a: Optional[float] = 1.0, **kwargs) -> 'Lattice':
Expand All @@ -107,7 +108,8 @@ def square(cls, a: Optional[float] = 1.0, **kwargs) -> 'Lattice':
return cls(a * np.eye(2), **kwargs)

@classmethod
def rectangular(cls, a1: Optional[float] = 1., a2: Optional[float] = 1., **kwargs) -> 'Lattice':
def rectangular(cls, a1: Optional[float] = 1., a2: Optional[float] = 1.,
**kwargs) -> 'Lattice':
return cls(np.array([[a1, 0], [0, a2]]), **kwargs)

@classmethod
Expand All @@ -122,8 +124,11 @@ def oblique(cls, alpha: float, a1: Optional[float] = 1.0,
return cls(vectors, **kwargs)

@classmethod
def hexagonal3D(cls, a: Optional[float] = 1., az: Optional[float] = 1., **kwargs) -> 'Lattice': # noqa
vectors = a / 2 * np.array([[3, np.sqrt(3), 0], [3, -np.sqrt(3), 0], [0, 0, az]])
def hexagonal3D(cls, a: Optional[float] = 1., az: Optional[float] = 1.,
**kwargs) -> 'Lattice': # noqa
vectors = a / 2 * np.array([[3, np.sqrt(3), 0],
[3, -np.sqrt(3), 0],
[0, 0, az]])
return cls(vectors, **kwargs)

@classmethod
Expand Down
Loading

0 comments on commit 119e42e

Please sign in to comment.