Skip to content

Commit

Permalink
Added test that ToTensorV2 converts mask
Browse files Browse the repository at this point in the history
  • Loading branch information
ternaus committed Oct 23, 2024
1 parent 1277efc commit 2b10b9f
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 4 deletions.
3 changes: 0 additions & 3 deletions albumentations/pytorch/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,6 @@ def apply(self, img: np.ndarray, **params: Any) -> torch.Tensor:

return torch.from_numpy(img.transpose(2, 0, 1))

def apply_to_images(self, images: list[np.ndarray], **params: Any) -> list[torch.Tensor]:
return [self.apply(image, **params) for image in images]

def apply_to_mask(self, mask: np.ndarray, **params: Any) -> torch.Tensor:
if self.transpose_mask and mask.ndim == NUM_MULTI_CHANNEL_DIMENSIONS:
mask = mask.transpose(2, 0, 1)
Expand Down
22 changes: 21 additions & 1 deletion tests/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
from tests.conftest import RECTANGULAR_UINT8_IMAGE, UINT8_IMAGES
from tests.conftest import RECTANGULAR_UINT8_IMAGE, SQUARE_UINT8_IMAGE, UINT8_IMAGES

from .utils import set_seed

Expand Down Expand Up @@ -283,3 +283,23 @@ def test_to_tensor_v2_on_non_contiguous_array_with_random_rotate90():
assert isinstance(transformed["masks"][0], torch.Tensor)
assert transformed["image"].numpy().shape in ((3, 640, 480), (3, 480, 640))
assert transformed["masks"][0].shape in ((640, 480), (480, 640))


def test_to_tensor_v2_images_masks():
transform = A.Compose([ToTensorV2(p=1)])
image = SQUARE_UINT8_IMAGE
mask = np.random.randint(0, 2, (100, 100), dtype=np.uint8)

transformed = transform(
image=image,
mask=mask,
masks=[mask] * 2,
images=[image] * 2
)

# Check all outputs are torch.Tensor
for key in ['image', 'mask']:
assert isinstance(transformed[key], torch.Tensor)

for key in ['masks', 'images']:
assert all(isinstance(t, torch.Tensor) for t in transformed[key])

0 comments on commit 2b10b9f

Please sign in to comment.