Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add thin plate spline #2156

Merged
merged 3 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ Spatial-level transforms will simultaneously change both an input image as well
| [SafeRotate](https://explore.albumentations.ai/transform/SafeRotate) | ✓ | ✓ | ✓ | ✓ |
| [ShiftScaleRotate](https://explore.albumentations.ai/transform/ShiftScaleRotate) | ✓ | ✓ | ✓ | ✓ |
| [SmallestMaxSize](https://explore.albumentations.ai/transform/SmallestMaxSize) | ✓ | ✓ | ✓ | ✓ |
| [ThinPlateSpline](https://explore.albumentations.ai/transform/ThinPlateSpline) | ✓ | ✓ | ✓ | ✓ |
| [TimeMasking](https://explore.albumentations.ai/transform/TimeMasking) | ✓ | ✓ | ✓ | ✓ |
| [TimeReverse](https://explore.albumentations.ai/transform/TimeReverse) | ✓ | ✓ | ✓ | ✓ |
| [Transpose](https://explore.albumentations.ai/transform/Transpose) | ✓ | ✓ | ✓ | ✓ |
Expand Down
95 changes: 95 additions & 0 deletions albumentations/augmentations/geometric/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2978,3 +2978,98 @@ def shuffle_tiles_within_shape_groups(
mapping[old] = new

return mapping


def compute_tps_weights(src_points: np.ndarray, dst_points: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
"""Compute Thin Plate Spline weights.

Args:
src_points: Source control points with shape (num_points, 2)
dst_points: Destination control points with shape (num_points, 2)

Returns:
tuple of:
- nonlinear_weights: TPS kernel weights for nonlinear deformation (num_points, 2)
- affine_weights: Weights for affine transformation (3, 2)
[constant term, x scale/shear, y scale/shear]

Note:
The TPS interpolation is decomposed into:
1. Nonlinear part (controlled by kernel weights)
2. Affine part (global scaling, rotation, translation)
"""
num_points = src_points.shape[0]

# Compute kernel matrix of pairwise distances
kernel_matrix = np.zeros((num_points, num_points))
for i in range(num_points):
ternaus marked this conversation as resolved.
Show resolved Hide resolved
for j in range(num_points):
if i != j:
# U(r) = r² log(r) is the TPS kernel function
dist = np.linalg.norm(src_points[i] - src_points[j])
kernel_matrix[i, j] = dist * dist * np.log(dist + 1e-6)

# Construct affine terms matrix
affine_terms = np.ones((num_points, 3))
affine_terms[:, 1:] = src_points # [1, x, y] for each point

# Build system matrix
system_matrix = np.zeros((num_points + 3, num_points + 3))
system_matrix[:num_points, :num_points] = kernel_matrix
system_matrix[:num_points, num_points:] = affine_terms
system_matrix[num_points:, :num_points] = affine_terms.T

# Right-hand side of the system
target_coords = np.zeros((num_points + 3, 2))
target_coords[:num_points] = dst_points

# Solve the system for both x and y coordinates
all_weights = np.linalg.solve(system_matrix, target_coords)

# Split weights into nonlinear and affine components
nonlinear_weights = all_weights[:num_points]
affine_weights = all_weights[num_points:]

return nonlinear_weights, affine_weights


def tps_transform(
target_points: np.ndarray,
control_points: np.ndarray,
nonlinear_weights: np.ndarray,
affine_weights: np.ndarray,
) -> np.ndarray:
"""Apply Thin Plate Spline transformation to points.

Args:
target_points: Points to transform with shape (num_targets, 2)
control_points: Original control points with shape (num_controls, 2)
nonlinear_weights: TPS kernel weights with shape (num_controls, 2)
affine_weights: Affine transformation weights with shape (3, 2)

Returns:
Transformed points with shape (num_targets, 2)

Note:
The transformation combines:
1. Nonlinear warping based on distances to control points
2. Global affine transformation (scale, rotation, translation)
"""
num_controls = control_points.shape[0]
num_targets = target_points.shape[0]

# Compute kernel matrix of distances to control points
kernel_matrix = np.zeros((num_targets, num_controls))
for i in range(num_targets):
for j in range(num_controls):
dist = np.linalg.norm(target_points[i] - control_points[j])
if dist > 0:
# U(r) = r² log(r) is the TPS kernel function
kernel_matrix[i, j] = dist * dist * np.log(dist)

# Prepare affine terms [1, x, y] for each point
affine_terms = np.ones((num_targets, 3))
affine_terms[:, 1:] = target_points

# Combine nonlinear and affine transformations
return kernel_matrix @ nonlinear_weights + affine_terms @ affine_weights
177 changes: 169 additions & 8 deletions albumentations/augmentations/geometric/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
InterpolationType,
NonNegativeFloatRangeType,
SymmetricRangeType,
check_01,
check_1plus,
)
from albumentations.core.transforms_interface import BaseTransformInitSchema, DualTransform
Expand Down Expand Up @@ -53,6 +54,7 @@
"GridElasticDeform",
"RandomGridShuffle",
"Pad",
"ThinPlateSpline",
]

NUM_PADS_XY = 2
Expand All @@ -73,7 +75,7 @@ class BaseDistortion(DualTransform):
border_mode (int): Border mode to be used for handling pixels outside the image boundaries.
Should be one of the OpenCV border types (e.g., cv2.BORDER_REFLECT_101,
cv2.BORDER_CONSTANT). Default: cv2.BORDER_REFLECT_101
value (int, float, list of int, list of float, optional): Padding value if border_mode is
value (ColorType | None): Padding value if border_mode is
cv2.BORDER_CONSTANT. Default: None
mask_value (ColorType | None): Padding value for mask if
border_mode is cv2.BORDER_CONSTANT. Default: None
Expand Down Expand Up @@ -241,21 +243,20 @@ def __init__(
border_mode: int = cv2.BORDER_REFLECT_101,
value: ColorType | None = None,
mask_value: ColorType | None = None,
always_apply: bool | None = None,
approximate: bool = False,
same_dxdy: bool = False,
mask_interpolation: int = cv2.INTER_NEAREST,
noise_distribution: Literal["gaussian", "uniform"] = "gaussian",
p: float = 0.5,
always_apply: bool | None = None,
):
super().__init__(
interpolation=interpolation,
border_mode=border_mode,
value=value,
mask_value=mask_value,
always_apply=always_apply,
p=p,
mask_interpolation=mask_interpolation,
p=p,
)
self.alpha = alpha
self.sigma = sigma
Expand Down Expand Up @@ -1464,15 +1465,14 @@ def __init__(
value: ColorType | None = None,
mask_value: ColorType | None = None,
mask_interpolation: int = cv2.INTER_NEAREST,
always_apply: bool | None = None,
p: float = 0.5,
always_apply: bool | None = None,
):
super().__init__(
interpolation=interpolation,
border_mode=border_mode,
value=value,
mask_value=mask_value,
always_apply=always_apply,
mask_interpolation=mask_interpolation,
p=p,
)
Expand Down Expand Up @@ -1563,6 +1563,8 @@ class InitSchema(BaseDistortion.InitSchema):
num_steps: Annotated[int, Field(ge=1)]
distort_limit: SymmetricRangeType
normalized: bool
value: ColorType | None
mask_value: ColorType | None

@field_validator("distort_limit")
@classmethod
Expand Down Expand Up @@ -1590,13 +1592,14 @@ def __init__(
border_mode=border_mode,
value=value,
mask_value=mask_value,
always_apply=always_apply,
mask_interpolation=mask_interpolation,
p=p,
)
self.num_steps = num_steps
self.distort_limit = cast(tuple[float, float], distort_limit)
self.normalized = normalized
self.value = value
self.mask_value = mask_value

def get_params_dependent_on_data(self, params: dict[str, Any], data: dict[str, Any]) -> dict[str, Any]:
image_shape = params["shape"][:2]
Expand Down Expand Up @@ -1754,7 +1757,7 @@ def __init__(
p: float = 1.0,
always_apply: bool | None = None,
):
super().__init__(p=p, always_apply=always_apply)
super().__init__(p=p)
self.num_grid_xy = num_grid_xy
self.magnitude = magnitude
self.interpolation = interpolation
Expand Down Expand Up @@ -2233,3 +2236,161 @@ def get_transform_init_args_names(self) -> tuple[str, ...]:
"value",
"mask_value",
)


class ThinPlateSpline(BaseDistortion):
r"""Apply Thin Plate Spline (TPS) transformation to create smooth, non-rigid deformations.

Imagine the image printed on a thin metal plate that can be bent and warped smoothly:
- Control points act like pins pushing or pulling the plate
- The plate resists sharp bending, creating smooth deformations
- The transformation maintains continuity (no tears or folds)
- Areas between control points are interpolated naturally

The transform works by:
1. Creating a regular grid of control points (like pins in the plate)
2. Randomly displacing these points (like pushing/pulling the pins)
3. Computing a smooth interpolation (like the plate bending)
4. Applying the resulting deformation to the image


Args:
scale_range (tuple[float, float]): Range for random displacement of control points.
Values should be in [0.0, 1.0]:
- 0.0: No displacement (identity transform)
- 0.1: Subtle warping
- 0.2-0.4: Moderate deformation (recommended range)
- 0.5+: Strong warping
Default: (0.2, 0.4)

num_control_points (int): Number of control points per side.
Creates a grid of num_control_points x num_control_points points.
- 2: Minimal deformation (affine-like)
- 3-4: Moderate flexibility (recommended)
- 5+: More local deformation control
Must be >= 2. Default: 4

interpolation (int): OpenCV interpolation flag. Used for image sampling.
See also: cv2.INTER_*
Default: cv2.INTER_LINEAR

border_mode (int): OpenCV border mode. Determines how to fill areas
outside the image.
See also: cv2.BORDER_*
Default: cv2.BORDER_CONSTANT

value (int | float | list[int] | list[float] | None): Padding value if
border_mode is cv2.BORDER_CONSTANT.
Default: None

mask_value (int | float | list[int] | list[float] | None): Padding value
for mask if border_mode is cv2.BORDER_CONSTANT.
Default: None

p (float): Probability of applying the transform. Default: 0.5

Targets:
image, mask, keypoints, bboxes

Image types:
uint8, float32

Note:
- The transformation preserves smoothness and continuity
- Stronger scale values may create more extreme deformations
- Higher number of control points allows more local deformations
- The same deformation is applied consistently to all targets

Example:
>>> import albumentations as A
>>> # Basic usage
>>> transform = A.ThinPlateSpline()
>>>
>>> # Subtle deformation
>>> transform = A.ThinPlateSpline(
... scale_range=(0.1, 0.2),
... num_control_points=3
... )
>>>
>>> # Strong warping with fine control
>>> transform = A.ThinPlateSpline(
... scale_range=(0.3, 0.5),
... num_control_points=5,
... border_mode=cv2.BORDER_REFLECT_101
... )

References:
- "Principal Warps: Thin-Plate Splines and the Decomposition of Deformations"
by F.L. Bookstein
https://doi.org/10.1109/34.24792

- Thin Plate Splines in Computer Vision:
https://en.wikipedia.org/wiki/Thin_plate_spline

- Similar implementation in Kornia:
https://kornia.readthedocs.io/en/latest/augmentation.html#kornia.augmentation.RandomThinPlateSpline

See Also:
- ElasticTransform: For different type of non-rigid deformation
- GridDistortion: For grid-based warping
- OpticalDistortion: For lens-like distortions
"""

class InitSchema(BaseDistortion.InitSchema):
scale_range: Annotated[tuple[float, float], AfterValidator(check_01)]
num_control_points: int = Field(ge=2)

def __init__(
self,
scale_range: tuple[float, float] = (0.2, 0.4),
num_control_points: int = 4,
interpolation: int = cv2.INTER_LINEAR,
mask_interpolation: int = cv2.INTER_NEAREST,
border_mode: int = cv2.BORDER_CONSTANT,
value: ColorType | None = None,
mask_value: ColorType | None = None,
always_apply: bool | None = None,
p: float = 0.5,
):
super().__init__(
interpolation=interpolation,
mask_interpolation=mask_interpolation,
border_mode=border_mode,
value=value,
mask_value=mask_value,
p=p,
)
self.scale_range = scale_range
ternaus marked this conversation as resolved.
Show resolved Hide resolved
self.num_control_points = num_control_points

def get_params_dependent_on_data(self, params: dict[str, Any], data: dict[str, Any]) -> dict[str, Any]:
height, width = params["shape"][:2]

# Create regular grid of control points
grid_size = self.num_control_points
x = np.linspace(0, 1, grid_size)
y = np.linspace(0, 1, grid_size)
src_points = np.stack(np.meshgrid(x, y), axis=-1).reshape(-1, 2)

# Add random displacement to destination points
scale = self.py_random.uniform(*self.scale_range)
dst_points = src_points + self.random_generator.normal(0, scale, src_points.shape)
ternaus marked this conversation as resolved.
Show resolved Hide resolved

# Compute TPS weights
weights, affine = fgeometric.compute_tps_weights(src_points, dst_points)

# Create grid of points
x, y = np.meshgrid(np.arange(width), np.arange(height))
points = np.stack([x.flatten(), y.flatten()], axis=1).astype(np.float32)

# Transform points
transformed = fgeometric.tps_transform(points / [width, height], src_points, weights, affine)
transformed *= [width, height]

return {
"map_x": transformed[:, 0].reshape(height, width).astype(np.float32),
"map_y": transformed[:, 1].reshape(height, width).astype(np.float32),
}

def get_transform_init_args_names(self) -> tuple[str, ...]:
return ("scale_range", "num_control_points", *super().get_transform_init_args_names())
1 change: 1 addition & 0 deletions tests/aug_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,4 +411,5 @@
[A.PlasmaBrightnessContrast, {}],
[A.PlasmaShadow, {}],
[A.Illumination, {}],
[A.ThinPlateSpline, {}],
ternaus marked this conversation as resolved.
Show resolved Hide resolved
]