Skip to content

Commit

Permalink
Compute colors based on sign
Browse files Browse the repository at this point in the history
  • Loading branch information
gmatteo committed Feb 20, 2025
1 parent adddfb7 commit fd285ff
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 80 deletions.
3 changes: 1 addition & 2 deletions abipy/electrons/ebands.py
Original file line number Diff line number Diff line change
Expand Up @@ -2261,8 +2261,7 @@ def plot(self, spin=None, band_range=None, klabels=None, e0="fermie", ax=None, y
if ylims is None:
set_axlims(ax, (-mgap - 5, +mgap + 5), "y")

gaps_string = self.get_gaps_string()
if gaps_string:
if gaps_string := self.get_gaps_string():
ax.set_title(gaps_string, fontsize=fontsize)

if max_phfreq is not None and (self.mband > self.nspinor * self.nelect // 2):
Expand Down
37 changes: 26 additions & 11 deletions abipy/electrons/orbmag.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,13 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
"""Activated at the end of the with statement. It automatically closes all the files."""
self.close()

def close(self):
def close(self) -> None:
"""Close all the files."""
for orb in self.orb_files:
orb.close()

@lazy_property
def structure(self):
def structure(self) -> Structure:
"""Structure object."""
# Perform consistency check
structure = self.orb_files[0].structure
Expand All @@ -153,14 +153,14 @@ def structure(self):
return structure

@lazy_property
def has_timrev(self):
def has_timrev(self) -> bool:
"""True if time-reversal symmetry is used in the BZ sampling."""
has_timrev = self.orb_files[0].ebands.has_timrev
if any(orb_file.ebands.has_timrev != has_timrev for orb_file in self.orb_files[1:]):
raise RuntimeError("ORBMAG.nc files have different values of timrev")

@print_options_decorator(precision=2, suppress=True)
def report_eigvals(self, report_type):
def report_eigvals(self, report_type) -> None:
"""
"""
#np.set_printoptions(precision=2)
Expand Down Expand Up @@ -249,8 +249,8 @@ def insert_inbox(self, what: str, spin: int) -> tuple:
Return data, ngkpt, shifts where data is a (mband, nkx, nky, nkz)) array
Args:
spin: Spin index.
what: Strings defining the quantity to insert in the box
spin: Spin index.
"""
# Need to know the shape of the k-mesh.
ngkpt, shifts = self.ngkpt_and_shifts
Expand Down Expand Up @@ -339,6 +339,14 @@ def get_bz_interpolator_spin(self, what: str, interp_method: str) -> list[BzRegu
interp_spin[spin] = BzRegularGridInterpolator(self.structure, shifts, data, method=interp_method)
else:
raise NotImplementedError("k-points must cover the full BZ.")
ngkpt, shifts = self.ngkpt_and_shifts
orb = self.orb_files[0]
ibz = orb.kpoints.frac_coords
bz2ibz = map_grid2ibz(self.structure, ibz, ngkpt, shifts, self.has_timrev, pbc=True)

# Compute values in the IBZ
#self.get_value(self, what: str, spin: int, ikpt: int, band: int) -> float:
# Reconstruct BZ from IBZ

return interp_spin

Expand All @@ -349,12 +357,12 @@ def plot_fatbands(self, ebands_kpath,
marker_alpha=0.5, fontsize=12, interp_method="linear",
ax_mat=None, **kwargs) -> Figure:
"""
Plot fatbands ...
Plot fatbands FIXME ...
Args:
ebands_kpath: ElectronBands instance with energies along a k-path
or path to a netcdf file providing it.
what_list: string or list of strings defining the quantity to show.
what_list: string or list of strings defining the quantities to show.
ylims: Set the data limits for the y-axis. Accept tuple e.g. ``(left, right)``
scale: Scaling factor for fatbands.
marker_color: Color for markers
Expand Down Expand Up @@ -394,14 +402,21 @@ def plot_fatbands(self, ebands_kpath,
for spin in range(self.nsppol):
for ik, kpoint in enumerate(ebands_kpath.kpoints):
enes_n = ebands_kpath.eigens[spin, ik]
for e, a2 in zip(enes_n, interp_spin[spin].eval_kpoint(kpoint), strict=True):
x.append(ik); y.append(e); s.append(scale * abs(a2))
for e, value in zip(enes_n, interp_spin[spin].eval_kpoint(kpoint), strict=True):
x.append(ik); y.append(e); s.append(scale * value)
ymin, ymax = min(ymin, e), max(ymax, e)

# Plot electron bands with markers.
points = Marker(x, y, s, color=marker_color, edgecolors=marker_edgecolor,

# Compute colors based on sign (e.g., red for positive, blue for negative)
y = np.array(y)
c = np.where(y >= 0, "red", "blue")

points = Marker(x, y, s,
c=c,
#color=marker_color,edgecolors=marker_edgecolor,
alpha=marker_alpha, label=what)

# Plot electron bands with markers.
ebands_kpath.plot(ax=ax, points=points, show=False, linewidth=1.0)
ax.legend(loc="best", shadow=True, fontsize=fontsize)

Expand Down
102 changes: 54 additions & 48 deletions abipy/eph/vpq.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,19 @@ def params(self) -> dict:
ngkpt, shifts = ksampling.mpdivs, ksampling.shifts
nkbz = np.prod(ngkpt)

od = dict([
("nkbz", nkbz),
("ngkpt", ngkpt),
("invsc_size", 1.0 / (nkbz * ((abu.Ang_Bohr * self.structure.lattice.volume) ** (1/3)))),
("frohl_ntheta", r.frohl_ntheta),
])
return od
d = dict(
nkbz=nkbz,
ngkpt=ngkpt,
nksmall=min(ngkpt),
cbrt_ngkpt=np.cbrt(np.prod(ngkpt)),
frohl_ntheta=r.frohl_ntheta,
#("invsc_size", 1.0 / (nkbz * ((abu.Ang_Bohr * self.structure.lattice.volume) ** (1/3)))),
)

#keys = ["e_pol", "e_el", "e_ph", "e_elph", "eps"]
#energies = np.array(scf[nstep - 1], dtype=float) * HA2EV

return d

def __str__(self) -> str:
return self.to_string()
Expand Down Expand Up @@ -1113,15 +1119,15 @@ class VpqRobot(Robot, RobotWithEbands):
def __str__(self) -> str:
return self.to_string()

def to_string(self, verbose=0) -> str:
def to_string(self, verbose: int = 0) -> str:
"""String representation with verbosiy level ``verbose``."""
lines = []; app = lines.append
df = self.get_final_results_df()
lines.append(str(df))

return "\n".join(lines)

def get_final_results_df(self, spin=None, sortby=None, with_params: bool = True) -> pd.DataFrame:
def get_final_results_df(self, spin: int = None, sortby: str = None, with_params: bool = True) -> pd.DataFrame:
"""
Return dataframe with the last iteration for all polaronic states.
NB: Energies are in eV.
Expand Down Expand Up @@ -1160,53 +1166,53 @@ def get_final_results_df(self, spin=None, sortby=None, with_params: bool = True)
# ax=None, fontsize=8, **kwargs)
# return fig

@add_fig_kwargs
def plot_kconv(self, colormap="jet", fontsize=12, **kwargs) -> Figure:
"""
Plot the convergence of the results wrt to the k-point sampling.
Args:
colormap: matplotlib color map.
fontsize: fontsize for legends and titles
"""
nsppol = self.getattr_alleq("nsppol")
#@add_fig_kwargs
#def plot_kconv(self, colormap="jet", fontsize=12, **kwargs) -> Figure:
# """
# Plot the convergence of the results wrt to the k-point sampling.

# Build grid of plots.
nrows, ncols = len(_ALL_ENTRIES), nsppol
ax_mat, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
sharex=True, sharey=False, squeeze=False)
cmap = plt.get_cmap(colormap)
for spin in range(nsppol):
df = self.get_final_results_df(spin=spin, sortby=None)
xs = df["invsc_size"]
xvals = np.linspace(0.0, 1.1 * xs.max(), 100)

for ix, ylabel in enumerate(_ALL_ENTRIES):
ax = ax_mat[ix, spin]
ys = df[ylabel]

# Plot ab-initio points.
ax.scatter(xs, ys, color="red", marker="o")

# Plot fit using the first nn points.
for nn in range(1, len(xs)):
color = cmap((nn - 1) / len(xs))
p = np.poly1d(np.polyfit(xs[:nn+1], ys[:nn+1], deg=1))
ax.plot(xvals, p(xvals), color=color, ls="--")

xlabel = "Inverse supercell size (Bohr$^-1$)" if ix == len(_ALL_ENTRIES) - 1 else None
set_grid_legend(ax, fontsize, xlabel=xlabel, ylabel=f"{ylabel} (eV)", legend=False)
ax.tick_params(axis='x', color='black', labelsize='20', pad=5, length=5, width=2)
# Args:
# colormap: matplotlib color map.
# fontsize: fontsize for legends and titles
# """
# nsppol = self.getattr_alleq("nsppol")

# # Build grid of plots.
# nrows, ncols = len(_ALL_ENTRIES), nsppol
# ax_mat, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
# sharex=True, sharey=False, squeeze=False)
# cmap = plt.get_cmap(colormap)
# for spin in range(nsppol):
# df = self.get_final_results_df(spin=spin, sortby=None)
# xs = df["invsc_size"]
# xvals = np.linspace(0.0, 1.1 * xs.max(), 100)

# for ix, ylabel in enumerate(_ALL_ENTRIES):
# ax = ax_mat[ix, spin]
# ys = df[ylabel]

# # Plot ab-initio points.
# ax.scatter(xs, ys, color="red", marker="o")

# # Plot fit using the first nn points.
# for nn in range(1, len(xs)):
# color = cmap((nn - 1) / len(xs))
# p = np.poly1d(np.polyfit(xs[:nn+1], ys[:nn+1], deg=1))
# ax.plot(xvals, p(xvals), color=color, ls="--")

# xlabel = "Inverse supercell size (Bohr$^-1$)" if ix == len(_ALL_ENTRIES) - 1 else None
# set_grid_legend(ax, fontsize, xlabel=xlabel, ylabel=f"{ylabel} (eV)", legend=False)
# ax.tick_params(axis='x', color='black', labelsize='20', pad=5, length=5, width=2)

return fig
# return fig

def yield_figs(self, **kwargs): # pragma: no cover
"""
This function *generates* a predefined list of matplotlib figures with minimal input from the user.
Used in abiview.py to get a quick look at the results.
"""
#yield self.plot_scf_cycle(show=False)
yield self.plot_kconv()
yield self.plot_scf_cycle(show=False)
#yield self.plot_kconv()

def write_notebook(self, nbpath=None) -> str:
"""
Expand Down
26 changes: 7 additions & 19 deletions abipy/tools/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,12 +1056,11 @@ def plot(self, cplx_mode="abs", colormap="jet", fontsize=8, **kwargs) -> Figure:
return fig


#TODO Rename it to ScatterData?
class Marker:
"""
Stores the position and the size of the marker.
Stores the position and the size of the markers.
A marker is a list of tuple(x, y, s) where x, and y are the position
in the graph and s is the size of the marker.
in the plot and s is the size of the marker.
Used for plotting purpose e.g. QP data, energy derivatives...
Example::
Expand All @@ -1071,33 +1070,22 @@ class Marker:
"""

def __init__(self, x, y, s, **scatter_kwargs):
#marker: str = "o", color: str = "y", alpha: float = 1.0, label=None, self.edgecolors=None):
self.x, self.y, self.s = np.array(x), np.array(y), np.array(s)

if len(self.x) != len(self.y):
raise ValueError("len(self.x) != len(self.y)")
raise ValueError(f"{len(self.x)=} != {len(self.y)=}")

if len(self.y) != len(self.s):
raise ValueError("len(self.y) != len(self.s)")
raise ValueError(f"{len(self.y)=} != {len(self.s)=}")

#self.marker = marker
#self.color = color
#self.alpha = alpha
#self.label = label
#self.edgecolors = edgecolors
self.scatter_kwargs = scatter_kwargs

# Step 1: Normalize sizes to a suitable range for plotting
#min_size = 10 # Minimum size for points
#max_size = 100 # Maximum size for points
#normalized_s = min_size + (max_size - min_size) * (self.s - np.min(self.s)) / (np.max(self.s) - np.min(self.s))
#self.s = normalized_s

def __bool__(self):
return bool(len(self.s))

__nonzero__ = __bool__

def posneg_marker(self) -> tuple[Marker, Marker]:
def posneg_marker(self, threshold: float = 0.0) -> tuple[Marker, Marker]:
"""
Split data into two sets: the first one contains all the points with positive size.
The first set contains all the points with negative size.
Expand All @@ -1106,7 +1094,7 @@ def posneg_marker(self) -> tuple[Marker, Marker]:
neg_x, neg_y, neg_s = [], [], []

for x, y, s in zip(self.x, self.y, self.s):
if s >= 0.0:
if s >= threshold:
pos_x.append(x)
pos_y.append(y)
pos_s.append(s)
Expand Down

0 comments on commit fd285ff

Please sign in to comment.