From 41322295f19adc1393b691bea647440cc96191ab Mon Sep 17 00:00:00 2001 From: Yi-Chia Chang Date: Wed, 18 Dec 2024 10:44:27 -0600 Subject: [PATCH 1/3] init temporal --- torchgeo/transforms/temporal.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 torchgeo/transforms/temporal.py diff --git a/torchgeo/transforms/temporal.py b/torchgeo/transforms/temporal.py new file mode 100644 index 00000000000..1d9696b9721 --- /dev/null +++ b/torchgeo/transforms/temporal.py @@ -0,0 +1,15 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""TorchGeo temporal transforms.""" + +from typing import Any + +import kornia.augmentation as K +import torch +from einops import rearrange +from kornia.contrib import extract_tensor_patches +from kornia.geometry import crop_by_indices +from kornia.geometry.boxes import Boxes +from torch import Tensor +from torch.nn.modules import Module \ No newline at end of file From 93509a6bc366224b66a21bb9a20d2800bb51cbe3 Mon Sep 17 00:00:00 2001 From: Yi-Chia Chang Date: Thu, 30 Jan 2025 12:27:12 -0600 Subject: [PATCH 2/3] temporal rearrange --- torchgeo/transforms/temporal.py | 90 ++++++++++++++++++++++++++++++--- 1 file changed, 82 insertions(+), 8 deletions(-) diff --git a/torchgeo/transforms/temporal.py b/torchgeo/transforms/temporal.py index 1d9696b9721..f5ecaf39f30 100644 --- a/torchgeo/transforms/temporal.py +++ b/torchgeo/transforms/temporal.py @@ -3,13 +3,87 @@ """TorchGeo temporal transforms.""" -from typing import Any - -import kornia.augmentation as K -import torch +from typing import Any, Literal from einops import rearrange -from kornia.contrib import extract_tensor_patches -from kornia.geometry import crop_by_indices -from kornia.geometry.boxes import Boxes +import kornia.augmentation as K from torch import Tensor -from torch.nn.modules import Module \ No newline at end of file + + +class TemporalRearrange(K.IntensityAugmentationBase2D): + """Rearrange temporal and channel dimensions. + + This transform allows conversion between: + - B x T x C x H x W (temporal-explicit) + - B x (T*C) x H x W (temporal-channel) + """ + + def __init__( + self, + mode: Literal["merge", "split"], + num_temporal_channels: int, + p: float = 1.0, + p_batch: float = 1.0, + same_on_batch: bool = False, + keepdim: bool = False, + ) -> None: + """Initialize a new TemporalRearrange instance. + + Args: + mode: Whether to 'merge' (B x T x C x H x W -> B x TC x H x W) or + 'split' (B x TC x H x W -> B x T x C x H x W) temporal dimensions + num_temporal_channels: Number of temporal channels (T) in the sequence + p: Probability for applying the transform element-wise + p_batch: Probability for applying the transform batch-wise + same_on_batch: Apply the same transformation across the batch + keepdim: Whether to keep the output shape the same as input + """ + super().__init__( + p=p, p_batch=p_batch, same_on_batch=same_on_batch, keepdim=keepdim + ) + if mode not in ["merge", "split"]: + raise ValueError("mode must be either 'merge' or 'split'") + + self.flags = { + "mode": mode, + "num_temporal_channels": num_temporal_channels, + } + + def apply_transform( + self, + input: Tensor, + flags: dict[str, Any], + ) -> Tensor: + """Apply the transform. + + Args: + input: Input tensor + flags: Static parameters including mode and number of temporal channels + + Returns: + Transformed tensor with rearranged dimensions + + Raises: + ValueError: If input tensor dimensions don't match expected shape + """ + mode = flags["mode"] + t = flags["num_temporal_channels"] + + if mode == "merge": + if input.ndim != 5: + raise ValueError( + f"Expected 5D input tensor (B,T,C,H,W), got shape {input.shape}" + ) + return rearrange(input, "b t c h w -> b (t c) h w") + else: + if input.ndim != 4: + raise ValueError( + f"Expected 4D input tensor (B,TC,H,W), got shape {input.shape}" + ) + tc = input.shape[1] + if tc % t != 0: + raise ValueError( + f"Input channels ({tc}) must be divisible by " + f"num_temporal_channels ({t})" + ) + c = tc // t + return rearrange(input, "b (t c) h w -> b t c h w", t=t, c=c) \ No newline at end of file From 545c2c9794ea5616475bee3638fcada2d188dccd Mon Sep 17 00:00:00 2001 From: Yi-Chia Chang Date: Thu, 30 Jan 2025 16:54:49 -0600 Subject: [PATCH 3/3] fix style --- torchgeo/transforms/temporal.py | 36 ++++++++++++++------------------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/torchgeo/transforms/temporal.py b/torchgeo/transforms/temporal.py index f5ecaf39f30..813bb187eec 100644 --- a/torchgeo/transforms/temporal.py +++ b/torchgeo/transforms/temporal.py @@ -4,8 +4,9 @@ """TorchGeo temporal transforms.""" from typing import Any, Literal -from einops import rearrange + import kornia.augmentation as K +from einops import rearrange from torch import Tensor @@ -19,7 +20,7 @@ class TemporalRearrange(K.IntensityAugmentationBase2D): def __init__( self, - mode: Literal["merge", "split"], + mode: Literal['merge', 'split'], num_temporal_channels: int, p: float = 1.0, p_batch: float = 1.0, @@ -40,19 +41,12 @@ def __init__( super().__init__( p=p, p_batch=p_batch, same_on_batch=same_on_batch, keepdim=keepdim ) - if mode not in ["merge", "split"]: + if mode not in ['merge', 'split']: raise ValueError("mode must be either 'merge' or 'split'") - self.flags = { - "mode": mode, - "num_temporal_channels": num_temporal_channels, - } + self.flags = {'mode': mode, 'num_temporal_channels': num_temporal_channels} - def apply_transform( - self, - input: Tensor, - flags: dict[str, Any], - ) -> Tensor: + def apply_transform(self, input: Tensor, flags: dict[str, Any]) -> Tensor: """Apply the transform. Args: @@ -65,25 +59,25 @@ def apply_transform( Raises: ValueError: If input tensor dimensions don't match expected shape """ - mode = flags["mode"] - t = flags["num_temporal_channels"] + mode = flags['mode'] + t = flags['num_temporal_channels'] - if mode == "merge": + if mode == 'merge': if input.ndim != 5: raise ValueError( - f"Expected 5D input tensor (B,T,C,H,W), got shape {input.shape}" + f'Expected 5D input tensor (B,T,C,H,W), got shape {input.shape}' ) - return rearrange(input, "b t c h w -> b (t c) h w") + return rearrange(input, 'b t c h w -> b (t c) h w') else: if input.ndim != 4: raise ValueError( - f"Expected 4D input tensor (B,TC,H,W), got shape {input.shape}" + f'Expected 4D input tensor (B,TC,H,W), got shape {input.shape}' ) tc = input.shape[1] if tc % t != 0: raise ValueError( - f"Input channels ({tc}) must be divisible by " - f"num_temporal_channels ({t})" + f'Input channels ({tc}) must be divisible by ' + f'num_temporal_channels ({t})' ) c = tc // t - return rearrange(input, "b (t c) h w -> b t c h w", t=t, c=c) \ No newline at end of file + return rearrange(input, 'b (t c) h w -> b t c h w', t=t, c=c)