From 2b10b9f66320812543822262d35bcc08aeb01ffa Mon Sep 17 00:00:00 2001 From: Vladimir Iglovikov Date: Wed, 23 Oct 2024 14:22:19 -0700 Subject: [PATCH] Added test that ToTensorV2 converts mask --- albumentations/pytorch/transforms.py | 3 --- tests/test_pytorch.py | 22 +++++++++++++++++++++- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/albumentations/pytorch/transforms.py b/albumentations/pytorch/transforms.py index 3053578aa..94c796929 100644 --- a/albumentations/pytorch/transforms.py +++ b/albumentations/pytorch/transforms.py @@ -48,9 +48,6 @@ def apply(self, img: np.ndarray, **params: Any) -> torch.Tensor: return torch.from_numpy(img.transpose(2, 0, 1)) - def apply_to_images(self, images: list[np.ndarray], **params: Any) -> list[torch.Tensor]: - return [self.apply(image, **params) for image in images] - def apply_to_mask(self, mask: np.ndarray, **params: Any) -> torch.Tensor: if self.transpose_mask and mask.ndim == NUM_MULTI_CHANNEL_DIMENSIONS: mask = mask.transpose(2, 0, 1) diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index bebcad6a0..703f63091 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -6,7 +6,7 @@ import albumentations as A from albumentations.pytorch.transforms import ToTensorV2 -from tests.conftest import RECTANGULAR_UINT8_IMAGE, UINT8_IMAGES +from tests.conftest import RECTANGULAR_UINT8_IMAGE, SQUARE_UINT8_IMAGE, UINT8_IMAGES from .utils import set_seed @@ -283,3 +283,23 @@ def test_to_tensor_v2_on_non_contiguous_array_with_random_rotate90(): assert isinstance(transformed["masks"][0], torch.Tensor) assert transformed["image"].numpy().shape in ((3, 640, 480), (3, 480, 640)) assert transformed["masks"][0].shape in ((640, 480), (480, 640)) + + +def test_to_tensor_v2_images_masks(): + transform = A.Compose([ToTensorV2(p=1)]) + image = SQUARE_UINT8_IMAGE + mask = np.random.randint(0, 2, (100, 100), dtype=np.uint8) + + transformed = transform( + image=image, + mask=mask, + masks=[mask] * 2, + images=[image] * 2 + ) + + # Check all outputs are torch.Tensor + for key in ['image', 'mask']: + assert isinstance(transformed[key], torch.Tensor) + + for key in ['masks', 'images']: + assert all(isinstance(t, torch.Tensor) for t in transformed[key])