diff --git a/README.md b/README.md index 70302cdf0..eb7e76b05 100644 --- a/README.md +++ b/README.md @@ -198,8 +198,8 @@ Pixel-level transforms will change just an input image and will leave any additi - [FDA](https://explore.albumentations.ai/transform/FDA) - [FancyPCA](https://explore.albumentations.ai/transform/FancyPCA) - [FromFloat](https://explore.albumentations.ai/transform/FromFloat) -- [GaussNoise](https://explore.albumentations.ai/transform/GaussNoise) - [GaussianBlur](https://explore.albumentations.ai/transform/GaussianBlur) +- [GaussianNoise](https://explore.albumentations.ai/transform/GaussianNoise) - [GlassBlur](https://explore.albumentations.ai/transform/GlassBlur) - [HistogramMatching](https://explore.albumentations.ai/transform/HistogramMatching) - [HueSaturationValue](https://explore.albumentations.ai/transform/HueSaturationValue) @@ -230,7 +230,9 @@ Pixel-level transforms will change just an input image and will leave any additi - [RandomJPEG](https://explore.albumentations.ai/transform/RandomJPEG) - [RandomMedianBlur](https://explore.albumentations.ai/transform/RandomMedianBlur) - [RandomPlanckianJitter](https://explore.albumentations.ai/transform/RandomPlanckianJitter) +- [RandomPosterize](https://explore.albumentations.ai/transform/RandomPosterize) - [RandomRain](https://explore.albumentations.ai/transform/RandomRain) +- [RandomSaturation](https://explore.albumentations.ai/transform/RandomSaturation) - [RandomShadow](https://explore.albumentations.ai/transform/RandomShadow) - [RandomSnow](https://explore.albumentations.ai/transform/RandomSnow) - [RandomSolarize](https://explore.albumentations.ai/transform/RandomSolarize) diff --git a/albumentations/augmentations/tk/transform.py b/albumentations/augmentations/tk/transform.py index e4b783d77..0241debaa 100644 --- a/albumentations/augmentations/tk/transform.py +++ b/albumentations/augmentations/tk/transform.py @@ -17,14 +17,23 @@ CLAHE, ColorJitter, Equalize, + GaussNoise, ImageCompression, InvertImg, PlanckianJitter, + Posterize, RandomBrightnessContrast, Solarize, ToGray, ) -from albumentations.core.pydantic import InterpolationType, check_0plus, check_01, check_1plus, nondecreasing +from albumentations.core.pydantic import ( + InterpolationType, + check_0plus, + check_01, + check_1plus, + check_range_bounds, + nondecreasing, +) from albumentations.core.transforms_interface import BaseTransformInitSchema from albumentations.core.types import PAIR, ColorType, ScaleFloatType, ScaleIntType, Targets @@ -48,6 +57,9 @@ "RandomPlanckianJitter", "RandomMedianBlur", "RandomSolarize", + "RandomPosterize", + "RandomSaturation", + "GaussianNoise", ] @@ -1378,3 +1390,203 @@ def __init__( def get_transform_init_args_names(self) -> tuple[str, ...]: return ("thresholds",) + + +class RandomPosterize(Posterize): + """Reduce the number of bits for each color channel. + + This transform is an alias for Posterize, provided for compatibility with + Kornia API. For new code, it is recommended to use albumentations.Posterize directly. + + Args: + num_bits (tuple[int, int]): Range for number of bits to keep for each channel. + Values should be in range [0, 8] for uint8 images. + Default: (3, 3). + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image + + Image types: + uint8, float32 + + Number of channels: + Any + + Note: + This transform is a direct alias for Posterize with identical functionality. + For new projects, it is recommended to use Posterize directly as it + provides a more consistent interface within the Albumentations ecosystem. + + For float32 images: + 1. Image is converted to uint8 (multiplied by 255 and clipped) + 2. Posterization is applied + 3. Image is converted back to float32 (divided by 255) + + Example: + >>> # RandomPosterize way (Kornia compatibility) + >>> transform = A.RandomPosterize(num_bits=(3, 3)) # Fixed 3 bits per channel + >>> transform = A.RandomPosterize(num_bits=(3, 5)) # Random from 3 to 5 bits + >>> # Preferred Posterize way + >>> transform = A.Posterize(bits=(3, 3)) + >>> transform = A.Posterize(bits=(3, 5)) + + References: + - Kornia: https://kornia.readthedocs.io/en/latest/augmentation.html#kornia.augmentation.RandomPosterize + """ + + class InitSchema(BaseTransformInitSchema): + num_bits: Annotated[tuple[int, int], AfterValidator(check_range_bounds(0, 8)), AfterValidator(nondecreasing)] + + def __init__( + self, + num_bits: tuple[int, int] = (3, 3), + always_apply: bool | None = None, + p: float = 0.5, + ): + warn( + "RandomPosterize is an alias for Posterize transform. " + "Consider using Posterize directly from albumentations.Posterize.", + UserWarning, + stacklevel=2, + ) + + super().__init__( + num_bits=num_bits, + p=p, + ) + + def get_transform_init_args_names(self) -> tuple[str, ...]: + return ("num_bits",) + + +class RandomSaturation(ColorJitter): + """Randomly change the saturation of an RGB image. + + This is a specialized version of ColorJitter that only adjusts saturation. + + Args: + saturation (tuple[float, float]): Range for the saturation factor. + Values should be non-negative numbers. + A saturation factor of 0 will result in a grayscale image + A saturation factor of 1 will give the original image + A saturation factor of 2 will enhance the saturation by a factor of 2 + Default: (1.0, 1.0) + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image + + Image types: + uint8, float32 + + Number of channels: + 1, 3 + + Note: + - This transform can only be applied to RGB/BGR images. + - The saturation adjustment is done by converting to HSV color space, + modifying the S channel, and converting back to RGB. + + Example: + >>> import albumentations as A + >>> transform = A.RandomSaturation(saturation_range=(0.5, 1.5), p=0.5) + >>> # Reduce saturation by 50% to increase by 50% + >>> + >>> transform = A.RandomSaturation(saturation_range=(0.0, 1.0), p=0.5) + >>> # Randomly convert to grayscale with 50% probability + """ + + class InitSchema(BaseTransformInitSchema): + saturation: Annotated[tuple[float, float], AfterValidator(check_0plus), AfterValidator(nondecreasing)] + + def __init__( + self, + saturation: tuple[float, float] = (1.0, 1.0), + always_apply: bool | None = None, + p: float = 0.5, + ): + super().__init__( + brightness=(1.0, 1.0), # No brightness change + contrast=(1.0, 1.0), # No contrast change + saturation=saturation, + hue=(0.0, 0.0), # No hue change + p=p, + ) + self.saturation = saturation + + def get_transform_init_args_names(self) -> tuple[str]: + return ("saturation",) + + +class GaussianNoise(GaussNoise): + """Add Gaussian noise to the input image. + + A specialized version of GaussNoise that follows torchvision's API. + + Args: + mean (float): Mean of the Gaussian noise as a fraction + of the maximum value (255 for uint8 images or 1.0 for float images). + Value should be in range [0, 1]. Default: 0.0. + sigma (float): Standard deviation of the Gaussian noise as a fraction + of the maximum value (255 for uint8 images or 1.0 for float images). + Value should be in range [0, 1]. Default: 0.1. + p (float): Probability of applying the transform. Default: 0.5. + + Targets: + image + + Image types: + uint8, float32 + + Note: + - The noise parameters (sigma and mean) are normalized to [0, 1] range: + * For uint8 images, they are multiplied by 255 + * For float32 images, they are used directly + - Unlike GaussNoise, this transform: + * Uses fixed sigma and mean values (no ranges) + * Always applies same noise to all channels + * Does not support noise_scale_factor optimization + - For more flexibility, use GaussNoise which allows sampling both std and mean + from ranges and supports per-channel noise + + Example: + >>> import albumentations as A + >>> # Add noise with sigma=0.1 (10% of the image range) + >>> transform = A.GaussianNoise(mean=0.0, sigma=0.1, p=1.0) + + References: + - torchvision: https://pytorch.org/vision/master/generated/torchvision.transforms.v2.GaussianNoise.html + - kornia: https://kornia.readthedocs.io/en/latest/augmentation.module.html#kornia.augmentation.RandomGaussianNoise + """ + + class InitSchema(BaseTransformInitSchema): + mean: float = Field(ge=-1, le=1) + sigma: float = Field(ge=0, le=1) + + def __init__( + self, + mean: float = 0.0, + sigma: float = 0.1, + always_apply: bool | None = None, + p: float = 0.5, + ): + warn( + "GaussianNoise is a specialized version of GaussNoise that follows torchvision's API. " + "Consider using GaussNoise directly from albumentations.GaussNoise.", + UserWarning, + stacklevel=2, + ) + + super().__init__( + std_range=(sigma, sigma), # Fixed sigma value + mean_range=(mean, mean), # Fixed mean value + per_channel=False, # Always apply same noise to all channels + noise_scale_factor=1.0, # No noise scale optimization + p=p, + ) + self.mean = mean + self.sigma = sigma + + def get_transform_init_args_names(self) -> tuple[str, ...]: + return "mean", "sigma" diff --git a/albumentations/augmentations/transforms.py b/albumentations/augmentations/transforms.py index 8f471af3c..78f4b8e4c 100644 --- a/albumentations/augmentations/transforms.py +++ b/albumentations/augmentations/transforms.py @@ -47,6 +47,7 @@ check_0plus, check_01, check_1plus, + check_range_bounds, nondecreasing, ) from albumentations.core.transforms_interface import ( @@ -2018,7 +2019,7 @@ def get_params(self) -> dict[str, Any]: num_bits = self.num_bits return {"num_bits": self.py_random.randint(int(num_bits[0]), int(num_bits[1]))} # type: ignore[arg-type] - def get_transform_init_args_names(self) -> tuple[str]: + def get_transform_init_args_names(self) -> tuple[str, ...]: return ("num_bits",) @@ -2367,9 +2368,16 @@ class GaussNoise(ImageOnlyTransform): """Apply Gaussian noise to the input image. Args: - var_limit (tuple[float, float] | float): Variance range for noise. If var_limit is a single float value, - the range will be (0, var_limit). Default: (10.0, 50.0). - mean (float): Mean of the noise. Default: 0. + std_range (tuple[float, float]): Range for noise standard deviation as a fraction + of the maximum value (255 for uint8 images or 1.0 for float images). + Values should be in range [0, 1]. Default: (0.2, 0.44). + mean_range (tuple[float, float]): Range for noise mean as a fraction + of the maximum value (255 for uint8 images or 1.0 for float images). + Values should be in range [-1, 1]. Default: (0.0, 0.0). + var_limit (tuple[float, float] | float): [Deprecated] Variance range for noise. + If var_limit is a single float value, the range will be (0, var_limit). + Default: (10.0, 50.0). + mean (float): [Deprecated] Mean of the noise. Default: 0. per_channel (bool): If True, noise will be sampled for each channel independently. Otherwise, the noise will be sampled once for all channels. Default: True. noise_scale_factor (float): Scaling factor for noise generation. Value should be in the range (0, 1]. @@ -2386,77 +2394,114 @@ class GaussNoise(ImageOnlyTransform): Number of channels: Any - Returns: - numpy.ndarray: Image with applied Gaussian noise. - Note: - - The noise is generated in the same range as the input image. - - For uint8 input images, the noise is generated in the range [0, 255]. - - For float32 input images, the noise is generated in the range [0, 1]. - - The resulting image is clipped to keep its values in the input range. - - Setting per_channel=False is faster but applies the same noise to all channels. - - The noise_scale_factor parameter allows for a trade-off between transform speed and noise granularity. + - The noise parameters (std_range and mean_range) are normalized to [0, 1] range: + * For uint8 images, they are multiplied by 255 + * For float32 images, they are used directly + - The behavior differs between old and new parameters: + * When using var_limit (deprecated): samples variance uniformly and takes sqrt to get std dev + * When using std_range: samples standard deviation directly (aligned with torchvision/kornia) + - Setting per_channel=False is faster but applies the same noise to all channels + - The noise_scale_factor parameter allows for a trade-off between transform speed and noise granularity Examples: >>> import numpy as np >>> import albumentations as A >>> image = np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8) >>> - >>> # Apply Gaussian noise with default parameters - >>> transform = A.GaussNoise(p=1.0) + >>> # Apply Gaussian noise with normalized std_range + >>> transform = A.GaussNoise(std_range=(0.1, 0.2), p=1.0) # 10-20% of max value >>> noisy_image = transform(image=image)['image'] >>> - >>> # Apply Gaussian noise with custom variance range and mean + >>> # Using deprecated var_limit (will be converted to std_range) >>> transform = A.GaussNoise(var_limit=(50.0, 100.0), mean=10, p=1.0) >>> noisy_image = transform(image=image)['image'] - >>> - >>> # Apply the same noise to all channels - >>> transform = A.GaussNoise(per_channel=False, p=1.0) - >>> noisy_image = transform(image=image)['image'] - >>> - >>> # Apply noise with reduced granularity for faster processing - >>> transform = A.GaussNoise(noise_scale_factor=0.5, p=1.0) - >>> noisy_image = transform(image=image)['image'] - """ class InitSchema(BaseTransformInitSchema): - var_limit: NonNegativeFloatRangeType - mean: float + var_limit: ScaleFloatType | None = Field( + deprecated="var_limit parameter is deprecated. Use std_range instead.", + ) + mean: float | None = Field( + deprecated="mean parameter is deprecated. Use mean_range instead.", + ) + std_range: Annotated[tuple[float, float], AfterValidator(check_01), AfterValidator(nondecreasing)] + mean_range: Annotated[ + tuple[float, float], + AfterValidator(check_range_bounds(-1, 1)), + AfterValidator(nondecreasing), + ] per_channel: bool noise_scale_factor: float = Field(gt=0, le=1) + @model_validator(mode="after") + def check_range(self) -> Self: + if self.var_limit is not None: + self.var_limit = to_tuple(self.var_limit, 0) + if self.var_limit[1] > 1: + # Convert legacy uint8 variance to normalized std dev + self.std_range = (math.sqrt(10 / 255), math.sqrt(50 / 255)) + else: + # Already normalized variance, convert to std dev + self.std_range = (math.sqrt(self.var_limit[0]), math.sqrt(self.var_limit[1])) + if self.mean is not None: + self.mean_range = (0.0, 0.0) + + if self.mean is not None: + if self.mean >= 1: + # Convert legacy uint8 mean to normalized range + self.mean_range = (self.mean / 255, self.mean / 255) + else: + # Already normalized mean + self.mean_range = (self.mean, self.mean) + + return self + def __init__( self, - var_limit: ScaleFloatType = (10.0, 50.0), - mean: float = 0, + var_limit: ScaleFloatType | None = None, + mean: float | None = None, + std_range: tuple[float, float] = (0.2, 0.44), # sqrt(10 / 255), sqrt(50 / 255) + mean_range: tuple[float, float] = (0.0, 0.0), per_channel: bool = True, noise_scale_factor: float = 1, always_apply: bool | None = None, p: float = 0.5, ): super().__init__(p=p, always_apply=always_apply) - self.var_limit = cast(tuple[float, float], var_limit) - self.mean = mean + self.std_range = std_range + self.mean_range = mean_range self.per_channel = per_channel self.noise_scale_factor = noise_scale_factor + self.var_limit = var_limit + def apply(self, img: np.ndarray, gauss: np.ndarray, **params: Any) -> np.ndarray: return fmain.add_noise(img, gauss) def get_params_dependent_on_data(self, params: dict[str, Any], data: dict[str, Any]) -> dict[str, float]: image = data["image"] if "image" in data else data["images"][0] - var = self.py_random.uniform(*self.var_limit) - sigma = math.sqrt(var) + max_value = MAX_VALUES_BY_DTYPE[image.dtype] + + if self.var_limit is not None: + # Legacy behavior: sample variance uniformly then take sqrt + var = self.py_random.uniform(self.std_range[0] ** 2, self.std_range[1] ** 2) + sigma = math.sqrt(var) + else: + # New behavior: sample std dev directly (aligned with torchvision/kornia) + sigma = self.py_random.uniform(*self.std_range) + + sigma *= max_value + mean = self.py_random.uniform(*self.mean_range) * max_value if self.per_channel: target_shape = image.shape if self.noise_scale_factor == 1: - gauss = self.random_generator.normal(self.mean, sigma, target_shape) + gauss = self.random_generator.normal(mean, sigma, target_shape) else: gauss = fmain.generate_approx_gaussian_noise( target_shape, - self.mean, + mean, sigma, self.noise_scale_factor, self.random_generator, @@ -2464,11 +2509,11 @@ def get_params_dependent_on_data(self, params: dict[str, Any], data: dict[str, A else: target_shape = image.shape[:2] if self.noise_scale_factor == 1: - gauss = self.random_generator.normal(self.mean, sigma, target_shape) + gauss = self.random_generator.normal(mean, sigma, target_shape) else: gauss = fmain.generate_approx_gaussian_noise( target_shape, - self.mean, + mean, sigma, self.noise_scale_factor, self.random_generator, @@ -2480,7 +2525,7 @@ def get_params_dependent_on_data(self, params: dict[str, Any], data: dict[str, A return {"gauss": gauss} def get_transform_init_args_names(self) -> tuple[str, ...]: - return "var_limit", "per_channel", "mean", "noise_scale_factor" + return "std_range", "mean_range", "per_channel", "noise_scale_factor" class ISONoise(ImageOnlyTransform): diff --git a/albumentations/core/pydantic.py b/albumentations/core/pydantic.py index 0da0f3cf3..50f6cd3ec 100644 --- a/albumentations/core/pydantic.py +++ b/albumentations/core/pydantic.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import Callable from typing import Annotated, overload import cv2 @@ -144,3 +145,32 @@ def check_01(value: tuple[Number, Number]) -> tuple[Number, Number]: def repeat_if_scalar(value: ScaleType) -> tuple[float, float]: return (value, value) if isinstance(value, (int, float)) else value + + +def check_range_bounds( + min_val: Number, + max_val: Number | None = None, +) -> Callable[[tuple[Number, Number]], tuple[Number, Number]]: + """Validates that both values in a tuple are within specified bounds. + + Args: + min_val: Minimum allowed value (inclusive) + max_val: Maximum allowed value (inclusive). If None, only lower bound is checked. + + Returns: + Validator function that checks if both values in tuple are within bounds. + If max_val is None, only checks that values are >= min_val. + + Raises: + ValueError: If any value in tuple is outside the allowed range + """ + + def validator(value: tuple[Number, Number]) -> tuple[Number, Number]: + if max_val is None: + if not (value[0] >= min_val and value[1] >= min_val): + raise ValueError(f"All values in {value} must be >= {min_val}") + elif not (min_val <= value[0] <= max_val and min_val <= value[1] <= max_val): + raise ValueError(f"All values in {value} must be in range [{min_val}, {max_val}]") + return value + + return validator diff --git a/tests/aug_definitions.py b/tests/aug_definitions.py index 9caf42fff..7c319fd3c 100644 --- a/tests/aug_definitions.py +++ b/tests/aug_definitions.py @@ -22,7 +22,7 @@ [A.MotionBlur, {"blur_limit": 3}], [A.MedianBlur, {"blur_limit": 3}], [A.GaussianBlur, {"blur_limit": 3}], - [A.GaussNoise, {"var_limit": (20, 90), "mean": 10, "per_channel": False}], + [A.GaussNoise, {"std_range": (0.2, 0.44), "mean_range": (0.0, 0.0), "per_channel": False}], [A.CLAHE, {"clip_limit": 2, "tile_grid_size": (12, 12)}], [A.RandomGamma, {"gamma_limit": (10, 90)}], [A.CoarseDropout, {"num_holes_range": (2, 5), "hole_height_range": (3, 4), "hole_width_range": (4, 6)}], @@ -403,4 +403,7 @@ [A.RandomPlanckianJitter, {}], [A.RandomMedianBlur, {}], [A.RandomSolarize, {}], + [A.RandomPosterize, {}], + [A.RandomSaturation, {}], + [A.GaussianNoise, {}], ] diff --git a/tests/test_augmentations.py b/tests/test_augmentations.py index 0abcb93c4..3583718bd 100644 --- a/tests/test_augmentations.py +++ b/tests/test_augmentations.py @@ -611,6 +611,7 @@ def test_mask_fill_value(augmentation_cls, params): A.RandomGrayscale, A.RandomHue, A.RandomClahe, + A.RandomSaturation, }, ), ) @@ -703,6 +704,7 @@ def test_multichannel_image_augmentations(augmentation_cls, params): A.RandomGrayscale, A.RandomHue, A.RandomClahe, + A.RandomSaturation, }, ), ) @@ -789,6 +791,7 @@ def test_float_multichannel_image_augmentations(augmentation_cls, params): A.RandomGrayscale, A.RandomHue, A.RandomClahe, + A.RandomSaturation, }, ), ) @@ -877,6 +880,7 @@ def test_multichannel_image_augmentations_diff_channels(augmentation_cls, params A.RandomGrayscale, A.RandomHue, A.RandomClahe, + A.RandomSaturation, }, ), ) @@ -1102,7 +1106,6 @@ def test_pad_if_needed_position(params, image_shape): A.TextImage, A.RGBShift, A.HueSaturationValue, - A.GaussNoise, A.ColorJitter, }, ), diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 56e5ff7e5..346ab1246 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -7,7 +7,7 @@ import cv2 import numpy as np import pytest -from albucore import to_float, clip +from albucore import to_float, clip, MAX_VALUES_BY_DTYPE from torchvision import transforms as torch_transforms @@ -1256,6 +1256,7 @@ def test_coarse_dropout_invalid_input(params): A.RandomHue: {"hue": (-0.2, 0.2)}, A.RandomContrast: {"contrast": (0.8, 1.2)}, A.RandomBrightness: {"brightness": (0.8, 1.2)}, + A.RandomSaturation: {"saturation": (0.8, 1.2)}, }, except_augmentations={ A.RandomCropNearBBox, @@ -1327,6 +1328,7 @@ def test_change_image(augmentation_cls, params): A.RandomHue: {"hue": (-0.2, 0.2)}, A.RandomContrast: {"contrast": (0.8, 1.2)}, A.RandomBrightness: {"brightness": (0.8, 1.2)}, + A.RandomSaturation: {"saturation": (0.8, 1.2)}, }, except_augmentations={ A.Crop, @@ -1765,9 +1767,9 @@ def test_random_fog_invalid_input(params): @pytest.mark.parametrize("image", IMAGES + [np.full((100, 100), 128, dtype=np.uint8)]) -@pytest.mark.parametrize("mean", (0, 10, -10)) +@pytest.mark.parametrize("mean", (0, 0.1, -0.1)) def test_gauss_noise(mean, image): - aug = A.GaussNoise(p=1, noise_scale_factor=1.0, mean=mean) + aug = A.GaussNoise(p=1, noise_scale_factor=1.0, mean_range=(mean, mean)) aug.set_random_seed(42) apply_params = aug.get_params_dependent_on_data( @@ -1775,7 +1777,7 @@ def test_gauss_noise(mean, image): data={"image": image}, ) - assert np.abs(mean - apply_params["gauss"].mean()) < 0.5 + assert np.abs(mean - apply_params["gauss"].mean() / MAX_VALUES_BY_DTYPE[image.dtype]) < 0.5 result = A.Compose([aug], seed=42)(image=image) assert not (result["image"] >= image).all()