Skip to content

Commit

Permalink
[core] Implement composition of SVFs
Browse files Browse the repository at this point in the history
  • Loading branch information
aschuh-hf committed Nov 4, 2023
1 parent af0e00b commit 9f10f46
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 0 deletions.
64 changes: 64 additions & 0 deletions src/deepali/core/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,70 @@ def compose_flows(u: Tensor, v: Tensor, align_corners: bool = True) -> Tensor:
return w


def compose_svfs(
u: Tensor,
v: Tensor,
order: int = 1,
mode: Optional[str] = None,
sigma: Optional[float] = None,
spacing: Optional[Union[Scalar, Array]] = None,
stride: Optional[ScalarOrTuple[int]] = None,
) -> Tensor:
r"""Approximate stationary velocity field (SVF) of composite deformation.
The output velocity field is ``w = log(exp(v) o exp(u))``, where ``exp`` is the exponential map
of a stationary velocity field, and ``log`` its inverse. The velocity field ``w`` is given by the
`Baker-Campbell-Hausdorff (BCH) formula <https://en.wikipedia.org/wiki/Baker%E2%80%93Campbell%E2%80%93Hausdorff_formula>`_.
Args:
u: First stationary velocity field as tensor of shape ``(N, D, ..., X)``.
v: Second stationary velocity field as tensor of shape ``(N, D, ..., X)``.
order: Order of approximation. The highest implemented order is 3.
When 0, the returned velocity field is the sum of ``u`` and ``v``.
This approximation is accurate if the input velocity fields commute, i.e.,
the Lie bracket [u, v] = 0. When ``order=1``, the approximation is given by
``w = v + u + 1/2 [v, u]`` (note that deformation ``exp(u)`` is applied first).
mode: Mode of :func:`flow_derivatives()` approximation.
sigma: Standard deviation of Gaussian used for computing spatial derivatives.
spacing: Physical size of image voxels used to compute spatial derivatives.
stride: Number of output grid points between control points plus one for ``mode='bspline'``.
Returns:
Approximation of BCH formula as tensor of shape ``(N, D, ..., X)``.
"""

def lb(a: Tensor, b: Tensor) -> Tensor:
return lie_bracket(a, b, mode=mode, sigma=sigma, spacing=spacing, stride=stride)

for name, flow in [("u", u), ("v", v)]:
if flow.ndim < 4:
raise ValueError(f"compose_svfs() '{name}' must be vector field of shape (N, D, ..., X)")
if flow.shape[1] != flow.ndim - 2:
raise ValueError(f"compose_svfs() '{name}' must have shape (N, D, ..., X)")
if u.shape != v.shape:
raise ValueError(f"compose_svfs() 'u' and 'v' must have the same shape")
if order < 0:
raise ValueError("compose_svfs() 'order' must not be negative")
elif order > 3:
raise NotImplementedError(f"compose_svfs() approximation of order={order} not implemented")

w = v.add(u)
if order >= 1:
vu = lb(v, u)
w = w.add(vu.mul_(0.5))
if order >= 2:
uv = lb(u, v)
vvu = lb(v, vu)
uuv = lb(u, uv)
w = w.add(vvu.add(uuv).mul_(1/12))
if order >= 3:
uvvu = lb(u, vvu)
w = w.add(uvvu.mul_(1/24))

return w


def curl(
flow: Tensor,
mode: Optional[str] = None,
Expand Down
2 changes: 2 additions & 0 deletions src/deepali/core/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@

from .flow import affine_flow
from .flow import compose_flows
from .flow import compose_svfs
from .flow import curl
from .flow import denormalize_flow
from .flow import divergence
Expand Down Expand Up @@ -184,6 +185,7 @@
"closest_point_distances",
"closest_point_indices",
"compose_flows",
"compose_svfs",
"conv",
"conv1d",
"crop",
Expand Down

0 comments on commit 9f10f46

Please sign in to comment.