From 36441b21d1705c706071faf3ee3cf02845bcb6f8 Mon Sep 17 00:00:00 2001 From: Andreas Schuh Date: Thu, 14 Dec 2023 14:52:15 +0000 Subject: [PATCH] [core] Fix spatial derivatives when using mode='bspline' or 'gaussian' --- src/deepali/core/image.py | 45 +++++--- tests/_test_core_flow_deriv.py | 205 +++++++++++++++++++++++++++++++++ 2 files changed, 236 insertions(+), 14 deletions(-) create mode 100644 tests/_test_core_flow_deriv.py diff --git a/src/deepali/core/image.py b/src/deepali/core/image.py index 9f19786..0cd1040 100644 --- a/src/deepali/core/image.py +++ b/src/deepali/core/image.py @@ -1486,7 +1486,10 @@ def spatial_derivatives( If ``None``, ``forward_central_backward`` is used as default mode. sigma: Standard deviation of Gaussian kernel in grid units. If ``None`` or zero, no Gaussian smoothing is used for calculation of finite differences, and a - default standard deviation of 0.4 is used when ``mode="gaussian"``. + default standard deviation of 0.7355 is used when ``mode="gaussian"``. With a smaller + standard deviation, the magnitude of the derivative values starts to deviate between + ``mode="gaussian"`` and finite differences of a Gaussian smoothed input. This is likely + due to a too small discretized Gaussian filter and its derivative. spacing: Physical spacing between image grid points, e.g., ``(sx, sy, sz)``. When a scalar is given, the same spacing is used for each image and spatial dimension. If a sequence is given, it must be of length equal to the number of spatial dimensions ``D``, @@ -1556,7 +1559,7 @@ def spatial_derivatives( if mode in ("forward", "backward", "central", "forward_central_backward", "prewitt", "sobel"): if sigma and sigma > 0: blur = gaussian1d(sigma, dtype=torch.float, device=data.device) - data = conv(data, blur, padding=PaddingMode.ZEROS) + data = conv(data, blur, padding=PaddingMode.REPLICATE) if mode in ("prewitt", "sobel"): avg_kernel = torch.tensor([1, 1 if mode == "prewitt" else 2, 1], dtype=data.dtype) avg_kernel /= avg_kernel.sum() @@ -1589,7 +1592,7 @@ def spatial_derivatives( if sigma and sigma > 0: blur = gaussian1d(sigma, dtype=torch.float, device=data.device) - data = conv(data, blur, padding=PaddingMode.ZEROS) + data = conv(data, blur, padding=PaddingMode.REPLICATE) if stride is None: stride = 1 @@ -1616,27 +1619,41 @@ def bspline1d(s: int, d: int) -> Tensor: for spatial_dim in SpatialDerivativeKeys.split(code): order[spatial_dim] += 1 kernel = [bspline1d(s, d) for s, d in zip(stride, order)] - derivs[code] = evaluate_cubic_bspline(data, kernel=kernel) + deriv = evaluate_cubic_bspline(data, kernel=kernel) + if sum(order) > 0: + denom = torch.ones(N, dtype=spacing.dtype, device=spacing.device) + for delta, d in zip(spacing.transpose(0, 1), order): + if d > 0: + denom.mul_(delta.pow(d)) + denom = denom.reshape((N,) + (1,) * (deriv.ndim - 1)) + deriv = deriv.div_(denom.to(deriv)) + derivs[code] = deriv elif mode == "gaussian": + + def pad_spatial_dim(data: Tensor, sdim: int, padding: int) -> Tensor: + pad = [(padding, padding) if d == sdim else (0, 0) for d in range(data.ndim - 2)] + pad = [n for v in pad for n in v] + return F.pad(data, pad, mode="replicate") + if not sigma: - sigma = 0.4 - kernel_0 = gaussian1d(sigma, normalize=False, dtype=torch.float) - kernel_1 = gaussian1d_I(sigma, normalize=False, dtype=torch.float) - norm = kernel_0.sum() - kernel_0 = kernel_0.div_(norm).to(data.device) - kernel_1 = kernel_1.div_(norm).to(data.device) + sigma = 0.7355 # same default value as used in downsample() + kernel_0 = gaussian1d(sigma, normalize=False, dtype=torch.float, device=data.device) + kernel_1 = gaussian1d_I(sigma, normalize=False, dtype=torch.float, device=data.device) for i in range(max_order): for code in unique_keys: key = code[: i + 1] if i < len(code) and key not in derivs: sdim = SpatialDim.from_arg(code[i]) - result = data if i == 0 else derivs[code[:i]] + deriv = data if i == 0 else derivs[code[:i]] for d in range(D): - dim = SpatialDim(d).tensor_dim(result.ndim) + dim = SpatialDim(d).tensor_dim(deriv.ndim) kernel = kernel_1 if sdim == d else kernel_0 - result = conv1d(result, kernel, dim=dim, padding=len(kernel) // 2) - derivs[key] = result + deriv = pad_spatial_dim(deriv, d, len(kernel) // 2) + deriv = conv1d(deriv, kernel, dim=dim, padding=0) + denom = spacing.narrow(1, sdim, 1).reshape((N,) + (1,) * (deriv.ndim - 1)) + deriv = deriv.div_(denom.to(deriv)) + derivs[key] = deriv derivs = {key: derivs[SpatialDerivativeKeys.sorted(key)] for key in which} else: diff --git a/tests/_test_core_flow_deriv.py b/tests/_test_core_flow_deriv.py new file mode 100644 index 0000000..eecfe8e --- /dev/null +++ b/tests/_test_core_flow_deriv.py @@ -0,0 +1,205 @@ +r"""Interactive test and visualization of vector flow derivatives.""" + +# %% +# Imports +from typing import Dict, Optional, Sequence + +import matplotlib.pyplot as plt + +import torch +from torch import Tensor +from torch.random import Generator + +from deepali.core import Axes, Grid +import deepali.core.bspline as B +import deepali.core.functional as U + + +# %% +# Auxiliary functions +def change_axes(flow: Tensor, grid: Grid, axes: Axes, to_axes: Axes) -> Tensor: + if axes != to_axes: + flow = U.move_dim(flow, 1, -1) + flow = grid.transform_vectors(flow, axes=axes, to_axes=to_axes) + flow = U.move_dim(flow, -1, 1) + return flow + + +def flow_derivatives( + flow: Tensor, grid: Grid, axes: Axes, to_axes: Optional[Axes] = None, **kwargs +) -> Dict[str, Tensor]: + if to_axes is None: + to_axes = axes + flow = change_axes(flow, grid, axes, to_axes) + axes = to_axes + if "spacing" not in kwargs: + if axes == Axes.CUBE: + spacing = tuple(2 / n for n in grid.size()) + elif axes == Axes.CUBE_CORNERS: + spacing = tuple(2 / (n - 1) for n in grid.size()) + elif axes == Axes.GRID: + spacing = 1 + elif axes == Axes.WORLD: + spacing = grid.spacing() + else: + spacing = None + kwargs["spacing"] = spacing + return U.flow_derivatives(flow, **kwargs) + + +def random_svf( + size: Sequence[int], + stride: int = 1, + generator: Optional[Generator] = None, +) -> Tensor: + cp_grid_size = B.cubic_bspline_control_point_grid_size(size, stride=stride) + cp_grid_size = tuple(reversed(cp_grid_size)) + data = torch.randn((1, 3) + cp_grid_size, generator=generator) + data = U.fill_border(data, margin=3, value=0, inplace=True) + return B.evaluate_cubic_bspline(data, size=size, stride=stride) + + +def visualize_flow( + ax: plt.Axes, + flow: Tensor, + grid: Optional[Grid] = None, + axes: Optional[Axes] = None, + label: Optional[str] = None, +) -> None: + if grid is None: + grid = Grid(shape=flow.shape[2:]) + if axes is None: + axes = grid.axes() + flow = change_axes(flow, grid, axes, grid.axes()) + x = grid.coords(channels_last=False, dtype=flow.dtype, device=flow.device) + x = U.move_dim(x.unsqueeze_(0).add_(flow), 1, -1) + target_grid = U.grid_image(shape=flow.shape[2:], inverted=True, stride=(5, 5)) + warped_grid = U.warp_image(target_grid, x, align_corners=grid.align_corners()) + ax.imshow(warped_grid[0, 0, flow.shape[2] // 2], cmap="gray") + if label: + ax.set_title(label, fontsize=24) + + +# %% +# Random velocity fields +generator = torch.Generator().manual_seed(42) +grid = Grid(size=(128, 128, 64), spacing=(0.5, 0.5, 1.0)) +flow = random_svf(grid.size(), stride=8, generator=generator).mul_(0.1) + +fig, axes = plt.subplots(1, 1, figsize=(4, 4)) + +ax = axes +ax.set_title("v", fontsize=24, pad=20) +visualize_flow(ax, flow, grid=grid, axes=grid.axes()) + + +# %% +# Visualise first order derivatives for different modes +configs = [ + dict(mode="forward_central_backward"), + dict(mode="bspline"), + dict(mode="gaussian", sigma=0.7355), +] + +fig, axes = plt.subplots(len(configs), 4, figsize=(16, 4 * len(configs))) + +for i, config in enumerate(configs): + derivs = flow_derivatives( + flow, + grid=grid, + axes=grid.axes(), + to_axes=Axes.GRID, + which=["du/dx", "du/dy", "dv/dx", "dv/dy"], + **config, + ) + for ax, (key, deriv) in zip(axes[i], derivs.items()): + if i == 0: + ax.set_title(key, fontsize=24, pad=20) + ax.imshow(deriv[0, 0, deriv.shape[2] // 2], vmin=-1, vmax=1) + + +# %% +# Compare magnitudes of first order derivatives for different modes +flow_axes = [Axes.GRID, Axes.WORLD, Axes.CUBE_CORNERS] + +sigma = 0.7355 + +configs = [ + dict(mode="bspline"), + dict(mode="gaussian", sigma=sigma), + dict(mode="forward_central_backward", sigma=sigma), + dict(mode="forward_central_backward"), +] + +for to_axes in flow_axes: + for config in configs: + print(f"axes={to_axes}, " + ", ".join(f"{k}={v!r}" for k, v in config.items())) + derivs = flow_derivatives( + flow, + grid=grid, + axes=grid.axes(), + to_axes=to_axes, + which=["du/dx", "du/dy", "dv/dx", "dv/dy"], + **config, + ) + for key, deriv in derivs.items(): + print(f"- max(abs({key})): {deriv.abs().max().item():.5f}") + print() + print("\n") + + +# %% +# Visualise second order derivatives for different modes +configs = [ + dict(mode="forward_central_backward"), + dict(mode="bspline"), + dict(mode="gaussian", sigma=0.7355), +] + +fig, axes = plt.subplots(len(configs), 4, figsize=(16, 4 * len(configs))) + +for i, config in enumerate(configs): + derivs = flow_derivatives( + flow, + grid=grid, + axes=grid.axes(), + to_axes=Axes.GRID, + which=["du/dxx", "du/dxy", "dv/dxy", "dv/dyy"], + **config, + ) + for ax, (key, deriv) in zip(axes[i], derivs.items()): + if i == 0: + ax.set_title(key, fontsize=24, pad=20) + ax.imshow(deriv[0, 0, deriv.shape[2] // 2], vmin=-0.4, vmax=0.4) + + +# %% +# Compare magnitudes of second order derivatives for different modes +flow_axes = [Axes.GRID, Axes.WORLD, Axes.CUBE_CORNERS] + +sigma = 0.7355 + +configs = [ + dict(mode="bspline"), + dict(mode="gaussian", sigma=sigma), + dict(mode="forward_central_backward", sigma=sigma), + dict(mode="forward_central_backward"), +] + +for to_axes in flow_axes: + for config in configs: + print(f"axes={to_axes}, " + ", ".join(f"{k}={v!r}" for k, v in config.items())) + derivs = flow_derivatives( + flow, + grid=grid, + axes=grid.axes(), + to_axes=to_axes, + which=["du/dxx", "du/dxy", "dv/dxy", "dv/dyy"], + **config, + ) + for key, deriv in derivs.items(): + print(f"- max(abs({key})): {deriv.abs().max().item():.5f}") + print() + print("\n") + +# %%