diff --git a/src/koyo/image.py b/src/koyo/image.py index adf7151..634793d 100644 --- a/src/koyo/image.py +++ b/src/koyo/image.py @@ -93,10 +93,19 @@ def reshape_array_from_coordinates( def flatten_array_from_coordinates(array: np.ndarray, coordinates: np.ndarray, offset: int = 0) -> np.ndarray: """Flatten array based on xy coordinates.""" - try: - return array[coordinates[:, 1] - offset, coordinates[:, 0] - offset] - except IndexError: - return array[coordinates[:, 0] - offset, coordinates[:, 1] - offset] + if array.ndim == 2: + try: + return array[coordinates[:, 1] - offset, coordinates[:, 0] - offset] + except IndexError: + return array[coordinates[:, 0] - offset, coordinates[:, 1] - offset] + else: + try: + res = array[:, coordinates[:, 1] - offset, coordinates[:, 0] - offset] + except IndexError: + res = array[:, coordinates[:, 0] - offset, coordinates[:, 1] - offset] + # need to swap axes + res = np.swapaxes(res, 0, 1) + return res def reshape_array_batch( diff --git a/src/koyo/rand.py b/src/koyo/rand.py index 70500c6..32137a4 100644 --- a/src/koyo/rand.py +++ b/src/koyo/rand.py @@ -1,3 +1,7 @@ +"""Random number utilities.""" + +from __future__ import annotations + from contextlib import contextmanager import numpy as np @@ -8,6 +12,14 @@ def get_random_seed() -> int: return np.random.randint(0, np.iinfo(np.int32).max - 1, 1)[0] +def get_random_state(n: int = 1) -> int | list[int]: + """Retrieve random state(s).""" + from random import randint + + state = [randint(0, 2**32 - 1) for _ in range(n)] + return state if n > 1 else state[0] + + @contextmanager def temporary_seed(seed: int, skip_if_negative_one: bool = False): """Temporarily set numpy seed.""" diff --git a/src/koyo/system.py b/src/koyo/system.py index 8cac455..6bc64ee 100644 --- a/src/koyo/system.py +++ b/src/koyo/system.py @@ -1,5 +1,6 @@ """System utilities.""" +import inspect import os import platform import sys @@ -81,3 +82,18 @@ def get_cli_path(name: str, env_key: str = "", default: str = "") -> str: if default: return default raise RuntimeError(f"Could not find '{name}' executable.") + + +def who_called_me() -> tuple[str, str, int]: + """Get the file name, function name, and line number of the caller.""" + # Get the current frame + current_frame = inspect.currentframe() + # Get the caller's frame + caller_frame = current_frame.f_back + + # Extract file name, line number, and function name + file_name = caller_frame.f_code.co_filename + line_number = caller_frame.f_lineno + function_name = caller_frame.f_code.co_name + print(f"Called from file: {file_name}, function: {function_name}, line: {line_number}") + return file_name, function_name, line_number