Skip to content

Commit

Permalink
Added PlasmaBrightnessContrast
Browse files Browse the repository at this point in the history
  • Loading branch information
ternaus committed Nov 18, 2024
1 parent 499ebb8 commit 8950edd
Show file tree
Hide file tree
Showing 6 changed files with 362 additions and 27 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ Pixel-level transforms will change just an input image and will leave any additi
- [Normalize](https://explore.albumentations.ai/transform/Normalize)
- [PixelDistributionAdaptation](https://explore.albumentations.ai/transform/PixelDistributionAdaptation)
- [PlanckianJitter](https://explore.albumentations.ai/transform/PlanckianJitter)
- [PlasmaBrightnessContrast](https://explore.albumentations.ai/transform/PlasmaBrightnessContrast)
- [Posterize](https://explore.albumentations.ai/transform/Posterize)
- [RGBShift](https://explore.albumentations.ai/transform/RGBShift)
- [RandomBrightness](https://explore.albumentations.ai/transform/RandomBrightness)
Expand Down
160 changes: 160 additions & 0 deletions albumentations/augmentations/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2192,3 +2192,163 @@ def apply_salt_and_pepper(
result[salt_mask] = MAX_VALUES_BY_DTYPE[img.dtype]
result[pepper_mask] = 0
return result


def get_grid_size(size: int, target_shape: tuple[int, int]) -> int:
"""Round up to nearest power of 2."""
return 2 ** int(np.ceil(np.log2(max(size, *target_shape))))


def random_offset(current_size: int, total_size: int, roughness: float, random_generator: np.random.Generator) -> float:
"""Calculate random offset based on current grid size."""
return (random_generator.random() - 0.5) * (current_size / total_size) ** (roughness / 2)


def initialize_grid(grid_size: int, random_generator: np.random.Generator) -> np.ndarray:
"""Initialize grid with random corners."""
pattern = np.zeros((grid_size + 1, grid_size + 1), dtype=np.float32)
for corner in [(0, 0), (0, -1), (-1, 0), (-1, -1)]:
pattern[corner] = random_generator.random()
return pattern


def square_step(
pattern: np.ndarray,
y: int,
x: int,
step: int,
grid_size: int,
roughness: float,
random_generator: np.random.Generator,
) -> float:
"""Compute center value during square step."""
corners = [
pattern[y, x], # top-left
pattern[y, x + step], # top-right
pattern[y + step, x], # bottom-left
pattern[y + step, x + step], # bottom-right
]
return sum(corners) / 4.0 + random_offset(step, grid_size, roughness, random_generator)


def diamond_step(
pattern: np.ndarray,
y: int,
x: int,
half: int,
grid_size: int,
roughness: float,
random_generator: np.random.Generator,
) -> float:
"""Compute edge value during diamond step."""
points = []
if y >= half:
points.append(pattern[y - half, x])
if y + half <= grid_size:
points.append(pattern[y + half, x])
if x >= half:
points.append(pattern[y, x - half])
if x + half <= grid_size:
points.append(pattern[y, x + half])

return sum(points) / len(points) + random_offset(half * 2, grid_size, roughness, random_generator)


def generate_plasma_pattern(
target_shape: tuple[int, int],
size: int,
roughness: float,
random_generator: np.random.Generator,
) -> np.ndarray:
"""Generate a plasma fractal pattern using the Diamond-Square algorithm.
The Diamond-Square algorithm creates a natural-looking noise pattern by recursively
subdividing a grid and adding random displacements at each step. The roughness
parameter controls how quickly the random displacements decrease with each iteration.
Args:
target_shape: Final shape (height, width) of the pattern
size: Initial size of the pattern grid. Will be rounded up to nearest power of 2.
Larger values create more detailed patterns.
roughness: Controls pattern roughness. Higher values create more rough/sharp transitions.
Typical values are between 1.0 and 5.0.
random_generator: NumPy random generator.
Returns:
Normalized plasma pattern array of shape target_shape with values in [0, 1]
"""
# Initialize grid
grid_size = get_grid_size(size, target_shape)
pattern = initialize_grid(grid_size, random_generator)

# Diamond-Square algorithm
step_size = grid_size
while step_size > 1:
half_step = step_size // 2

# Square step
for y in range(0, grid_size, step_size):
for x in range(0, grid_size, step_size):
if half_step > 0:
pattern[y + half_step, x + half_step] = square_step(
pattern,
y,
x,
step_size,
half_step,
roughness,
random_generator,
)

# Diamond step
for y in range(0, grid_size + 1, half_step):
for x in range((y + half_step) % step_size, grid_size + 1, step_size):
pattern[y, x] = diamond_step(pattern, y, x, half_step, grid_size, roughness, random_generator)

step_size = half_step

min_pattern = pattern.min()

# Normalize to [0, 1] range
pattern = (pattern - min_pattern) / (pattern.max() - min_pattern)

return (
fgeometric.resize(pattern, target_shape, interpolation=cv2.INTER_LINEAR)
if pattern.shape != target_shape
else pattern
)


@clipped
def apply_plasma_brightness_contrast(
img: np.ndarray,
brightness_factor: float,
contrast_factor: float,
plasma_pattern: np.ndarray,
) -> np.ndarray:
"""Apply plasma-based brightness and contrast adjustments.
The plasma pattern is used to create spatially-varying adjustments:
1. Brightness is modified by adding the pattern * brightness_factor
2. Contrast is modified by interpolating between mean and original
using the pattern * contrast_factor
"""
result = img.copy()

max_value = MAX_VALUES_BY_DTYPE[img.dtype]

# Expand plasma pattern to match image dimensions
plasma_pattern = plasma_pattern[..., np.newaxis] if img.ndim > MONO_CHANNEL_DIMENSIONS else plasma_pattern

# Apply brightness adjustment
if brightness_factor != 0:
brightness_adjustment = plasma_pattern * brightness_factor * max_value
result = np.clip(result + brightness_adjustment, 0, max_value)

# Apply contrast adjustment
if contrast_factor != 0:
mean = result.mean()
contrast_weights = plasma_pattern * contrast_factor + 1
result = np.clip(mean + (result - mean) * contrast_weights, 0, max_value)

return result
2 changes: 1 addition & 1 deletion albumentations/augmentations/geometric/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def resize(img: np.ndarray, target_shape: tuple[int, int], interpolation: int) -
if target_shape == img.shape[:2]:
return img

height, width = target_shape
height, width = target_shape[:2]
resize_fn = maybe_process_in_chunks(cv2.resize, dsize=(width, height), interpolation=interpolation)
return resize_fn(img)

Expand Down
35 changes: 15 additions & 20 deletions albumentations/augmentations/tk/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ class InitSchema(BaseTransformInitSchema):
def __init__(
self,
jpeg_quality: tuple[int, int] = (50, 50),
always_apply: bool = False,
p: float = 0.5,
always_apply: bool = False,
):
warn(
"RandomJPEG is a specialized version of ImageCompression. "
Expand Down Expand Up @@ -173,7 +173,7 @@ def __init__(
UserWarning,
stacklevel=2,
)
super().__init__(p=p, always_apply=always_apply)
super().__init__(p=p)


class RandomVerticalFlip(VerticalFlip):
Expand Down Expand Up @@ -342,8 +342,6 @@ class RandomPerspective(Perspective):
- Kornia: https://kornia.readthedocs.io/en/latest/augmentation.html#kornia.augmentation.RandomPerspective
"""

_targets = (Targets.IMAGE, Targets.MASK, Targets.BBOXES, Targets.KEYPOINTS)

class InitSchema(BaseTransformInitSchema):
distortion_scale: float = Field(ge=0, le=1)
fill: ColorType
Expand Down Expand Up @@ -439,8 +437,6 @@ class RandomAffine(Affine):
- Kornia: https://kornia.readthedocs.io/en/latest/augmentation.html#kornia.augmentation.RandomAffine
"""

_targets = (Targets.IMAGE, Targets.MASK, Targets.BBOXES, Targets.KEYPOINTS)

class InitSchema(BaseTransformInitSchema):
degrees: ScaleFloatType
translate: tuple[float, float]
Expand Down Expand Up @@ -760,8 +756,8 @@ class InitSchema(BaseTransformInitSchema):
def __init__(
self,
hue: tuple[float, float] = (0, 0),
always_apply: bool | None = None,
p: float = 0.5,
always_apply: bool | None = None,
):
warn(
"RandomHue is a specialized version of ColorJitter. "
Expand Down Expand Up @@ -826,8 +822,8 @@ def __init__(
self,
clip_limit: float | tuple[float, float] = (1, 4),
tile_grid_size: tuple[int, int] = (8, 8),
always_apply: bool | None = None,
p: float = 0.5,
always_apply: bool | None = None,
):
warn(
"RandomClahe is an alias for CLAHE transform. Consider using CLAHE directly from albumentations.CLAHE.",
Expand Down Expand Up @@ -892,8 +888,8 @@ class InitSchema(BaseTransformInitSchema):
def __init__(
self,
contrast: tuple[float, float] = (1, 1),
always_apply: bool | None = None,
p: float = 0.5,
always_apply: bool | None = None,
):
warn(
"RandomContrast is a specialized version of RandomBrightnessContrast. "
Expand Down Expand Up @@ -958,8 +954,8 @@ class InitSchema(BaseTransformInitSchema):
def __init__(
self,
brightness: tuple[float, float] = (1, 1),
always_apply: bool | None = None,
p: float = 0.5,
always_apply: bool | None = None,
):
warn(
"RandomBrightness is a specialized version of RandomBrightnessContrast. "
Expand Down Expand Up @@ -1027,8 +1023,8 @@ def __init__(
self,
num_drop_channels: int = 1,
fill_value: float = 0,
always_apply: bool | None = None,
p: float = 0.5,
always_apply: bool | None = None,
):
warn(
"RandomChannelDropout is an alias for ChannelDropout transform. "
Expand Down Expand Up @@ -1087,8 +1083,8 @@ class InitSchema(BaseTransformInitSchema):

def __init__(
self,
always_apply: bool | None = None,
p: float = 0.5,
always_apply: bool | None = None,
):
warn(
"RandomEqualize is a specialized version of Equalize transform. "
Expand Down Expand Up @@ -1159,8 +1155,8 @@ def __init__(
self,
kernel_size: ScaleIntType = (3, 7),
sigma: ScaleFloatType = 0,
always_apply: bool | None = None,
p: float = 0.5,
always_apply: bool | None = None,
):
warn(
"RandomGaussianBlur is an alias for GaussianBlur transform. "
Expand All @@ -1172,7 +1168,6 @@ def __init__(
blur_limit=kernel_size,
sigma_limit=sigma,
p=p,
always_apply=always_apply,
)
self.kernel_size = kernel_size
self.sigma = sigma
Expand Down Expand Up @@ -1234,8 +1229,8 @@ class InitSchema(BaseTransformInitSchema):
def __init__(
self,
mode: Literal["blackbody", "cied"] = "blackbody",
always_apply: bool | None = None,
p: float = 0.5,
always_apply: bool | None = None,
):
warn(
"RandomPlanckianJitter is a specialized version of PlanckianJitter transform. "
Expand Down Expand Up @@ -1303,8 +1298,8 @@ class InitSchema(BaseTransformInitSchema):
def __init__(
self,
kernel_size: tuple[int, int] = (3, 3),
always_apply: bool | None = None,
p: float = 0.5,
always_apply: bool | None = None,
):
warn(
"RandomMedianBlur is a specialized version of MedianBlur with a probability parameter. "
Expand Down Expand Up @@ -1370,8 +1365,8 @@ class InitSchema(BaseTransformInitSchema):
def __init__(
self,
thresholds: tuple[float, float] = (0.1, 0.1),
always_apply: bool | None = None,
p: float = 0.5,
always_apply: bool | None = None,
):
warn(
"RandomSolarize is an alias for Solarize transform. "
Expand Down Expand Up @@ -1441,8 +1436,8 @@ class InitSchema(BaseTransformInitSchema):
def __init__(
self,
num_bits: tuple[int, int] = (3, 3),
always_apply: bool | None = None,
p: float = 0.5,
always_apply: bool | None = None,
):
warn(
"RandomPosterize is an alias for Posterize transform. "
Expand Down Expand Up @@ -1503,8 +1498,8 @@ class InitSchema(BaseTransformInitSchema):
def __init__(
self,
saturation: tuple[float, float] = (1.0, 1.0),
always_apply: bool | None = None,
p: float = 0.5,
always_apply: bool | None = None,
):
super().__init__(
brightness=(1.0, 1.0), # No brightness change
Expand Down Expand Up @@ -1568,8 +1563,8 @@ def __init__(
self,
mean: float = 0.0,
sigma: float = 0.1,
always_apply: bool | None = None,
p: float = 0.5,
always_apply: bool | None = None,
):
warn(
"GaussianNoise is a specialized version of GaussNoise that follows torchvision's API. "
Expand Down
Loading

0 comments on commit 8950edd

Please sign in to comment.