diff --git a/src/deepali/core/flow.py b/src/deepali/core/flow.py index a580abb..fa24903 100644 --- a/src/deepali/core/flow.py +++ b/src/deepali/core/flow.py @@ -62,6 +62,70 @@ def compose_flows(u: Tensor, v: Tensor, align_corners: bool = True) -> Tensor: return w +def compose_svfs( + u: Tensor, + v: Tensor, + order: int = 1, + mode: Optional[str] = None, + sigma: Optional[float] = None, + spacing: Optional[Union[Scalar, Array]] = None, + stride: Optional[ScalarOrTuple[int]] = None, +) -> Tensor: + r"""Approximate stationary velocity field (SVF) of composite deformation. + + The output velocity field is ``w = log(exp(v) o exp(u))``, where ``exp`` is the exponential map + of a stationary velocity field, and ``log`` its inverse. The velocity field ``w`` is given by the + `Baker-Campbell-Hausdorff (BCH) formula `_. + + Args: + u: First stationary velocity field as tensor of shape ``(N, D, ..., X)``. + v: Second stationary velocity field as tensor of shape ``(N, D, ..., X)``. + order: Order of approximation. The highest implemented order is 3. + When 0, the returned velocity field is the sum of ``u`` and ``v``. + This approximation is accurate if the input velocity fields commute, i.e., + the Lie bracket [u, v] = 0. When ``order=1``, the approximation is given by + ``w = v + u + 1/2 [v, u]`` (note that deformation ``exp(u)`` is applied first). + mode: Mode of :func:`flow_derivatives()` approximation. + sigma: Standard deviation of Gaussian used for computing spatial derivatives. + spacing: Physical size of image voxels used to compute spatial derivatives. + stride: Number of output grid points between control points plus one for ``mode='bspline'``. + + Returns: + Approximation of BCH formula as tensor of shape ``(N, D, ..., X)``. + + """ + + def lb(a: Tensor, b: Tensor) -> Tensor: + return lie_bracket(a, b, mode=mode, sigma=sigma, spacing=spacing, stride=stride) + + for name, flow in [("u", u), ("v", v)]: + if flow.ndim < 4: + raise ValueError(f"compose_svfs() '{name}' must be vector field of shape (N, D, ..., X)") + if flow.shape[1] != flow.ndim - 2: + raise ValueError(f"compose_svfs() '{name}' must have shape (N, D, ..., X)") + if u.shape != v.shape: + raise ValueError(f"compose_svfs() 'u' and 'v' must have the same shape") + if order < 0: + raise ValueError("compose_svfs() 'order' must not be negative") + elif order > 3: + raise NotImplementedError(f"compose_svfs() approximation of order={order} not implemented") + + w = v.add(u) + if order >= 1: + vu = lb(v, u) + w = w.add(vu.mul_(0.5)) + if order >= 2: + uv = lb(u, v) + vvu = lb(v, vu) + uuv = lb(u, uv) + w = w.add(vvu.add(uuv).mul_(1/12)) + if order >= 3: + uvvu = lb(u, vvu) + w = w.add(uvvu.mul_(1/24)) + + return w + + def curl( flow: Tensor, mode: Optional[str] = None, diff --git a/src/deepali/core/functional.py b/src/deepali/core/functional.py index 4b7d01b..5324840 100644 --- a/src/deepali/core/functional.py +++ b/src/deepali/core/functional.py @@ -93,6 +93,7 @@ from .flow import affine_flow from .flow import compose_flows +from .flow import compose_svfs from .flow import curl from .flow import denormalize_flow from .flow import divergence @@ -184,6 +185,7 @@ "closest_point_distances", "closest_point_indices", "compose_flows", + "compose_svfs", "conv", "conv1d", "crop",