diff --git a/albumentations/pytorch/transforms.py b/albumentations/pytorch/transforms.py index 323a102f1..3053578aa 100644 --- a/albumentations/pytorch/transforms.py +++ b/albumentations/pytorch/transforms.py @@ -31,7 +31,12 @@ def __init__(self, transpose_mask: bool = False, p: float = 1.0, always_apply: b @property def targets(self) -> dict[str, Any]: - return {"image": self.apply, "mask": self.apply_to_mask, "masks": self.apply_to_masks} + return { + "image": self.apply, + "images": self.apply_to_images, + "mask": self.apply_to_mask, + "masks": self.apply_to_masks, + } def apply(self, img: np.ndarray, **params: Any) -> torch.Tensor: if len(img.shape) not in [2, 3]: @@ -43,6 +48,9 @@ def apply(self, img: np.ndarray, **params: Any) -> torch.Tensor: return torch.from_numpy(img.transpose(2, 0, 1)) + def apply_to_images(self, images: list[np.ndarray], **params: Any) -> list[torch.Tensor]: + return [self.apply(image, **params) for image in images] + def apply_to_mask(self, mask: np.ndarray, **params: Any) -> torch.Tensor: if self.transpose_mask and mask.ndim == NUM_MULTI_CHANNEL_DIMENSIONS: mask = mask.transpose(2, 0, 1)