Skip to content

Commit

Permalink
Implement ToTensorv2 for multiple images (#2014)
Browse files Browse the repository at this point in the history
* implement totensorv2 for imageS

* precommits
  • Loading branch information
loicmagne authored Oct 23, 2024
1 parent 13a510d commit 96bb6ee
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion albumentations/pytorch/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@ def __init__(self, transpose_mask: bool = False, p: float = 1.0, always_apply: b

@property
def targets(self) -> dict[str, Any]:
return {"image": self.apply, "mask": self.apply_to_mask, "masks": self.apply_to_masks}
return {
"image": self.apply,
"images": self.apply_to_images,
"mask": self.apply_to_mask,
"masks": self.apply_to_masks,
}

def apply(self, img: np.ndarray, **params: Any) -> torch.Tensor:
if len(img.shape) not in [2, 3]:
Expand All @@ -43,6 +48,9 @@ def apply(self, img: np.ndarray, **params: Any) -> torch.Tensor:

return torch.from_numpy(img.transpose(2, 0, 1))

def apply_to_images(self, images: list[np.ndarray], **params: Any) -> list[torch.Tensor]:
return [self.apply(image, **params) for image in images]

def apply_to_mask(self, mask: np.ndarray, **params: Any) -> torch.Tensor:
if self.transpose_mask and mask.ndim == NUM_MULTI_CHANNEL_DIMENSIONS:
mask = mask.transpose(2, 0, 1)
Expand Down

0 comments on commit 96bb6ee

Please sign in to comment.