Skip to content

Commit

Permalink
Add transforms (#2147)
Browse files Browse the repository at this point in the history
* Empty-Commit

* Updated pydantic min version

* Refcactoring

* Fix
  • Loading branch information
ternaus authored Nov 17, 2024
1 parent d58da62 commit 30697e7
Show file tree
Hide file tree
Showing 9 changed files with 167 additions and 54 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ Pixel-level transforms will change just an input image and will leave any additi
- [RandomRain](https://explore.albumentations.ai/transform/RandomRain)
- [RandomShadow](https://explore.albumentations.ai/transform/RandomShadow)
- [RandomSnow](https://explore.albumentations.ai/transform/RandomSnow)
- [RandomSolarize](https://explore.albumentations.ai/transform/RandomSolarize)
- [RandomSunFlare](https://explore.albumentations.ai/transform/RandomSunFlare)
- [RandomToneCurve](https://explore.albumentations.ai/transform/RandomToneCurve)
- [RingingOvershoot](https://explore.albumentations.ai/transform/RingingOvershoot)
Expand Down
15 changes: 11 additions & 4 deletions albumentations/augmentations/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,28 +126,35 @@ def shift_hsv(img: np.ndarray, hue_shift: float, sat_shift: float, val_shift: fl


@clipped
def solarize(img: np.ndarray, threshold: int) -> np.ndarray:
def solarize(img: np.ndarray, threshold: float) -> np.ndarray:
"""Invert all pixel values above a threshold.
Args:
img: The image to solarize.
threshold: All pixels above this grayscale level are inverted.
img: The image to solarize. Can be uint8 or float32.
threshold: Normalized threshold value in range [0, 1].
For uint8 images: pixels above threshold * 255 are inverted
For float32 images: pixels above threshold are inverted
Returns:
Solarized image.
Note:
The threshold is normalized to [0, 1] range for both uint8 and float32 images.
For uint8 images, the threshold is internally scaled by 255.
"""
dtype = img.dtype
max_val = MAX_VALUES_BY_DTYPE[dtype]

if dtype == np.uint8:
lut = [(i if i < threshold else max_val - i) for i in range(int(max_val) + 1)]
lut = [(max_val - i if i >= threshold * max_val else i) for i in range(int(max_val) + 1)]

prev_shape = img.shape
img = sz_lut(img, np.array(lut, dtype=dtype), inplace=False)

return np.expand_dims(img, -1) if len(prev_shape) != img.ndim else img

img = img.copy()

cond = img >= threshold
img[cond] = max_val - img[cond]
return img
Expand Down
74 changes: 73 additions & 1 deletion albumentations/augmentations/tk/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@
InvertImg,
PlanckianJitter,
RandomBrightnessContrast,
Solarize,
ToGray,
)
from albumentations.core.pydantic import InterpolationType, check_0plus, check_1plus, nondecreasing
from albumentations.core.pydantic import InterpolationType, check_0plus, check_01, check_1plus, nondecreasing
from albumentations.core.transforms_interface import BaseTransformInitSchema
from albumentations.core.types import PAIR, ColorType, ScaleFloatType, ScaleIntType, Targets

Expand All @@ -46,6 +47,7 @@
"RandomGaussianBlur",
"RandomPlanckianJitter",
"RandomMedianBlur",
"RandomSolarize",
]


Expand Down Expand Up @@ -1306,3 +1308,73 @@ def __init__(

def get_transform_init_args_names(self) -> tuple[str, ...]:
return ("kernel_size",)


class RandomSolarize(Solarize):
"""Invert all pixel values above a threshold.
This transform is an alias for Solarize, provided for compatibility with
Kornia API naming convention, but using Albumentations' parameter format.
Args:
thresholds (tuple[float, float]): Range for solarizing threshold as a fraction
of maximum value. The thresholds should be in the range [0, 1] and will be multiplied by the
maximum value of the image type (255 for uint8 images or 1.0 for float images).
Default: (0.1, 0.1).
p (float): probability of applying the transform. Default: 0.5.
Targets:
image
Image types:
uint8, float32
Note:
This transform differs from Kornia's RandomSolarize in parameter format:
- Uses normalized thresholds [0, 1] for both uint8 and float32 images
- No support for post-solarization brightness addition
For brightness adjustment, use composition with RandomBrightness:
```python
A.Compose([
A.RandomSolarize(thresholds=0.1),
A.RandomBrightness(limit=0.1)
])
```
Example:
>>> # RandomSolarize with fixed threshold at 10% of max value
>>> transform = A.RandomSolarize(thresholds=0.1) # 25.5 for uint8, 0.1 for float32
>>> # RandomSolarize with threshold range
>>> transform = A.RandomSolarize(thresholds=(0.1, 0.1))
References:
- Kornia: https://kornia.readthedocs.io/en/latest/augmentation.html#kornia.augmentation.RandomSolarize
"""

class InitSchema(BaseTransformInitSchema):
thresholds: Annotated[tuple[float, float], AfterValidator(check_01), AfterValidator(nondecreasing)]

def __init__(
self,
thresholds: tuple[float, float] = (0.1, 0.1),
always_apply: bool | None = None,
p: float = 0.5,
):
warn(
"RandomSolarize is an alias for Solarize transform. "
"Consider using Solarize directly from albumentations.Solarize. "
"Note: parameter format differs from Kornia's implementation."
"For brightness addition, use composition with RandomBrightness.",
UserWarning,
stacklevel=2,
)

super().__init__(
threshold_range=thresholds,
p=p,
)
self.thresholds = thresholds

def get_transform_init_args_names(self) -> tuple[str, ...]:
return ("thresholds",)
79 changes: 55 additions & 24 deletions albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
check_01,
check_1plus,
nondecreasing,
repeat_if_scalar,
)
from albumentations.core.transforms_interface import (
BaseTransformInitSchema,
Expand Down Expand Up @@ -1815,11 +1814,10 @@ class Solarize(ImageOnlyTransform):
In this implementation, all pixel values above a threshold are inverted.
Args:
threshold (float | tuple[float, float]): Range for solarizing threshold.
If threshold is a single int, the range will be [threshold, threshold].
If it's a tuple of (min, max), the range will be [min, max].
The threshold should be in the range [0, 255] for uint8 images or [0, 1.0] for float images.
Default: 128.
threshold_range (tuple[float, float]): Range for solarizing threshold as a fraction
of maximum value. The threshold_range should be in the range [0, 1] and will be multiplied by the
maximum value of the image type (255 for uint8 images or 1.0 for float images).
Default: (0.5, 0.5) (corresponds to 127.5 for uint8 and 0.5 for float32).
p (float): Probability of applying the transform. Default: 0.5.
Targets:
Expand All @@ -1828,36 +1826,43 @@ class Solarize(ImageOnlyTransform):
Image types:
uint8, float32
Number of channels:
Any
Note:
- For uint8 images, pixel values above the threshold are inverted as: 255 - pixel_value
- For float32 images, pixel values above the threshold are inverted as: 1.0 - pixel_value
- The threshold is applied to each channel independently
- The threshold is calculated in two steps:
1. Sample a value from threshold_range
2. Multiply by the image's maximum value:
* For uint8: threshold = sampled_value * 255
* For float32: threshold = sampled_value * 1.0
- This transform can create interesting artistic effects or be used for data augmentation
Raises:
TypeError: If the input image data type is not supported.
Examples:
>>> import numpy as np
>>> import albumentations as A
>>>
# Solarize uint8 image with fixed threshold
# Solarize uint8 image with fixed threshold at 50% of max value (127.5)
>>> image = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
>>> transform = A.Solarize(threshold=128, p=1.0)
>>> transform = A.Solarize(threshold_range=(0.5, 0.5), p=1.0)
>>> solarized_image = transform(image=image)['image']
>>>
# Solarize uint8 image with random threshold
>>> transform = A.Solarize(threshold=(100, 200), p=1.0)
# Solarize uint8 image with random threshold between 40-60% of max value (102-153)
>>> transform = A.Solarize(threshold_range=(0.4, 0.6), p=1.0)
>>> solarized_image = transform(image=image)['image']
>>>
# Solarize float32 image
# Solarize float32 image at 50% of max value (0.5)
>>> image = np.random.rand(100, 100, 3).astype(np.float32)
>>> transform = A.Solarize(threshold=0.5, p=1.0)
>>> transform = A.Solarize(threshold_range=(0.5, 0.5), p=1.0)
>>> solarized_image = transform(image=image)['image']
Mathematical Formulation:
For each pixel value p and threshold t:
if p > t:
Let f be a value sampled from threshold_range (min, max).
For each pixel value p:
threshold = f * max_value
if p > threshold:
p_new = max_value - p
else:
p_new = p
Expand All @@ -1869,20 +1874,46 @@ class Solarize(ImageOnlyTransform):
"""

class InitSchema(BaseTransformInitSchema):
threshold: Annotated[ScaleFloatType, AfterValidator(repeat_if_scalar), AfterValidator(check_0plus)]
threshold: ScaleFloatType | None = Field(
default=None,
deprecated="threshold parameter is deprecated. Use threshold_range instead.",
)
threshold_range: Annotated[tuple[float, float], AfterValidator(check_01), AfterValidator(nondecreasing)]

@staticmethod
def normalize_threshold(
threshold: ScaleFloatType | None,
threshold_range: tuple[float, float],
) -> tuple[float, float]:
"""Convert legacy threshold or use threshold_range, normalizing to [0,1] range."""
if threshold is None:
return threshold_range
value = to_tuple(threshold, threshold)
return (value[0] / 255, value[1] / 255) if value[1] > 1 else value

def __init__(self, threshold: ScaleFloatType = (128, 128), p: float = 0.5, always_apply: bool | None = None):
@model_validator(mode="after")
def process_threshold(self) -> Self:
self.threshold_range = self.normalize_threshold(self.threshold, self.threshold_range)
return self

def __init__(
self,
threshold: ScaleFloatType | None = None,
threshold_range: tuple[float, float] = (0.5, 0.5),
p: float = 0.5,
always_apply: bool | None = None,
):
super().__init__(p=p, always_apply=always_apply)
self.threshold = cast(tuple[float, float], threshold)
self.threshold_range = threshold_range

def apply(self, img: np.ndarray, threshold: int, **params: Any) -> np.ndarray:
def apply(self, img: np.ndarray, threshold: float, **params: Any) -> np.ndarray:
return fmain.solarize(img, threshold)

def get_params(self) -> dict[str, float]:
return {"threshold": self.py_random.uniform(*self.threshold)}
return {"threshold": self.py_random.uniform(*self.threshold_range)}

def get_transform_init_args_names(self) -> tuple[str]:
return ("threshold",)
def get_transform_init_args_names(self) -> tuple[str, ...]:
return ("threshold_range",)


class Posterize(ImageOnlyTransform):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"scipy>=1.10.0",
"PyYAML",
"typing-extensions>=4.9.0; python_version<'3.10'",
"pydantic>=2.7.0",
"pydantic>=2.9.2",
"albucore==0.0.21",
"eval-type-backport",
]
Expand Down
1 change: 1 addition & 0 deletions tests/aug_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,4 +402,5 @@
[A.RandomGaussianBlur, {}],
[A.RandomPlanckianJitter, {}],
[A.RandomMedianBlur, {}],
[A.RandomSolarize, {}],
]
39 changes: 17 additions & 22 deletions tests/functional/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,30 +335,25 @@ def test_swap_tiles_on_image(img, tiles, mapping, expected):
np.testing.assert_array_equal(result_img, expected)


@pytest.mark.parametrize("dtype", [np.uint8, np.float32])
def test_solarize(dtype):
max_value = MAX_VALUES_BY_DTYPE[dtype]
@pytest.mark.parametrize("image", IMAGES)
@pytest.mark.parametrize("threshold", [0.0, 1/3, 2/3, 1.0, 1.1])
def test_solarize(image, threshold):
max_value = MAX_VALUES_BY_DTYPE[image.dtype]
check_img = image.copy()

if dtype == np.dtype("float32"):
img = np.arange(2**10, dtype=np.float32) / (2**10)
img = img.reshape([2**5, 2**5])
if image.dtype == np.uint8:
threshold_value = threshold * max_value
else:
max_count = 1024
count = min(max_value + 1, 1024)
step = max(1, (max_value + 1) // max_count)
shape = [int(np.sqrt(count))] * 2
img = np.arange(0, max_value + 1, step, dtype=dtype).reshape(shape)

for threshold in [0, max_value // 3, max_value // 3 * 2, max_value, max_value + 1]:
check_img = img.copy()
cond = check_img >= threshold
check_img[cond] = max_value - check_img[cond]

result_img = F.solarize(img, threshold=threshold)

assert np.all(np.isclose(result_img, check_img))
assert np.min(result_img) >= 0
assert np.max(result_img) <= max_value
threshold_value = threshold

cond = check_img >= threshold_value
check_img[cond] = max_value - check_img[cond]

result_img = F.solarize(image, threshold=threshold)

assert np.all(np.isclose(result_img, check_img))
assert np.min(result_img) >= 0
assert np.max(result_img) <= max_value


@pytest.mark.parametrize(
Expand Down
1 change: 0 additions & 1 deletion tests/test_augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,7 +1100,6 @@ def test_pad_if_needed_position(params, image_shape):
A.TemplateTransform,
A.OverlayElements,
A.TextImage,
A.Solarize,
A.RGBShift,
A.HueSaturationValue,
A.GaussNoise,
Expand Down
9 changes: 8 additions & 1 deletion tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,14 @@ def get_all_init_schema_fields(model_cls: A.BasicTransform) -> Set[str]:
fields = set()
if hasattr(model_cls, "InitSchema"):
for field_name, field in model_cls.InitSchema.model_fields.items():
if not field.deprecated:
# Check if field is deprecated either directly or in its default annotation
is_deprecated = (
field.deprecated is not None
or (hasattr(field.default, "metadata")
and any(getattr(m, "deprecated", None) is not None
for m in field.default.metadata))
)
if not is_deprecated:
fields.add(field_name)

return fields
Expand Down

0 comments on commit 30697e7

Please sign in to comment.