Skip to content

Commit

Permalink
Code cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
gmatteo committed Feb 21, 2025
1 parent fd285ff commit e42dba6
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 74 deletions.
109 changes: 38 additions & 71 deletions abipy/electrons/orbmag.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class OrbmagAnalyzer:
orban.plot_fatbands("bands_GSR.nc")
"""

def __init__(self, filepaths: list):
def __init__(self, filepaths: list, verbose=0):
"""
Args:
filepaths: List of filepaths to ORBMAG.nc files
Expand All @@ -67,7 +67,7 @@ def __init__(self, filepaths: list):
# @JOE TODO: One should store the direction in the netcdf file
# so that we can check that the files are given in the right order.
self.orb_files = [OrbmagFile(path) for path in filepaths]
self.verbose = 0
self.verbose = verbose

# This piece of code is taken from merge_orbmag_mesh. The main difference
# is that here ncroots[0] is replaced by the reader instance of the first OrbmagFile.
Expand All @@ -90,17 +90,18 @@ def __init__(self, filepaths: list):
# here ncroots have been replaced by a list of reader instances.
readers = [orb.r for orb in self.orb_files]

for iband in range(mband):
for isppol in range(nsppol):
for ikpt in range(nkpt):
for iterm in range(orbmag_nterms):
for idir, r in enumerate(readers):
for idir, r in enumerate(readers):
orbmag_mesh = r.read_value('orbmag_mesh')
for iterm in range(orbmag_nterms):
for isppol in range(nsppol):
for ikpt in range(nkpt):
for iband in range(mband):
# convert terms to Cart coords; formulae differ depending on term. First
# four were in k space, remaining two already in real space
if iterm < 4:
omtmp = ucvol * np.matmul(gprimd, r.read_variable('orbmag_mesh')[iterm,0:ndir,isppol,ikpt,iband])
omtmp = ucvol * np.matmul(gprimd, orbmag_mesh[iterm,0:ndir,isppol,ikpt,iband])
else:
omtmp = np.matmul(rprimd, r.read_variable('orbmag_mesh')[iterm,0:ndir,isppol,ikpt,iband])
omtmp = np.matmul(rprimd, orbmag_mesh[iterm,0:ndir,isppol,ikpt,iband])

self.orbmag_merge_mesh[iterm,idir,0:ndir,isppol,ikpt,iband] = omtmp

Expand All @@ -111,9 +112,9 @@ def __init__(self, filepaths: list):

self.orbmag_merge_sigij_mesh = np.zeros((orbmag_nterms, nsppol, nkpt, mband, ndir, ndir))

for iband in range(mband):
for isppol in range(nsppol):
for ikpt in range(nkpt):
for isppol in range(nsppol):
for ikpt in range(nkpt):
for iband in range(mband):
# weight factor for each band and k point
trnrm = occ[isppol,ikpt,iband] * wtk[ikpt] / ucvol
for iterm in range(orbmag_nterms):
Expand All @@ -123,13 +124,15 @@ def __init__(self, filepaths: list):
self.orbmag_merge_sigij_mesh[iterm,isppol,ikpt,iband,0:ndir,0:ndir] = \
ucvol * trnrm * self.orbmag_merge_mesh[iterm,0:ndir,0:ndir,isppol,ikpt,iband]

#def __str__(self) -> str:
# """String representation"""
# return self.to_string()
def __str__(self) -> str:
return self.to_string()

#def to_string(self, verbose: int = 0) -> str:
# lines = []; app = lines.append
# return "\n".join(lines)
def to_string(self, verbose: int = 0) -> str:
"""String representation with verbosity level verbose"""
lines = []; app = lines.append
for orb in self.orb_files:
app(orb.to_string(verbose=verbose))
return "\n".join(lines)

def __enter__(self):
return self
Expand All @@ -152,16 +155,20 @@ def structure(self) -> Structure:
raise RuntimeError("ORBMAG.nc files have different structures")
return structure

@lazy_property
def has_timrev(self) -> bool:
"""True if time-reversal symmetry is used in the BZ sampling."""
has_timrev = self.orb_files[0].ebands.has_timrev
if any(orb_file.ebands.has_timrev != has_timrev for orb_file in self.orb_files[1:]):
raise RuntimeError("ORBMAG.nc files have different values of timrev")
#@lazy_property
#def has_timrev(self) -> bool:
# """True if time-reversal symmetry is used in the BZ sampling."""
# has_timrev = self.orb_files[0].ebands.has_timrev
# if any(orb_file.ebands.has_timrev != has_timrev for orb_file in self.orb_files[1:]):
# raise RuntimeError("ORBMAG.nc files have different values of timrev")

@print_options_decorator(precision=2, suppress=True)
def report_eigvals(self, report_type) -> None:
"""
FIXME
Args:
report_type
"""
#np.set_printoptions(precision=2)

Expand Down Expand Up @@ -268,42 +275,10 @@ def insert_inbox(self, what: str, spin: int) -> tuple:
for iband in range(self.mband):
for ikpt, k_inds in zip(range(self.nkpt), k_indices, strict=True):
ix, iy, iz = k_inds
value = self.get_value(what, spin, ikpt, iband)
data[iband, ix, iy, iz] = value
data[iband, ix, iy, iz] = self.get_value(what, spin, ikpt, iband)

return data, ngkpt, shifts

def get_skw_interpolator(self, what: str, lpratio: int, filter_params: None):
"""
"""
orb = self.orb_files[0]
ebands = orb.ebands.kpoints
kpoints = orb.ebands.kpoints

# Get symmetries from abinit spacegroup (read from file).
if (abispg := self.structure.abi_spacegroup) is None:
abispg = self.structure.spgset_abi_spacegroup(has_timerev=self.has_timrev)
fm_symrel = [s for (s, afm) in zip(abispg.symrel, abispg.symafm, strict=True) if afm == 1]

from abipy.core.skw import SkwInterpolator
cell = (self.structure.lattice.matrix, self.structure.frac_coords, self.structure.atomic_numbers)

interp_spin = [None for _ in range(self.nsppol)]

for spin in range(self.nsppol):
values_kb = np.empty((self.nkpt, self.mband))
for ikpt in range(self.nkpt):
for iband in range(self.mband):
values_kb[ikpt, iband] = self.get_value(what, spin, ikpt, band)

skw = SkwInterpolator(lpratio, kpoints.frac_coords, self.eigens[:,:,bstart:bstop], ebands.fermie, ebands.nelect,
cell, fm_symrel, self.has_timrev,
filter_params=filter_params, verbose=self.verbose)
interp_spin[spin] = skw
#skw.eval_sk(spin, kpt, der1=None, der2=None) -> np.ndarray:

return interp_spin

@lazy_property
def has_full_bz(self) -> bool:
"""True if the list of k-points cover the full BZ."""
Expand Down Expand Up @@ -333,23 +308,15 @@ def get_bz_interpolator_spin(self, what: str, interp_method: str) -> list[BzRegu

interp_spin = [None for _ in range(self.nsppol)]

if self.has_full_bz:
for spin in range(self.nsppol):
data, ngkpt, shifts = self.insert_inbox(what, spin)
interp_spin[spin] = BzRegularGridInterpolator(self.structure, shifts, data, method=interp_method)
else:
if not self.has_full_bz:
raise NotImplementedError("k-points must cover the full BZ.")
ngkpt, shifts = self.ngkpt_and_shifts
orb = self.orb_files[0]
ibz = orb.kpoints.frac_coords
bz2ibz = map_grid2ibz(self.structure, ibz, ngkpt, shifts, self.has_timrev, pbc=True)

# Compute values in the IBZ
#self.get_value(self, what: str, spin: int, ikpt: int, band: int) -> float:
# Reconstruct BZ from IBZ

for spin in range(self.nsppol):
data, ngkpt, shifts = self.insert_inbox(what, spin)
interp_spin[spin] = BzRegularGridInterpolator(self.structure, shifts, data, method=interp_method)
return interp_spin


@add_fig_kwargs
def plot_fatbands(self, ebands_kpath,
what_list="isotropic",
Expand Down Expand Up @@ -439,7 +406,7 @@ class OrbmagFile(AbinitNcFile, Has_Header, Has_Structure, Has_ElectronBands):
Interface to the ORBMAG.nc file.
.. rubric:: Inheritance Diagram
.. inheritance-diagram::
.. inheritance-diagram:: OrbmagFile
"""

@classmethod
Expand Down
35 changes: 32 additions & 3 deletions abipy/flowtk/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ def _execute(self, workdir, with_mpirun=False, exec_args=None) -> int:
)

# Write the script.
script_file = os.path.join(workdir, "run" + self.name + ".sh")
with open(script_file, "w") as fh:
script_file = os.path.join(workdir, "run_" + self.name + ".sh")
with open(script_file, "wt") as fh:
fh.write(script)
os.chmod(script_file, 0o740)

Expand Down Expand Up @@ -371,6 +371,35 @@ def run(self, nc_paths: list[str], workdir=None) -> int:
print(self.stdout_data)
print("stderr:")
print(self.stderr_data)
raise RuntimeError("Error while running lruj in %s" % workdir)
raise RuntimeError(f"Error while running lruj in {workdir}")

return retcode


class Abitk(ExecWrapper):
"""
Wraps the abitk Fortran executable.
"""
_name = "abitk"

stdin_fname = None

def run(self, exec_args: list, workdir=None) -> int:
"""
Execute abitk inside directory `workdir`.
"""
workdir = get_workdir(workdir)
#print("workdir", workdir)

self.stdout_fname, self.stderr_fname = \
map(os.path.join, 2 * [workdir], ["abitk.stdout", "abitk.stderr"])

retcode = self.execute(workdir, exec_args=exec_args)
if retcode != 0:
print("stdout:")
print(self.stdout_data)
print("stderr:")
print(self.stderr_data)
raise RuntimeError(f"Error while running abitk in {workdir}")

return retcode

0 comments on commit e42dba6

Please sign in to comment.