Skip to content

Commit

Permalink
Some imporvement to the scf cycle plotting in vpq (#314)
Browse files Browse the repository at this point in the history
* sync with abipy/develop

* some beautification for plot_scf_cycle in vpq
  • Loading branch information
ezhique authored Feb 25, 2025
1 parent 1c97368 commit f856fd3
Showing 1 changed file with 33 additions and 19 deletions.
52 changes: 33 additions & 19 deletions abipy/eph/vpq.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class Entry:
Entry(name="E_ph", latex=r'$E_{ph}$', info="Phonon part", utype="energy"),
Entry(name="elph", latex=r'$E_{elph}$', info="e-ph term", utype="energy"),
Entry(name="epsilon", latex=r"$\varepsilon$", info="Polaron eigenvalue", utype="energy"),
Entry(name="gsr", latex=r"$|\nabla|$", info="||gradient||", utype="gradient"),
Entry(name="grs", latex=r"$|\nabla|$", info="||gradient||", utype="gradient"),
]

# Convert to dictionary: name --> Entry
Expand Down Expand Up @@ -516,7 +516,7 @@ def plot_scf_cycle(self, ax_mat=None, fontsize=8, **kwargs) -> Figure:
fontsize: fontsize for legends and titles
"""
# Build grid of plots.
nrows, ncols = self.nstates, 2
nrows, ncols = self.nstates, 3
ax_mat, fig, plt = get_axarray_fig_plt(ax_mat, nrows=nrows, ncols=ncols,
sharex=False, sharey=False, squeeze=False)

Expand All @@ -527,34 +527,48 @@ def plot_scf_cycle(self, ax_mat=None, fontsize=8, **kwargs) -> Figure:

for iax, ax in enumerate(ax_mat[pstate]):
# Create a twin Axes sharing the x-axis
#grad_ax = ax
grad_ax = ax.twinx()

for ilab, (name, entry) in enumerate(_ALL_ENTRIES.items()):
# Convert energies to Hartree. Keep gradient as it is.
ys = df[name].to_numpy()

_ax, energy_like = ax, True
if entry.utype == "gradient":
_ax, energy_like = grad_ax, False

if iax == 0:
# Plot values linear scale.
_ax.plot(xs, ys, label=entry.latex)
# plot only the gradient residual on the 1st panel
if iax == 0:
energy_like = False
ax.plot(xs, ys, label=entry.latex, c='k')
ax.set_yscale("log")
else:
# Plot deltas in logscale.
_ax.plot(xs, np.abs(ys - ys[-1]), label=entry.latex)
_ax.set_yscale("log")

_ax.set_xlim(1, niter)
if iax == 1:
energy_like =True
# Plot values linear scale.
ax.plot(xs, ys, label=entry.latex)
elif iax == 2:
energy_like = True
# Plot deltas in logscale.
# (remove the last point for pretty-plotting)
ax.plot(xs[:-1], np.abs(ys - ys[-1])[:-1], label=entry.latex)
ax.set_yscale("log")

ax.set_xlim(1, niter)

if energy_like:
ylabel = "Energy (eV)" if iax == 0 else r"$|\Delta E|$ (eV)"
ylabel = "Energy (eV)" if iax == 1 else r"$|\Delta E|$ (eV)"
else:
ylabel = r"$|\nabla|$" if iax == 0 else r"$|\Delta |\nabla|$"

set_grid_legend(_ax, fontsize, xlabel="Iteration") #, ylabel=ylabel)
_ax.set_ylabel(ylabel)
set_grid_legend(ax, fontsize, xlabel="Iteration") #, ylabel=ylabel)
ax.set_ylabel(ylabel)
ax.legend()

if pstate == 0:
if iax == 0:
ax.set_title("Gradient norm")
elif iax == 1:
ax.set_title("Energy terms")
else:
ax.set_title("Log-scale difference")


fig.suptitle(self.get_title(with_gaps=True))
fig.tight_layout()
Expand Down Expand Up @@ -996,7 +1010,7 @@ def plot_bqnu_with_phbands(self, phbands_qpath, with_legend=True,
if fill_dos:
y_common = np.linspace(ymin, ymax+span*0.1, 100)
xleft = np.zeros_like(y_common)
# skip eDOS, fill only BDOS
# skip phDOS, fill only BDOS
for dos, c in zip(dos_lines[1:], colors[1:]):
for line in dos:
x_data, y_data = line.get_xdata(), line.get_ydata()
Expand Down

0 comments on commit f856fd3

Please sign in to comment.