Skip to content

Commit

Permalink
Updated image compression
Browse files Browse the repository at this point in the history
  • Loading branch information
ternaus committed Jan 7, 2025
1 parent 5956993 commit d3edf9f
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 69 deletions.
50 changes: 1 addition & 49 deletions albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,53 +342,10 @@ class InitSchema(BaseTransformInitSchema):
AfterValidator(nondecreasing),
]

quality_lower: int | None = Field(
ge=1,
le=100,
)
quality_upper: int | None = Field(
ge=1,
le=100,
)
compression_type: Literal["jpeg", "webp"]

@model_validator(mode="after")
def validate_ranges(self) -> Self:
# Update the quality_range based on the non-None values of quality_lower and quality_upper
if self.quality_lower is not None or self.quality_upper is not None:
if self.quality_lower is not None:
warn(
"`quality_lower` is deprecated. Use `quality_range` as tuple"
" (quality_lower, quality_upper) instead.",
DeprecationWarning,
stacklevel=2,
)
if self.quality_upper is not None:
warn(
"`quality_upper` is deprecated. Use `quality_range` as tuple"
" (quality_lower, quality_upper) instead.",
DeprecationWarning,
stacklevel=2,
)
lower = self.quality_lower if self.quality_lower is not None else self.quality_range[0]
upper = self.quality_upper if self.quality_upper is not None else self.quality_range[1]
self.quality_range = (lower, upper)
# Clear the deprecated individual quality settings
self.quality_lower = None
self.quality_upper = None

# Validate the quality_range
if not (1 <= self.quality_range[0] <= MAX_JPEG_QUALITY and 1 <= self.quality_range[1] <= MAX_JPEG_QUALITY):
raise ValueError(
f"Quality range values should be within [1, {MAX_JPEG_QUALITY}] range.",
)

return self

def __init__(
self,
quality_lower: int | None = None,
quality_upper: int | None = None,
compression_type: Literal["jpeg", "webp"] = "jpeg",
quality_range: tuple[int, int] = (99, 100),
p: float = 0.5,
Expand All @@ -407,12 +364,7 @@ def apply(
return fmain.image_compression(img, quality, image_type)

def get_params(self) -> dict[str, int | str]:
if self.compression_type == "jpeg":
image_type = ".jpg"
elif self.compression_type == "webp":
image_type = ".webp"
else:
raise ValueError(f"Unknown image compression type: {self.compression_type}")
image_type = ".jpg" if self.compression_type == "jpeg" else ".webp"

return {
"quality": self.py_random.randint(*self.quality_range),
Expand Down
20 changes: 0 additions & 20 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1200,26 +1200,6 @@ def test_random_crop_from_borders(
assert aug(image=image, mask=image, bboxes=bboxes, keypoints=keypoints)


@pytest.mark.parametrize(
"params, expected",
[
# Test default initialization values
({}, {"quality_range": (99, 100), "compression_type": "jpeg"}),
# Test custom quality range and compression type
(
{"quality_range": (10, 90), "compression_type": "webp"},
{"quality_range": (10, 90), "compression_type": "webp"},
),
# Deprecated quality values handling
({"quality_lower": 75}, {"quality_range": (75, 100)}),
],
)
def test_image_compression_initialization(params, expected):
img_comp = A.ImageCompression(**params)
for key, value in expected.items():
assert getattr(img_comp, key) == value, f"Failed on {key} with value {value}"


@pytest.mark.parametrize(
"params",
[
Expand Down

0 comments on commit d3edf9f

Please sign in to comment.