From c40fb691891c215f5452f5c810689c02c09750ed Mon Sep 17 00:00:00 2001 From: Vladimir Iglovikov Date: Fri, 20 Sep 2024 12:32:59 -0700 Subject: [PATCH] Added decorator to show way to process data (#31) --- .pre-commit-config.yaml | 2 +- albucore/__init__.py | 3 +- albucore/decorators.py | 51 +++++++++++++++++++++++ albucore/functions.py | 80 +++++++++++++++++++++++++++++++++++-- albucore/utils.py | 39 ------------------ tests/test_to_from_float.py | 20 ++++++++++ tests/test_utils.py | 69 +++++++++++++++++++++++++++++++- 7 files changed, 219 insertions(+), 45 deletions(-) create mode 100644 albucore/decorators.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e62de64..0bb3c15 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -59,7 +59,7 @@ repos: additional_dependencies: ["tomli"] - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.6.5 + rev: v0.6.6 hooks: # Run the linter. - id: ruff diff --git a/albucore/__init__.py b/albucore/__init__.py index bfcc765..be31db0 100644 --- a/albucore/__init__.py +++ b/albucore/__init__.py @@ -1,4 +1,5 @@ -__version__ = "0.0.16" +__version__ = "0.0.17" +from .decorators import * from .functions import * from .utils import * diff --git a/albucore/decorators.py b/albucore/decorators.py new file mode 100644 index 0000000..10d44bf --- /dev/null +++ b/albucore/decorators.py @@ -0,0 +1,51 @@ +import sys +from functools import wraps +from typing import Callable + +import numpy as np + +from albucore.utils import MONO_CHANNEL_DIMENSIONS, NUM_MULTI_CHANNEL_DIMENSIONS, P + +if sys.version_info >= (3, 10): + from typing import Concatenate +else: + from typing_extensions import Concatenate + + +def contiguous( + func: Callable[Concatenate[np.ndarray, P], np.ndarray], +) -> Callable[Concatenate[np.ndarray, P], np.ndarray]: + """Ensure that input img is contiguous and the output array is also contiguous.""" + + @wraps(func) + def wrapped_function(img: np.ndarray, *args: P.args, **kwargs: P.kwargs) -> np.ndarray: + # Ensure the input array is contiguous + img = np.require(img, requirements=["C_CONTIGUOUS"]) + # Call the original function with the contiguous input + result = func(img, *args, **kwargs) + # Ensure the output array is contiguous + if not result.flags["C_CONTIGUOUS"]: + return np.require(result, requirements=["C_CONTIGUOUS"]) + + return result + + return wrapped_function + + +def preserve_channel_dim( + func: Callable[Concatenate[np.ndarray, P], np.ndarray], +) -> Callable[Concatenate[np.ndarray, P], np.ndarray]: + """Preserve dummy channel dim.""" + + @wraps(func) + def wrapped_function(img: np.ndarray, *args: P.args, **kwargs: P.kwargs) -> np.ndarray: + shape = img.shape + result = func(img, *args, **kwargs) + if len(shape) == NUM_MULTI_CHANNEL_DIMENSIONS and shape[-1] == 1 and result.ndim == MONO_CHANNEL_DIMENSIONS: + return np.expand_dims(result, axis=-1) + + if len(shape) == MONO_CHANNEL_DIMENSIONS and result.ndim == NUM_MULTI_CHANNEL_DIMENSIONS: + return result[:, :, 0] + return result + + return wrapped_function diff --git a/albucore/functions.py b/albucore/functions.py index e8c2281..fc83b27 100644 --- a/albucore/functions.py +++ b/albucore/functions.py @@ -1,10 +1,12 @@ from __future__ import annotations -from typing import Literal +from functools import wraps +from typing import Any, Callable, Literal import cv2 import numpy as np +from albucore.decorators import contiguous, preserve_channel_dim from albucore.utils import ( MAX_OPENCV_WORKING_CHANNELS, MAX_VALUES_BY_DTYPE, @@ -13,11 +15,9 @@ ValueType, clip, clipped, - contiguous, convert_value, get_max_value, get_num_channels, - preserve_channel_dim, ) np_operations = {"multiply": np.multiply, "add": np.add, "power": np.power} @@ -570,6 +570,10 @@ def to_float_lut(img: np.ndarray, max_value: float | None = None) -> np.ndarray: def to_float(img: np.ndarray, max_value: float | None = None) -> np.ndarray: + if img.dtype == np.float64: + return img.astype(np.float32) + if img.dtype == np.float32: + return img if img.dtype == np.uint8: return to_float_lut(img, max_value) return to_float_numpy(img, max_value) @@ -620,6 +624,12 @@ def from_float(img: np.ndarray, target_dtype: np.dtype, max_value: float | None - For other input types, it falls back to a numpy-based implementation. - The function clips values to ensure they fit within the range of the target data type. """ + if target_dtype == np.float32: + return img + + if target_dtype == np.float64: + return img.astype(np.float32) + if img.dtype == np.float32: return from_float_opencv(img, target_dtype, max_value) @@ -652,3 +662,67 @@ def vflip_numpy(img: np.ndarray) -> np.ndarray: def vflip(img: np.ndarray) -> np.ndarray: return vflip_cv2(img) + + +def float32_io(func: Callable[..., np.ndarray]) -> Callable[..., np.ndarray]: + """Decorator to ensure float32 input/output for image processing functions. + + This decorator converts the input image to float32 before passing it to the wrapped function, + and then converts the result back to the original dtype if it wasn't float32. + + Args: + func (Callable[..., np.ndarray]): The image processing function to be wrapped. + + Returns: + Callable[..., np.ndarray]: A wrapped function that handles float32 conversion. + + Example: + @float32_io + def some_image_function(img: np.ndarray) -> np.ndarray: + # Function implementation + return processed_img + """ + + @wraps(func) + def float32_wrapper(img: np.ndarray, *args: Any, **kwargs: Any) -> np.ndarray: + input_dtype = img.dtype + if input_dtype != np.float32: + img = to_float(img) + result = func(img, *args, **kwargs) + + return from_float(result, target_dtype=input_dtype) if input_dtype != np.float32 else result + + return float32_wrapper + + +def uint8_io(func: Callable[..., np.ndarray]) -> Callable[..., np.ndarray]: + """Decorator to ensure uint8 input/output for image processing functions. + + This decorator converts the input image to uint8 before passing it to the wrapped function, + and then converts the result back to the original dtype if it wasn't uint8. + + Args: + func (Callable[..., np.ndarray]): The image processing function to be wrapped. + + Returns: + Callable[..., np.ndarray]: A wrapped function that handles uint8 conversion. + + Example: + @uint8_io + def some_image_function(img: np.ndarray) -> np.ndarray: + # Function implementation + return processed_img + """ + + @wraps(func) + def uint8_wrapper(img: np.ndarray, *args: Any, **kwargs: Any) -> np.ndarray: + input_dtype = img.dtype + + if input_dtype != np.uint8: + img = from_float(img, target_dtype=np.uint8) + + result = func(img, *args, **kwargs) + + return to_float(result) if input_dtype != np.uint8 else result + + return uint8_wrapper diff --git a/albucore/utils.py b/albucore/utils.py index 84382c9..00aa68a 100644 --- a/albucore/utils.py +++ b/albucore/utils.py @@ -109,25 +109,6 @@ def wrapped_function(img: np.ndarray, *args: P.args, **kwargs: P.kwargs) -> np.n return wrapped_function -def preserve_channel_dim( - func: Callable[Concatenate[np.ndarray, P], np.ndarray], -) -> Callable[Concatenate[np.ndarray, P], np.ndarray]: - """Preserve dummy channel dim.""" - - @wraps(func) - def wrapped_function(img: np.ndarray, *args: P.args, **kwargs: P.kwargs) -> np.ndarray: - shape = img.shape - result = func(img, *args, **kwargs) - if len(shape) == NUM_MULTI_CHANNEL_DIMENSIONS and shape[-1] == 1 and result.ndim == MONO_CHANNEL_DIMENSIONS: - return np.expand_dims(result, axis=-1) - - if len(shape) == MONO_CHANNEL_DIMENSIONS and result.ndim == NUM_MULTI_CHANNEL_DIMENSIONS: - return result[:, :, 0] - return result - - return wrapped_function - - def get_num_channels(image: np.ndarray) -> int: return image.shape[2] if image.ndim == NUM_MULTI_CHANNEL_DIMENSIONS else 1 @@ -151,26 +132,6 @@ def is_multispectral_image(image: np.ndarray) -> bool: return num_channels not in {1, 3} -def contiguous( - func: Callable[Concatenate[np.ndarray, P], np.ndarray], -) -> Callable[Concatenate[np.ndarray, P], np.ndarray]: - """Ensure that input img is contiguous and the output array is also contiguous.""" - - @wraps(func) - def wrapped_function(img: np.ndarray, *args: P.args, **kwargs: P.kwargs) -> np.ndarray: - # Ensure the input array is contiguous - img = np.require(img, requirements=["C_CONTIGUOUS"]) - # Call the original function with the contiguous input - result = func(img, *args, **kwargs) - # Ensure the output array is contiguous - if not result.flags["C_CONTIGUOUS"]: - return np.require(result, requirements=["C_CONTIGUOUS"]) - - return result - - return wrapped_function - - def convert_value(value: np.ndarray | float, num_channels: int) -> float | np.ndarray: """Convert a multiplier to a float / int or a numpy array. diff --git a/tests/test_to_from_float.py b/tests/test_to_from_float.py index 9566697..95b1816 100644 --- a/tests/test_to_from_float.py +++ b/tests/test_to_from_float.py @@ -324,3 +324,23 @@ def test_from_float_opencv_input_unchanged(dtype, channels): img_copy = img.copy() _ = from_float_opencv(img, dtype, max_value) np.testing.assert_array_equal(img, img_copy) + + +def test_to_float_returns_same_object_for_float32(): + float32_image = np.random.rand(10, 10, 3).astype(np.float32) + result = to_float(float32_image) + assert result is float32_image # Check if it's the same object + + +@pytest.mark.parametrize("dtype", [np.uint8, np.uint16, np.float32]) +def test_to_float_from_float_roundtrip(dtype): + if dtype == np.float32: + original = np.random.rand(10, 10, 3).astype(dtype) + else: + original = np.random.randint(0, 256, (10, 10, 3)).astype(dtype) + + float_version = to_float(original) + roundtrip = from_float(float_version, dtype) + + assert roundtrip.dtype == dtype + np.testing.assert_allclose(original, roundtrip, rtol=1e-5, atol=1e-8) diff --git a/tests/test_utils.py b/tests/test_utils.py index c24ca45..422adb5 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,7 +1,9 @@ import numpy as np import pytest import cv2 -from albucore.utils import NPDTYPE_TO_OPENCV_DTYPE, clip, convert_value, get_opencv_dtype_from_numpy, contiguous +from albucore.decorators import contiguous +from albucore.functions import float32_io, from_float, to_float, uint8_io +from albucore.utils import NPDTYPE_TO_OPENCV_DTYPE, clip, convert_value, get_opencv_dtype_from_numpy @pytest.mark.parametrize("input_img, dtype, expected", [ @@ -88,3 +90,68 @@ def test_contiguous_decorator(input_array): # Check if the content is correct (same as reversing the original array) expected_output = input_array[::-1, ::-1] np.testing.assert_array_equal(output_array, expected_output), "Output array content is not as expected" + + +# Sample functions to be wrapped +@float32_io +def dummy_float32_func(img): + return img * 2 + +@uint8_io +def dummy_uint8_func(img): + return np.clip(img + 10, 0, 255).astype(np.uint8) + +# Test data +@pytest.fixture(params=[ + np.uint8, np.float32 +]) +def test_image(request): + dtype = request.param + if np.issubdtype(dtype, np.integer): + return np.random.randint(0, 256, (10, 10, 3), dtype=dtype) + else: + return np.random.rand(10, 10, 3).astype(dtype) + +# Tests +@pytest.mark.parametrize("wrapper,func, image", [ + (float32_io, dummy_float32_func, np.random.randint(0, 256, (10, 10, 3), dtype=np.uint8)), + (uint8_io, dummy_uint8_func, np.random.rand(10, 10, 3).astype(np.float32)) +]) +def test_io_wrapper(wrapper, func, image): + input_dtype = image.dtype + result = func(image) + + # Check if the output dtype matches the input dtype + assert result.dtype == input_dtype + + # Check if the function was actually applied + if wrapper == float32_io: + expected = from_float(to_float(image) * 2, input_dtype) + else: # uint8_io + expected = to_float(from_float(image, np.uint8) + 10) + + np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-5) + +@pytest.mark.parametrize("wrapper,func,expected_intermediate_dtype", [ + (float32_io, dummy_float32_func, np.float32), + (uint8_io, dummy_uint8_func, np.uint8) +]) +def test_intermediate_dtype(wrapper, func, expected_intermediate_dtype, test_image): + original_func = func.__wrapped__ # Access the original function + + def check_dtype(img): + assert img.dtype == expected_intermediate_dtype + return original_func(img) + + wrapped_func = wrapper(check_dtype) + wrapped_func(test_image) # This will raise an assertion error if the intermediate dtype is incorrect + +def test_float32_io_preserves_float32(test_image): + if test_image.dtype == np.float32: + result = dummy_float32_func(test_image) + assert result.dtype == np.float32 + +def test_uint8_io_preserves_uint8(test_image): + if test_image.dtype == np.uint8: + result = dummy_uint8_func(test_image) + assert result.dtype == np.uint8