From f3cbc40cc96e1258b69c88f2f80b6c21b38a0e06 Mon Sep 17 00:00:00 2001 From: Matteo Giantomassi Date: Tue, 18 Feb 2025 15:55:11 +0100 Subject: [PATCH] Handle one shift in kpoints_indices --- abipy/abio/tests/test_inputs.py | 6 +--- abipy/core/kpoints.py | 14 +++++--- abipy/core/tests/test_kpoints.py | 55 +++++++++++++++++++++++++++++++- abipy/electrons/orbmag.py | 8 +++-- abipy/eph/gstore.py | 4 +-- abipy/eph/gwan.py | 2 +- abipy/eph/vpq.py | 17 ++++------ abipy/tools/numtools.py | 8 ++++- 8 files changed, 87 insertions(+), 27 deletions(-) diff --git a/abipy/abio/tests/test_inputs.py b/abipy/abio/tests/test_inputs.py index 3698a48ce..ef1e2970f 100644 --- a/abipy/abio/tests/test_inputs.py +++ b/abipy/abio/tests/test_inputs.py @@ -615,11 +615,7 @@ def test_dfpt_methods(self): gs_inp.abiget_irred_phperts() with self.assertRaises(gs_inp.Error): - try: - ddk_inputs = gs_inp.make_ddk_inputs(tolerance={"tolfoo": 1e10}) - except Exception as exc: - print(exc) - raise + ddk_inputs = gs_inp.make_ddk_inputs(tolerance={"tolfoo": 1e10}) phg_inputs = gs_inp.make_ph_inputs_qpoint(qpt=(0, 0, 0), tolerance=None) #print("phonon inputs at Gamma\n", phg_inputs) diff --git a/abipy/core/kpoints.py b/abipy/core/kpoints.py index 006cb1e01..401f6c72c 100644 --- a/abipy/core/kpoints.py +++ b/abipy/core/kpoints.py @@ -413,7 +413,7 @@ def map_kpoints(other_kpoints, other_lattice, ref_lattice, ref_kpoints, ref_symr # #return irred_map -def kpoints_indices(frac_coords, ngkpt, check_mesh=0) -> np.ndarray: +def kpoints_indices(frac_coords, ngkpt, shift, check_mesh=0) -> np.ndarray: """ This function is used when we need to insert k-dependent quantities in a (nx, ny, nz) array. It computes and returns the indices of the k-points assuming these points @@ -422,15 +422,21 @@ def kpoints_indices(frac_coords, ngkpt, check_mesh=0) -> np.ndarray: Args: frac_coords: array with the fractional coordinates of the k-points. ngkpt: Number of divisions of the mesh. + shift: Grid shift (only one shift is supported here) check_mesh: > 0 to activate debugging sections. """ + shift = np.reshape(shift, (3,)) + if np.any(np.abs(shift) > 1e-6): + # Unshift the points + frac_coords = np.array(frac_coords) - shift + # Transforms kpt in its corresponding reduced number in the interval [0,1[ k_indices = [np.round((kpt % 1) * ngkpt) for kpt in frac_coords] k_indices = np.array(k_indices, dtype=int) # Debug secction. if check_mesh: - print(f"kpoints_indices: Testing whether k-points belong to the {ngkpt =} mesh") + print(f"kpoints_indices: Testing whether k-points belong to the {ngkpt=} mesh") ierr = 0 for kpt, inds in zip(frac_coords, k_indices): if check_mesh > 1: print("kpt:", kpt, "inds:", inds) @@ -438,12 +444,12 @@ def kpoints_indices(frac_coords, ngkpt, check_mesh=0) -> np.ndarray: if not issamek(kpt, same_k): ierr += 1; print(kpt, "-->", same_k) if ierr: - raise ValueError("Wrong mapping") + raise ValueError(f"Wrong mapping, {ierr=}") #for kpt, inds in zip(frac_coords, k_indices): # if np.any(inds >= ngkpt): # raise ValueError(f"inds >= nkgpt for {kpt=}, {np.round(kpt % 1)=} {inds=})") - print("Check succesfull!") + print("check_mesh succesfull!") return k_indices diff --git a/abipy/core/tests/test_kpoints.py b/abipy/core/tests/test_kpoints.py index 66778b7ef..46d7c7993 100644 --- a/abipy/core/tests/test_kpoints.py +++ b/abipy/core/tests/test_kpoints.py @@ -9,7 +9,7 @@ from abipy import abilab from abipy.core.kpoints import (wrap_to_ws, wrap_to_bz, issamek, Kpoint, KpointList, IrredZone, Kpath, KpointsReader, has_timrev_from_kptopt, KSamplingInfo, as_kpoints, rc_list, kmesh_from_mpdivs, map_grid2ibz, - set_atol_kdiff, set_spglib_tols, kpath_from_bounds_and_ndivsm, build_segments) #Ktables, + set_atol_kdiff, set_spglib_tols, kpath_from_bounds_and_ndivsm, build_segments, kpoints_indices) #Ktables, from abipy.core.testing import AbipyTest @@ -654,3 +654,56 @@ def test_map_grid2ibz(self): # k = Ktables(self.mgb2, mesh, is_shift, has_timrev) # repr(k); str(k) # k.print_bz2ibz() + + def test_kpoints_indices(self): + """Testing kpoints_indices""" + # test basic_functionality + frac_coords = np.array([[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]]) + ngkpt = [4, 4, 4] + shift = [0.0, 0.0, 0.0] + + expected_indices = np.array([[0, 0, 0], [2, 2, 2]]) + computed_indices = kpoints_indices(frac_coords, ngkpt, shift, check_mesh=1) + + self.assert_equal(computed_indices, expected_indices) + + # test with shift + frac_coords = np.array([[0.25, 0.25, 0.25], [0.75, 0.75, 0.75]]) + ngkpt = [4, 4, 4] + shift = [0.25, 0.25, 0.25] + + # Shift is removed + expected_indices = np.array([[0, 0, 0], [2, 2, 2]]) + computed_indices = kpoints_indices(frac_coords, ngkpt, shift, check_mesh=1) + + self.assert_equal(computed_indices, expected_indices) + + # test periodic boundary conditions + frac_coords = np.array([[1.0, 1.0, 1.0], [-0.25, -0.25, -0.25]]) + ngkpt = [4, 4, 4] + shift = [0.0, 0.0, 0.0] + + expected_indices = np.array([[0, 0, 0], [3, 3, 3]]) # 1.0 and -0.25 wrap correctly + computed_indices = kpoints_indices(frac_coords, ngkpt, shift, check_mesh=1) + + self.assert_equal(computed_indices, expected_indices) + + # test rounding behavior + frac_coords = np.array([[0.49, 0.49, 0.49], [0.51, 0.51, 0.51]]) + ngkpt = [10, 10, 10] + shift = [0.0, 0.0, 0.0] + + # Both round to 5 + #expected_indices = np.array([[5, 5, 5], [5, 5, 5]]) + #computed_indices = kpoints_indices(frac_coords, ngkpt, shift, check_mesh=1) + #self.assert_equal(computed_indices, expected_indices) + + # test_check_mesh + frac_coords = np.array([[0.2, 0.4, 0.6]]) + ngkpt = [5, 5, 5] + shift = [0.0, 0.0, 0.0] + + computed_indices = kpoints_indices(frac_coords, ngkpt, shift, check_mesh=2) + + # Ensure shape is correct + self.assert_equal(computed_indices.shape, (1, 3)) diff --git a/abipy/electrons/orbmag.py b/abipy/electrons/orbmag.py index 76799af23..19a7baf6f 100644 --- a/abipy/electrons/orbmag.py +++ b/abipy/electrons/orbmag.py @@ -62,11 +62,13 @@ def __init__(self, filepaths: list): Args: filepaths: List of filepaths to ORBMAG.nc files """ - if not isinstance(filepaths, (list, tuple): - raise TypeError(f"Expecting list or tuple with paths but got {type(filepaths)") + if not isinstance(filepaths, (list, tuple)): + raise TypeError(f"Expecting list or tuple with paths but got {type(filepaths)=}") if len(filepaths) != 3: raise ValueError(f"{len(filepaths)=} != 3") + # TODO: One should store the direction in the netcdf file + # so that we can check that the files are given in the right order. self.orb_files = [OrbmagFile(path) for path in filepaths] # This piece of code is taken from merge_orbmag_mesh. The main difference @@ -220,7 +222,7 @@ def insert_inbox(self, what: str, spin: int) -> tuple: # Need to know the shape of the k-mesh. ngkpt, shifts = self.ngkpt_and_shifts orb = self.orb_files[0] - k_indices = kpoints_indices(orb.kpoints.frac_coords, ngkpt) + k_indices = kpoints_indices(orb.kpoints.frac_coords, ngkpt, shifts) nx, ny, nz = ngkpt # I'd start by weighting each band and kpt by trace(sigij)/3.0, the isotropic part of sigij, diff --git a/abipy/eph/gstore.py b/abipy/eph/gstore.py index a88f76a57..dae491dd9 100644 --- a/abipy/eph/gstore.py +++ b/abipy/eph/gstore.py @@ -291,7 +291,7 @@ def get_g2q_interpolator_kpoint(self, kpoint, method="linear", check_mesh=1) -> # Compute indices of qpoints in the ngqpt mesh. ngqpt, shifts = r.ngqpt, [0, 0, 0] - q_indices = kpoints_indices(r.qbz, ngqpt, check_mesh=check_mesh) + q_indices = kpoints_indices(r.qbz, ngqpt, shifts, check_mesh=check_mesh) natom3 = 3 * len(self.structure) nb = self.nb @@ -585,4 +585,4 @@ def write_notebook(self, nbpath=None) -> str: #nb.cells.extend(self.get_baserobot_code_cells()) #nb.cells.extend(self.get_ebands_code_cells()) - return self._write_nb_nbpath(nb, nbpath) \ No newline at end of file + return self._write_nb_nbpath(nb, nbpath) diff --git a/abipy/eph/gwan.py b/abipy/eph/gwan.py index f66ab1624..a373b35a4 100644 --- a/abipy/eph/gwan.py +++ b/abipy/eph/gwan.py @@ -245,7 +245,7 @@ def get_g2q_interpolator_kpoint(self, kpoint, method="linear", check_mesh=1): # Compute indices of qpoints in the ngqpt mesh. ngqpt, shifts = r.ngqpt, [0, 0, 0] - q_indices = kpoints_indices(r.qbz, ngqpt, check_mesh=check_mesh) + q_indices = kpoints_indices(r.qbz, ngqpt, shifts, check_mesh=check_mesh) natom3 = 3 * len(self.structure) nb = self.nb diff --git a/abipy/eph/vpq.py b/abipy/eph/vpq.py index f52773fad..8712b4934 100644 --- a/abipy/eph/vpq.py +++ b/abipy/eph/vpq.py @@ -401,7 +401,8 @@ def insert_a_inbox(self, fill_value=None) -> tuple: """ # Need to know the shape of the k-mesh. ngkpt, shifts = self.ngkpt_and_shifts - k_indices = kpoints_indices(self.kpoints, ngkpt) + k_indices = kpoints_indices(self.kpoints, ngkpt, shifts) + #print(f"{k_indices=}") nx, ny, nz = ngkpt shape = (self.nstates, self.nb, nx, ny, nz) @@ -422,7 +423,7 @@ def insert_b_inbox(self, fill_value=None) -> tuple: """ # Need to know the shape of the q-mesh (always Gamma-centered) ngqpt, shifts = self.varpeq.r.ngqpt, [0, 0, 0] - q_indices = kpoints_indices(self.qpoints, ngqpt) + q_indices = kpoints_indices(self.qpoints, ngqpt, shifts) natom3 = 3 * len(self.structure) nx, ny, nz = ngqpt @@ -595,7 +596,7 @@ def plot_ank_with_ebands(self, ebands_kpath, gridspec_kw = {'width_ratios': [2, 1]} ax_mat, fig, plt = get_axarray_fig_plt(ax_mat, nrows=nrows, ncols=ncols, sharex=False, sharey=True, squeeze=False, gridspec_kw=gridspec_kw) - # Get interpolators for A_nk + # Get interpolators for |A_nk|^2 a2_interp_state = self.get_a2_interpolator_state(interp_method) df = self.get_final_results_df() @@ -603,8 +604,6 @@ def plot_ank_with_ebands(self, ebands_kpath, ebands_kpath = ElectronBands.as_ebands(ebands_kpath) ymin, ymax = +np.inf, -np.inf - a_data, *_ = self.insert_a_inbox(fill_value=0) - pkind = self.varpeq.r.vpq_pkind vbm_or_cbm = "vbm" if pkind == "hole" else "cbm" bm = self.ebands.get_edge_state(vbm_or_cbm, self.spin).eig @@ -613,7 +612,7 @@ def plot_ank_with_ebands(self, ebands_kpath, for pstate in range(self.nstates): x, y, s = [], [], [] - a2_max = np.max(np.abs(a_data[pstate]))**2 + a2_max = a2_interp_state[pstate].get_max_abs_data() scale *= 1. / a2_max for ik, kpoint in enumerate(ebands_kpath.kpoints): @@ -858,11 +857,9 @@ def plot_bqnu_with_phbands(self, phbands_qpath, with_legend=True, phbands_qpath = PhononBands.as_phbands(phbands_qpath) - # Get interpolators for B_qnu + # Get interpolators for |B_qnu|^2 b2_interp_state = self.get_b2_interpolator_state(interp_method) - b_data, *_ = self.insert_b_inbox(fill_value=0) - # TODO: need to fix this hardcoded representation units = 'meV' units_scale = 1e3 if units == 'meV' else 1 @@ -872,7 +869,7 @@ def plot_bqnu_with_phbands(self, phbands_qpath, with_legend=True, for pstate in range(self.nstates): x, y, s = [], [], [] - b2_max = np.max(np.abs(b_data[pstate]))**2 + b2_max = b2_interp_state[pstate].get_max_abs_data() scale *= 1. / b2_max for iq, qpoint in enumerate(phbands_qpath.qpoints): diff --git a/abipy/tools/numtools.py b/abipy/tools/numtools.py index 5891be12c..778752926 100644 --- a/abipy/tools/numtools.py +++ b/abipy/tools/numtools.py @@ -564,7 +564,7 @@ def __init__(self, structure, shifts, datak, add_replicas=True, **kwargs): Args: structure: :class:`Structure` object. datak: [ndat, nx, ny, nz] array. - shifts: Shift of the mesh. + shifts: Shifts of the mesh (only one shift is supported here) add_replicas: If True, data is padded with redundant data points. in order to have a periodic 3D array of shape=[ndat, nx+1, ny+1, nz+1]. kwargs: Extra arguments are passed to RegularGridInterpolator e.g.: method @@ -576,6 +576,7 @@ def __init__(self, structure, shifts, datak, add_replicas=True, **kwargs): if self.shifts.shape[0] != 1: raise ValueError(f"Multiple shifts are not supported! {self.shifts.shape[0]=}") + if np.any(self.shifts[0] != 0): raise ValueError(f"Shift should be zero but got: {self.shifts=}") @@ -608,6 +609,8 @@ def __init__(self, structure, shifts, datak, add_replicas=True, **kwargs): self.abs_data_max_idat[idat] = np.max(np.abs(datak[idat])) def get_max_abs_data(self, idat=None) -> tuple: + """ + """ if idat is None: return self.abs_data_max_idat.max() return self.abs_data_max_idat[idat] @@ -631,6 +634,9 @@ def eval_kpoint(self, frac_coords, cartesian=False, **kwargs) -> np.ndarray: red_from_cart = self.structure.reciprocal_lattice.inv_matrix.T frac_coords = np.dot(red_from_cart, frac_coords) + # Remove the shift here + frac_coords -= self.shifts[0] + uc_coords = np.reshape(frac_coords, (3,)) % 1 values = np.empty(self.ndat, dtype=self.dtype)