Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[spatial] Reverse stride order when subsampling DenseVectorFieldTransform #133

Merged
merged 6 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions src/deepali/core/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,6 @@ def divergence(
raise TypeError("divergence() 'flow' must be of type torch.Tensor")
if flow.ndim < 4:
raise ValueError("divergence() 'flow' must be at least 4-dimensional tensor")
N = flow.shape[0]
D = flow.shape[1]
if flow.ndim != D + 2:
raise ValueError(
Expand All @@ -243,10 +242,10 @@ def divergence(
kwargs = dict(mode=mode, sigma=sigma, spacing=spacing, stride=stride)
which = FlowDerivativeKeys.divergence(spatial_dims=D)
deriv = flow_derivatives(flow, which=which, **kwargs)
ref = deriv["du/dx"]
div = torch.zeros((N, 1) + ref.shape[2:], dtype=ref.dtype, device=ref.device)
div: Optional[Tensor] = None
for value in deriv.values():
div = div.add_(value)
div = value if div is None else div.add_(value)
assert div is not None
return div


Expand All @@ -259,6 +258,9 @@ def divergence_free_flow(
) -> Tensor:
r"""Construct divergence-free vector field from D-1 scalar fields or one 3-dimensional vector field, respectively.

Experimental: This function may change in the future. Constructing a divergence-free field in 3D using curl() works best.
The construction of a divergence free field from one or two scalar fields, respectively, may need to be revised.

The input fields must be sufficiently smooth for the output vector field to have zero divergence. To produce a
3-dimensional vector field, a better result may be obtained using the :func:`curl()` of another 3-dimensional
vector field instead of two scalar fields concatenated along the channel dimension. Gaussian blurring with
Expand All @@ -274,7 +276,7 @@ def divergence_free_flow(
the cross product of the gradients of the two scalar fields. Otherwise, the input tensor must be of shape
``(N, 3, Z, Y, X)`` and the output is the curl of the vector field.
mode: Mode of :func:`flow_derivatives()` approximation.
sigma: Standard deviation of Gaussian used smooth input field.
sigma: Standard deviation of Gaussian used to smooth input field.
spacing: Physical size of image voxels used to compute finite differences.
stride: Number of output grid points between control points plus one for ``mode='bspline'``.

Expand Down Expand Up @@ -830,7 +832,7 @@ def sample_flow(
padding = PaddingMode.BORDER
x = coords.expand((N,) + coords.shape[1:])
t = flow.expand((N,) + flow.shape[1:])
g = x.reshape((N,) + (1,) * (t.ndim - 3) + (-1, D))
g = x if x.ndim == t.ndim else x.reshape((N,) + (1,) * (t.ndim - 3) + (-1, D))
u = grid_sample(t, g, padding=padding, align_corners=align_corners)
u = move_dim(u, 1, -1)
u = u.reshape(x.shape)
Expand Down
4 changes: 2 additions & 2 deletions src/deepali/data/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def _torch_function_result(cls, func, data, grid: Optional[Sequence[Grid]]) -> A
data._grid = grid
else:
data = cls(data, grid)
elif type(data) != Tensor:
elif type(data) is not Tensor:
data = data.as_subclass(Tensor)
return data

Expand Down Expand Up @@ -1025,7 +1025,7 @@ def _torch_function_result(cls, func, data, grid: Optional[Grid]) -> Any:
data._grid = grid
else:
data = cls(data, grid)
elif type(data) != Tensor:
elif type(data) is not Tensor:
data = data.as_subclass(Tensor)
return data

Expand Down
10 changes: 6 additions & 4 deletions src/deepali/losses/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,7 +1149,6 @@ def grad_loss(
"grad_loss() not implemented for linear transformation and 'reduction'='none'"
)
return torch.tensor(0, dtype=u.dtype, device=u.device)
N = u.shape[0]
D = u.shape[1]
if u.ndim - 2 != D:
raise ValueError("grad_loss() 'u' must be tensor of shape (N, D, ..., X)")
Expand All @@ -1167,9 +1166,11 @@ def grad_loss(
deriv = {k: v.pow_(p) for k, v in deriv.items()}
else:
deriv = {k: v.abs_().pow_(p) for k, v in deriv.items()}
loss = torch.zeros((N, 1) + u.shape[2:], dtype=u.dtype, device=u.device)
loss: Optional[Tensor] = None
for value in deriv.values():
loss = loss.add_(value.sum(dim=1, keepdim=True))
value = value.sum(dim=1, keepdim=True)
loss = value if loss is None else loss.add_(value)
assert loss is not None
if q == 0:
loss.abs_()
elif q != 1:
Expand Down Expand Up @@ -1299,7 +1300,8 @@ def curvature_loss(
kwargs = dict(mode=mode or "sobel", sigma=sigma, spacing=spacing, stride=stride)
which = FlowDerivativeKeys.curvature(spatial_dims=D)
deriv = flow_derivatives(u, which=which, **kwargs)
loss = torch.zeros((N, D) + u.shape[2:], dtype=u.dtype, device=u.device)
shape = deriv["du/dxx"].shape
loss = torch.zeros((N, D) + shape[2:], dtype=u.dtype, device=u.device)
for i, j in itertools.product(range(D), repeat=2):
loss.narrow(1, i, 1).add_(deriv[FlowDerivativeKeys.symbol(i, j, j)])
loss = loss.square_().sum(dim=1, keepdim=True)
Expand Down
2 changes: 1 addition & 1 deletion src/deepali/spatial/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def __init__(
Args:
grid: Grid domain on which transformation is defined.
groups: Number of transformations. Must be either 1 or equal to batch size.
params: (nnormalized quaternion as 2-dimensional tensor of ``(N, 4)``.
params: (normalized quaternion as 2-dimensional tensor of ``(N, 4)``.

"""
if grid.ndim != 3:
Expand Down
32 changes: 28 additions & 4 deletions src/deepali/spatial/nonrigid.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ def __init__(
Can be used to subsample the dense vector field with respect to the image grid of the fixed target
image. When ``grid.align_corners() is True``, the corner points of the ``grid`` and the resampled
vector field grid are aligned. Otherwise, the edges of the grid domains are aligned.
Must be either a single scalar if the grid is subsampled equally along each spatial dimension,
or a sequence with length less than or equal to ``grid.ndim``, with subsampling factors given
in the order (x, ...). When a sequence shorter than ``grid.ndim`` is given, remaining spatial
dimensions are not being subsampled.
resize: Whether to resize vector field during transformation update. If ``True``, the buffered vector
field ``u`` (and ``v`` if applicable) is resized to match the image ``grid`` size. This means that
transformation constraints defined on these resized vector fields, such as those based on finite
Expand All @@ -62,11 +66,15 @@ def __init__(
stride = 1
if isinstance(stride, (int, float)):
stride = (stride,) * grid.ndim
if len(stride) != grid.ndim:
if not isinstance(stride, Sequence):
raise TypeError(f"{type(self).__name__}() 'stride' must be float or Sequence[float]")
if len(stride) > grid.ndim:
raise ValueError(
f"{type(self).__name__}() 'stride' must be float or Sequence of length {grid.ndim}"
f"{type(self).__name__}() 'stride' sequence length"
f" ({len(stride)}) exceeds grid dimensions ({grid.ndim})"
)
self.stride = tuple(float(s) for s in stride)
stride = tuple(float(s) for s in stride) + (1.0,) * (grid.ndim - len(stride))
self.stride = stride
self._resize = resize
super().__init__(grid, groups=groups, params=params)

Expand All @@ -93,7 +101,19 @@ def data_grid_shape(
grid = self.grid()
if stride is None:
stride = self.stride
return tuple(int(math.ceil(n / s)) for n, s in zip(grid.shape, stride))
if isinstance(stride, (int, float)):
stride = (stride,) * grid.ndim
if not isinstance(stride, Sequence):
raise TypeError(
f"{type(self).__name__}.data_grid_shape() 'stride' must be float or Sequence[float]"
)
if len(stride) > grid.ndim:
raise ValueError(
f"{type(self).__name__}.data_grid_shape() 'stride' sequence length"
f" ({len(stride)}) exceeds grid dimensions ({grid.ndim})"
)
stride = tuple(float(s) for s in stride) + (1.0,) * (grid.ndim - len(stride))
return tuple(int(math.ceil(n / s)) for n, s in zip(grid.shape, reversed(stride)))

@torch.no_grad()
def grid_(self: TDenseVectorFieldTransform, grid: Grid) -> TDenseVectorFieldTransform:
Expand Down Expand Up @@ -208,6 +228,10 @@ def __init__(
Can be used to subsample the dense vector field with respect to the image grid of the fixed target
image. When ``grid.align_corners() is True``, the corner points of the ``grid`` and the resampled
vector field grid are aligned. Otherwise, the edges of the grid domains are aligned.
Must be either a single scalar if the grid is subsampled equally along each spatial dimension,
or a sequence with length less than or equal to ``grid.ndim``, with subsampling factors given
in the order (x, ...). When a sequence shorter than ``grid.ndim`` is given, remaining spatial
dimensions are not being subsampled.
resize: Whether to resize vector field during transformation update. If ``True``, the buffered vector
fields ``v`` and ``u`` are resized to match the image ``grid`` size. This means that transformation
constraints defined on these resized vector fields, such as those based on finite differences, are
Expand Down
37 changes: 21 additions & 16 deletions tests/test_core_flow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import pytest
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.random import Generator

Expand All @@ -25,8 +24,7 @@ def periodic_flow(p: Tensor) -> Tensor:


def periodic_flow_du_dx(p: Tensor) -> Tensor:
q = p.mul(PERIODIC_FLOW_X_SCALE)
g = q.narrow(1, 0, 1).cos()
g = p.narrow(1, 0, 1).mul(PERIODIC_FLOW_X_SCALE).cos()
g = g.mul_(-PERIODIC_FLOW_X_SCALE * PERIODIC_FLOW_U_SCALE)
return g

Expand All @@ -36,8 +34,7 @@ def periodic_flow_du_dy(p: Tensor) -> Tensor:


def periodic_flow_du_dxx(p: Tensor) -> Tensor:
q = p.mul(PERIODIC_FLOW_X_SCALE)
g = q.narrow(1, 0, 1).sin()
g = p.narrow(1, 0, 1).mul(PERIODIC_FLOW_X_SCALE).sin()
g = g.mul_(PERIODIC_FLOW_X_SCALE * PERIODIC_FLOW_X_SCALE * PERIODIC_FLOW_U_SCALE)
return g

Expand All @@ -51,8 +48,7 @@ def periodic_flow_dv_dx(p: Tensor) -> Tensor:


def periodic_flow_dv_dy(p: Tensor) -> Tensor:
q = p.mul(PERIODIC_FLOW_X_SCALE)
g = q.narrow(1, 1, 1).sin()
g = p.narrow(1, 1, 1).mul(PERIODIC_FLOW_X_SCALE).sin()
g = g.mul_(-PERIODIC_FLOW_X_SCALE * PERIODIC_FLOW_U_SCALE)
return g

Expand All @@ -62,8 +58,7 @@ def periodic_flow_dv_dxx(p: Tensor) -> Tensor:


def periodic_flow_dv_dyy(p: Tensor) -> Tensor:
q = p.mul(PERIODIC_FLOW_X_SCALE)
g = q.narrow(1, 1, 1).cos()
g = p.narrow(1, 1, 1).mul(PERIODIC_FLOW_X_SCALE).cos()
g = g.mul_(-PERIODIC_FLOW_X_SCALE * PERIODIC_FLOW_X_SCALE * PERIODIC_FLOW_U_SCALE)
return g

Expand Down Expand Up @@ -192,6 +187,11 @@ def test_flow_derivatives() -> None:
assert difference(deriv["dw/dy"], x).abs().max().lt(1e-5)
assert deriv["dw/dz"].abs().max().lt(1e-5)

deriv = U.flow_derivatives(flow, which=["du/dxz", "dv/dzy", "dw/dxy"])
assert deriv["du/dxz"].sub(1).abs().max().lt(1e-4)
assert deriv["dv/dzy"].sub(1).abs().max().lt(1e-4)
assert deriv["dw/dxy"].sub(1).abs().max().lt(1e-4)


def test_flow_divergence() -> None:
grid = Grid(size=(16, 14))
Expand Down Expand Up @@ -234,13 +234,18 @@ def test_flow_divergence_free() -> None:
flow = U.divergence_free_flow(data, sigma=2.0)
assert flow.shape == (data.shape[0], 3) + data.shape[2:]
div = U.divergence(flow)
assert div[0, 0, 1:-1, 1:-1, 1:-1].abs().max().lt(1e-4)

coef = F.pad(data, (1, 2, 1, 2, 1, 2))
flow = U.divergence_free_flow(coef, mode="bspline", sigma=0.8)
assert flow.shape == (data.shape[0], 3) + data.shape[2:]
div = U.divergence(flow)
assert div[0, 0, 1:-1, 1:-1, 1:-1].abs().max().lt(1e-4)
assert div[0, 0, 1:-1, 1:-1, 1:-1].abs().max().lt(1e-3)

# coef = F.pad(data, (1, 2, 1, 2, 1, 2))
# flow = U.divergence_free_flow(coef, mode="bspline", sigma=1.0)
# assert flow.shape == (data.shape[0], 3) + data.shape[2:]
# div = U.divergence(flow)
# assert div[0, 0, 1:-1, 1:-1, 1:-1].abs().max().lt(1e-4)

# flow = U.divergence_free_flow(data, mode="gaussian", sigma=0.7355)
# assert flow.shape == (data.shape[0], 3) + data.shape[2:]
# div = U.divergence(flow)
# assert div[0, 0, 1:-1, 1:-1, 1:-1].abs().max().lt(1e-4)

# constructing a divergence-free field using curl() seems to work best given
# the higher magnitude and no need for Gaussian blurring of the random field
Expand Down