Skip to content

Commit

Permalink
[core] Implement Lie bracket of vector fields
Browse files Browse the repository at this point in the history
  • Loading branch information
aschuh-hf committed Nov 4, 2023
1 parent 9a35d1b commit af0e00b
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 0 deletions.
54 changes: 54 additions & 0 deletions src/deepali/core/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,60 @@ def jacobian_triu(
return torch.cat([jac[(i, j)] for i, j in combinations_with_replacement(range(D), 2)], dim=1)


def lie_bracket(
u: Tensor,
v: Tensor,
mode: Optional[str] = None,
sigma: Optional[float] = None,
spacing: Optional[Union[Scalar, Array]] = None,
stride: Optional[ScalarOrTuple[int]] = None,
) -> Tensor:
r"""Lie bracket of two vector fields.
Args:
u: First vector field as tensor of shape ``(N, D, ..., X)``.
v: Second vector field as tensor of shape ``(N, D, ..., X)``.
mode: Mode of :func:`flow_derivatives()` approximation.
sigma: Standard deviation of Gaussian used for computing spatial derivatives.
spacing: Physical size of image voxels used to compute spatial derivatives.
stride: Number of output grid points between control points plus one for ``mode='bspline'``.
Returns:
Lie bracket of vector fields [u, v] as tensor of shape ``(N, D, ..., X)``.
"""
for name, flow in [("u", u), ("v", v)]:
if flow.ndim < 4:
raise ValueError(f"lie_bracket() '{name}' must be vector field of shape (N, D, ..., X)")
if flow.shape[1] != flow.ndim - 2:
raise ValueError(f"lie_bracket() '{name}' must have shape (N, D, ..., X)")
if u.shape != v.shape:
raise ValueError(f"lie_bracket() 'u' and 'v' must have the same shape")
jac_u = jacobian_dict(
u,
mode=mode,
sigma=sigma,
spacing=spacing,
stride=stride,
)
jac_v = jacobian_dict(
v,
mode=mode,
sigma=sigma,
spacing=spacing,
stride=stride,
)
D = flow.ndim - 2
w = torch.zeros_like(u)
for i in range(D):
w_i = w.narrow(1, i, 1)
for j in range(D):
w_i = w_i.add_(jac_v[(i, j)].mul(u.narrow(1, j, 1)))
for j in range(D):
w_i = w_i.sub_(jac_u[(i, j)].mul(v.narrow(1, j, 1)))
return w


def normalize_flow(
data: Tensor,
size: Optional[Union[Tensor, torch.Size]] = None,
Expand Down
2 changes: 2 additions & 0 deletions src/deepali/core/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
from .flow import jacobian_dict
from .flow import jacobian_matrix
from .flow import jacobian_triu
from .flow import lie_bracket
from .flow import normalize_flow
from .flow import sample_flow
from .flow import warp_grid
Expand Down Expand Up @@ -218,6 +219,7 @@
"jacobian_dict",
"jacobian_matrix",
"jacobian_triu",
"lie_bracket",
"max_pool",
"min_pool",
"normalize_flow",
Expand Down

0 comments on commit af0e00b

Please sign in to comment.