From 96bb6ee7b273191c90d2e856e989f5271f7cd65e Mon Sep 17 00:00:00 2001 From: lm <53355258+loicmagne@users.noreply.github.com> Date: Wed, 23 Oct 2024 23:13:00 +0200 Subject: [PATCH] Implement ToTensorv2 for multiple images (#2014) * implement totensorv2 for imageS * precommits --- albumentations/pytorch/transforms.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/albumentations/pytorch/transforms.py b/albumentations/pytorch/transforms.py index 323a102f1..3053578aa 100644 --- a/albumentations/pytorch/transforms.py +++ b/albumentations/pytorch/transforms.py @@ -31,7 +31,12 @@ def __init__(self, transpose_mask: bool = False, p: float = 1.0, always_apply: b @property def targets(self) -> dict[str, Any]: - return {"image": self.apply, "mask": self.apply_to_mask, "masks": self.apply_to_masks} + return { + "image": self.apply, + "images": self.apply_to_images, + "mask": self.apply_to_mask, + "masks": self.apply_to_masks, + } def apply(self, img: np.ndarray, **params: Any) -> torch.Tensor: if len(img.shape) not in [2, 3]: @@ -43,6 +48,9 @@ def apply(self, img: np.ndarray, **params: Any) -> torch.Tensor: return torch.from_numpy(img.transpose(2, 0, 1)) + def apply_to_images(self, images: list[np.ndarray], **params: Any) -> list[torch.Tensor]: + return [self.apply(image, **params) for image in images] + def apply_to_mask(self, mask: np.ndarray, **params: Any) -> torch.Tensor: if self.transpose_mask and mask.ndim == NUM_MULTI_CHANNEL_DIMENSIONS: mask = mask.transpose(2, 0, 1)