Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add random grayscale #2128

Merged
merged 3 commits into from
Nov 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ Pixel-level transforms will change just an input image and will leave any additi
- [RandomFog](https://explore.albumentations.ai/transform/RandomFog)
- [RandomGamma](https://explore.albumentations.ai/transform/RandomGamma)
- [RandomGravel](https://explore.albumentations.ai/transform/RandomGravel)
- [RandomGrayscale](https://explore.albumentations.ai/transform/RandomGrayscale)
- [RandomJPEG](https://explore.albumentations.ai/transform/RandomJPEG)
- [RandomRain](https://explore.albumentations.ai/transform/RandomRain)
- [RandomShadow](https://explore.albumentations.ai/transform/RandomShadow)
Expand Down
76 changes: 72 additions & 4 deletions albumentations/augmentations/tk/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
from pydantic import AfterValidator

from albumentations.augmentations.geometric.transforms import HorizontalFlip, VerticalFlip
from albumentations.augmentations.transforms import ImageCompression
from albumentations.augmentations.transforms import ImageCompression, ToGray
from albumentations.core.pydantic import check_0plus, nondecreasing
from albumentations.core.transforms_interface import BaseTransformInitSchema
from albumentations.core.types import Targets

__all__ = ["RandomJPEG", "RandomHorizontalFlip", "RandomVerticalFlip"]
__all__ = ["RandomJPEG", "RandomHorizontalFlip", "RandomVerticalFlip", "RandomGrayscale"]


class RandomJPEG(ImageCompression):
Expand Down Expand Up @@ -105,7 +105,7 @@ class RandomHorizontalFlip(HorizontalFlip):
>>> transform = A.HorizontalFlip(p=0.5)

References:
- torchvision: https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.RandomHorizontalFlip
- torchvision: https://pytorch.org/vision/stable/generated/torchvision.transforms.v2.RandomHorizontalFlip.html
- Kornia: https://kornia.readthedocs.io/en/latest/augmentation.html#kornia.augmentation.RandomHorizontalFlip
"""

Expand Down Expand Up @@ -158,7 +158,7 @@ class RandomVerticalFlip(VerticalFlip):
>>> transform = A.VerticalFlip(p=0.5)

References:
- torchvision: https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.RandomVerticalFlip
- torchvision: https://pytorch.org/vision/stable/generated/torchvision.transforms.v2.RandomVerticalFlip.html
- Kornia: https://kornia.readthedocs.io/en/latest/augmentation.html#kornia.augmentation.RandomVerticalFlip
"""

Expand All @@ -179,3 +179,71 @@ def __init__(
stacklevel=2,
)
super().__init__(p=p, always_apply=always_apply)


class RandomGrayscale(ToGray):
"""Randomly convert the input image to grayscale with a given probability.

This transform is an alias for ToGray, provided for compatibility with
torchvision and Kornia APIs. For new code, it is recommended to use
albumentations.ToGray directly.

Uses ITU-R 601-2 luma transform: grayscale = 0.299R + 0.587G + 0.114B
(same as torchvision.transforms.RandomGrayscale).

Args:
p (float): probability that image should be converted to grayscale. Default: 0.1.

Targets:
image

Image types:
uint8, float32

Number of channels:
3

Note:
This is a direct alias for albumentations.ToGray transform with method="weighted_average".
It is provided to make migration from torchvision and Kornia easier by
maintaining API compatibility.

For more flexibility, consider using albumentations.ToGray directly, which supports:
- Multiple grayscale conversion methods ("weighted_average", "from_lab", "desaturation", etc.)
- Some methods that work with any number of channels ("desaturation", "average", "max", "pca")
- Different perceptual and performance trade-offs

This transform specifically:
- If input image is 3 channel: grayscale version is 3 channel with r == g == b
- Unlike torchvision, single-channel inputs are not supported
- Uses the same ITU-R 601-2 weights (0.299, 0.587, 0.114) as torchvision

Example:
>>> transform = A.RandomGrayscale(p=0.1)
>>> # Consider using instead:
>>> transform = A.ToGray(p=0.1, method="weighted_average")

References:
- torchvision: https://pytorch.org/vision/stable/generated/torchvision.transforms.v2.RandomGrayscale.html
- Kornia: https://kornia.readthedocs.io/en/latest/augmentation.html#kornia.augmentation.RandomGrayscale
- ITU-R BT.601: https://en.wikipedia.org/wiki/Rec._601
"""

class InitSchema(BaseTransformInitSchema):
pass

def __init__(
self,
p: float = 0.1,
always_apply: bool | None = None,
):
warn(
"RandomGrayscale is an alias for ToGray transform. "
"Consider using ToGray directly from albumentations.ToGray.",
UserWarning,
stacklevel=2,
)
super().__init__(p=p, always_apply=always_apply)

def get_transform_init_args_names(self) -> tuple[str, ...]:
return ()
1 change: 1 addition & 0 deletions tests/aug_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,4 +385,5 @@
[A.RandomJPEG, {"jpeg_quality": (50, 50)}],
[A.RandomHorizontalFlip, {}],
[A.RandomVerticalFlip, {}],
[A.RandomGrayscale, {}],
]
4 changes: 4 additions & 0 deletions tests/test_augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,7 @@ def test_mask_fill_value(augmentation_cls, params):
A.Spatter,
A.ChromaticAberration,
A.PlanckianJitter,
A.RandomGrayscale,
},
),
)
Expand Down Expand Up @@ -695,6 +696,7 @@ def test_multichannel_image_augmentations(augmentation_cls, params):
A.Spatter,
A.ChromaticAberration,
A.PlanckianJitter,
A.RandomGrayscale,
},
),
)
Expand Down Expand Up @@ -778,6 +780,7 @@ def test_float_multichannel_image_augmentations(augmentation_cls, params):
A.Spatter,
A.ChromaticAberration,
A.PlanckianJitter,
A.RandomGrayscale,
},
),
)
Expand Down Expand Up @@ -863,6 +866,7 @@ def test_multichannel_image_augmentations_diff_channels(augmentation_cls, params
A.Spatter,
A.ChromaticAberration,
A.PlanckianJitter,
A.RandomGrayscale,
},
),
)
Expand Down