From 57c89eb7c6a3924c664650d2c86a15b2b6bcbf23 Mon Sep 17 00:00:00 2001 From: Lukasz Migas Date: Wed, 2 Oct 2024 12:10:02 +0200 Subject: [PATCH] Fixes --- src/koyo/click.py | 8 +++++++- src/koyo/fig_mixin.py | 10 ++++++++++ src/koyo/visuals.py | 32 +++++++++++++++++++++++++++++++- 3 files changed, 48 insertions(+), 2 deletions(-) diff --git a/src/koyo/click.py b/src/koyo/click.py index cdc8a6c..c24faaf 100644 --- a/src/koyo/click.py +++ b/src/koyo/click.py @@ -450,7 +450,13 @@ def parse_extra_args(extra_args: tuple[str, ...] | None) -> dict[str, ty.Any]: info_msg(f"Skipping argument {arg} as it does not contain '='") continue name, value = parse_arg(arg, "") - kwargs[name] = value + if name in kwargs: + if isinstance(kwargs[name], list): + kwargs[name].append(value) + else: + kwargs[name] = [kwargs[name], value] + else: + kwargs[name] = value return kwargs diff --git a/src/koyo/fig_mixin.py b/src/koyo/fig_mixin.py index a26354e..311f22a 100644 --- a/src/koyo/fig_mixin.py +++ b/src/koyo/fig_mixin.py @@ -115,6 +115,7 @@ def _add_or_export_mpl_figure( override: bool = False, pdf: PdfPages | None = None, pptx: Presentation | None = None, + close: bool = False, **kwargs: ty.Any, ) -> None: """Export figure to file.""" @@ -128,6 +129,8 @@ def _add_or_export_mpl_figure( self._add_mpl_figure_to_pptx(filename, fig, pptx=pptx, **kwargs) elif override or not filename.exists(): fig.savefig(filename, dpi=dpi, facecolor=face_color, bbox_inches=bbox_inches, **kwargs) + if close: + plt.close(fig) def _add_or_export_pil_image( self, @@ -138,6 +141,7 @@ def _add_or_export_pil_image( override: bool = False, pdf: PdfPages | None = None, pptx: Presentation | None = None, + close: bool = False, **kwargs: ty.Any, ) -> None: """Export PIL image to file.""" @@ -147,6 +151,8 @@ def _add_or_export_pil_image( self._add_pil_image_to_pptx(filename, image, pptx=pptx, **kwargs) elif override or not filename.exists(): image.save(filename, dpi=(dpi, dpi), format=fmt, **kwargs) + if close: + image.close() class FigureExporter(FigureMixin): @@ -222,6 +228,10 @@ def make_directory_if_not_exporting(self, directory: PathLike) -> Path: directory.mkdir(parents=True, exist_ok=True) return directory + def figure_exists(self, filename: Path, override: bool = False) -> bool: + """Check whether figure exists.""" + return (filename.exists() and not override) and not self.as_pptx_or_pdf + @property def as_pptx_or_pdf(self) -> bool: """Check whether export is enabled.""" diff --git a/src/koyo/visuals.py b/src/koyo/visuals.py index a0da6af..cd308c9 100644 --- a/src/koyo/visuals.py +++ b/src/koyo/visuals.py @@ -257,10 +257,36 @@ def make_legend_handles( return handles +def add_patches( + axs: ty.List, windows: ty.List[ty.Tuple[float, float]], colors: ty.Optional[ty.List] = None, alpha: float = 0.5 +): + """Add rectangular patches associated with the peak.""" + from matplotlib.patches import Rectangle + + if colors is None: + colors = plt.cm.get_cmap("viridis", len(axs)).colors + assert len(axs) == len(windows) == len(colors), "The number of axes does not match the number of windows." + + for ax, (xmin, xmax), color in zip(axs, windows, colors): + ax.add_patch(Rectangle((xmin, 0), xmax - xmin, ax.get_ylim()[1], alpha=alpha, color=color)) + + +def add_scalebar(ax, px_size: ty.Optional[float], color="k"): + """Add scalebar to figure.""" + try: + from matplotlib_scalebar.scalebar import ScaleBar + except ImportError: + return + + scalebar = ScaleBar(px_size, "um", frameon=False, color=color, font_properties={"size": 20}) + ax.add_artist(scalebar) + return scalebar + + def add_legend( fig: plt.Figure, ax: plt.Axes, - legend_palettes: dict[str, dict[str, str]], + legend_palettes: dict[str, str] | dict[str, dict[str, str]], fontsize: float = 14, labelsize: float = 16, x_pad: float = 0.01, @@ -302,6 +328,10 @@ def _make_legend(n_col=1, loc="best"): ncol=n_col, ) + # check if legend_palettes is a nested dictionary + if not all(isinstance(v, dict) for v in legend_palettes.values()): + legend_palettes = {"": legend_palettes} + n_palettes = len(legend_palettes) > 1 rend = fig.canvas.get_renderer() x_offset = ax.get_tightbbox(rend).transformed(ax.transAxes.inverted()).xmax + x_pad