Skip to content

Commit

Permalink
fix style
Browse files Browse the repository at this point in the history
  • Loading branch information
yichiac committed Jan 30, 2025
1 parent 87f3123 commit 545c2c9
Showing 1 changed file with 15 additions and 21 deletions.
36 changes: 15 additions & 21 deletions torchgeo/transforms/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
"""TorchGeo temporal transforms."""

from typing import Any, Literal
from einops import rearrange

import kornia.augmentation as K
from einops import rearrange
from torch import Tensor


Expand All @@ -19,7 +20,7 @@ class TemporalRearrange(K.IntensityAugmentationBase2D):

def __init__(
self,
mode: Literal["merge", "split"],
mode: Literal['merge', 'split'],
num_temporal_channels: int,
p: float = 1.0,
p_batch: float = 1.0,
Expand All @@ -40,19 +41,12 @@ def __init__(
super().__init__(
p=p, p_batch=p_batch, same_on_batch=same_on_batch, keepdim=keepdim
)
if mode not in ["merge", "split"]:
if mode not in ['merge', 'split']:
raise ValueError("mode must be either 'merge' or 'split'")

self.flags = {
"mode": mode,
"num_temporal_channels": num_temporal_channels,
}
self.flags = {'mode': mode, 'num_temporal_channels': num_temporal_channels}

def apply_transform(
self,
input: Tensor,
flags: dict[str, Any],
) -> Tensor:
def apply_transform(self, input: Tensor, flags: dict[str, Any]) -> Tensor:
"""Apply the transform.
Args:
Expand All @@ -65,25 +59,25 @@ def apply_transform(
Raises:
ValueError: If input tensor dimensions don't match expected shape
"""
mode = flags["mode"]
t = flags["num_temporal_channels"]
mode = flags['mode']
t = flags['num_temporal_channels']

if mode == "merge":
if mode == 'merge':
if input.ndim != 5:
raise ValueError(
f"Expected 5D input tensor (B,T,C,H,W), got shape {input.shape}"
f'Expected 5D input tensor (B,T,C,H,W), got shape {input.shape}'
)
return rearrange(input, "b t c h w -> b (t c) h w")
return rearrange(input, 'b t c h w -> b (t c) h w')
else:
if input.ndim != 4:
raise ValueError(
f"Expected 4D input tensor (B,TC,H,W), got shape {input.shape}"
f'Expected 4D input tensor (B,TC,H,W), got shape {input.shape}'
)
tc = input.shape[1]
if tc % t != 0:
raise ValueError(
f"Input channels ({tc}) must be divisible by "
f"num_temporal_channels ({t})"
f'Input channels ({tc}) must be divisible by '
f'num_temporal_channels ({t})'
)
c = tc // t
return rearrange(input, "b (t c) h w -> b t c h w", t=t, c=c)
return rearrange(input, 'b (t c) h w -> b t c h w', t=t, c=c)

0 comments on commit 545c2c9

Please sign in to comment.