diff --git a/torchgeo/transforms/temporal.py b/torchgeo/transforms/temporal.py index f5ecaf39f30..813bb187eec 100644 --- a/torchgeo/transforms/temporal.py +++ b/torchgeo/transforms/temporal.py @@ -4,8 +4,9 @@ """TorchGeo temporal transforms.""" from typing import Any, Literal -from einops import rearrange + import kornia.augmentation as K +from einops import rearrange from torch import Tensor @@ -19,7 +20,7 @@ class TemporalRearrange(K.IntensityAugmentationBase2D): def __init__( self, - mode: Literal["merge", "split"], + mode: Literal['merge', 'split'], num_temporal_channels: int, p: float = 1.0, p_batch: float = 1.0, @@ -40,19 +41,12 @@ def __init__( super().__init__( p=p, p_batch=p_batch, same_on_batch=same_on_batch, keepdim=keepdim ) - if mode not in ["merge", "split"]: + if mode not in ['merge', 'split']: raise ValueError("mode must be either 'merge' or 'split'") - self.flags = { - "mode": mode, - "num_temporal_channels": num_temporal_channels, - } + self.flags = {'mode': mode, 'num_temporal_channels': num_temporal_channels} - def apply_transform( - self, - input: Tensor, - flags: dict[str, Any], - ) -> Tensor: + def apply_transform(self, input: Tensor, flags: dict[str, Any]) -> Tensor: """Apply the transform. Args: @@ -65,25 +59,25 @@ def apply_transform( Raises: ValueError: If input tensor dimensions don't match expected shape """ - mode = flags["mode"] - t = flags["num_temporal_channels"] + mode = flags['mode'] + t = flags['num_temporal_channels'] - if mode == "merge": + if mode == 'merge': if input.ndim != 5: raise ValueError( - f"Expected 5D input tensor (B,T,C,H,W), got shape {input.shape}" + f'Expected 5D input tensor (B,T,C,H,W), got shape {input.shape}' ) - return rearrange(input, "b t c h w -> b (t c) h w") + return rearrange(input, 'b t c h w -> b (t c) h w') else: if input.ndim != 4: raise ValueError( - f"Expected 4D input tensor (B,TC,H,W), got shape {input.shape}" + f'Expected 4D input tensor (B,TC,H,W), got shape {input.shape}' ) tc = input.shape[1] if tc % t != 0: raise ValueError( - f"Input channels ({tc}) must be divisible by " - f"num_temporal_channels ({t})" + f'Input channels ({tc}) must be divisible by ' + f'num_temporal_channels ({t})' ) c = tc // t - return rearrange(input, "b (t c) h w -> b t c h w", t=t, c=c) \ No newline at end of file + return rearrange(input, 'b (t c) h w -> b t c h w', t=t, c=c)