Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
lukasz-migas committed Oct 2, 2024
1 parent 2936409 commit 57c89eb
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 2 deletions.
8 changes: 7 additions & 1 deletion src/koyo/click.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,13 @@ def parse_extra_args(extra_args: tuple[str, ...] | None) -> dict[str, ty.Any]:
info_msg(f"Skipping argument {arg} as it does not contain '='")
continue
name, value = parse_arg(arg, "")
kwargs[name] = value
if name in kwargs:
if isinstance(kwargs[name], list):
kwargs[name].append(value)
else:
kwargs[name] = [kwargs[name], value]
else:
kwargs[name] = value
return kwargs


Expand Down
10 changes: 10 additions & 0 deletions src/koyo/fig_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def _add_or_export_mpl_figure(
override: bool = False,
pdf: PdfPages | None = None,
pptx: Presentation | None = None,
close: bool = False,
**kwargs: ty.Any,
) -> None:
"""Export figure to file."""
Expand All @@ -128,6 +129,8 @@ def _add_or_export_mpl_figure(
self._add_mpl_figure_to_pptx(filename, fig, pptx=pptx, **kwargs)
elif override or not filename.exists():
fig.savefig(filename, dpi=dpi, facecolor=face_color, bbox_inches=bbox_inches, **kwargs)
if close:
plt.close(fig)

def _add_or_export_pil_image(
self,
Expand All @@ -138,6 +141,7 @@ def _add_or_export_pil_image(
override: bool = False,
pdf: PdfPages | None = None,
pptx: Presentation | None = None,
close: bool = False,
**kwargs: ty.Any,
) -> None:
"""Export PIL image to file."""
Expand All @@ -147,6 +151,8 @@ def _add_or_export_pil_image(
self._add_pil_image_to_pptx(filename, image, pptx=pptx, **kwargs)
elif override or not filename.exists():
image.save(filename, dpi=(dpi, dpi), format=fmt, **kwargs)
if close:
image.close()


class FigureExporter(FigureMixin):
Expand Down Expand Up @@ -222,6 +228,10 @@ def make_directory_if_not_exporting(self, directory: PathLike) -> Path:
directory.mkdir(parents=True, exist_ok=True)
return directory

def figure_exists(self, filename: Path, override: bool = False) -> bool:
"""Check whether figure exists."""
return (filename.exists() and not override) and not self.as_pptx_or_pdf

@property
def as_pptx_or_pdf(self) -> bool:
"""Check whether export is enabled."""
Expand Down
32 changes: 31 additions & 1 deletion src/koyo/visuals.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,10 +257,36 @@ def make_legend_handles(
return handles


def add_patches(
axs: ty.List, windows: ty.List[ty.Tuple[float, float]], colors: ty.Optional[ty.List] = None, alpha: float = 0.5
):
"""Add rectangular patches associated with the peak."""
from matplotlib.patches import Rectangle

if colors is None:
colors = plt.cm.get_cmap("viridis", len(axs)).colors
assert len(axs) == len(windows) == len(colors), "The number of axes does not match the number of windows."

for ax, (xmin, xmax), color in zip(axs, windows, colors):
ax.add_patch(Rectangle((xmin, 0), xmax - xmin, ax.get_ylim()[1], alpha=alpha, color=color))


def add_scalebar(ax, px_size: ty.Optional[float], color="k"):
"""Add scalebar to figure."""
try:
from matplotlib_scalebar.scalebar import ScaleBar
except ImportError:
return

scalebar = ScaleBar(px_size, "um", frameon=False, color=color, font_properties={"size": 20})
ax.add_artist(scalebar)
return scalebar


def add_legend(
fig: plt.Figure,
ax: plt.Axes,
legend_palettes: dict[str, dict[str, str]],
legend_palettes: dict[str, str] | dict[str, dict[str, str]],
fontsize: float = 14,
labelsize: float = 16,
x_pad: float = 0.01,
Expand Down Expand Up @@ -302,6 +328,10 @@ def _make_legend(n_col=1, loc="best"):
ncol=n_col,
)

# check if legend_palettes is a nested dictionary
if not all(isinstance(v, dict) for v in legend_palettes.values()):
legend_palettes = {"": legend_palettes}

n_palettes = len(legend_palettes) > 1
rend = fig.canvas.get_renderer()
x_offset = ax.get_tightbbox(rend).transformed(ax.transAxes.inverted()).xmax + x_pad
Expand Down

0 comments on commit 57c89eb

Please sign in to comment.