Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix in Solarize
Browse files Browse the repository at this point in the history
ternaus committed Oct 16, 2024
1 parent cb433ba commit 0776043
Showing 5 changed files with 110 additions and 4 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -77,11 +77,11 @@ repos:
hooks:
- id: markdownlint
- repo: https://github.com/tox-dev/pyproject-fmt
rev: "2.2.4"
rev: "2.3.1"
hooks:
- id: pyproject-fmt
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.11.2
rev: v1.12.0
hooks:
- id: mypy
files: ^albumentations/
3 changes: 2 additions & 1 deletion albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
@@ -46,6 +46,7 @@
ProbabilityType,
SymmetricRangeType,
ZeroOneRangeType,
ZeroPlusRangeType,
check_0plus,
check_01,
check_1plus,
@@ -1944,7 +1945,7 @@ class Solarize(ImageOnlyTransform):
"""

class InitSchema(BaseTransformInitSchema):
threshold: OnePlusFloatRangeType = (128, 128)
threshold: ZeroPlusRangeType = (128, 128)

def __init__(self, threshold: ScaleFloatType = (128, 128), p: float = 0.5, always_apply: bool | None = None):
super().__init__(p=p, always_apply=always_apply)
9 changes: 8 additions & 1 deletion albumentations/core/pydantic.py
Original file line number Diff line number Diff line change
@@ -130,4 +130,11 @@ def check_01(value: tuple[NumericType, NumericType]) -> tuple[NumericType, Numer
return value


ZeroOneRangeType = Annotated[ScaleType, AfterValidator(convert_to_0plus_range), AfterValidator(check_01)]
ZeroOneRangeType = Annotated[
ScaleType,
AfterValidator(convert_to_0plus_range),
AfterValidator(check_01),
AfterValidator(nondecreasing),
]

ZeroPlusRangeType = Annotated[ScaleType, AfterValidator(convert_to_0plus_range), AfterValidator(nondecreasing)]
83 changes: 83 additions & 0 deletions tests/functional/test_dropout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import pytest
import numpy as np
from skimage.measure import label as ski_label
from albumentations.augmentations.dropout.functional import label as cv_label
from scipy import stats

from tests.utils import set_seed

@pytest.mark.parametrize("shape, dtype, connectivity", [
((8, 8), np.uint8, 1),
((10, 10), np.uint8, 2),
((12, 12), np.int32, 1),
((12, 12), np.int32, 2),
((14, 14), np.uint8, 1),
((35, 35), np.uint8, 2),
])
def test_label_function(shape, dtype, connectivity):
set_seed(42)
# Generate a random binary mask
mask = np.random.randint(0, 2, shape).astype(dtype)

# Compare results with scikit-image
ski_result = ski_label(mask, connectivity=connectivity)
cv_result = cv_label(mask, connectivity=connectivity)

np.testing.assert_array_equal(cv_result, ski_result), "Label results do not match"

@pytest.mark.parametrize("shape, dtype, connectivity", [
((10, 10), np.uint8, 1),
((20, 20), np.int32, 2),
((30, 30), np.uint8, 1),
])
def test_label_function_return_num(shape, dtype, connectivity):
mask = np.random.randint(0, 2, shape).astype(dtype)

ski_result, ski_num = ski_label(mask, connectivity=connectivity, return_num=True)
cv_result, cv_num = cv_label(mask, connectivity=connectivity, return_num=True)

np.testing.assert_array_equal(cv_result, ski_result), "Label results do not match"
assert ski_num == cv_num, "Number of labels do not match"

@pytest.mark.parametrize("shape, num_objects", [
((10, 10), 3),
((20, 20), 5),
((30, 30), 10),
])
def test_label_function_with_multiple_objects(shape, num_objects):
set_seed(43)
mask = np.zeros(shape, dtype=np.uint8)
for i in range(1, num_objects + 1):
x, y = np.random.randint(0, shape[0]), np.random.randint(0, shape[1])
mask[x:x+3, y:y+3] = i

ski_result, ski_num = ski_label(mask, return_num=True)
cv_result, cv_num = cv_label(mask, return_num=True)

# Check for one-to-one mapping
combined = np.stack((ski_result, cv_result))
unique_combinations = np.unique(combined.reshape(2, -1).T, axis=0)

assert len(unique_combinations) == len(np.unique(ski_result)) == len(np.unique(cv_result)), \
"Labels are not equal up to enumeration"

assert ski_num == cv_num, "Number of labels do not match"
assert cv_num == num_objects, f"Expected {num_objects} labels, got {cv_num}"

def test_label_function_empty_mask():
mask = np.zeros((10, 10), dtype=np.uint8)

ski_result, ski_num = ski_label(mask, return_num=True)
cv_result, cv_num = cv_label(mask, return_num=True)

np.testing.assert_array_equal(cv_result, ski_result), "Label results do not match for empty mask"
assert ski_num == cv_num == 0, "Number of labels should be 0 for empty mask"

def test_label_function_full_mask():
mask = np.ones((10, 10), dtype=np.uint8)

ski_result, ski_num = ski_label(mask, return_num=True)
cv_result, cv_num = cv_label(mask, return_num=True)

np.testing.assert_array_equal(cv_result, ski_result), "Label results do not match for full mask"
assert ski_num == cv_num, "Number of labels should be 2 for full mask (background + one object)"
15 changes: 15 additions & 0 deletions tests/test_augmentations.py
Original file line number Diff line number Diff line change
@@ -1096,3 +1096,18 @@ def test_augmentations_match_uint8_float32(augmentation_cls, params):
transformed_float32 = transform(**data)["image"]

np.testing.assert_array_almost_equal(to_float(transformed_uint8), transformed_float32, decimal=2)


def test_solarize_threshold():
image = SQUARE_UINT8_IMAGE
image[20:40, 20:40] = 255
transform = A.Solarize(threshold=128, p=1)
transformed_image = transform(image=image)["image"]
assert (transformed_image[20:40, 20:40] == 0).all()

transform = A.Solarize(threshold=0.5, p=1)

float_image = SQUARE_FLOAT_IMAGE
float_image[20:40, 20:40] = 1
transformed_image = transform(image=float_image)["image"]
assert (transformed_image[20:40, 20:40] == 0).all()

0 comments on commit 0776043

Please sign in to comment.