Skip to content

Commit

Permalink
temporal rearrange
Browse files Browse the repository at this point in the history
  • Loading branch information
yichiac committed Jan 30, 2025
1 parent 2cc6447 commit 93509a6
Showing 1 changed file with 82 additions and 8 deletions.
90 changes: 82 additions & 8 deletions torchgeo/transforms/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,87 @@

"""TorchGeo temporal transforms."""

from typing import Any

import kornia.augmentation as K
import torch
from typing import Any, Literal
from einops import rearrange
from kornia.contrib import extract_tensor_patches
from kornia.geometry import crop_by_indices
from kornia.geometry.boxes import Boxes
import kornia.augmentation as K
from torch import Tensor

Check failure on line 9 in torchgeo/transforms/temporal.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

torchgeo/transforms/temporal.py:6:1: I001 Import block is un-sorted or un-formatted
from torch.nn.modules import Module


class TemporalRearrange(K.IntensityAugmentationBase2D):
"""Rearrange temporal and channel dimensions.
This transform allows conversion between:
- B x T x C x H x W (temporal-explicit)
- B x (T*C) x H x W (temporal-channel)
"""

def __init__(
self,
mode: Literal["merge", "split"],
num_temporal_channels: int,
p: float = 1.0,
p_batch: float = 1.0,
same_on_batch: bool = False,
keepdim: bool = False,
) -> None:
"""Initialize a new TemporalRearrange instance.
Args:
mode: Whether to 'merge' (B x T x C x H x W -> B x TC x H x W) or
'split' (B x TC x H x W -> B x T x C x H x W) temporal dimensions
num_temporal_channels: Number of temporal channels (T) in the sequence
p: Probability for applying the transform element-wise
p_batch: Probability for applying the transform batch-wise
same_on_batch: Apply the same transformation across the batch
keepdim: Whether to keep the output shape the same as input
"""
super().__init__(
p=p, p_batch=p_batch, same_on_batch=same_on_batch, keepdim=keepdim
)
if mode not in ["merge", "split"]:
raise ValueError("mode must be either 'merge' or 'split'")

self.flags = {
"mode": mode,
"num_temporal_channels": num_temporal_channels,
}

def apply_transform(
self,
input: Tensor,
flags: dict[str, Any],
) -> Tensor:
"""Apply the transform.
Args:
input: Input tensor
flags: Static parameters including mode and number of temporal channels
Returns:
Transformed tensor with rearranged dimensions
Raises:
ValueError: If input tensor dimensions don't match expected shape
"""
mode = flags["mode"]
t = flags["num_temporal_channels"]

if mode == "merge":
if input.ndim != 5:
raise ValueError(
f"Expected 5D input tensor (B,T,C,H,W), got shape {input.shape}"
)
return rearrange(input, "b t c h w -> b (t c) h w")
else:
if input.ndim != 4:
raise ValueError(
f"Expected 4D input tensor (B,TC,H,W), got shape {input.shape}"
)
tc = input.shape[1]
if tc % t != 0:
raise ValueError(
f"Input channels ({tc}) must be divisible by "
f"num_temporal_channels ({t})"
)
c = tc // t
return rearrange(input, "b (t c) h w -> b t c h w", t=t, c=c)

0 comments on commit 93509a6

Please sign in to comment.