Skip to content

Commit

Permalink
Additional developments related to orbmag
Browse files Browse the repository at this point in the history
  • Loading branch information
gmatteo committed Feb 18, 2025
1 parent 88dbd23 commit 4f401df
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 68 deletions.
11 changes: 5 additions & 6 deletions abipy/core/kpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,14 +416,14 @@ def map_kpoints(other_kpoints, other_lattice, ref_lattice, ref_kpoints, ref_symr
def kpoints_indices(frac_coords, ngkpt, check_mesh=0) -> np.ndarray:
"""
This function is used when we need to insert k-dependent quantities in a (nx, ny, nz) array.
It computes the indices of the k-points assuming these points belong to a mesh with ngkpt divisions.
It computes and returns the indices of the k-points assuming these points
belong to a mesh with ngkpt divisions.
Args:
frac_coords
ngkpt:
check_mesh:
frac_coords: array with the fractional coordinates of the k-points.
ngkpt: Number of divisions of the mesh.
check_mesh: > 0 to activate debugging sections.
"""

# Transforms kpt in its corresponding reduced number in the interval [0,1[
k_indices = [np.round((kpt % 1) * ngkpt) for kpt in frac_coords]
k_indices = np.array(k_indices, dtype=int)
Expand All @@ -443,7 +443,6 @@ def kpoints_indices(frac_coords, ngkpt, check_mesh=0) -> np.ndarray:
#for kpt, inds in zip(frac_coords, k_indices):
# if np.any(inds >= ngkpt):
# raise ValueError(f"inds >= nkgpt for {kpt=}, {np.round(kpt % 1)=} {inds=})")

print("Check succesfull!")

return k_indices
Expand Down
91 changes: 58 additions & 33 deletions abipy/electrons/orbmag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,12 @@
"""Classes for the analysis of electronic fatbands and projected DOSes."""
from __future__ import annotations

import traceback
import numpy as np

#from tabulate import tabulate
from numpy.linalg import inv, det, eigvals
from monty.termcolor import cprint
#from monty.termcolor import cprint
from monty.functools import lazy_property
from monty.string import is_string, list_strings, marquee
#from pymatgen.core.periodic_table import Element
from monty.string import list_strings, marquee
from abipy.core.mixins import AbinitNcFile, Has_Header, Has_Structure, Has_ElectronBands, NotebookWriter
from abipy.core.structure import Structure
from abipy.tools.numtools import BzRegularGridInterpolator
Expand All @@ -19,8 +16,8 @@
#from abipy.tools.numtools import gaussian
from abipy.tools.typing import Figure
from abipy.tools.plotting import set_axlims, get_ax_fig_plt, get_axarray_fig_plt, add_fig_kwargs, Marker
#from abipy.tools.plotting import (set_axlims, add_fig_kwargs, get_ax_fig_plt, get_axarray_fig_plt,
# rotate_ticklabels, set_visible, plot_unit_cell, set_ax_xylabels
#from abipy.tools.plotting import (set_axlims, add_fig_kwargs, get_ax_fig_plt
# rotate_ticklabels, set_visible, plot_unit_cell, set_ax_xylabels)


def print_options_decorator(**kwargs):
Expand All @@ -45,19 +42,35 @@ def wrapper(*args, **kwargs_inner):

class OrbmagAnalyzer:
"""
TODO
This object gather three ORBMAG.nc files, post-process the data and
provides tools to analyze/plot the results.
Usage example:
.. code-block:: python
from abipy.electrons.orbmag import OrbmagAnalyzer
orban = OrbmagAnalyzer(["gso_DS1_ORBMAG.nc", "gso_DS2_ORBMAG.nc", "gso_DS3_ORBMAG.nc"])
print(orban)
orban.report_eigvals(report_type="S")
orban.plot_fatbands("bands_GSR.nc")
"""

def __init__(self, filepaths):
def __init__(self, filepaths: list):
"""
Args:
filepaths:
filepaths: List of filepaths to ORBMAG.nc files
"""
if not isinstance(filepaths, (list, tuple):
raise TypeError(f"Expecting list or tuple with paths but got {type(filepaths)")
if len(filepaths) != 3:
raise ValueError(f"{len(filepaths)=} != 3")

self.orb_files = [OrbmagFile(path) for path in filepaths]

# This piece of code is taken from merge_orbmag_mesh
# the main difference is that here ncroots[0] is replaced by
# the reader instance of the first OrbmagFile.
# This piece of code is taken from merge_orbmag_mesh. The main difference
# is that here ncroots[0] is replaced by the reader instance of the first OrbmagFile.
r0 = self.orb_files[0].r

self.mband = mband = r0.read_dimvalue('mband')
Expand Down Expand Up @@ -85,7 +98,7 @@ def __init__(self, filepaths):
# convert terms to Cart coords; formulae differ depending on term. First
# four were in k space, remaining two already in real space
if iterm < 4:
omtmp = ucvol*np.matmul(gprimd, r.read_variable('orbmag_mesh')[iterm,0:ndir,isppol,ikpt,iband])
omtmp = ucvol * np.matmul(gprimd, r.read_variable('orbmag_mesh')[iterm,0:ndir,isppol,ikpt,iband])
else:
omtmp = np.matmul(rprimd, r.read_variable('orbmag_mesh')[iterm,0:ndir,isppol,ikpt,iband])

Expand All @@ -102,7 +115,7 @@ def __init__(self, filepaths):
for isppol in range(nsppol):
for ikpt in range(nkpt):
# weight factor for each band and k point
trnrm = occ[0,ikpt,iband] * wtk[ikpt] / ucvol
trnrm = occ[isppol,ikpt,iband] * wtk[ikpt] / ucvol
for iterm in range(orbmag_nterms):
# sigij = \sigma_ij the 3x3 shielding tensor term for each sppol, kpt, and band
# additional ucvol factor converts to induced dipole moment (was dipole moment density,
Expand Down Expand Up @@ -141,7 +154,7 @@ def report_eigvals(self, report_type):
eigenvalues = -1.0E6 * np.real(eigvals(total_sigij))
isotropic = eigenvalues.sum() / 3.0
span = eigenvalues.max() - eigenvalues.min()
skew = 3.0 * (eigenvalues.sum() - eigenvalues.max() -eigenvalues.min() -isotropic) / span
skew = 3.0 * (eigenvalues.sum() - eigenvalues.max() - eigenvalues.min() - isotropic) / span

print('\nShielding tensor eigenvalues, ppm : ', eigenvalues)
print('Shielding tensor iso, span, skew, ppm : %6.2f %6.2f %6.2f \n' % (isotropic,span,skew))
Expand Down Expand Up @@ -185,25 +198,24 @@ def ngkpt_and_shifts(self) -> tuple:
if ngkpt is None:
raise ValueError("Non diagonal k-meshes are not supported!")
if len(shifts) > 1:
raise ValueError("Multiple k-shifts are not supported!")
raise ValueError("Multiple shifts are not supported!")

# check that all files have the same value.
if ifile == 0:
_ngkpt, _shifts = ngkpt, shifts
else:
# check that all files have the same value.
if np.any(ngkpt != _ngkpt) or np.any(shifts != _shifts):
raise ValueError(f"ORBMAG files have different values of ngkpt: {ngkpt=} {_ngkpt=} or shifts {shifts=}, {_shifts=}")

return ngkpt, shifts

def insert_inbox(self, spin: int, what: str) -> tuple:
def insert_inbox(self, what: str, spin: int) -> tuple:
"""
Return data, ngkpt, shifts where data is a
(mband, nkx, nky, nkz)) array with A_{pnk} with p the polaron index.
Return data, ngkpt, shifts where data is a (mband, nkx, nky, nkz)) array
Args:
spin: Spin index.
what:
what: Strings defining the quantity to insert in the box
"""
# Need to know the shape of the k-mesh.
ngkpt, shifts = self.ngkpt_and_shifts
Expand All @@ -227,7 +239,7 @@ def insert_inbox(self, spin: int, what: str) -> tuple:
eigenvalues = -1.0E6 * vals
value = eigenvalues.sum() / 3.0
#span = eigenvalues.max() - eigenvalues.min()
#skew = 3.0 * (eigenvalues.sum() - eigenvalues.max() -eigenvalues.min() -isotropic) / span
#skew = 3.0 * (eigenvalues.sum() - eigenvalues.max() - eigenvalues.min() - isotropic) / span
else:
raise ValueError(f"Invalid {what=}")

Expand All @@ -240,12 +252,13 @@ def get_bz_interpolator_spin(self, what: str, interp_method: str) -> list[BzRegu
Build and return an interpolator for
Args:
what: Strings defining the quantity to insert in the box.
interp_method: The method of interpolation. Supported are “linear”, “nearest”,
“slinear”, “cubic”, “quintic” and “pchip”.
"""
interp_spin = [None for _ in range(self.nsppol)]
for spin in range(self.nsppol):
data, ngkpt, shifts = self.insert_inbox(spin, what)
data, ngkpt, shifts = self.insert_inbox(what, spin)
interp_spin[spin] = BzRegularGridInterpolator(self.structure, shifts, data, method=interp_method)

return interp_spin
Expand All @@ -260,15 +273,17 @@ def plot_fatbands(self, ebands_kpath,
Plot fatbands.
Args:
ebands_kpath
what_list: string or list of strings defining the quantity to compute and show.
ebands_kpath: ElectronBands instance with energies along a k-path
or path to a netcdf file providing it.
what_list: string or list of strings defining the quantity to show.
ylims: Set the data limits for the y-axis. Accept tuple e.g. ``(left, right)``
scale: Scaling factor for
scale: Scaling factor for fatbands.
marker_color: Color for markers
marker_edgecolor: Color for marker edges.
marker_edgecolor: Marker transparency.
fontsize: fontsize for legends and titles
interp_method: Interpolation method.
interp_method: The method of interpolation. Supported are “linear”, “nearest”,
“slinear”, “cubic”, “quintic” and “pchip”.
ax_mat: matrix of |matplotlib-Axes| or None if a new figure should be created.
"""
what_list = list_strings(what_list)
Expand All @@ -291,10 +306,8 @@ def plot_fatbands(self, ebands_kpath,
# Get interpolator for `what` quantity.
interp_spin = self.get_bz_interpolator_spin(what, interp_method)

#a2_max = np.max(np.abs(data[pstate]))**2
#a2_max = max((interp.get_max_abs_data2() for interp in interp_spin))
#scale *= 1. / a2_max
scale = 1e-1
abs_max = max((interp.get_max_abs_data() for interp in interp_spin))
scale *= 1. / abs_max

ymin, ymax = +np.inf, -np.inf
x, y, s = [], [], []
Expand All @@ -303,7 +316,7 @@ def plot_fatbands(self, ebands_kpath,
for ik, kpoint in enumerate(ebands_kpath.kpoints):
enes_n = ebands_kpath.eigens[spin, ik]
for e, a2 in zip(enes_n, interp_spin[spin].eval_kpoint(kpoint), strict=True):
x.append(ik); y.append(e); s.append(scale * a2)
x.append(ik); y.append(e); s.append(scale * abs(a2))
ymin, ymax = min(ymin, e), max(ymax, e)

# Plot electron bands with markers.
Expand All @@ -313,11 +326,23 @@ def plot_fatbands(self, ebands_kpath,
ebands_kpath.plot(ax=ax, points=points, show=False, linewidth=1.0)
ax.legend(loc="best", shadow=True, fontsize=fontsize)

e0 = self.orb_files[0].ebands.fermie
if ylims is None:
# Automatic ylims.
span = ymax - ymin
ymin -= 0.1 * span
ymax += 0.1 * span
ylims = [ymin - e0, ymax - e0]

for ax in ax_list:
set_axlims(ax, ylims, "y")

return fig


class OrbmagFile(AbinitNcFile, Has_Header, Has_Structure, Has_ElectronBands):
"""
Interface to the ORBMAG.nc file.
.. rubric:: Inheritance Diagram
.. inheritance-diagram::
Expand Down
2 changes: 1 addition & 1 deletion abipy/eph/vpq.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,7 @@ def plot_ank_with_ebands(self, ebands_kpath,
# Handle filtering
allowed = True
if filter_value:
energy_window = filter_value*1.1
energy_window = filter_value * 1.1
if pkind == "hole" and bm - e > energy_window:
allowed = False
elif pkind == "electron" and e - bm > energy_window:
Expand Down
16 changes: 8 additions & 8 deletions abipy/tools/numtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,20 +597,20 @@ def __init__(self, structure, shifts, datak, add_replicas=True, **kwargs):
from scipy.interpolate import RegularGridInterpolator
self._interpolators = [None] * self.ndat

self.abs_data2_min_idat = np.empty(self.ndat)
self.abs_data2_max_idat = np.empty(self.ndat)
self.abs_data_min_idat = np.empty(self.ndat)
self.abs_data_max_idat = np.empty(self.ndat)

for idat in range(self.ndat):
self._interpolators[idat] = RegularGridInterpolator((x, y, z), datak[idat], **kwargs)

# Compute min and max of |f|^2 to be used to scale markers in matplotlib plots.
self.abs_data2_min_idat[idat] = np.min(np.abs(datak[idat])) ** 2
self.abs_data2_max_idat[idat] = np.max(np.abs(datak[idat])) ** 2
# Compute min and max of |f| to be used to scale markers in matplotlib plots.
self.abs_data_min_idat[idat] = np.min(np.abs(datak[idat]))
self.abs_data_max_idat[idat] = np.max(np.abs(datak[idat]))

def get_max_abs_data2(self, idat=None) -> tuple:
def get_max_abs_data(self, idat=None) -> tuple:
if idat is None:
return self.abs_data2_max_idat.max()
return self.abs_data2_max_idat[idat]
return self.abs_data_max_idat.max()
return self.abs_data_max_idat[idat]

def eval_kpoint(self, frac_coords, cartesian=False, **kwargs) -> np.ndarray:
"""
Expand Down
46 changes: 26 additions & 20 deletions abipy/tools/tests/test_numtools.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
import numpy as np
import pytest

Expand Down Expand Up @@ -131,9 +132,12 @@ def test_api(self):
structure = Structure(lattice, species=["Si"], coords=[[0, 0, 0]])

# Test that BzRegularGridInterpolator initializes correctly.
ndat, nx, ny, nz = 2, 4, 5, 6
shifts = [0, 0, 0]
datak = np.zeros((1, 4, 4, 4)) # All zeros except one point # (ndat=1, nx=4, ny=4, nz=4)
datak[0, 2, 2, 2] = 1.0 # Set one known value
shape = (ndat, nx, ny ,nz) # (ndat=1, nx=4, ny=4, nz=4)
datak = np.zeros(shape) # All zeros except one point
datak[0, 2, 2, 2] = 1.0 # Set one known value
datak[1, 2, 2, 2] = 2.0 # Set one known value

# Multiple shifts should raise an error
with pytest.raises(ValueError, match="Multiple shifts are not supported"):
Expand All @@ -144,33 +148,35 @@ def test_api(self):
BzRegularGridInterpolator(structure, [0.1, 0.2, 0.3], datak)

interp = BzRegularGridInterpolator(structure, shifts, datak)
assert interp.ndat == 1
assert interp.ndat == ndat
assert interp.dtype == datak.dtype

# Test interpolation at known fractional coordinates."""
result = interp.eval_kpoint([0.5, 0.5, 0.5]) # Middle of the grid
values = interp.eval_kpoint([0.5, 0.5, 0.5]) # Middle of the grid

assert isinstance(result, np.ndarray)
assert result.shape == (1,)
assert 0 <= result[0] <= 1 # Ensure interpolation is reasonable
assert isinstance(values, np.ndarray)
assert values.shape == (ndat,)
assert 0 <= values[0] <= 1 # Ensure interpolation is reasonable
assert 0 <= values[1] <= 2 # Ensure interpolation is reasonable

# Test interpolation with Cartesian coordinates.
cart_coords = structure.reciprocal_lattice.matrix @ [0.5, 0.5, 0.5] # Convert to Cartesian

result = interp.eval_kpoint(cart_coords, cartesian=True)
assert isinstance(result, np.ndarray)
assert result.shape == (1,)
assert 0 <= result[0] <= 1
values = interp.eval_kpoint(cart_coords, cartesian=True)
assert isinstance(values, np.ndarray)
assert values.shape == (ndat,)
assert 0 <= values[0] <= 1
assert 0 <= values[1] <= 2

# Test that interpolation handles periodic boundaries correctly."""
result1 = interp.eval_kpoint([1.0, 1.0, 1.0])
result2 = interp.eval_kpoint([0.0, 0.0, 0.0])
values1 = interp.eval_kpoint([1.0, 1.0, 1.0])
values2 = interp.eval_kpoint([0.0, 0.0, 0.0])

np.testing.assert_allclose(result1, result2, atol=1e-6)
np.testing.assert_allclose(values1, values2, atol=1e-6)

# DEBUG SECTION
#ref_akn = np.abs(self.a_kn) ** 2
#for ik, kpoint in enumerate(self.kpoints):
# interp = a2_interp_state[0].eval_kpoint(kpoint)
# print("MAX (A2 ref - A2 interp) at qpoint", kpoint)
# print((np.abs(ref_akn[ik] - interp)).max())
# Compare interpolated and initial reference value.
for ix, iy, iz in itertools.product(range(nx), range(ny), range(nz)):
kpoint = [ix/nx, iy/ny, iz/nz]
values = interp.eval_kpoint(kpoint)
ref_values = datak[:, ix, iy, iz]
self.assert_almost_equal(values, ref_values)

0 comments on commit 4f401df

Please sign in to comment.