From 6a0ad9b5891e434f43b3e167366ae110271559e8 Mon Sep 17 00:00:00 2001 From: Vladimir Iglovikov Date: Fri, 13 Dec 2024 13:02:11 -0800 Subject: [PATCH] Add pad if needed3d (#2196) * Empty-Commit * Added PadIfNeeded * Fix * Fixes in some tests * Tests pass * Added PadIfNeeded3D and ToTensor3D * Sourcery fixes * Sourcery fixes --------- Co-authored-by: Vladimir Iglovikov --- .pre-commit-config.yaml | 2 +- README.md | 8 + albumentations/augmentations/__init__.py | 1 + .../augmentations/dropout/coarse_dropout.py | 2 +- .../augmentations/geometric/functional.py | 6 +- .../augmentations/transforms3d/__init__.py | 2 + .../augmentations/transforms3d/functional.py | 86 +++++++++ .../augmentations/transforms3d/transforms.py | 163 ++++++++++++++++++ albumentations/core/composition.py | 107 +++++++++--- albumentations/core/pydantic.py | 43 +++++ albumentations/core/transforms_interface.py | 91 +++++++--- albumentations/core/types.py | 1 + albumentations/py.typed | 0 albumentations/pytorch/transforms.py | 89 +++++++++- tests/aug_definitions.py | 3 +- tests/test_augmentations.py | 32 ++-- tests/test_bbox.py | 3 +- tests/test_core.py | 123 +++++++++---- tests/test_crop.py | 6 +- tests/test_pytorch.py | 21 ++- tests/test_serialization.py | 43 +++-- tests/test_targets.py | 12 +- tests/test_transforms.py | 54 ++++-- tests/transforms3d/test_pytorch.py | 37 ++++ tests/transforms3d/test_targets.py | 88 ++++++++++ tests/transforms3d/test_transforms.py | 137 +++++++++++++++ tests/utils.py | 56 +++++- tools/make_transforms_docs.py | 64 +++++-- 28 files changed, 1103 insertions(+), 177 deletions(-) create mode 100644 albumentations/augmentations/transforms3d/__init__.py create mode 100644 albumentations/augmentations/transforms3d/functional.py create mode 100644 albumentations/augmentations/transforms3d/transforms.py delete mode 100644 albumentations/py.typed create mode 100644 tests/transforms3d/test_pytorch.py create mode 100644 tests/transforms3d/test_targets.py create mode 100644 tests/transforms3d/test_transforms.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2d7b8ed47..c1c52bf32 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -40,7 +40,7 @@ repos: - id: check-docstrings name: Check Docstrings for '---' sequences entry: python tools/check_docstrings.py - language: system + language: python types: [python] - repo: local hooks: diff --git a/README.md b/README.md index 8f55caf40..131d0f0c5 100644 --- a/README.md +++ b/README.md @@ -297,6 +297,14 @@ Spatial-level transforms will simultaneously change both an input image as well | [VerticalFlip](https://explore.albumentations.ai/transform/VerticalFlip) | ✓ | ✓ | ✓ | ✓ | | [XYMasking](https://explore.albumentations.ai/transform/XYMasking) | ✓ | ✓ | ✓ | ✓ | +### 3D transforms + +3D transforms operate on volumetric data and can modify both the input volume and associated 3D mask. + +| Transform | Image | Mask | +| -------------------------------------------------------------------------- | :---: | :--: | +| [PadIfNeeded3D](https://explore.albumentations.ai/transform/PadIfNeeded3D) | ✓ | ✓ | + ## A few more examples of **augmentations** ### Semantic segmentation on the Inria dataset diff --git a/albumentations/augmentations/__init__.py b/albumentations/augmentations/__init__.py index c46226f21..d93d4eb19 100644 --- a/albumentations/augmentations/__init__.py +++ b/albumentations/augmentations/__init__.py @@ -21,4 +21,5 @@ from .text.functional import * from .text.transforms import * from .transforms import * +from .transforms3d.transforms import * from .utils import * diff --git a/albumentations/augmentations/dropout/coarse_dropout.py b/albumentations/augmentations/dropout/coarse_dropout.py index 27568184f..09b3d164e 100644 --- a/albumentations/augmentations/dropout/coarse_dropout.py +++ b/albumentations/augmentations/dropout/coarse_dropout.py @@ -40,7 +40,7 @@ class CoarseDropout(BaseDropout): - 'inpaint_telea': uses OpenCV Telea inpainting method - 'inpaint_ns': uses OpenCV Navier-Stokes inpainting method Default: 0 - mask_fill_value (ColorType | None): Fill value for dropout regions in the mask. + fill_mask (ColorType | None): Fill value for dropout regions in the mask. If None, mask regions corresponding to image dropouts are unchanged. Default: None p (float): Probability of applying the transform. Default: 0.5 diff --git a/albumentations/augmentations/geometric/functional.py b/albumentations/augmentations/geometric/functional.py index 797f54dfe..f470db7c8 100644 --- a/albumentations/augmentations/geometric/functional.py +++ b/albumentations/augmentations/geometric/functional.py @@ -2888,7 +2888,7 @@ def bboxes_piecewise_affine( return bboxes -def _get_dimension_padding( +def get_dimension_padding( current_size: int, min_size: int | None, divisor: int | None, @@ -2940,12 +2940,12 @@ def get_padding_params( """ rows, cols = image_shape[:2] - h_pad_top, h_pad_bottom = _get_dimension_padding( + h_pad_top, h_pad_bottom = get_dimension_padding( rows, min_height, pad_height_divisor, ) - w_pad_left, w_pad_right = _get_dimension_padding(cols, min_width, pad_width_divisor) + w_pad_left, w_pad_right = get_dimension_padding(cols, min_width, pad_width_divisor) return h_pad_top, h_pad_bottom, w_pad_left, w_pad_right diff --git a/albumentations/augmentations/transforms3d/__init__.py b/albumentations/augmentations/transforms3d/__init__.py new file mode 100644 index 000000000..9d447f1d7 --- /dev/null +++ b/albumentations/augmentations/transforms3d/__init__.py @@ -0,0 +1,2 @@ +from .functional import * +from .transforms import * diff --git a/albumentations/augmentations/transforms3d/functional.py b/albumentations/augmentations/transforms3d/functional.py new file mode 100644 index 000000000..d91aac733 --- /dev/null +++ b/albumentations/augmentations/transforms3d/functional.py @@ -0,0 +1,86 @@ +import random +from typing import Literal + +import numpy as np + +from albumentations.core.types import NUM_VOLUME_DIMENSIONS, ColorType + + +def adjust_padding_by_position3d( + paddings: list[tuple[int, int]], # [(front, back), (top, bottom), (left, right)] + position: Literal["center", "random"], + py_random: random.Random, +) -> tuple[int, int, int, int, int, int]: + """Adjust padding values based on desired position for 3D data. + + Args: + paddings: List of tuples containing padding pairs for each dimension [(d_pad), (h_pad), (w_pad)] + position: Position of the image after padding. Either 'center' or 'random' + py_random: Random number generator + + Returns: + tuple[int, int, int, int, int, int]: Final padding values (d_front, d_back, h_top, h_bottom, w_left, w_right) + """ + if position == "center": + return ( + paddings[0][0], # d_front + paddings[0][1], # d_back + paddings[1][0], # h_top + paddings[1][1], # h_bottom + paddings[2][0], # w_left + paddings[2][1], # w_right + ) + + # For random position, redistribute padding for each dimension + d_pad = sum(paddings[0]) + h_pad = sum(paddings[1]) + w_pad = sum(paddings[2]) + + return ( + py_random.randint(0, d_pad), # d_front + d_pad - py_random.randint(0, d_pad), # d_back + py_random.randint(0, h_pad), # h_top + h_pad - py_random.randint(0, h_pad), # h_bottom + py_random.randint(0, w_pad), # w_left + w_pad - py_random.randint(0, w_pad), # w_right + ) + + +def pad_3d_with_params( + img: np.ndarray, + padding: tuple[int, int, int, int, int, int], # (d_front, d_back, h_top, h_bottom, w_left, w_right) + value: ColorType, +) -> np.ndarray: + """Pad 3D image with given parameters. + + Args: + img: Input image with shape (depth, height, width) or (depth, height, width, channels) + padding: Padding values (d_front, d_back, h_top, h_bottom, w_left, w_right) + value: Padding value + + Returns: + Padded image with same number of dimensions as input + """ + d_front, d_back, h_top, h_bottom, w_left, w_right = padding + + # Skip if no padding is needed + if d_front == d_back == h_top == h_bottom == w_left == w_right == 0: + return img + + # Handle both 3D and 4D arrays + pad_width = [ + (d_front, d_back), # depth padding + (h_top, h_bottom), # height padding + (w_left, w_right), # width padding + ] + + # Add channel padding if 4D array + if img.ndim == NUM_VOLUME_DIMENSIONS: + pad_width.append((0, 0)) # no padding for channels + + return np.pad( + img, + pad_width=pad_width, + mode="constant", + constant_values=value, + ) diff --git a/albumentations/augmentations/transforms3d/transforms.py b/albumentations/augmentations/transforms3d/transforms.py new file mode 100644 index 000000000..5095f5bc2 --- /dev/null +++ b/albumentations/augmentations/transforms3d/transforms.py @@ -0,0 +1,163 @@ +from __future__ import annotations + +from typing import Annotated, Any, Literal, cast + +import numpy as np +from pydantic import AfterValidator, model_validator +from typing_extensions import Self + +from albumentations.augmentations.geometric import functional as fgeometric +from albumentations.augmentations.transforms3d import functional as f3d +from albumentations.core.pydantic import check_range_bounds_3d +from albumentations.core.transforms_interface import Transform3D +from albumentations.core.types import ColorType, Targets + +__all__ = ["PadIfNeeded3D"] + + +class PadIfNeeded3D(Transform3D): + """Pads the sides of a 3D volume if its dimensions are less than specified minimum dimensions. + If the pad_divisor_zyx is specified, the function additionally ensures that the volume + dimensions are divisible by these values. + + Args: + min_zyx (tuple[int, int, int] | None): Minimum desired size as (depth, height, width). + Ensures volume dimensions are at least these values. + If not specified, pad_divisor_zyx must be provided. + pad_divisor_zyx (tuple[int, int, int] | None): If set, pads each dimension to make it + divisible by corresponding value in format (depth_div, height_div, width_div). + If not specified, min_zyx must be provided. + position (Literal["center", "random"]): Position where the volume is to be placed after padding. + Default is 'center'. + fill (ColorType): Value to fill the border voxels for images. Default: 0 + fill_mask (ColorType): Value to fill the border voxels for masks. Default: 0 + p (float): Probability of applying the transform. Default: 1.0 + + Targets: + images, masks + + Image types: + uint8, float32 + + Note: + - Either min_zyx or pad_divisor_zyx must be set, but not both for each dimension. + - The transform will maintain consistency across all targets (image and mask). + - Input volumes can be either 3D arrays (depth, height, width) or + 4D arrays (depth, height, width, channels). + - Padding is always applied using constant values specified by fill/fill_mask. + + Example: + >>> import albumentations as A + >>> transform = A.Compose([ + ... A.PadIfNeeded3D( + ... min_zyx=(64, 128, 128), # Minimum size for each dimension + ... fill=0, # Fill value for images + ... fill_mask=0, # Fill value for masks + ... ), + ... ]) + >>> # For divisible dimensions + >>> transform = A.Compose([ + ... A.PadIfNeeded3D( + ... pad_divisor_zyx=(16, 16, 16), # Make dimensions divisible by 16 + ... fill=0, + ... ), + ... ]) + >>> transformed = transform(image=volume, masks=masks) + >>> padded_volume = transformed['images'] + >>> padded_masks = transformed['masks'] + """ + + _targets = (Targets.IMAGE, Targets.MASK) + + class InitSchema(Transform3D.InitSchema): + min_zyx: Annotated[tuple[int, int, int] | None, AfterValidator(check_range_bounds_3d(0, None))] + pad_divisor_zyx: Annotated[tuple[int, int, int] | None, AfterValidator(check_range_bounds_3d(1, None))] + position: Literal["center", "random"] + fill: ColorType + fill_mask: ColorType + + @model_validator(mode="after") + def validate_params(self) -> Self: + if self.min_zyx is None and self.pad_divisor_zyx is None: + msg = "At least one of min_zyx or pad_divisor_zyx must be set" + raise ValueError(msg) + return self + + def __init__( + self, + min_zyx: tuple[int, int, int] | None = None, + pad_divisor_zyx: tuple[int, int, int] | None = None, + position: Literal["center", "random"] = "center", + fill: ColorType = 0, + fill_mask: ColorType = 0, + p: float = 1.0, + always_apply: bool | None = None, + ): + super().__init__(p=p, always_apply=always_apply) + self.min_zyx = min_zyx + self.pad_divisor_zyx = pad_divisor_zyx + self.position = position + self.fill = fill + self.fill_mask = fill_mask + + def get_params_dependent_on_data( + self, + params: dict[str, Any], + data: dict[str, Any], + ) -> dict[str, Any]: + depth, height, width = data["images"].shape[:3] + sizes = (depth, height, width) + + paddings = [ + fgeometric.get_dimension_padding( + current_size=size, + min_size=self.min_zyx[i] if self.min_zyx else None, + divisor=self.pad_divisor_zyx[i] if self.pad_divisor_zyx else None, + ) + for i, size in enumerate(sizes) + ] + + padding = f3d.adjust_padding_by_position3d( + paddings=paddings, + position=self.position, + py_random=self.py_random, + ) + + return {"padding": padding} # (d_front, d_back, h_top, h_bottom, w_left, w_right) + + def apply_to_images( + self, + images: np.ndarray, + padding: tuple[int, int, int, int, int, int], + **params: Any, + ) -> np.ndarray: + if padding == (0, 0, 0, 0, 0, 0): + return images + return f3d.pad_3d_with_params( + img=images, + padding=padding, # (d_front, d_back, h_top, h_bottom, w_left, w_right) + value=cast(ColorType, self.fill), + ) + + def apply_to_masks( + self, + masks: np.ndarray, + padding: tuple[int, int, int, int, int, int], + **params: Any, + ) -> np.ndarray: + if padding == (0, 0, 0, 0, 0, 0): + return masks + return f3d.pad_3d_with_params( + img=masks, + padding=padding, # (d_front, d_back, h_top, h_bottom, w_left, w_right) + value=cast(ColorType, self.fill_mask), + ) + + def get_transform_init_args_names(self) -> tuple[str, ...]: + return ( + "min_zyx", + "pad_divisor_zyx", + "position", + "fill", + "fill_mask", + ) diff --git a/albumentations/core/composition.py b/albumentations/core/composition.py index 3a5a435b2..fbecedc96 100644 --- a/albumentations/core/composition.py +++ b/albumentations/core/composition.py @@ -9,8 +9,6 @@ import cv2 import numpy as np -from albumentations.core.types import NUM_MULTI_CHANNEL_DIMENSIONS - from .bbox_utils import BboxParams, BboxProcessor from .hub_mixin import HubMixin from .keypoints_utils import KeypointParams, KeypointsProcessor @@ -378,6 +376,8 @@ def __init__( self._set_processors_for_transforms(self.transforms) self.save_applied_params = save_applied_params + self._images_was_list = False + self._masks_was_list = False def _set_processors_for_transforms(self, transforms: TransformsSeqType) -> None: for transform in transforms: @@ -428,28 +428,76 @@ def __call__(self, *args: Any, force_apply: bool = False, **data: Any) -> dict[s return self.postprocess(data) def preprocess(self, data: Any) -> None: - if self.strict: - for data_name in data: - if ( - data_name not in self._available_keys - and data_name not in MASK_KEYS - and data_name not in IMAGE_KEYS - and data_name != "applied_transforms" - ): - msg = f"Key {data_name} is not in available keys." - raise ValueError(msg) + """Preprocess input data before applying transforms.""" + self._validate_data(data) + self._preprocess_processors(data) + self._preprocess_arrays(data) + + def _validate_data(self, data: dict[str, Any]) -> None: + """Validate input data keys and arguments.""" + if not self.strict: + return + + for data_name in data: + if not self._is_valid_key(data_name): + raise ValueError(f"Key {data_name} is not in available keys.") + if self.is_check_args: self._check_args(**data) - if self.main_compose: - for p in self.processors.values(): - p.ensure_data_valid(data) - for p in self.processors.values(): - p.preprocess(data) + + def _is_valid_key(self, key: str) -> bool: + """Check if the key is valid for processing.""" + return key in self._available_keys or key in MASK_KEYS or key in IMAGE_KEYS or key == "applied_transforms" + + def _preprocess_processors(self, data: dict[str, Any]) -> None: + """Run preprocessors if this is the main compose.""" + if not self.main_compose: + return + + for processor in self.processors.values(): + processor.ensure_data_valid(data) + for processor in self.processors.values(): + processor.preprocess(data) + + def _preprocess_arrays(self, data: dict[str, Any]) -> None: + """Convert lists to numpy arrays for images and masks.""" + self._preprocess_images(data) + self._preprocess_masks(data) + + def _preprocess_images(self, data: dict[str, Any]) -> None: + """Convert image lists to numpy arrays.""" + if "images" not in data: + return + + if isinstance(data["images"], (list, tuple)): + self._images_was_list = True + data["images"] = np.stack(data["images"]) + else: + self._images_was_list = False + + def _preprocess_masks(self, data: dict[str, Any]) -> None: + """Convert mask lists to numpy arrays.""" + if "masks" not in data: + return + + if isinstance(data["masks"], (list, tuple)): + self._masks_was_list = True + data["masks"] = np.stack(data["masks"]) + else: + self._masks_was_list = False def postprocess(self, data: dict[str, Any]) -> dict[str, Any]: if self.main_compose: for p in self.processors.values(): p.postprocess(data) + + # Convert back to list if original input was a list + if "images" in data and self._images_was_list: + data["images"] = list(data["images"]) + + if "masks" in data and self._masks_was_list: + data["masks"] = list(data["masks"]) + return data def to_dict_private(self) -> dict[str, Any]: @@ -489,11 +537,28 @@ def _check_single_data(data_name: str, data: Any) -> tuple[int, int]: @staticmethod def _check_masks_data(data_name: str, data: Any) -> tuple[int, int]: + """Check masks data format and return shape. + + Args: + data_name: Name of the data field being checked + data: Input data in one of these formats: + - List of numpy arrays, each of shape (H, W) or (H, W, C) + - Numpy array of shape (N, H, W) or (N, H, W, C) + + Returns: + tuple: (height, width) of the first mask + + Raises: + TypeError: If data format is invalid + """ if isinstance(data, np.ndarray): - if data.ndim not in [3, 4]: - raise TypeError(f"{data_name} must be a 3D or 4D numpy array") - return data.shape[1:3] if data.ndim == NUM_MULTI_CHANNEL_DIMENSIONS else data.shape[:2] - if isinstance(data, Sequence): + if data.ndim not in [3, 4]: # (N,H,W) or (N,H,W,C) + raise TypeError(f"{data_name} as numpy array must be 3D or 4D") + return data.shape[1:3] # Return (H,W) + + if isinstance(data, (list, tuple)): + if not data: + raise ValueError(f"{data_name} cannot be empty") if not all(isinstance(m, np.ndarray) for m in data): raise TypeError(f"All elements in {data_name} must be numpy arrays") if any(m.ndim not in [2, 3] for m in data): diff --git a/albumentations/core/pydantic.py b/albumentations/core/pydantic.py index 50f6cd3ec..6d392a2bb 100644 --- a/albumentations/core/pydantic.py +++ b/albumentations/core/pydantic.py @@ -174,3 +174,46 @@ def validator(value: tuple[Number, Number]) -> tuple[Number, Number]: return value return validator + + +def check_range_bounds_3d( + min_val: Number, + max_val: Number | None = None, + min_inclusive: bool = True, + max_inclusive: bool = True, +) -> Callable[[tuple[Number, Number, Number] | None], tuple[Number, Number, Number] | None]: + """Validates that all three values in a tuple are within specified bounds. + + Args: + min_val: Minimum allowed value + max_val: Maximum allowed value. If None, only lower bound is checked. + min_inclusive: If True, min_val is inclusive (>=). If False, exclusive (>). + max_inclusive: If True, max_val is inclusive (<=). If False, exclusive (<). + + Returns: + Validator function that checks if all values in tuple are within bounds. + Returns None if input is None. + + Raises: + ValueError: If any value in tuple is outside the allowed range + """ + + def validator(value: tuple[Number, Number, Number] | None) -> tuple[Number, Number, Number] | None: + if value is None: + return None + + min_op = (lambda x, y: x >= y) if min_inclusive else (lambda x, y: x > y) + max_op = (lambda x, y: x <= y) if max_inclusive else (lambda x, y: x < y) + + if max_val is None: + if not all(min_op(x, min_val) for x in value): + op_symbol = ">=" if min_inclusive else ">" + raise ValueError(f"All values in {value} must be {op_symbol} {min_val}") + else: + min_symbol = ">=" if min_inclusive else ">" + max_symbol = "<=" if max_inclusive else "<" + if not all(min_op(x, min_val) and max_op(x, max_val) for x in value): + raise ValueError(f"All values in {value} must be {min_symbol} {min_val} and {max_symbol} {max_val}") + return value + + return validator diff --git a/albumentations/core/transforms_interface.py b/albumentations/core/transforms_interface.py index 34b6f80a8..63f1703a2 100644 --- a/albumentations/core/transforms_interface.py +++ b/albumentations/core/transforms_interface.py @@ -1,7 +1,6 @@ from __future__ import annotations import random -from collections.abc import Sequence from copy import deepcopy from typing import Any, Callable from warnings import warn @@ -17,14 +16,13 @@ from .serialization import Serializable, SerializableMeta, get_shortest_class_fullname from .types import ( - NUM_MULTI_CHANNEL_DIMENSIONS, ColorType, DropoutFillValue, Targets, ) from .utils import ensure_contiguous_output, format_args -__all__ = ["BasicTransform", "DualTransform", "ImageOnlyTransform", "NoOp"] +__all__ = ["BasicTransform", "DualTransform", "ImageOnlyTransform", "NoOp", "Transform3D"] class Interpolation: @@ -226,25 +224,22 @@ def apply(self, img: np.ndarray, *args: Any, **params: Any) -> np.ndarray: """Apply transform on image.""" raise NotImplementedError - def apply_to_images(self, images: np.ndarray, **params: Any) -> np.ndarray | list[np.ndarray]: + def apply_to_images(self, images: np.ndarray, *args: Any, **params: Any) -> np.ndarray: """Apply transform on images. Args: - images: Input images in one of these formats: - - List-like of images - - Numpy array of shape (num_images, height, width, channels) - - Numpy array of shape (num_images, height, width) for grayscale + images: Input images as numpy array of shape: + - (num_images, height, width, channels) + - (num_images, height, width) for grayscale + *args: Additional positional arguments **params: Additional parameters specific to the transform Returns: - Transformed images in the same format as input + Transformed images as numpy array in the same format as input """ - if isinstance(images, np.ndarray) and images.ndim in [3, 4]: - # Handle batched numpy array input - transformed = np.stack([self.apply(image, **params) for image in images]) - return np.require(transformed, requirements=["C_CONTIGUOUS"]) - # Handle list-like input - return [self.apply(image, **params) for image in images] + # Handle batched numpy array input + transformed = np.stack([self.apply(image, **params) for image in images]) + return np.require(transformed, requirements=["C_CONTIGUOUS"]) def get_params(self) -> dict[str, Any]: """Returns parameters independent of input.""" @@ -414,6 +409,9 @@ class DualTransform(BasicTransform): Returns Transformed bounding boxes array of shape (N, 4+). + apply_to_images(images: np.ndarray, **params: Any) -> np.ndarray: + Apply the transform to multiple images. + Note: - All `apply_*` methods should maintain the input shape and format of the data. - The `apply_to_mask` and `apply_to_masks` methods handle both single arrays and sequences of arrays. @@ -465,20 +463,12 @@ def apply_to_bboxes(self, bboxes: np.ndarray, *args: Any, **params: Any) -> np.n raise NotImplementedError(f"BBoxes not implemented for {self.__class__.__name__}") def apply_to_mask(self, mask: np.ndarray, *args: Any, **params: Any) -> np.ndarray: - return self.apply(mask, **{k: cv2.INTER_NEAREST if k == "interpolation" else v for k, v in params.items()}) - - def apply_to_masks(self, masks: np.ndarray | Sequence[np.ndarray], **params: Any) -> list[np.ndarray] | np.ndarray: - if isinstance(masks, np.ndarray): - if masks.ndim == NUM_MULTI_CHANNEL_DIMENSIONS: - # Transpose from (num_channels, height, width) to (height, width, num_channels) - masks = np.transpose(masks, (1, 2, 0)) - masks = np.require(masks, requirements=["C_CONTIGUOUS"]) - transformed_masks = self.apply_to_mask(masks, **params) - # Transpose back to (num_channels, height, width) - return np.require(np.transpose(transformed_masks, (2, 0, 1)), requirements=["C_CONTIGUOUS"]) + return self.apply(mask, *args, **params) - return self.apply_to_mask(masks, **params) - return [self.apply_to_mask(mask, **params) for mask in masks] + def apply_to_masks(self, masks: np.ndarray, *args: Any, **params: Any) -> np.ndarray: + """Apply transform to masks by applying to each slice separately.""" + transformed_slices = [self.apply_to_mask(mask_slice, *args, **params) for mask_slice in masks] + return np.stack(transformed_slices) class ImageOnlyTransform(BasicTransform): @@ -514,3 +504,48 @@ def apply_to_mask(self, mask: np.ndarray, **params: Any) -> np.ndarray: def get_transform_init_args_names(self) -> tuple[str, ...]: return () + + +class Transform3D(DualTransform): + """Base class for all 3D transforms. + + Transform3D inherits from DualTransform because 3D transforms can be applied to both + images (volumes) and masks, similar to how 2D DualTransforms work with images and masks. + + Targets: + images: 3D numpy array of shape (D, H, W) or (D, H, W, C) + masks: 3D numpy array of shape (D, H, W) or sequence of such arrays + """ + + def apply(self, img: np.ndarray, **params: Any) -> np.ndarray: + raise NotImplementedError("Use 'images' target instead") + + def apply_to_mask(self, mask: np.ndarray, **params: Any) -> np.ndarray: + raise NotImplementedError("Use 'masks' target instead") + + def apply_to_images(self, images: np.ndarray, *args: Any, **params: Any) -> np.ndarray: + """Apply transform to 3D volume.""" + raise NotImplementedError + + def apply_to_masks( + self, + masks: np.ndarray, + *args: Any, + **params: Any, + ) -> np.ndarray: + """Apply transform to 3D mask or sequence of 3D masks.""" + raise NotImplementedError + + def apply_to_bboxes(self, bboxes: np.ndarray, **params: Any) -> np.ndarray: + raise NotImplementedError("3D bounding boxes not implemented") + + def apply_to_keypoints(self, keypoints: np.ndarray, **params: Any) -> np.ndarray: + raise NotImplementedError("3D keypoints not implemented") + + @property + def targets(self) -> dict[str, Callable[..., Any]]: + """Define valid targets for 3D transforms.""" + return { + "images": self.apply_to_images, + "masks": self.apply_to_masks, + } diff --git a/albumentations/core/types.py b/albumentations/core/types.py index e14878c89..bac43ddd1 100644 --- a/albumentations/core/types.py +++ b/albumentations/core/types.py @@ -49,6 +49,7 @@ class Targets(Enum): KEYPOINTS = "Keypoints" +NUM_VOLUME_DIMENSIONS = 4 NUM_MULTI_CHANNEL_DIMENSIONS = 3 MONO_CHANNEL_DIMENSIONS = 2 NUM_RGB_CHANNELS = 3 diff --git a/albumentations/py.typed b/albumentations/py.typed deleted file mode 100644 index e69de29bb..000000000 diff --git a/albumentations/pytorch/transforms.py b/albumentations/pytorch/transforms.py index 98aa7bdb4..dc5fb5db1 100644 --- a/albumentations/pytorch/transforms.py +++ b/albumentations/pytorch/transforms.py @@ -6,9 +6,14 @@ import torch from albumentations.core.transforms_interface import BasicTransform -from albumentations.core.types import MONO_CHANNEL_DIMENSIONS, NUM_MULTI_CHANNEL_DIMENSIONS, Targets +from albumentations.core.types import ( + MONO_CHANNEL_DIMENSIONS, + NUM_MULTI_CHANNEL_DIMENSIONS, + NUM_VOLUME_DIMENSIONS, + Targets, +) -__all__ = ["ToTensorV2"] +__all__ = ["ToTensor3D", "ToTensorV2"] class ToTensorV2(BasicTransform): @@ -53,8 +58,84 @@ def apply_to_mask(self, mask: np.ndarray, **params: Any) -> torch.Tensor: mask = mask.transpose(2, 0, 1) return torch.from_numpy(mask) - def apply_to_masks(self, masks: list[np.ndarray], **params: Any) -> list[torch.Tensor]: - return [self.apply_to_mask(mask, **params) for mask in masks] + def apply_to_masks(self, masks: np.ndarray, **params: Any) -> torch.Tensor: + """Convert numpy array of masks to torch tensor. + + Args: + masks: numpy array of shape (N, H, W) or (N, H, W, C) + params: Additional parameters + + Returns: + torch.Tensor: If transpose_mask is True and input is (N, H, W, C), + returns tensor of shape (N, C, H, W). + Otherwise returns tensor with same shape as input. + """ + if self.transpose_mask and masks.ndim == NUM_VOLUME_DIMENSIONS: # (N, H, W, C) + masks = np.transpose(masks, (0, 3, 1, 2)) # -> (N, C, H, W) + return torch.from_numpy(masks) + + def apply_to_images(self, images: np.ndarray, **params: Any) -> torch.Tensor: + """Convert batch of images from (N, H, W, C) to (N, C, H, W).""" + if images.ndim != NUM_VOLUME_DIMENSIONS: # N,H,W,C + raise ValueError(f"Expected 4D array (N,H,W,C), got {images.ndim}D array") + return torch.from_numpy(images.transpose(0, 3, 1, 2)) # -> (N,C,H,W) + + def get_transform_init_args_names(self) -> tuple[str, ...]: + return ("transpose_mask",) + + +class ToTensor3D(BasicTransform): + """Convert 3D volumes and masks to PyTorch tensors. + + This transform is designed for 3D medical imaging data. It handles the conversion + of numpy arrays to PyTorch tensors and performs necessary channel transpositions. + + For volumes: + - Input: (D, H, W, C) - depth, height, width, channels + - Output: (C, D, H, W) - channels, depth, height, width + + For masks: + - If transpose_mask=False: + - Input: (D, H, W) or (D, H, W, C) + - Output: Same shape as input + - If transpose_mask=True and mask has channels: + - Input: (D, H, W, C) + - Output: (C, D, H, W) + + Args: + transpose_mask (bool): If True and masks have channels, transposes masks from + (D, H, W, C) to (C, D, H, W). Default: False + p (float): Probability of applying the transform. Default: 1.0 + """ + + _targets = (Targets.IMAGE, Targets.MASK) + + def __init__(self, transpose_mask: bool = False, p: float = 1.0, always_apply: bool | None = None): + super().__init__(p=p, always_apply=always_apply) + self.transpose_mask = transpose_mask + + @property + def targets(self) -> dict[str, Any]: + return { + "images": self.apply_to_images, + "masks": self.apply_to_masks, + } + + def apply_to_images(self, images: np.ndarray, **params: Any) -> torch.Tensor: + """Convert 3D volume from (D,H,W,C) to (C,D,H,W).""" + if images.ndim != NUM_VOLUME_DIMENSIONS: # D,H,W,C + raise ValueError(f"Expected 4D array (D,H,W,C), got {images.ndim}D array") + return torch.from_numpy(images.transpose(3, 0, 1, 2)) + + def apply_to_masks(self, masks: np.ndarray, **params: Any) -> torch.Tensor: + """Convert 3D mask to tensor. + + If transpose_mask is True and mask has channels (D,H,W,C), + converts to (C,D,H,W). Otherwise keeps the original shape. + """ + if self.transpose_mask and masks.ndim == NUM_VOLUME_DIMENSIONS: # D,H,W,C + masks = masks.transpose(3, 0, 1, 2) + return torch.from_numpy(masks) def get_transform_init_args_names(self) -> tuple[str, ...]: return ("transpose_mask",) diff --git a/tests/aug_definitions.py b/tests/aug_definitions.py index 88c5be3a6..933f98f4d 100644 --- a/tests/aug_definitions.py +++ b/tests/aug_definitions.py @@ -192,7 +192,7 @@ { "scale": 0.5, "keep_size": False, - "pad_mode": cv2.BORDER_REFLECT_101, + "border_mode": cv2.BORDER_REFLECT_101, "fill": 10, "fill_mask": 100, "fit_output": True, @@ -404,4 +404,5 @@ [A.Illumination, {}], [A.ThinPlateSpline, {}], [A.AutoContrast, {}], + [A.PadIfNeeded3D, {"min_zyx": (300, 200, 400), "pad_divisor_zyx": (10, 10, 10), "position": "center", "fill": 10, "fill_mask": 20}], ] diff --git a/tests/test_augmentations.py b/tests/test_augmentations.py index e1735d3a0..53ff00f11 100644 --- a/tests/test_augmentations.py +++ b/tests/test_augmentations.py @@ -15,7 +15,7 @@ SQUARE_UINT8_IMAGE, ) -from .utils import get_dual_transforms, get_image_only_transforms, get_transforms, set_seed +from .utils import get_2d_transforms, get_dual_transforms, get_image_only_transforms, get_transforms, set_seed @pytest.mark.parametrize( @@ -192,7 +192,7 @@ def test_dual_augmentations_with_float_values(augmentation_cls, params): @pytest.mark.parametrize( ["augmentation_cls", "params"], - get_transforms( + get_2d_transforms( custom_arguments={ A.HistogramMatching: { "reference_images": [SQUARE_UINT8_IMAGE], @@ -253,13 +253,13 @@ def test_augmentations_wont_change_input(augmentation_cls, params): else: aug(image=image, mask=mask) - assert np.array_equal(image, image_copy) - assert np.array_equal(mask, mask_copy) + np.testing.assert_array_equal(image, image_copy) + np.testing.assert_array_equal(mask, mask_copy) @pytest.mark.parametrize( ["augmentation_cls", "params"], - get_transforms( + get_2d_transforms( custom_arguments={ A.HistogramMatching: { "reference_images": [SQUARE_FLOAT_IMAGE], @@ -326,12 +326,12 @@ def test_augmentations_wont_change_float_input(augmentation_cls, params): aug(**data) - assert np.array_equal(image, float_image_copy) + np.testing.assert_array_equal(image, float_image_copy) @pytest.mark.parametrize( ["augmentation_cls", "params"], - get_transforms( + get_2d_transforms( custom_arguments={ A.HistogramMatching: { "reference_images": [np.random.randint(0, 255, [100, 100], dtype=np.uint8)], @@ -420,7 +420,7 @@ def test_augmentations_wont_change_shape_grayscale(augmentation_cls, params, sha @pytest.mark.parametrize( ["augmentation_cls", "params"], - get_transforms( + get_2d_transforms( custom_arguments={ A.HistogramMatching: { "reference_images": [SQUARE_UINT8_IMAGE], @@ -544,7 +544,7 @@ def test_mask_fill_value(augmentation_cls, params): @pytest.mark.parametrize( ["augmentation_cls", "params"], - get_transforms( + get_2d_transforms( custom_arguments={ A.HistogramMatching: { "reference_images": [SQUARE_MULTI_UINT8_IMAGE], @@ -632,7 +632,7 @@ def test_multichannel_image_augmentations(augmentation_cls, params): @pytest.mark.parametrize( ["augmentation_cls", "params"], - get_transforms( + get_2d_transforms( custom_arguments={ A.HistogramMatching: { "reference_images": [SQUARE_MULTI_FLOAT_IMAGE], @@ -721,7 +721,7 @@ def test_float_multichannel_image_augmentations(augmentation_cls, params): @pytest.mark.parametrize( ["augmentation_cls", "params"], - get_transforms( + get_2d_transforms( custom_arguments={ A.Crop: {"y_min": 0, "y_max": 10, "x_min": 0, "x_max": 10}, A.CenterCrop: {"height": 10, "width": 10}, @@ -806,7 +806,7 @@ def test_multichannel_image_augmentations_diff_channels(augmentation_cls, params @pytest.mark.parametrize( ["augmentation_cls", "params"], - get_transforms( + get_2d_transforms( custom_arguments={ A.Crop: {"y_min": 0, "y_max": 10, "x_min": 0, "x_max": 10}, A.CenterCrop: {"height": 10, "width": 10}, @@ -1047,7 +1047,7 @@ def test_pad_if_needed_position(params, image_shape): @pytest.mark.parametrize( ["augmentation_cls", "params"], - get_transforms( + get_2d_transforms( custom_arguments={ A.Crop: {"y_min": 0, "y_max": 10, "x_min": 0, "x_max": 10}, A.CenterCrop: {"height": 10, "width": 10}, @@ -1092,8 +1092,7 @@ def test_augmentations_match_uint8_float32(augmentation_cls, params): image_uint8 = RECTANGULAR_UINT8_IMAGE image_float32 = to_float(image_uint8) - transform = A.Compose([augmentation_cls(p=1, **params)]) - + transform = A.Compose([augmentation_cls(p=1, **params)], seed=42) data = {"image": image_uint8} if augmentation_cls == A.MaskDropout: @@ -1101,11 +1100,8 @@ def test_augmentations_match_uint8_float32(augmentation_cls, params): mask[:20, :20] = 1 data["mask"] = mask - set_seed(42) - transform.set_random_seed(42) transformed_uint8 = transform(**data)["image"] - set_seed(42) data["image"] = image_float32 transform.set_random_seed(42) diff --git a/tests/test_bbox.py b/tests/test_bbox.py index 19f1160f6..81d36b511 100644 --- a/tests/test_bbox.py +++ b/tests/test_bbox.py @@ -1025,12 +1025,13 @@ def test_bounding_box_vflip(bbox, expected_bbox) -> None: @pytest.mark.parametrize( "get_transform", [ - lambda sign: A.Affine(translate_px=sign * 2, mode=cv2.BORDER_CONSTANT, cval=255), + lambda sign: A.Affine(translate_px=sign * 2, mode=cv2.BORDER_CONSTANT, fill=255), lambda sign: A.ShiftScaleRotate( shift_limit=(sign * 0.02, sign * 0.02), scale_limit=0, rotate_limit=0, border_mode=cv2.BORDER_CONSTANT, + fill=255, ), ], ) diff --git a/tests/test_core.py b/tests/test_core.py index eda8a3b61..b2a1a9157 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -155,25 +155,31 @@ def test_image_only_transform(image): @pytest.mark.parametrize("image", IMAGES) def test_dual_transform(image): mask = image.copy() - image_call = call( - image, - interpolation=cv2.INTER_LINEAR, - cols=image.shape[1], - rows=image.shape[0], - shape=image.shape, - ) - mask_call = call( - mask, - interpolation=cv2.INTER_NEAREST, - cols=mask.shape[1], - rows=mask.shape[0], - shape=mask.shape, - ) + with mock.patch.object(DualTransform, "apply") as mocked_apply: - with mock.patch.object(DualTransform, "get_params", return_value={"interpolation": cv2.INTER_LINEAR}): + with mock.patch.object(DualTransform, "get_params", return_value={}): # Empty params aug = DualTransform(p=1) aug(image=image, mask=mask) - mocked_apply.assert_has_calls([image_call, mask_call], any_order=True) + + # Get the actual calls + calls = mocked_apply.call_args_list + assert len(calls) == 2 # Should be called twice + + # Check each call has correct structure + for call_args in calls: + args, kwargs = call_args + + # Check kwargs contain correct keys and values + assert "cols" in kwargs + assert "rows" in kwargs + assert "shape" in kwargs + assert kwargs["cols"] == image.shape[1] + assert kwargs["rows"] == image.shape[0] + assert kwargs["shape"] == image.shape + + # Check input array is either image or mask + input_array = args[0] + assert np.array_equal(input_array, image) or np.array_equal(input_array, mask) @pytest.mark.parametrize("image", IMAGES) @@ -1099,7 +1105,7 @@ def test_transform_always_apply_warning() -> None: "reference_images": [np.random.randint(0, 256, [100, 100, 3], dtype=np.uint8)], "read_fn": lambda x: x, "transform_type": "standard", - }, + } }, except_augmentations={ A.FDA, @@ -1111,6 +1117,7 @@ def test_transform_always_apply_warning() -> None: A.BBoxSafeRandomCrop, A.OverlayElements, A.TextImage, + A.PadIfNeeded3D, }, ), ) @@ -1143,6 +1150,7 @@ def test_images_as_target(augmentation_cls, params, as_array, shape): aug = A.Compose( [augmentation_cls(p=1, **params)], + p=1, ) transformed = aug(**data) @@ -1219,6 +1227,7 @@ def test_images_as_target(augmentation_cls, params, as_array, shape): }, A.TextImage: dict(font_path="./tests/files/LiberationSerif-Bold.ttf"), A.GridElasticDeform: {"num_grid_xy": (10, 10), "magnitude": 10}, + A.PadIfNeeded3D: {"min_zyx": (300, 200, 400), "pad_divisor_zyx": (10, 10, 10), "position": "center", "fill": 10, "fill_mask": 20}, }, ), ) @@ -1230,26 +1239,47 @@ def test_non_contiguous_input_with_compose(augmentation_cls, params, bboxes): assert not image.flags["C_CONTIGUOUS"] assert not mask.flags["C_CONTIGUOUS"] + transforms3d = {A.PadIfNeeded3D} + if augmentation_cls == A.RandomCropNearBBox: # requires "cropping_bbox" arg aug = A.Compose([augmentation_cls(p=1, **params)]) - transformed = aug(image=image, mask=mask, cropping_bbox=bboxes[0]) + + data = { + "image": image, + "mask": mask, + "cropping_bbox": bboxes[0], + } elif augmentation_cls in [A.RandomSizedBBoxSafeCrop, A.BBoxSafeRandomCrop]: # requires "bboxes" arg aug = A.Compose([augmentation_cls(p=1, **params)], bbox_params=A.BboxParams(format="pascal_voc")) - transformed = aug(image=image, mask=mask, bboxes=bboxes) + data = { + "image": image, + "mask": mask, + "bboxes": bboxes, + } elif augmentation_cls == A.TextImage: aug = A.Compose([augmentation_cls(p=1, **params)], bbox_params=A.BboxParams(format="pascal_voc")) - transformed = aug( - image=image, - mask=mask, - bboxes=bboxes, - textimage_metadata={"text": "Hello, world!", "bbox": (0.1, 0.1, 0.9, 0.2)}, - ) + data = { + "image": image, + "mask": mask, + "bboxes": bboxes, + "textimage_metadata": {"text": "Hello, world!", "bbox": (0.1, 0.1, 0.9, 0.2)}, + } elif augmentation_cls == A.OverlayElements: # requires "metadata" arg aug = A.Compose([augmentation_cls(p=1, **params)]) - transformed = aug(image=image, overlay_metadata=[], mask=mask) + data = { + "image": image, + "mask": mask, + "overlay_metadata": [], + } + elif augmentation_cls in transforms3d: + aug = A.Compose([augmentation_cls(p=1, **params)], p=1) + data = { + "images": np.stack([image] * 2), + "masks": np.stack([mask] * 2), + } else: # standard args: image and mask if augmentation_cls == A.FromFloat: @@ -1260,16 +1290,25 @@ def test_non_contiguous_input_with_compose(augmentation_cls, params, bboxes): # requires single channel mask mask = mask[:, :, 0] - aug = augmentation_cls(p=1, **params) - transformed = aug(image=image, mask=mask) + aug = A.Compose([augmentation_cls(p=1, **params)], p=1) + data = { + "image": image, + "mask": mask, + } + transformed = aug(**data) + + if augmentation_cls in transforms3d: + assert transformed["masks"].flags["C_CONTIGUOUS"], f"{augmentation_cls.__name__} did not return a C_CONTIGUOUS masks" + assert transformed["images"].flags["C_CONTIGUOUS"], f"{augmentation_cls.__name__} did not return a C_CONTIGUOUS images" - assert transformed["image"].flags["C_CONTIGUOUS"] + else: + assert transformed["image"].flags["C_CONTIGUOUS"], f"{augmentation_cls.__name__} did not return a C_CONTIGUOUS image" - # Check if the augmentation is not an ImageOnlyTransform and mask is in the output - if not issubclass(augmentation_cls, ImageOnlyTransform) and "mask" in transformed: - assert transformed["mask"].flags[ - "C_CONTIGUOUS" - ], f"{augmentation_cls.__name__} did not return a C_CONTIGUOUS mask" + # Check if the augmentation is not an ImageOnlyTransform and mask is in the output + if not issubclass(augmentation_cls, ImageOnlyTransform) and "mask" in transformed: + assert transformed["mask"].flags[ + "C_CONTIGUOUS" + ], f"{augmentation_cls.__name__} did not return a C_CONTIGUOUS mask" @pytest.mark.parametrize( @@ -1305,6 +1344,7 @@ def test_non_contiguous_input_with_compose(augmentation_cls, params, bboxes): "read_fn": lambda x: x, "transform_type": "standard", }, + A.PadIfNeeded3D: {"min_zyx": (300, 200, 400), "pad_divisor_zyx": (10, 10, 10), "position": "center", "fill": 10, "fill_mask": 20}, }, except_augmentations={ A.FDA, @@ -1332,11 +1372,24 @@ def test_non_contiguous_input_with_compose(augmentation_cls, params, bboxes): def test_masks_as_target(augmentation_cls, params, masks): image = SQUARE_UINT8_IMAGE + transforms3d = {A.PadIfNeeded3D} + + data = { + "image": image, + "masks": masks, + } + + if augmentation_cls in transforms3d: + data = { + "images": np.stack([image] * 2), + "masks": masks, + } + aug = A.Compose( [augmentation_cls(p=1, **params)], ) - transformed = aug(image=image, masks=masks) + transformed = aug(**data) np.testing.assert_array_equal(transformed["masks"][0], transformed["masks"][1]) diff --git a/tests/test_crop.py b/tests/test_crop.py index 2b55dcd75..e9ea8c949 100644 --- a/tests/test_crop.py +++ b/tests/test_crop.py @@ -167,8 +167,8 @@ def test_pad_position_equivalence( crop_cls( **crop_params, pad_if_needed=True, - pad_mode=pad_mode, - pad_cval=0, + border_mode=pad_mode, + fill=0, pad_position=pad_position, ) ], keypoint_params=A.KeypointParams(format="xyas"), bbox_params=A.BboxParams(format="pascal_voc")) @@ -179,7 +179,7 @@ def test_pad_position_equivalence( min_height=crop_params["height"], min_width=crop_params["width"], border_mode=pad_mode, - value=0, + fill=0, position=pad_position, ), crop_cls( diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index 359c07b85..8b4be2df4 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -276,13 +276,18 @@ def test_to_tensor_v2_images_masks(): transformed = transform( image=image, mask=mask, - masks=[mask] * 2, - images=[image] * 2 + masks=np.stack([mask] * 2), # Now passing stacked numpy array + images=np.stack([image] * 2) # Stacked numpy array ) - # Check all outputs are torch.Tensor - for key in ['image', 'mask']: - assert isinstance(transformed[key], torch.Tensor) - - for key in ['masks', 'images']: - assert all(isinstance(t, torch.Tensor) for t in transformed[key]) + # Check outputs are torch.Tensor + assert isinstance(transformed["image"], torch.Tensor) + assert isinstance(transformed["mask"], torch.Tensor) + assert isinstance(transformed["masks"], torch.Tensor) + assert isinstance(transformed["images"], torch.Tensor) # Now checking single tensor + + # Check shapes + assert transformed["image"].shape == (3, 100, 100) # (C, H, W) + assert transformed["mask"].shape == (100, 100) # (H, W) + assert transformed["masks"].shape == (2, 100, 100) # (N, H, W) + assert transformed["images"].shape == (2, 3, 100, 100) # (N, C, H, W) diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 0813a971a..c4175f5b1 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -23,6 +23,7 @@ from .utils import ( OpenMock, check_all_augs_exists, + get_2d_transforms, get_image_only_transforms, get_transforms, ) @@ -129,6 +130,7 @@ def test_augmentations_serialization_with_custom_parameters( mask = image[:, :, 0].copy() aug = augmentation_cls(p=p, **params) aug.set_random_seed(seed) + transforms3d = {A.PadIfNeeded3D} serialized_aug = A.to_dict(aug) deserialized_aug = A.from_dict(serialized_aug) @@ -144,11 +146,19 @@ def test_augmentations_serialization_with_custom_parameters( data["cropping_bbox"] = [10, 20, 40, 50] elif augmentation_cls == A.TextImage: data["textimage_metadata"] = [] + elif augmentation_cls in transforms3d: + data["images"] = np.array([image] * 10) + data["masks"] = np.array([mask] * 10) aug_data = aug(**data) deserialized_aug_data = deserialized_aug(**data) - np.testing.assert_array_equal(aug_data["image"], deserialized_aug_data["image"]) - np.testing.assert_array_equal(aug_data["mask"], deserialized_aug_data["mask"]) + + if augmentation_cls not in transforms3d: + np.testing.assert_array_equal(aug_data["image"], deserialized_aug_data["image"]) + np.testing.assert_array_equal(aug_data["mask"], deserialized_aug_data["mask"]) + else: + np.testing.assert_array_equal(aug_data["images"], deserialized_aug_data["images"]) + np.testing.assert_array_equal(aug_data["masks"], deserialized_aug_data["masks"]) @pytest.mark.parametrize("image", UINT8_IMAGES) @@ -167,6 +177,7 @@ def test_augmentations_serialization_to_file_with_custom_parameters( image, data_format, ): + transforms3d = {A.PadIfNeeded3D} mask = image[:, :, 0].copy() with patch("builtins.open", OpenMock()): aug = augmentation_cls(p=p, **params) @@ -188,16 +199,24 @@ def test_augmentations_serialization_to_file_with_custom_parameters( data["cropping_bbox"] = [10, 20, 40, 50] elif augmentation_cls == A.TextImage: data["textimage_metadata"] = [] + elif augmentation_cls in transforms3d: + data["images"] = np.array([image] * 10) + data["masks"] = np.array([mask] * 10) aug_data = aug(**data) deserialized_aug_data = deserialized_aug(**data) - np.testing.assert_array_equal(aug_data["image"], deserialized_aug_data["image"]) - np.testing.assert_array_equal(aug_data["mask"], deserialized_aug_data["mask"]) + + if augmentation_cls not in transforms3d: + np.testing.assert_array_equal(aug_data["image"], deserialized_aug_data["image"]) + np.testing.assert_array_equal(aug_data["mask"], deserialized_aug_data["mask"]) + else: + np.testing.assert_array_equal(aug_data["images"], deserialized_aug_data["images"]) + np.testing.assert_array_equal(aug_data["masks"], deserialized_aug_data["masks"]) @pytest.mark.parametrize( ["augmentation_cls", "params"], - get_transforms( + get_2d_transforms( custom_arguments={ A.Crop: {"y_min": 0, "y_max": 10, "x_min": 0, "x_max": 10}, A.CenterCrop: {"height": 10, "width": 10}, @@ -229,6 +248,7 @@ def test_augmentations_serialization_to_file_with_custom_parameters( A.CropNonEmptyMaskIfExists, A.OverlayElements, A.TextImage, + A.PadIfNeeded3D, }, ), ) @@ -263,7 +283,7 @@ def test_augmentations_for_bboxes_serialization( @pytest.mark.parametrize( ["augmentation_cls", "params"], - get_transforms( + get_2d_transforms( custom_arguments={ A.Crop: {"y_min": 0, "y_max": 10, "x_min": 0, "x_max": 10}, A.CenterCrop: {"height": 10, "width": 10}, @@ -355,13 +375,15 @@ def test_augmentations_serialization_with_call_params( ): aug = augmentation_cls(p=p, **params) aug.set_random_seed(seed) - annotations = {"image": image, **call_params} + data = {"image": image, **call_params} serialized_aug = A.to_dict(aug) deserialized_aug = A.from_dict(serialized_aug) deserialized_aug.set_random_seed(seed) - aug_data = aug(**annotations) - deserialized_aug_data = deserialized_aug(**annotations) - assert np.array_equal(aug_data["image"], deserialized_aug_data["image"]) + + aug_data = aug(**data) + deserialized_aug_data = deserialized_aug(**data) + + np.testing.assert_array_equal(aug_data["image"], deserialized_aug_data["image"]) @pytest.mark.parametrize("image", FLOAT32_IMAGES) @@ -817,6 +839,7 @@ def test_template_transform_serialization( A.RandomSizedBBoxSafeCrop: {"height": 10, "width": 10}, A.TextImage: dict(font_path="./tests/files/LiberationSerif-Bold.ttf"), A.GridElasticDeform: {"num_grid_xy": (10, 10), "magnitude": 10}, + A.PadIfNeeded3D: {"min_zyx": (512, 512, 512)}, }, except_augmentations={ A.FDA, diff --git a/tests/test_targets.py b/tests/test_targets.py index d45a31fba..9de9c876b 100644 --- a/tests/test_targets.py +++ b/tests/test_targets.py @@ -31,20 +31,18 @@ def get_targets_from_methods(cls): def extract_targets_from_docstring(cls): # Access the class's docstring - docstring = cls.__doc__ - if not docstring: + if not (docstring := cls.__doc__): return [] # Return an empty list if there's no docstring # Regular expression to match the 'Targets:' section in the docstring targets_pattern = r"Targets:\s*([^\n]+)" - # Search for the pattern in the docstring - matches = re.search(targets_pattern, docstring) - if matches: + # Search for the pattern in the docstring and extract targets if found + if matches := re.search(targets_pattern, docstring): # Extract the targets string and split it by commas or spaces - targets_str = matches.group(1) - targets = re.split(r"[,\s]+", targets_str) # Split by comma or whitespace + targets = re.split(r"[,\s]+", matches[1]) # Using subscript notation instead of group() return [target.strip() for target in targets if target.strip()] # Remove any extra whitespace + return [] # Return an empty list if the 'Targets:' section isn't found diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 5fe4be074..ee863ea99 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -23,7 +23,7 @@ RECTANGULAR_UINT8_IMAGE, ) -from .utils import get_dual_transforms, get_image_only_transforms, get_transforms +from .utils import get_2d_transforms, get_dual_transforms, get_image_only_transforms, get_transforms def test_transpose_both_image_and_mask(): @@ -193,6 +193,7 @@ def __test_multiprocessing_support_proc(args): A.OverlayElements, A.TextImage, A.MaskDropout, + A.PadIfNeeded3D, }, ), ) @@ -1061,7 +1062,7 @@ def test_safe_rotate(angle: float, targets: dict, expected: dict): @pytest.mark.parametrize( "aug_cls", [ - (lambda rotate: A.Affine(rotate=rotate, p=1, mode=cv2.BORDER_CONSTANT, cval=0)), + (lambda rotate: A.Affine(rotate=rotate, p=1, border_mode=cv2.BORDER_CONSTANT, fill=0)), ( lambda rotate: A.ShiftScaleRotate( shift_limit=(0, 0), @@ -1463,6 +1464,7 @@ def test_coarse_dropout_invalid_input(params): "spatial_mode": "constant", "noise_params": {"ranges": [(-0.2, 0.2), (-0.1, 0.1), (-0.1, 0.1)]}, }, + A.PadIfNeeded3D: {"min_zyx": (300, 200, 400), "pad_divisor_zyx": (10, 10, 10), "position": "center", "fill": 10, "fill_mask": 20}, }, except_augmentations={ A.RandomCropNearBBox, @@ -1479,6 +1481,8 @@ def test_change_image(augmentation_cls, params): """Checks whether resulting image is different from the original one.""" aug = A.Compose([augmentation_cls(p=1, **params)], seed=0) + transforms3d = {A.PadIfNeeded3D} + image = SQUARE_UINT8_IMAGE original_image = image.copy() @@ -1502,14 +1506,23 @@ def test_change_image(augmentation_cls, params): mask = np.zeros_like(image)[:, :, 0] mask[:20, :20] = 1 data["mask"] = mask + elif augmentation_cls == A.PadIfNeeded3D: + data["images"] = np.array([image] * 10) + data["masks"] = np.array([image[:, :, 0]] * 10) + + transformed = aug(**data) - np.testing.assert_array_equal(image, original_image) - assert not np.array_equal(aug(**data)["image"], image) + if augmentation_cls not in transforms3d: + np.testing.assert_array_equal(image, original_image) + assert not np.array_equal(transformed["image"], image) + else: + assert not np.array_equal(transformed["images"], data["images"]) + assert not np.array_equal(transformed["masks"], data["masks"]) @pytest.mark.parametrize( ["augmentation_cls", "params"], - get_transforms( + get_2d_transforms( custom_arguments={ A.XYMasking: { "num_masks_x": (1, 3), @@ -1788,7 +1801,7 @@ def test_random_snow_invalid_input(params): @pytest.mark.parametrize( ["augmentation_cls", "params"], - get_transforms( + get_2d_transforms( custom_arguments={ A.Crop: {"y_min": 0, "y_max": 10, "x_min": 0, "x_max": 10}, A.CenterCrop: {"height": 10, "width": 10}, @@ -1885,7 +1898,7 @@ def test_dual_transforms_methods(augmentation_cls, params): ], ) @pytest.mark.parametrize( - "pad_cval", + "fill", [ 0, (0, 255), @@ -1907,10 +1920,11 @@ def test_dual_transforms_methods(augmentation_cls, params): ], ) @pytest.mark.parametrize("image", IMAGES) -def test_crop_and_pad(px, percent, pad_cval, keep_size, sample_independently, image): - pad_cval_mask = 255 if isinstance(pad_cval, list) else pad_cval +def test_crop_and_pad(px, percent, fill, keep_size, sample_independently, image): + fill_mask = 255 if isinstance(fill, list) else fill + interpolation = cv2.INTER_LINEAR - pad_mode = cv2.BORDER_CONSTANT + border_mode = cv2.BORDER_CONSTANT if (px is None) == (percent is None): # Skip the test case where both px and percent are None or both are not None return @@ -1920,9 +1934,9 @@ def test_crop_and_pad(px, percent, pad_cval, keep_size, sample_independently, im A.CropAndPad( px=px, percent=percent, - pad_mode=pad_mode, - pad_cval=pad_cval, - pad_cval_mask=pad_cval_mask, + border_mode=border_mode, + fill=fill, + fill_mask=fill_mask, keep_size=keep_size, sample_independently=sample_independently, interpolation=interpolation, @@ -1963,8 +1977,8 @@ def test_crop_and_pad_percent(percent, expected_shape): A.CropAndPad( px=None, percent=percent, - pad_mode=cv2.BORDER_CONSTANT, - pad_cval=0, + border_mode=cv2.BORDER_CONSTANT, + fill=0, keep_size=False, ) ], @@ -1994,8 +2008,8 @@ def test_crop_and_pad_px_pixel_values(px, expected_shape): A.CropAndPad( px=px, percent=None, - pad_mode=cv2.BORDER_CONSTANT, - pad_cval=0, + border_mode=cv2.BORDER_CONSTANT, + fill=0, keep_size=False, ) ], @@ -2197,6 +2211,7 @@ def test_random_sun_flare_invalid_input(params): "read_fn": lambda x: x, }, A.TextImage: dict(font_path="./tests/files/LiberationSerif-Bold.ttf"), + A.PadIfNeeded3D: {"min_zyx": (300, 200, 400), "pad_divisor_zyx": (10, 10, 10), "position": "center", "fill": 10, "fill_mask": 20}, }, except_augmentations={ A.RandomCropNearBBox, @@ -2233,6 +2248,9 @@ def test_return_nonzero(augmentation_cls, params): mask = np.zeros_like(image)[:, :, 0] mask[:20, :20] = 1 data["mask"] = mask + elif augmentation_cls == A.PadIfNeeded3D: + data["images"] = np.ones((10, 100, 100), dtype=np.uint8) + data["masks"] = np.ones((10, 100, 100), dtype=np.uint8) result = aug(**data) @@ -2389,7 +2407,7 @@ def test_mask_dropout_bboxes(remove_invisible, expected_keypoints): @pytest.mark.parametrize( ["augmentation_cls", "params"], - get_transforms( + get_2d_transforms( custom_arguments={ A.Crop: {"y_min": 5, "y_max": 95, "x_min": 7, "x_max": 93}, A.CenterCrop: {"height": 90, "width": 95}, diff --git a/tests/transforms3d/test_pytorch.py b/tests/transforms3d/test_pytorch.py new file mode 100644 index 000000000..65b14fa2e --- /dev/null +++ b/tests/transforms3d/test_pytorch.py @@ -0,0 +1,37 @@ +import numpy as np +import torch +import albumentations as A +from albumentations.pytorch.transforms import ToTensor3D + +def test_to_tensor_3d(): + transform = A.Compose([ToTensor3D(p=1)]) + + # Create sample 3D data + images = np.random.randint(0, 256, (64, 64, 64, 3), dtype=np.uint8) # (D, H, W, C) + masks = np.random.randint(0, 2, (64, 64, 64), dtype=np.uint8) # (D, H, W) + + transformed = transform( + images=images, + masks=masks, + ) + + # Check outputs are torch.Tensor + assert isinstance(transformed["images"], torch.Tensor) + assert isinstance(transformed["masks"], torch.Tensor) + + # Check shapes + assert transformed["images"].shape == (3, 64, 64, 64) # (C, D, H, W) + assert transformed["masks"].shape == (64, 64, 64) # (D, H, W) + + # Test with transpose_mask=True and channeled mask + transform = A.Compose([ToTensor3D(p=1, transpose_mask=True)]) + mask_with_channels = np.random.randint(0, 2, (64, 64, 64, 4), dtype=np.uint8) # (D, H, W, C) + + transformed = transform( + images=images, + masks=mask_with_channels, + ) + + assert isinstance(transformed["images"], torch.Tensor) + assert isinstance(transformed["masks"], torch.Tensor) + assert transformed["masks"].shape == (4, 64, 64, 64) # (C, D, H, W) diff --git a/tests/transforms3d/test_targets.py b/tests/transforms3d/test_targets.py new file mode 100644 index 000000000..9ee9193e8 --- /dev/null +++ b/tests/transforms3d/test_targets.py @@ -0,0 +1,88 @@ +import re + +import numpy as np +import pytest + +import albumentations as A +from albumentations.core.types import Targets +from tests.utils import get_3d_transforms + + +def extract_targets_from_docstring(cls): + # Access the class's docstring + docstring = cls.__doc__ + if not docstring: + return [] # Return an empty list if there's no docstring + + # Regular expression to match the 'Targets:' section in the docstring + targets_pattern = r"Targets:\s*([^\n]+)" + + # Search for the pattern in the docstring + matches = re.search(targets_pattern, docstring) + if matches: + # Extract the targets string and split it by commas or spaces + targets_str = matches.group(1) + targets = re.split(r"[,\s]+", targets_str) # Split by comma or whitespace + return [target.strip() for target in targets if target.strip()] # Remove any extra whitespace + return [] # Return an empty list if the 'Targets:' section isn't found + + +def get_targets_from_methods(cls): + targets = {Targets.IMAGE, Targets.MASK} + + has_images_method = any( + hasattr(cls, attr) and getattr(cls, attr) is not getattr(A.Transform3D, attr, None) + for attr in ["apply_to_images"] + ) + if has_images_method: + targets.add(Targets.IMAGE) + + has_masks_method = any( + hasattr(cls, attr) and getattr(cls, attr) is not getattr(A.Transform3D, attr, None) + for attr in ["apply_to_masks"] + ) + if has_masks_method: + targets.add(Targets.MASK) + + has_bboxes_method = any( + hasattr(cls, attr) and getattr(cls, attr) is not getattr(A.Transform3D, attr, None) + for attr in ["apply_to_bboxes"] + ) + if has_bboxes_method: + targets.add(Targets.BBOXES) + + has_keypoints_method = any( + hasattr(cls, attr) and getattr(cls, attr) is not getattr(A.Transform3D, attr, None) + for attr in ["apply_to_keypoints"] + ) + if has_keypoints_method: + targets.add(Targets.KEYPOINTS) + + return targets + + +TRASNFORM_3d_DUAL_TARGETS = { + A.PadIfNeeded3D: (Targets.IMAGE, Targets.MASK), +} + + +str2target = { + "images": Targets.IMAGE, + "masks": Targets.MASK, +} + + +@pytest.mark.parametrize( + ["augmentation_cls", "params"], + get_3d_transforms(custom_arguments={ + A.PadIfNeeded3D: {"min_zyx": (4, 250, 230), "position": "center", "fill": 0, "fill_mask": 0}, + }) +) +def test_dual(augmentation_cls, params): + aug = augmentation_cls(p=1, **params) + assert set(aug._targets) == set(TRASNFORM_3d_DUAL_TARGETS.get(augmentation_cls, {Targets.IMAGE, Targets.MASK})) + assert set(aug._targets) <= get_targets_from_methods(augmentation_cls) + + targets_from_docstring = {str2target[target] for target in extract_targets_from_docstring(augmentation_cls)} + + assert set(aug._targets) == targets_from_docstring diff --git a/tests/transforms3d/test_transforms.py b/tests/transforms3d/test_transforms.py new file mode 100644 index 000000000..57fb85ec4 --- /dev/null +++ b/tests/transforms3d/test_transforms.py @@ -0,0 +1,137 @@ +import pytest +import numpy as np +import albumentations as A +from albucore import to_float +import cv2 + +from tests.conftest import RECTANGULAR_UINT8_IMAGE, SQUARE_FLOAT_IMAGE, SQUARE_UINT8_IMAGE +from tests.utils import get_3d_transforms + +@pytest.mark.parametrize( + ["volume_shape", "min_zyx", "pad_divisor_zyx", "expected_shape"], + [ + # Test no padding needed + ((10, 100, 100), (10, 100, 100), None, (10, 100, 100)), + + # Test 2D-like behavior (no z padding) + ((10, 100, 100), (10, 128, 128), None, (10, 128, 128)), + + # Test padding in all dimensions + ((10, 100, 100), (16, 128, 128), None, (16, 128, 128)), + + # Test divisibility padding + ((10, 100, 100), None, (8, 32, 32), (16, 128, 128)), + + # Test mixed min_size and divisibility + ((10, 100, 100), (16, 128, 128), (8, 32, 32), (16, 128, 128)), + ] +) +def test_pad_if_needed_3d_shapes(volume_shape, min_zyx, pad_divisor_zyx, expected_shape): + volume = np.random.randint(0, 256, volume_shape, dtype=np.uint8) + transform = A.PadIfNeeded3D( + min_zyx=min_zyx, + pad_divisor_zyx=pad_divisor_zyx, + position="center", + border_mode=cv2.BORDER_CONSTANT, + fill=0, + fill_mask=0 + ) + transformed = transform(images=volume) + assert transformed["images"].shape == expected_shape + +@pytest.mark.parametrize("position", ["center", "random"]) +def test_pad_if_needed_3d_positions(position): + volume = np.ones((5, 50, 50), dtype=np.uint8) + transform = A.PadIfNeeded3D( + min_zyx=(10, 100, 100), + position=position, + border_mode=cv2.BORDER_CONSTANT, + fill=0, + fill_mask=0 + ) + transformed = transform(images=volume) + # Check that the original volume is preserved somewhere in the padded volume + assert np.any(transformed["images"] == 1) + +def test_pad_if_needed_3d_2d_equivalence(): + """Test that PadIfNeeded3D behaves like PadIfNeeded when no z-padding is needed""" + # Create a volume with multiple identical slices + slice_2d = np.random.randint(0, 256, (100, 100), dtype=np.uint8) + volume_3d = np.stack([slice_2d] * 10) + + # Apply 3D padding with no z-axis changes + transform_3d = A.PadIfNeeded3D( + min_zyx=(10, 128, 128), + position="center", + border_mode=cv2.BORDER_CONSTANT, + fill=0, + fill_mask=0 + ) + transformed_3d = transform_3d(images=volume_3d) + + # Apply 2D padding to a single slice + transform_2d = A.PadIfNeeded( + min_height=128, + min_width=128, + position="center", + border_mode=cv2.BORDER_CONSTANT, + value=0, + mask_value=0 + ) + transformed_2d = transform_2d(image=slice_2d) + + # Compare each slice of 3D result with 2D result + for slice_idx in range(10): + np.testing.assert_array_equal( + transformed_3d["images"][slice_idx], + transformed_2d["image"] + ) + +def test_pad_if_needed_3d_fill_values(): + volume = np.zeros((5, 50, 50), dtype=np.uint8) + mask = np.ones((5, 50, 50), dtype=np.uint8) + + transform = A.PadIfNeeded3D( + min_zyx=(10, 100, 100), + position="center", + border_mode=cv2.BORDER_CONSTANT, + fill=255, + fill_mask=128 + ) + + transformed = transform(images=volume, masks=mask) + + # Check fill values in padded regions + assert np.all(transformed["images"][:, :25, :] == 255) # top padding + assert np.all(transformed["masks"][:, :25, :] == 128) # top padding in mask + + + +@pytest.mark.parametrize( + ["augmentation_cls", "params"], + get_3d_transforms( + custom_arguments={ + A.PadIfNeeded3D: {"min_zyx": (4, 250, 230), "position": "center", "fill": 0, "fill_mask": 0}, + }, + except_augmentations={ + }, + ), +) +def test_augmentations_match_uint8_float32(augmentation_cls, params): + image_uint8 = RECTANGULAR_UINT8_IMAGE + image_float32 = image_uint8 / 255.0 + + transform = A.Compose([augmentation_cls(p=1, **params)], seed=42) + + images = np.stack([image_uint8, image_uint8]) + + data = {"images": images} + + transformed_uint8 = transform(**data)["images"] + + data["images"] = np.stack([image_float32, image_float32]) + + transform.set_random_seed(42) + transformed_float32 = transform(**data)["images"] + + np.testing.assert_array_almost_equal(transformed_uint8 / 255.0, transformed_float32, decimal=2) diff --git a/tests/utils.py b/tests/utils.py index fd2a1f3b2..15046b78b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -90,13 +90,28 @@ def get_filtered_transforms( base_classes, custom_arguments=None, except_augmentations=None, + exclude_base_classes=None, ): custom_arguments = custom_arguments or {} except_augmentations = except_augmentations or set() + exclude_base_classes = exclude_base_classes or () result = [] for cls in get_all_valid_transforms(): - if not issubclass(cls, base_classes) or any(cls == i for i in base_classes) or cls in except_augmentations: + # Skip if class is in except_augmentations + if cls in except_augmentations: + continue + + # Skip if class is one of the base classes + if any(cls == i for i in base_classes): + continue + + # Skip if class inherits from any excluded base classes + if exclude_base_classes and issubclass(cls, exclude_base_classes): + continue + + # Check if class inherits from any of the required base classes + if not issubclass(cls, base_classes): continue result.append((cls, custom_arguments.get(cls, {}))) @@ -114,19 +129,36 @@ def get_dual_transforms( custom_arguments: dict[type[albumentations.DualTransform], dict] | None = None, except_augmentations: set[type[albumentations.DualTransform]] | None = None, ) -> list[tuple[type, dict]]: - return get_filtered_transforms((albumentations.DualTransform,), custom_arguments, except_augmentations) - + """Get all 2D dual transforms, excluding 3D transforms.""" + return get_filtered_transforms( + base_classes=(albumentations.DualTransform,), + custom_arguments=custom_arguments, + except_augmentations=except_augmentations, + exclude_base_classes=(albumentations.Transform3D,) + ) def get_transforms( custom_arguments: dict[type[albumentations.BasicTransform], dict] | None = None, except_augmentations: set[type[albumentations.BasicTransform]] | None = None, ) -> list[tuple[type, dict]]: + """Get all transforms (2D and 3D).""" return get_filtered_transforms( - (albumentations.ImageOnlyTransform, albumentations.DualTransform), - custom_arguments, - except_augmentations, + base_classes=(albumentations.ImageOnlyTransform, albumentations.DualTransform, albumentations.Transform3D), + custom_arguments=custom_arguments, + except_augmentations=except_augmentations, ) +def get_2d_transforms( + custom_arguments: dict[type[albumentations.BasicTransform], dict] | None = None, + except_augmentations: set[type[albumentations.BasicTransform]] | None = None, +) -> list[tuple[type, dict]]: + """Get all 2D transforms (both ImageOnly and Dual transforms), excluding 3D transforms.""" + return get_filtered_transforms( + base_classes=(albumentations.ImageOnlyTransform, albumentations.DualTransform), + custom_arguments=custom_arguments, + except_augmentations=except_augmentations, + exclude_base_classes=(albumentations.Transform3D,) # Exclude Transform3D and its children + ) def check_all_augs_exists( augmentations: list[list], @@ -145,3 +177,15 @@ def check_all_augs_exists( raise ValueError(f"These augmentations do not exist in augmentations and except_augmentations: {not_existed}") return augmentations + + +def get_3d_transforms( + custom_arguments: dict[type[albumentations.Transform3D], dict] | None = None, + except_augmentations: set[type[albumentations.Transform3D]] | None = None, +) -> list[tuple[type, dict]]: + """Get all 3D transforms.""" + return get_filtered_transforms( + base_classes=(albumentations.Transform3D,), + custom_arguments=custom_arguments, + except_augmentations=except_augmentations, + ) diff --git a/tools/make_transforms_docs.py b/tools/make_transforms_docs.py index 1eb28795a..dba483316 100644 --- a/tools/make_transforms_docs.py +++ b/tools/make_transforms_docs.py @@ -13,6 +13,7 @@ "BasicTransform", "DualTransform", "ImageOnlyTransform", + "Transform3D", } @@ -62,18 +63,24 @@ def get_image_only_transforms_info(): image_only_info = {} members = inspect.getmembers(albumentations) for name, cls in members: - if inspect.isclass(cls) and issubclass(cls, albumentations.ImageOnlyTransform) and name not in IGNORED_CLASSES: - if not is_deprecated(cls): - image_only_info[name] = { - "docs_link": make_augmentation_docs_link(cls) - } + if (inspect.isclass(cls) and + issubclass(cls, albumentations.ImageOnlyTransform) and + not issubclass(cls, albumentations.Transform3D) and + name not in IGNORED_CLASSES) and not is_deprecated(cls): + image_only_info[name] = { + "docs_link": make_augmentation_docs_link(cls) + } return image_only_info + def get_dual_transforms_info(): dual_transforms_info = {} members = inspect.getmembers(albumentations) for name, cls in members: - if inspect.isclass(cls) and issubclass(cls, albumentations.DualTransform) and name not in IGNORED_CLASSES: + if (inspect.isclass(cls) and + issubclass(cls, albumentations.DualTransform) and + not issubclass(cls, albumentations.Transform3D) and # Exclude 3D transforms + name not in IGNORED_CLASSES): if not is_deprecated(cls): dual_transforms_info[name] = { "targets": cls._targets, @@ -82,6 +89,20 @@ def get_dual_transforms_info(): return dual_transforms_info +def get_3d_transforms_info(): + transforms_3d_info = {} + members = inspect.getmembers(albumentations) + for name, cls in members: + if (inspect.isclass(cls) and + issubclass(cls, albumentations.Transform3D) and + name not in IGNORED_CLASSES) and not is_deprecated(cls): + transforms_3d_info[name] = { + "targets": cls._targets, + "docs_link": make_augmentation_docs_link(cls) + } + return transforms_3d_info + + def make_transforms_targets_table(transforms_info, header): rows = [header] for transform, info in sorted(transforms_info.items(), key=lambda kv: kv[0]): @@ -116,13 +137,14 @@ def make_transforms_targets_links(transforms_info): ) -def check_docs(filepath, image_only_transforms_links, dual_transforms_table) -> None: +def check_docs(filepath, image_only_transforms_links, dual_transforms_table, transforms_3d_table) -> None: with open(filepath, encoding="utf8") as f: text = f.read() outdated_docs = set() image_only_lines_not_in_text = [] dual_lines_not_in_text = [] + transforms_3d_lines_not_in_text = [] for line in image_only_transforms_links.split("\n"): if line not in text: @@ -134,6 +156,11 @@ def check_docs(filepath, image_only_transforms_links, dual_transforms_table) -> dual_lines_not_in_text.append(line) outdated_docs.update(["Spatial-level"]) + for line in transforms_3d_table.split("\n"): + if line not in text: + transforms_3d_lines_not_in_text.append(line) + outdated_docs.update(["3D"]) + if outdated_docs: msg = ( "Docs for the following transform types are outdated: {outdated_docs_headers}. " @@ -142,18 +169,20 @@ def check_docs(filepath, image_only_transforms_links, dual_transforms_table) -> "# Pixel-level transforms lines not in file:\n" "{image_only_lines}\n" "# Spatial-level transforms lines not in file:\n" - "{dual_lines}\n".format( + "{dual_lines}\n" + "# 3D transforms lines not in file:\n" + "{transforms_3d_lines}\n".format( outdated_docs_headers=", ".join(outdated_docs), py_file=Path(os.path.realpath(__file__)).name, filename=os.path.basename(filepath), image_only_lines="\n".join(image_only_lines_not_in_text), dual_lines="\n".join(dual_lines_not_in_text), + transforms_3d_lines="\n".join(transforms_3d_lines_not_in_text), ) ) raise ValueError(msg) - def main() -> None: args = parse_args() command = args.command @@ -162,12 +191,15 @@ def main() -> None: image_only_transforms = get_image_only_transforms_info() dual_transforms = get_dual_transforms_info() + transforms_3d = get_3d_transforms_info() image_only_transforms_links = make_transforms_targets_links(image_only_transforms) - dual_transforms_table = make_transforms_targets_table( dual_transforms, header=["Transform"] + [target.value for target in Targets] ) + transforms_3d_table = make_transforms_targets_table( + transforms_3d, header=["Transform"] + [target.value for target in [Targets.IMAGE, Targets.MASK]] + ) if command == "make": print("===== COPY THIS TABLE TO README.MD BELOW ### Pixel-level transforms =====") @@ -177,9 +209,17 @@ def main() -> None: print("===== COPY THIS TABLE TO README.MD BELOW ### Spatial-level transforms =====") print(dual_transforms_table) print("===== END OF COPY =====") - + print() + print("===== COPY THIS TABLE TO README.MD BELOW ### 3D transforms =====") + print(transforms_3d_table) + print("===== END OF COPY =====") else: - check_docs(args.filepath, image_only_transforms_links, dual_transforms_table) + check_docs( + args.filepath, + image_only_transforms_links, + dual_transforms_table, + transforms_3d_table + ) if __name__ == "__main__":