Skip to content

Commit

Permalink
adding QA plots for error propagation validation
Browse files Browse the repository at this point in the history
  • Loading branch information
ajmejia committed Feb 5, 2025
1 parent 7aa706b commit 0e05b4f
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 15 deletions.
69 changes: 66 additions & 3 deletions python/lvmdrp/core/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,13 +311,76 @@ def plot_detrend(ori_image, det_image, axs, mbias=None, mdark=None, labels=False

# add labels if requested
if labels:
fig = axs[0].get_figure()
fig.supxlabel(f"counts ({unit})")
fig.supylabel("#")
axs[2].set_xlabel(f"counts ({unit})")
axs[3].set_xlabel(f"Counts ({unit})")
axs[0].set_ylabel("#")
axs[2].set_ylabel("#")

return axs


def plot_error(frame, axs, counts_threshold=(1000, 20000), ref_value=1.0, labels=False):
"""Create plot to validate Poisson error propagation
It takes the given frame data and compares sqrt(data) / error to a given
reference value. Optionally a 3-tuple of quantiles can be given for the
reference value.
Parameters
----------
frame : lvmdrp.core.image.Image|lvmdrp.core.rss.RSS
2D or RSS frame containing data and error attributes
axs : plt.Axes
Axes where to make the plots
counts_threshold : tuple[int], optional
levels of counts above/below which the Poisson statistic holds, by default (1000, 20000)
ref_value : float|tuple[float], optional
Reference value(s) expected for the sqrt(data) / error ratio, by default 1.0
labels : bool, optional
Whether to add titles or not to the axes, by default False
"""

unit = frame._header["BUNIT"]

if isinstance(ref_value, (float, int)):
mu = ref_value
sig1 = sig2 = None
elif isinstance(ref_value, (tuple, list, np.ndarray)) and len(ref_value) == 3:
sig1, mu, sig2 = sorted(ref_value)
else:
raise ValueError(f"Wrong value for {ref_value = }, expected `float` or `3-tuple` for percentile levels")

data = frame._data.copy()
error = frame._error.copy()

pcut = (data >= counts_threshold[0])&(data<=counts_threshold[1])
data[~pcut] = np.nan
error[~pcut] = np.nan

n_pixels = pcut.sum()
median_ratio = np.nanmedian(np.sqrt(np.nanmedian(data, axis=0))/np.nanmedian(error, axis=0))

xs = data[pcut]
ys = np.sqrt(xs) / error[pcut]

axs[0].plot(xs, ys, ".", ms=4, color="tab:blue")
axs[0].axhline(mu, ls="--", lw=1, color="0.2")
axs[1].hist(ys, color="tab:blue", bins=500, range=(mu*0.9, mu*1.1), orientation="horizontal")
if sig1 is not None and sig2 is not None:
axs[0].axhspan(sig1, sig2, lw=0, color="0.2", alpha=0.2)
axs[1].axhspan(sig1, sig2, lw=0, color="0.2", alpha=0.2)
axs[1].axhline(mu, ls="--", lw=1, color="0.2")

axs[0].set_ylim(mu*0.9, mu*1.1)
axs[1].set_ylim(mu*0.9, mu*1.1)

if labels:
axs[0].set_title(f"{n_pixels = } | {median_ratio = :.2f} | {mu = :.2f}", loc="left")
axs[0].set_xlabel(f"Counts ({unit})")
axs[1].set_xlabel("#")
axs[0].set_ylabel(r"$\sqrt{\mathrm{Counts}} / \mathrm{Error}$")


def plot_wavesol_residuals(fiber, ref_waves, lines_pixels, poly_cls, coeffs, ax=None, labels=False):
"""Plot residuals in wavelength polynomial fitting
Expand Down
38 changes: 27 additions & 11 deletions python/lvmdrp/functions/imageMethod.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
glueImages,
loadImage,
)
from lvmdrp.core.plot import plt, create_subplots, plot_detrend, plot_strips, plot_image_shift, plot_fiber_thermal_shift, save_fig
from lvmdrp.core.plot import plt, create_subplots, plot_detrend, plot_error, plot_strips, plot_image_shift, plot_fiber_thermal_shift, save_fig
from lvmdrp.core.rss import RSS
from lvmdrp.core.spectrum1d import Spectrum1D, _spec_from_lines, _cross_match
from lvmdrp.core.tracemask import TraceMask
Expand Down Expand Up @@ -2747,6 +2747,16 @@ def extract_spectra(
rss.add_header_comment(f"{in_model}, fiber model used for {camera}")
rss.add_header_comment(f"{in_acorr}, fiber aperture correction used for {camera}")

# create error propagation plot
fig = plt.figure(figsize=(15, 5), layout="constrained")
gs = GridSpec(1, 14, figure=fig)

ax_1 = fig.add_subplot(gs[0, :-4])
ax_2 = fig.add_subplot(gs[0, -4:])

plot_error(frame=rss, axs=[ax_1, ax_2], counts_threshold=(3000, 60000), labels=True)
save_fig(fig, product_path=out_rss, to_display=display_plots, figure_path="qa", label="extracted_error")

# save extracted RSS
log.info(f"writing extracted spectra to {os.path.basename(out_rss)}")
rss.writeFitsData(out_rss)
Expand Down Expand Up @@ -3873,16 +3883,22 @@ def detrend_frame(
# show plots
log.info("plotting results")
# detrending process
fig, axs = create_subplots(
to_display=display_plots,
nrows=2,
ncols=2,
figsize=(15, 15),
sharex=True,
sharey=True,
)
plt.subplots_adjust(wspace=0.15, hspace=0.1)
plot_detrend(ori_image=org_img, det_image=detrended_img, axs=axs, mbias=mbias_img, mdark=mdark_img, labels=True)
fig = plt.figure(figsize=(15, 10), layout="constrained")
gs = GridSpec(3, 14, figure=fig)

ax1 = fig.add_subplot(gs[0, :7])
ax2 = fig.add_subplot(gs[0, 7:], sharex=ax1, sharey=ax1)
ax3 = fig.add_subplot(gs[1, :7], sharex=ax1, sharey=ax1)
ax4 = fig.add_subplot(gs[1, 7:], sharex=ax1, sharey=ax1)
ax1.tick_params(labelbottom=False)
ax2.tick_params(labelbottom=False)
ax2.tick_params(labelleft=False)
ax4.tick_params(labelleft=False)
ax_1 = fig.add_subplot(gs[2, :-4])
ax_2 = fig.add_subplot(gs[2, -4:], sharey=ax_1)
plot_detrend(ori_image=org_img, det_image=detrended_img, axs=[ax1, ax2, ax3, ax4], mbias=mbias_img, mdark=mdark_img, labels=True)
# Poisson error
plot_error(frame=detrended_img, axs=[ax_1, ax_2], labels=True)
save_fig(
fig,
product_path=out_image,
Expand Down
16 changes: 15 additions & 1 deletion python/lvmdrp/functions/rssMethod.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from lvmdrp.core.image import loadImage
from lvmdrp.core.passband import PassBand
from lvmdrp.core.plot import (plt, create_subplots, save_fig,
plot_error,
plot_wavesol_coeffs, plot_wavesol_residuals,
plot_wavesol_spec, plot_wavesol_wave,
plot_wavesol_lsf)
Expand Down Expand Up @@ -1151,7 +1152,7 @@ def correctPixTable_drp(
@skip_if_drpqual_flags(["BADTRACE", "EXTRACTBAD"], "in_rss")
def resample_wavelength(in_rss: str, out_rss: str, method: str = "linear",
wave_range: Tuple[float,float] = None, wave_disp: float = None,
convert_to_density: bool = False) -> RSS:
convert_to_density: bool = False, display_plots: bool = False) -> RSS:
"""Resamples the RSS wavelength solutions to a common wavelength solution
A common wavelength solution is computed for the RSS by resampling the
Expand All @@ -1177,6 +1178,8 @@ def resample_wavelength(in_rss: str, out_rss: str, method: str = "linear",
The "optimal" dispersion will be used if the parameter is empty.
convert_to_density : string of boolean, optional with default: False
If True, the resampled RSS will be converted to density units.
display_plots : bool, optional
If True, display plots to screen, by default False
Returns
-------
Expand All @@ -1201,6 +1204,17 @@ def resample_wavelength(in_rss: str, out_rss: str, method: str = "linear",
log.info("resampling the spectra ...")
new_rss = rss.rectify_wave(wave_range=wave_range, wave_disp=wave_disp)

# create error propagation plot
fig = plt.figure(figsize=(15, 5), layout="constrained")
gs = gridspec.GridSpec(1, 14, figure=fig)

ax_1 = fig.add_subplot(gs[0, :-4])
ax_2 = fig.add_subplot(gs[0, -4:])
dlambda = numpy.gradient(rss._wave, axis=1)
ref_value = numpy.percentile(dlambda / numpy.sqrt(dlambda), q=[25, 50, 75])
plot_error(frame=new_rss, axs=[ax_1, ax_2], counts_threshold=(3000, 60000), ref_value=ref_value, labels=True)
save_fig(fig, product_path=out_rss, to_display=display_plots, figure_path="qa", label="resampled_error")

# write output RSS
log.info(f"writing resampled RSS to '{os.path.basename(out_rss)}'")
new_rss.writeFitsData(out_rss)
Expand Down

0 comments on commit 0e05b4f

Please sign in to comment.