Skip to content

Commit

Permalink
Handle one shift in kpoints_indices
Browse files Browse the repository at this point in the history
  • Loading branch information
gmatteo committed Feb 18, 2025
1 parent 4f401df commit f3cbc40
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 27 deletions.
6 changes: 1 addition & 5 deletions abipy/abio/tests/test_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,11 +615,7 @@ def test_dfpt_methods(self):
gs_inp.abiget_irred_phperts()

with self.assertRaises(gs_inp.Error):
try:
ddk_inputs = gs_inp.make_ddk_inputs(tolerance={"tolfoo": 1e10})
except Exception as exc:
print(exc)
raise
ddk_inputs = gs_inp.make_ddk_inputs(tolerance={"tolfoo": 1e10})

phg_inputs = gs_inp.make_ph_inputs_qpoint(qpt=(0, 0, 0), tolerance=None)
#print("phonon inputs at Gamma\n", phg_inputs)
Expand Down
14 changes: 10 additions & 4 deletions abipy/core/kpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ def map_kpoints(other_kpoints, other_lattice, ref_lattice, ref_kpoints, ref_symr
# #return irred_map


def kpoints_indices(frac_coords, ngkpt, check_mesh=0) -> np.ndarray:
def kpoints_indices(frac_coords, ngkpt, shift, check_mesh=0) -> np.ndarray:
"""
This function is used when we need to insert k-dependent quantities in a (nx, ny, nz) array.
It computes and returns the indices of the k-points assuming these points
Expand All @@ -422,28 +422,34 @@ def kpoints_indices(frac_coords, ngkpt, check_mesh=0) -> np.ndarray:
Args:
frac_coords: array with the fractional coordinates of the k-points.
ngkpt: Number of divisions of the mesh.
shift: Grid shift (only one shift is supported here)
check_mesh: > 0 to activate debugging sections.
"""
shift = np.reshape(shift, (3,))
if np.any(np.abs(shift) > 1e-6):
# Unshift the points
frac_coords = np.array(frac_coords) - shift

# Transforms kpt in its corresponding reduced number in the interval [0,1[
k_indices = [np.round((kpt % 1) * ngkpt) for kpt in frac_coords]
k_indices = np.array(k_indices, dtype=int)

# Debug secction.
if check_mesh:
print(f"kpoints_indices: Testing whether k-points belong to the {ngkpt =} mesh")
print(f"kpoints_indices: Testing whether k-points belong to the {ngkpt=} mesh")
ierr = 0
for kpt, inds in zip(frac_coords, k_indices):
if check_mesh > 1: print("kpt:", kpt, "inds:", inds)
same_k = np.array((inds[0]/ngkpt[0], inds[1]/ngkpt[1], inds[2]/ngkpt[2]))
if not issamek(kpt, same_k):
ierr += 1; print(kpt, "-->", same_k)
if ierr:
raise ValueError("Wrong mapping")
raise ValueError(f"Wrong mapping, {ierr=}")

#for kpt, inds in zip(frac_coords, k_indices):
# if np.any(inds >= ngkpt):
# raise ValueError(f"inds >= nkgpt for {kpt=}, {np.round(kpt % 1)=} {inds=})")
print("Check succesfull!")
print("check_mesh succesfull!")

return k_indices

Expand Down
55 changes: 54 additions & 1 deletion abipy/core/tests/test_kpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from abipy import abilab
from abipy.core.kpoints import (wrap_to_ws, wrap_to_bz, issamek, Kpoint, KpointList, IrredZone, Kpath, KpointsReader,
has_timrev_from_kptopt, KSamplingInfo, as_kpoints, rc_list, kmesh_from_mpdivs, map_grid2ibz,
set_atol_kdiff, set_spglib_tols, kpath_from_bounds_and_ndivsm, build_segments) #Ktables,
set_atol_kdiff, set_spglib_tols, kpath_from_bounds_and_ndivsm, build_segments, kpoints_indices) #Ktables,
from abipy.core.testing import AbipyTest


Expand Down Expand Up @@ -654,3 +654,56 @@ def test_map_grid2ibz(self):
# k = Ktables(self.mgb2, mesh, is_shift, has_timrev)
# repr(k); str(k)
# k.print_bz2ibz()

def test_kpoints_indices(self):
"""Testing kpoints_indices"""
# test basic_functionality
frac_coords = np.array([[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]])
ngkpt = [4, 4, 4]
shift = [0.0, 0.0, 0.0]

expected_indices = np.array([[0, 0, 0], [2, 2, 2]])
computed_indices = kpoints_indices(frac_coords, ngkpt, shift, check_mesh=1)

self.assert_equal(computed_indices, expected_indices)

# test with shift
frac_coords = np.array([[0.25, 0.25, 0.25], [0.75, 0.75, 0.75]])
ngkpt = [4, 4, 4]
shift = [0.25, 0.25, 0.25]

# Shift is removed
expected_indices = np.array([[0, 0, 0], [2, 2, 2]])
computed_indices = kpoints_indices(frac_coords, ngkpt, shift, check_mesh=1)

self.assert_equal(computed_indices, expected_indices)

# test periodic boundary conditions
frac_coords = np.array([[1.0, 1.0, 1.0], [-0.25, -0.25, -0.25]])
ngkpt = [4, 4, 4]
shift = [0.0, 0.0, 0.0]

expected_indices = np.array([[0, 0, 0], [3, 3, 3]]) # 1.0 and -0.25 wrap correctly
computed_indices = kpoints_indices(frac_coords, ngkpt, shift, check_mesh=1)

self.assert_equal(computed_indices, expected_indices)

# test rounding behavior
frac_coords = np.array([[0.49, 0.49, 0.49], [0.51, 0.51, 0.51]])
ngkpt = [10, 10, 10]
shift = [0.0, 0.0, 0.0]

# Both round to 5
#expected_indices = np.array([[5, 5, 5], [5, 5, 5]])
#computed_indices = kpoints_indices(frac_coords, ngkpt, shift, check_mesh=1)
#self.assert_equal(computed_indices, expected_indices)

# test_check_mesh
frac_coords = np.array([[0.2, 0.4, 0.6]])
ngkpt = [5, 5, 5]
shift = [0.0, 0.0, 0.0]

computed_indices = kpoints_indices(frac_coords, ngkpt, shift, check_mesh=2)

# Ensure shape is correct
self.assert_equal(computed_indices.shape, (1, 3))
8 changes: 5 additions & 3 deletions abipy/electrons/orbmag.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,13 @@ def __init__(self, filepaths: list):
Args:
filepaths: List of filepaths to ORBMAG.nc files
"""
if not isinstance(filepaths, (list, tuple):
raise TypeError(f"Expecting list or tuple with paths but got {type(filepaths)")
if not isinstance(filepaths, (list, tuple)):
raise TypeError(f"Expecting list or tuple with paths but got {type(filepaths)=}")
if len(filepaths) != 3:
raise ValueError(f"{len(filepaths)=} != 3")

# TODO: One should store the direction in the netcdf file
# so that we can check that the files are given in the right order.
self.orb_files = [OrbmagFile(path) for path in filepaths]

# This piece of code is taken from merge_orbmag_mesh. The main difference
Expand Down Expand Up @@ -220,7 +222,7 @@ def insert_inbox(self, what: str, spin: int) -> tuple:
# Need to know the shape of the k-mesh.
ngkpt, shifts = self.ngkpt_and_shifts
orb = self.orb_files[0]
k_indices = kpoints_indices(orb.kpoints.frac_coords, ngkpt)
k_indices = kpoints_indices(orb.kpoints.frac_coords, ngkpt, shifts)
nx, ny, nz = ngkpt

# I'd start by weighting each band and kpt by trace(sigij)/3.0, the isotropic part of sigij,
Expand Down
4 changes: 2 additions & 2 deletions abipy/eph/gstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def get_g2q_interpolator_kpoint(self, kpoint, method="linear", check_mesh=1) ->

# Compute indices of qpoints in the ngqpt mesh.
ngqpt, shifts = r.ngqpt, [0, 0, 0]
q_indices = kpoints_indices(r.qbz, ngqpt, check_mesh=check_mesh)
q_indices = kpoints_indices(r.qbz, ngqpt, shifts, check_mesh=check_mesh)

natom3 = 3 * len(self.structure)
nb = self.nb
Expand Down Expand Up @@ -585,4 +585,4 @@ def write_notebook(self, nbpath=None) -> str:
#nb.cells.extend(self.get_baserobot_code_cells())
#nb.cells.extend(self.get_ebands_code_cells())

return self._write_nb_nbpath(nb, nbpath)
return self._write_nb_nbpath(nb, nbpath)
2 changes: 1 addition & 1 deletion abipy/eph/gwan.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def get_g2q_interpolator_kpoint(self, kpoint, method="linear", check_mesh=1):

# Compute indices of qpoints in the ngqpt mesh.
ngqpt, shifts = r.ngqpt, [0, 0, 0]
q_indices = kpoints_indices(r.qbz, ngqpt, check_mesh=check_mesh)
q_indices = kpoints_indices(r.qbz, ngqpt, shifts, check_mesh=check_mesh)

natom3 = 3 * len(self.structure)
nb = self.nb
Expand Down
17 changes: 7 additions & 10 deletions abipy/eph/vpq.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,8 @@ def insert_a_inbox(self, fill_value=None) -> tuple:
"""
# Need to know the shape of the k-mesh.
ngkpt, shifts = self.ngkpt_and_shifts
k_indices = kpoints_indices(self.kpoints, ngkpt)
k_indices = kpoints_indices(self.kpoints, ngkpt, shifts)
#print(f"{k_indices=}")
nx, ny, nz = ngkpt

shape = (self.nstates, self.nb, nx, ny, nz)
Expand All @@ -422,7 +423,7 @@ def insert_b_inbox(self, fill_value=None) -> tuple:
"""
# Need to know the shape of the q-mesh (always Gamma-centered)
ngqpt, shifts = self.varpeq.r.ngqpt, [0, 0, 0]
q_indices = kpoints_indices(self.qpoints, ngqpt)
q_indices = kpoints_indices(self.qpoints, ngqpt, shifts)

natom3 = 3 * len(self.structure)
nx, ny, nz = ngqpt
Expand Down Expand Up @@ -595,16 +596,14 @@ def plot_ank_with_ebands(self, ebands_kpath,
gridspec_kw = {'width_ratios': [2, 1]}
ax_mat, fig, plt = get_axarray_fig_plt(ax_mat, nrows=nrows, ncols=ncols,
sharex=False, sharey=True, squeeze=False, gridspec_kw=gridspec_kw)
# Get interpolators for A_nk
# Get interpolators for |A_nk|^2
a2_interp_state = self.get_a2_interpolator_state(interp_method)

df = self.get_final_results_df()

ebands_kpath = ElectronBands.as_ebands(ebands_kpath)
ymin, ymax = +np.inf, -np.inf

a_data, *_ = self.insert_a_inbox(fill_value=0)

pkind = self.varpeq.r.vpq_pkind
vbm_or_cbm = "vbm" if pkind == "hole" else "cbm"
bm = self.ebands.get_edge_state(vbm_or_cbm, self.spin).eig
Expand All @@ -613,7 +612,7 @@ def plot_ank_with_ebands(self, ebands_kpath,
for pstate in range(self.nstates):
x, y, s = [], [], []

a2_max = np.max(np.abs(a_data[pstate]))**2
a2_max = a2_interp_state[pstate].get_max_abs_data()
scale *= 1. / a2_max

for ik, kpoint in enumerate(ebands_kpath.kpoints):
Expand Down Expand Up @@ -858,11 +857,9 @@ def plot_bqnu_with_phbands(self, phbands_qpath, with_legend=True,

phbands_qpath = PhononBands.as_phbands(phbands_qpath)

# Get interpolators for B_qnu
# Get interpolators for |B_qnu|^2
b2_interp_state = self.get_b2_interpolator_state(interp_method)

b_data, *_ = self.insert_b_inbox(fill_value=0)

# TODO: need to fix this hardcoded representation
units = 'meV'
units_scale = 1e3 if units == 'meV' else 1
Expand All @@ -872,7 +869,7 @@ def plot_bqnu_with_phbands(self, phbands_qpath, with_legend=True,
for pstate in range(self.nstates):
x, y, s = [], [], []

b2_max = np.max(np.abs(b_data[pstate]))**2
b2_max = b2_interp_state[pstate].get_max_abs_data()
scale *= 1. / b2_max

for iq, qpoint in enumerate(phbands_qpath.qpoints):
Expand Down
8 changes: 7 additions & 1 deletion abipy/tools/numtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@ def __init__(self, structure, shifts, datak, add_replicas=True, **kwargs):
Args:
structure: :class:`Structure` object.
datak: [ndat, nx, ny, nz] array.
shifts: Shift of the mesh.
shifts: Shifts of the mesh (only one shift is supported here)
add_replicas: If True, data is padded with redundant data points.
in order to have a periodic 3D array of shape=[ndat, nx+1, ny+1, nz+1].
kwargs: Extra arguments are passed to RegularGridInterpolator e.g.: method
Expand All @@ -576,6 +576,7 @@ def __init__(self, structure, shifts, datak, add_replicas=True, **kwargs):

if self.shifts.shape[0] != 1:
raise ValueError(f"Multiple shifts are not supported! {self.shifts.shape[0]=}")

if np.any(self.shifts[0] != 0):
raise ValueError(f"Shift should be zero but got: {self.shifts=}")

Expand Down Expand Up @@ -608,6 +609,8 @@ def __init__(self, structure, shifts, datak, add_replicas=True, **kwargs):
self.abs_data_max_idat[idat] = np.max(np.abs(datak[idat]))

def get_max_abs_data(self, idat=None) -> tuple:
"""
"""
if idat is None:
return self.abs_data_max_idat.max()
return self.abs_data_max_idat[idat]
Expand All @@ -631,6 +634,9 @@ def eval_kpoint(self, frac_coords, cartesian=False, **kwargs) -> np.ndarray:
red_from_cart = self.structure.reciprocal_lattice.inv_matrix.T
frac_coords = np.dot(red_from_cart, frac_coords)

# Remove the shift here
frac_coords -= self.shifts[0]

uc_coords = np.reshape(frac_coords, (3,)) % 1

values = np.empty(self.ndat, dtype=self.dtype)
Expand Down

0 comments on commit f3cbc40

Please sign in to comment.