Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/vandeplaslab/koyo
Browse files Browse the repository at this point in the history
  • Loading branch information
lukasz-migas committed Feb 17, 2025
2 parents ff35ded + 6e907a3 commit 85939cd
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 15 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ dependencies = [
"tabulate",
"matplotlib",
"psutil",
"click-groups"
"click-groups",
]

# extras
Expand Down
11 changes: 6 additions & 5 deletions src/koyo/click.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,16 +100,15 @@ def __init__(self, name):
"""Define the special help messages after instantiating a `click.Command()`."""
click.Command.__init__(self, name)

util_name = os.path.basename(sys.argv and sys.argv[0] or __file__)
util_name = os.path.basename((sys.argv and sys.argv[0]) or __file__)

if os.environ.get("CLICK_PLUGINS_HONESTLY"): # pragma no cover
icon = "\U0001f4a9"
else:
icon = "\u2020"

self.help = (
"\nWarning: entry point could not be loaded. Contact "
"its author for help.\n\n\b\n" + traceback.format_exc()
"\nWarning: entry point could not be loaded. Contact its author for help.\n\n\b\n" + traceback.format_exc()
)
self.short_help = icon + f" Warning: could not load plugin. See `{util_name} {self.name} --help`."

Expand Down Expand Up @@ -240,7 +239,7 @@ def get_args_from_option(option: ty.Callable) -> str:
class Parameter:
"""Parameter object."""

__slots__ = ["description", "args", "value"]
__slots__ = ["args", "description", "value"]

def __init__(self, description: str, args: str | ty.Callable | None, value: ty.Any | None = None):
self.description = description
Expand Down Expand Up @@ -466,6 +465,8 @@ def parse_extra_args(extra_args: tuple[str, ...] | None) -> dict[str, ty.Any]:
continue
name, value = parse_arg(arg, "")
if name in kwargs:
if name in kwargs and not isinstance(kwargs[name], list):
kwargs[name] = [kwargs[name]]
if isinstance(kwargs[name], list):
kwargs[name].append(value)
elif isinstance(kwargs[name], (str, int, float, bool)):
Expand Down Expand Up @@ -530,7 +531,7 @@ def set_env_args(**kwargs: ty.Any) -> None:

for name, value in kwargs.items():
os.environ[name] = str(value)
logger.trace(f"Set environment variable: {name}={value}")
logger.info(f"Set environment variable: {name}={value}")


def filter_kwargs(*allowed: str, **kwargs: ty.Any) -> dict[str, ty.Any]:
Expand Down
22 changes: 22 additions & 0 deletions src/koyo/fig_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from koyo.pdf_mixin import PDFMixin
from koyo.pptx_mixin import PPTXMixin
from koyo.system import IS_WIN
from koyo.typing import PathLike

if ty.TYPE_CHECKING:
Expand Down Expand Up @@ -341,6 +342,27 @@ def add_or_export_mpl_figure(
)
elif override or not Path(filename).exists():
face_color = face_color if face_color is not None else fig.get_facecolor()
filename = _ensure_filename_is_not_too_long(filename)
fig.savefig(filename, dpi=dpi, facecolor=face_color, bbox_inches=bbox_inches, **kwargs)
if close:
plt.close(fig)


def _ensure_filename_is_not_too_long(filenme: PathLike) -> Path:
"""Ensures on Windows that filename is not too long."""
if not IS_WIN:
return Path(filenme)
filename = Path(filenme)
n = len(str(filename))
if n > 250:
parent = filename.parent
suffix = filename.suffix
if n - len(str(parent)) > 250:
raise ValueError("Filename is too long")
max_length = 250 - len(str(parent)) - len(suffix)
if max_length > len(filename.stem):
max_length = len(filename.stem)
filename_ = filename.stem[0:max_length] + suffix
logger.trace(f"Filename is too long, truncating to {filename_} from {filename.name}")
return parent / filename_
return filename
8 changes: 4 additions & 4 deletions src/koyo/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,14 @@ def reshape_array_batch_from_coordinates(
"""
if array.ndim != 2:
raise ValueError("Expected 2-D array.")
n = array.shape[1]
n_features = array.shape[1]
dtype = np.float32 if np.isnan(fill_value) else array.dtype
im = np.full((n, *image_shape), fill_value=fill_value, dtype=dtype)
im = np.full((n_features, *image_shape), fill_value=fill_value, dtype=dtype)
try:
for i in range(n):
for i in range(n_features):
im[i, coordinates[:, 1] - offset, coordinates[:, 0] - offset] = array[:, i]
except IndexError:
for i in range(n):
for i in range(n_features):
im[i, coordinates[:, 0] - offset, coordinates[:, 1] - offset] = array[:, i]
return im

Expand Down
2 changes: 1 addition & 1 deletion src/koyo/mosaic.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ def _get_mosaic_dims(n: int, width: int, height: int, n_cols: int = 0) -> tuple[
if n_cols > n:
n_cols = n
n_rows = ceil(n / n_cols)
if n_rows > ceil(n / n_cols):
while n_rows > ceil(n / n_cols):
n_rows -= 1
return n_rows, n_cols, _width, _height

Expand Down
22 changes: 22 additions & 0 deletions src/koyo/transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Affine transformation functions."""

from __future__ import annotations

import numpy as np


def transform_xy_coordinates(
xy: np.ndarray, *, yx_affine: np.ndarray | None = None, xy_affine: np.ndarray | None = None
) -> np.ndarray:
"""Transform xy coordinates using either yx or xy affine matrix."""
if xy_affine is None and yx_affine is None:
raise ValueError("Either xy_affine or yx_affine should be provided.")
if xy_affine is not None and yx_affine is not None:
raise ValueError("Only one of xy_affine or yx_affine should be provided.")
xy = np.hstack([xy, np.ones((xy.shape[0], 1))])
if yx_affine is not None:
xy = np.dot(xy, yx_affine.T)
if xy_affine is not None:
xy = np.dot(xy, xy_affine)
xy = xy[:, :2]
return xy
43 changes: 39 additions & 4 deletions src/koyo/visuals.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def add_contours(
contours = {"": contours}
color = color if color is not None else plt.rcParams["text.color"]
for _key, contour in contours.items():
ax.plot(contour[:, 0], contour[:, 1], lw=line_width, color=color)
ax.plot(contour[:, 0], contour[:, 1], lw=line_width, color=color, zorder=100)


def _sort_contour_order(contours: dict[str, np.ndarray]) -> tuple[dict[str, tuple[int, int]], dict[str, np.ndarray]]:
Expand All @@ -233,27 +233,43 @@ def _sort_contour_order(contours: dict[str, np.ndarray]) -> tuple[dict[str, tupl
return order, contours


def _get_alternate_locations(contours: dict[str, np.ndarray]) -> dict[str, str]:
"""Get alternative locations."""
_, contours = _sort_contour_order(contours)
locations = {}
is_top = True
for key in contours:
locations[key] = "top" if is_top else "bottom"
is_top = not is_top
return locations


def add_contour_labels(
ax: plt.Axes,
contours: np.ndarray | dict[str, np.ndarray],
labels: str | dict[str, str],
font_size: int = 12,
color: str | None = None,
where: ty.Literal["top", "bottom", "alternate"] = "alternate",
locations: dict[str, str] | None = None,
) -> None:
"""Add labels to the contours."""
if contours is None:
return

y_offset = font_size
is_top = where == "top"
is_alt = where == "alternate"
if is_alt:
is_top = True

if contours is None:
return
locations = locations or {}
if isinstance(contours, np.ndarray):
contours = {"": contours}
if isinstance(labels, str):
labels = {"": labels}
if isinstance(locations, str):
locations = {"": locations}

# get min x, y for each contour
if where == "alternate":
Expand All @@ -266,12 +282,22 @@ def add_contour_labels(
xs = contour[:, 0]
xmin, xmax = xs.min(), xs.max()
x = xmin + (xmax - xmin) / 2
is_top = locations.get(key, "top" if is_top else "bottom") == "top"
# find vertical top of the contour
if is_top:
y = np.min(contour[:, 1])
else:
y = np.max(contour[:, 1]) + y_offset
ax.text(x, y, labels[key], fontsize=font_size, color=color, va="bottom" if is_top else "top", ha="center")
ax.text(
x,
y,
labels[key],
fontsize=font_size,
color=color,
va="bottom" if is_top else "top",
ha="center",
zorder=100,
)
new_y_ax = min(y, new_y_ax) # if is_top else min(y, new_y_ax)
if is_alt:
is_top = not is_top
Expand Down Expand Up @@ -620,6 +646,15 @@ def fix_style(style: str) -> str:
return style


def shorten_style(style: str) -> str:
"""Shorten style name."""
style = style.replace("seaborn-v0_8-", "s-") # seaborn style is too long
style = style.replace("seaborn-", "s-")
style = style.replace("seaborn", "s")
style = style.replace("dark_background", "dark")
return style


def _override_seaborn_heatmap_annotations():
from seaborn.matrix import _HeatMapper

Expand Down

0 comments on commit 85939cd

Please sign in to comment.