From 8950eddf0120e3460f59831501b928080ea1bf89 Mon Sep 17 00:00:00 2001 From: Vladimir Iglovikov Date: Sun, 17 Nov 2024 19:38:48 -0800 Subject: [PATCH] Added PlasmaBrightnessContrast --- README.md | 1 + albumentations/augmentations/functional.py | 160 +++++++++++++++ .../augmentations/geometric/functional.py | 2 +- albumentations/augmentations/tk/transform.py | 35 ++-- albumentations/augmentations/transforms.py | 190 +++++++++++++++++- tests/aug_definitions.py | 1 + 6 files changed, 362 insertions(+), 27 deletions(-) diff --git a/README.md b/README.md index 959bf2e05..ab60b61ca 100644 --- a/README.md +++ b/README.md @@ -213,6 +213,7 @@ Pixel-level transforms will change just an input image and will leave any additi - [Normalize](https://explore.albumentations.ai/transform/Normalize) - [PixelDistributionAdaptation](https://explore.albumentations.ai/transform/PixelDistributionAdaptation) - [PlanckianJitter](https://explore.albumentations.ai/transform/PlanckianJitter) +- [PlasmaBrightnessContrast](https://explore.albumentations.ai/transform/PlasmaBrightnessContrast) - [Posterize](https://explore.albumentations.ai/transform/Posterize) - [RGBShift](https://explore.albumentations.ai/transform/RGBShift) - [RandomBrightness](https://explore.albumentations.ai/transform/RandomBrightness) diff --git a/albumentations/augmentations/functional.py b/albumentations/augmentations/functional.py index 9681dc945..6e22e9a63 100644 --- a/albumentations/augmentations/functional.py +++ b/albumentations/augmentations/functional.py @@ -2192,3 +2192,163 @@ def apply_salt_and_pepper( result[salt_mask] = MAX_VALUES_BY_DTYPE[img.dtype] result[pepper_mask] = 0 return result + + +def get_grid_size(size: int, target_shape: tuple[int, int]) -> int: + """Round up to nearest power of 2.""" + return 2 ** int(np.ceil(np.log2(max(size, *target_shape)))) + + +def random_offset(current_size: int, total_size: int, roughness: float, random_generator: np.random.Generator) -> float: + """Calculate random offset based on current grid size.""" + return (random_generator.random() - 0.5) * (current_size / total_size) ** (roughness / 2) + + +def initialize_grid(grid_size: int, random_generator: np.random.Generator) -> np.ndarray: + """Initialize grid with random corners.""" + pattern = np.zeros((grid_size + 1, grid_size + 1), dtype=np.float32) + for corner in [(0, 0), (0, -1), (-1, 0), (-1, -1)]: + pattern[corner] = random_generator.random() + return pattern + + +def square_step( + pattern: np.ndarray, + y: int, + x: int, + step: int, + grid_size: int, + roughness: float, + random_generator: np.random.Generator, +) -> float: + """Compute center value during square step.""" + corners = [ + pattern[y, x], # top-left + pattern[y, x + step], # top-right + pattern[y + step, x], # bottom-left + pattern[y + step, x + step], # bottom-right + ] + return sum(corners) / 4.0 + random_offset(step, grid_size, roughness, random_generator) + + +def diamond_step( + pattern: np.ndarray, + y: int, + x: int, + half: int, + grid_size: int, + roughness: float, + random_generator: np.random.Generator, +) -> float: + """Compute edge value during diamond step.""" + points = [] + if y >= half: + points.append(pattern[y - half, x]) + if y + half <= grid_size: + points.append(pattern[y + half, x]) + if x >= half: + points.append(pattern[y, x - half]) + if x + half <= grid_size: + points.append(pattern[y, x + half]) + + return sum(points) / len(points) + random_offset(half * 2, grid_size, roughness, random_generator) + + +def generate_plasma_pattern( + target_shape: tuple[int, int], + size: int, + roughness: float, + random_generator: np.random.Generator, +) -> np.ndarray: + """Generate a plasma fractal pattern using the Diamond-Square algorithm. + + The Diamond-Square algorithm creates a natural-looking noise pattern by recursively + subdividing a grid and adding random displacements at each step. The roughness + parameter controls how quickly the random displacements decrease with each iteration. + + Args: + target_shape: Final shape (height, width) of the pattern + size: Initial size of the pattern grid. Will be rounded up to nearest power of 2. + Larger values create more detailed patterns. + roughness: Controls pattern roughness. Higher values create more rough/sharp transitions. + Typical values are between 1.0 and 5.0. + random_generator: NumPy random generator. + + Returns: + Normalized plasma pattern array of shape target_shape with values in [0, 1] + """ + # Initialize grid + grid_size = get_grid_size(size, target_shape) + pattern = initialize_grid(grid_size, random_generator) + + # Diamond-Square algorithm + step_size = grid_size + while step_size > 1: + half_step = step_size // 2 + + # Square step + for y in range(0, grid_size, step_size): + for x in range(0, grid_size, step_size): + if half_step > 0: + pattern[y + half_step, x + half_step] = square_step( + pattern, + y, + x, + step_size, + half_step, + roughness, + random_generator, + ) + + # Diamond step + for y in range(0, grid_size + 1, half_step): + for x in range((y + half_step) % step_size, grid_size + 1, step_size): + pattern[y, x] = diamond_step(pattern, y, x, half_step, grid_size, roughness, random_generator) + + step_size = half_step + + min_pattern = pattern.min() + + # Normalize to [0, 1] range + pattern = (pattern - min_pattern) / (pattern.max() - min_pattern) + + return ( + fgeometric.resize(pattern, target_shape, interpolation=cv2.INTER_LINEAR) + if pattern.shape != target_shape + else pattern + ) + + +@clipped +def apply_plasma_brightness_contrast( + img: np.ndarray, + brightness_factor: float, + contrast_factor: float, + plasma_pattern: np.ndarray, +) -> np.ndarray: + """Apply plasma-based brightness and contrast adjustments. + + The plasma pattern is used to create spatially-varying adjustments: + 1. Brightness is modified by adding the pattern * brightness_factor + 2. Contrast is modified by interpolating between mean and original + using the pattern * contrast_factor + """ + result = img.copy() + + max_value = MAX_VALUES_BY_DTYPE[img.dtype] + + # Expand plasma pattern to match image dimensions + plasma_pattern = plasma_pattern[..., np.newaxis] if img.ndim > MONO_CHANNEL_DIMENSIONS else plasma_pattern + + # Apply brightness adjustment + if brightness_factor != 0: + brightness_adjustment = plasma_pattern * brightness_factor * max_value + result = np.clip(result + brightness_adjustment, 0, max_value) + + # Apply contrast adjustment + if contrast_factor != 0: + mean = result.mean() + contrast_weights = plasma_pattern * contrast_factor + 1 + result = np.clip(mean + (result - mean) * contrast_weights, 0, max_value) + + return result diff --git a/albumentations/augmentations/geometric/functional.py b/albumentations/augmentations/geometric/functional.py index a2d5259a4..bc0397668 100644 --- a/albumentations/augmentations/geometric/functional.py +++ b/albumentations/augmentations/geometric/functional.py @@ -253,7 +253,7 @@ def resize(img: np.ndarray, target_shape: tuple[int, int], interpolation: int) - if target_shape == img.shape[:2]: return img - height, width = target_shape + height, width = target_shape[:2] resize_fn = maybe_process_in_chunks(cv2.resize, dsize=(width, height), interpolation=interpolation) return resize_fn(img) diff --git a/albumentations/augmentations/tk/transform.py b/albumentations/augmentations/tk/transform.py index 0241debaa..bb9f4a69f 100644 --- a/albumentations/augmentations/tk/transform.py +++ b/albumentations/augmentations/tk/transform.py @@ -102,8 +102,8 @@ class InitSchema(BaseTransformInitSchema): def __init__( self, jpeg_quality: tuple[int, int] = (50, 50), - always_apply: bool = False, p: float = 0.5, + always_apply: bool = False, ): warn( "RandomJPEG is a specialized version of ImageCompression. " @@ -173,7 +173,7 @@ def __init__( UserWarning, stacklevel=2, ) - super().__init__(p=p, always_apply=always_apply) + super().__init__(p=p) class RandomVerticalFlip(VerticalFlip): @@ -342,8 +342,6 @@ class RandomPerspective(Perspective): - Kornia: https://kornia.readthedocs.io/en/latest/augmentation.html#kornia.augmentation.RandomPerspective """ - _targets = (Targets.IMAGE, Targets.MASK, Targets.BBOXES, Targets.KEYPOINTS) - class InitSchema(BaseTransformInitSchema): distortion_scale: float = Field(ge=0, le=1) fill: ColorType @@ -439,8 +437,6 @@ class RandomAffine(Affine): - Kornia: https://kornia.readthedocs.io/en/latest/augmentation.html#kornia.augmentation.RandomAffine """ - _targets = (Targets.IMAGE, Targets.MASK, Targets.BBOXES, Targets.KEYPOINTS) - class InitSchema(BaseTransformInitSchema): degrees: ScaleFloatType translate: tuple[float, float] @@ -760,8 +756,8 @@ class InitSchema(BaseTransformInitSchema): def __init__( self, hue: tuple[float, float] = (0, 0), - always_apply: bool | None = None, p: float = 0.5, + always_apply: bool | None = None, ): warn( "RandomHue is a specialized version of ColorJitter. " @@ -826,8 +822,8 @@ def __init__( self, clip_limit: float | tuple[float, float] = (1, 4), tile_grid_size: tuple[int, int] = (8, 8), - always_apply: bool | None = None, p: float = 0.5, + always_apply: bool | None = None, ): warn( "RandomClahe is an alias for CLAHE transform. Consider using CLAHE directly from albumentations.CLAHE.", @@ -892,8 +888,8 @@ class InitSchema(BaseTransformInitSchema): def __init__( self, contrast: tuple[float, float] = (1, 1), - always_apply: bool | None = None, p: float = 0.5, + always_apply: bool | None = None, ): warn( "RandomContrast is a specialized version of RandomBrightnessContrast. " @@ -958,8 +954,8 @@ class InitSchema(BaseTransformInitSchema): def __init__( self, brightness: tuple[float, float] = (1, 1), - always_apply: bool | None = None, p: float = 0.5, + always_apply: bool | None = None, ): warn( "RandomBrightness is a specialized version of RandomBrightnessContrast. " @@ -1027,8 +1023,8 @@ def __init__( self, num_drop_channels: int = 1, fill_value: float = 0, - always_apply: bool | None = None, p: float = 0.5, + always_apply: bool | None = None, ): warn( "RandomChannelDropout is an alias for ChannelDropout transform. " @@ -1087,8 +1083,8 @@ class InitSchema(BaseTransformInitSchema): def __init__( self, - always_apply: bool | None = None, p: float = 0.5, + always_apply: bool | None = None, ): warn( "RandomEqualize is a specialized version of Equalize transform. " @@ -1159,8 +1155,8 @@ def __init__( self, kernel_size: ScaleIntType = (3, 7), sigma: ScaleFloatType = 0, - always_apply: bool | None = None, p: float = 0.5, + always_apply: bool | None = None, ): warn( "RandomGaussianBlur is an alias for GaussianBlur transform. " @@ -1172,7 +1168,6 @@ def __init__( blur_limit=kernel_size, sigma_limit=sigma, p=p, - always_apply=always_apply, ) self.kernel_size = kernel_size self.sigma = sigma @@ -1234,8 +1229,8 @@ class InitSchema(BaseTransformInitSchema): def __init__( self, mode: Literal["blackbody", "cied"] = "blackbody", - always_apply: bool | None = None, p: float = 0.5, + always_apply: bool | None = None, ): warn( "RandomPlanckianJitter is a specialized version of PlanckianJitter transform. " @@ -1303,8 +1298,8 @@ class InitSchema(BaseTransformInitSchema): def __init__( self, kernel_size: tuple[int, int] = (3, 3), - always_apply: bool | None = None, p: float = 0.5, + always_apply: bool | None = None, ): warn( "RandomMedianBlur is a specialized version of MedianBlur with a probability parameter. " @@ -1370,8 +1365,8 @@ class InitSchema(BaseTransformInitSchema): def __init__( self, thresholds: tuple[float, float] = (0.1, 0.1), - always_apply: bool | None = None, p: float = 0.5, + always_apply: bool | None = None, ): warn( "RandomSolarize is an alias for Solarize transform. " @@ -1441,8 +1436,8 @@ class InitSchema(BaseTransformInitSchema): def __init__( self, num_bits: tuple[int, int] = (3, 3), - always_apply: bool | None = None, p: float = 0.5, + always_apply: bool | None = None, ): warn( "RandomPosterize is an alias for Posterize transform. " @@ -1503,8 +1498,8 @@ class InitSchema(BaseTransformInitSchema): def __init__( self, saturation: tuple[float, float] = (1.0, 1.0), - always_apply: bool | None = None, p: float = 0.5, + always_apply: bool | None = None, ): super().__init__( brightness=(1.0, 1.0), # No brightness change @@ -1568,8 +1563,8 @@ def __init__( self, mean: float = 0.0, sigma: float = 0.1, - always_apply: bool | None = None, p: float = 0.5, + always_apply: bool | None = None, ): warn( "GaussianNoise is a specialized version of GaussNoise that follows torchvision's API. " diff --git a/albumentations/augmentations/transforms.py b/albumentations/augmentations/transforms.py index 1df90c1c1..342ececd6 100644 --- a/albumentations/augmentations/transforms.py +++ b/albumentations/augmentations/transforms.py @@ -121,6 +121,7 @@ "ShotNoise", "AdditiveNoise", "SaltAndPepper", + "PlasmaBrightnessContrast", ] NUM_BITS_ARRAY_LENGTH = 3 @@ -5036,8 +5037,8 @@ def __init__( mode: Literal["blackbody", "cied"] = "blackbody", temperature_limit: tuple[int, int] | None = None, sampling_method: Literal["uniform", "gaussian"] = "uniform", - always_apply: bool | None = None, p: float = 0.5, + always_apply: bool | None = None, ) -> None: super().__init__(p=p, always_apply=always_apply) @@ -5154,7 +5155,7 @@ class ShotNoise(ImageOnlyTransform): class InitSchema(BaseTransformInitSchema): scale_range: Annotated[tuple[float, float], AfterValidator(nondecreasing), AfterValidator(check_0plus)] - def __init__(self, scale_range: tuple[float, float] = (0.1, 0.3), always_apply: bool = False, p: float = 0.5): + def __init__(self, scale_range: tuple[float, float] = (0.1, 0.3), p: float = 0.5, always_apply: bool = False): super().__init__(p=p, always_apply=always_apply) self.scale_range = scale_range @@ -5355,8 +5356,8 @@ def __init__( spatial_mode: Literal["constant", "per_pixel", "shared"] = "constant", noise_params: dict[str, Any] | None = None, approximation: float = 1.0, - always_apply: bool | None = None, p: float = 0.5, + always_apply: bool | None = None, ): super().__init__(always_apply=always_apply, p=p) self.noise_type = noise_type @@ -5480,8 +5481,8 @@ def __init__( r_shift_limit: ScaleFloatType = (-20, 20), g_shift_limit: ScaleFloatType = (-20, 20), b_shift_limit: ScaleFloatType = (-20, 20), - always_apply: bool | None = None, p: float = 0.5, + always_apply: bool | None = None, ): # Convert RGB shift limits to normalized ranges if needed def normalize_range(limit: tuple[float, float]) -> tuple[float, float]: @@ -5602,10 +5603,10 @@ def __init__( self, amount: tuple[float, float] = (0.01, 0.06), salt_vs_pepper: tuple[float, float] = (0.4, 0.6), - always_apply: bool | None = None, p: float = 0.5, + always_apply: bool | None = None, ): - super().__init__(always_apply=always_apply, p=p) + super().__init__(p=p, always_apply=always_apply) self.amount = amount self.salt_vs_pepper = salt_vs_pepper @@ -5634,3 +5635,180 @@ def apply(self, img: np.ndarray, salt_mask: np.ndarray, pepper_mask: np.ndarray, def get_transform_init_args_names(self) -> tuple[str, ...]: return "amount", "salt_vs_pepper" + + +class PlasmaBrightnessContrast(ImageOnlyTransform): + """Apply plasma fractal pattern to modify image brightness and contrast. + + This transform uses the Diamond-Square algorithm to generate organic-looking fractal patterns + that are then used to create spatially-varying brightness and contrast adjustments. + The result is a natural-looking, non-uniform modification of the image. + + Args: + brightness_range ((float, float)): Range for brightness adjustment strength. + Values between -1 and 1: + - Positive values increase brightness + - Negative values decrease brightness + - 0 means no brightness change + Default: (-0.3, 0.3) + + contrast_range ((float, float)): Range for contrast adjustment strength. + Values between -1 and 1: + - Positive values increase contrast + - Negative values decrease contrast + - 0 means no contrast change + Default: (-0.3, 0.3) + + plasma_size (int): Size of the plasma pattern. Will be rounded up to nearest power of 2. + Larger values create more detailed patterns. Default: 256 + + roughness (float): Controls the roughness of the plasma pattern. + Higher values create more rough/sharp transitions. + Must be greater than 0. + Typical values are between 1.0 and 5.0. Default: 3.0 + + p (float): Probability of applying the transform. Default: 0.5. + + Targets: + image + + Image types: + uint8, float32 + + Number of channels: + Any + + Mathematical Formulation: + 1. Plasma Pattern Generation: + The Diamond-Square algorithm generates a pattern P(x,y) ∈ [0,1] by: + - Starting with random corner values + - Recursively computing midpoints using: + M = (V1 + V2 + V3 + V4)/4 + R(d) + where V1..V4 are corner values and R(d) is random noise that + decreases with distance d according to the roughness parameter. + + 2. Brightness Adjustment: + For each pixel (x,y): + O(x,y) = I(x,y) + b·P(x,y)·max_value + where: + - I is the input image + - b is the brightness factor + - P is the plasma pattern + - max_value is the maximum possible pixel value + + 3. Contrast Adjustment: + For each pixel (x,y): + O(x,y) = μ + (I(x,y) - μ)·(1 + c·P(x,y)) + where: + - μ is the mean pixel value + - c is the contrast factor + - P is the plasma pattern + + Note: + - The plasma pattern creates smooth, organic variations in the adjustments + - Brightness and contrast modifications are applied sequentially + - Final values are clipped to valid range [0, max_value] + - The same plasma pattern is used for both brightness and contrast + to maintain coherent spatial variations + + Examples: + >>> import albumentations as A + >>> import numpy as np + + # Default parameters + >>> transform = A.PlasmaBrightnessContrast(p=1.0) + + # Custom adjustments with fine pattern + >>> transform = A.PlasmaBrightnessContrast( + ... brightness_range=(-0.5, 0.5), + ... contrast_range=(-0.3, 0.3), + ... plasma_size=512, # More detailed pattern + ... roughness=2.5, # Smoother transitions + ... p=1.0 + ... ) + + References: + .. [1] Fournier, Fussell, and Carpenter, "Computer rendering of stochastic models," + Communications of the ACM, 1982. + Paper introducing the Diamond-Square algorithm. + + .. [2] Miller, "The Diamond-Square Algorithm: A Detailed Analysis," + Journal of Computer Graphics Techniques, 2016. + Comprehensive analysis of the algorithm and its properties. + + .. [3] Ebert et al., "Texturing & Modeling: A Procedural Approach," + Chapter 12: Noise, Hypertexture, Antialiasing, and Gesture. + Detailed coverage of procedural noise patterns. + + .. [4] Diamond-Square algorithm: + https://en.wikipedia.org/wiki/Diamond-square_algorithm + + .. [5] Plasma effect: + https://lodev.org/cgtutor/plasma.html + + See Also: + - RandomBrightnessContrast: For uniform brightness/contrast adjustments + - CLAHE: For contrast limited adaptive histogram equalization + - FancyPCA: For color-based contrast enhancement + - HistogramMatching: For reference-based contrast adjustment + """ + + class InitSchema(BaseTransformInitSchema): + brightness_range: Annotated[tuple[float, float], AfterValidator(check_range_bounds(-1, 1))] + contrast_range: Annotated[tuple[float, float], AfterValidator(check_range_bounds(-1, 1))] + plasma_size: int = Field(default=256, gt=0) + roughness: float = Field(default=3.0, gt=0) + + def __init__( + self, + brightness_range: tuple[float, float] = (-0.3, 0.3), + contrast_range: tuple[float, float] = (-0.3, 0.3), + plasma_size: int = 256, + roughness: float = 3.0, + always_apply: bool | None = None, + p: float = 0.5, + ): + super().__init__(always_apply=always_apply, p=p) + self.brightness_range = brightness_range + self.contrast_range = contrast_range + self.plasma_size = plasma_size + self.roughness = roughness + + def get_params_dependent_on_data(self, params: dict[str, Any], data: dict[str, Any]) -> dict[str, Any]: + image = data["image"] if "image" in data else data["images"][0] + + # Sample adjustment strengths + brightness = self.py_random.uniform(*self.brightness_range) + contrast = self.py_random.uniform(*self.contrast_range) + + # Generate plasma pattern + plasma = fmain.generate_plasma_pattern( + target_shape=image.shape[:2], + size=self.plasma_size, + roughness=self.roughness, + random_generator=self.random_generator, + ) + + return { + "brightness_factor": brightness, + "contrast_factor": contrast, + "plasma_pattern": plasma, + } + + def apply( + self, + img: np.ndarray, + brightness_factor: float, + contrast_factor: float, + plasma_pattern: np.ndarray, + **params: Any, + ) -> np.ndarray: + return fmain.apply_plasma_brightness_contrast( + img, + brightness_factor, + contrast_factor, + plasma_pattern, + ) + + def get_transform_init_args_names(self) -> tuple[str, ...]: + return "brightness_range", "contrast_range", "plasma_size", "roughness" diff --git a/tests/aug_definitions.py b/tests/aug_definitions.py index bb5521dc9..837107e90 100644 --- a/tests/aug_definitions.py +++ b/tests/aug_definitions.py @@ -408,4 +408,5 @@ [A.GaussianNoise, {}], [A.AdditiveNoise, {}], [A.SaltAndPepper, {}], + [A.PlasmaBrightnessContrast, {}], ]