diff --git a/src/deepali/core/flow.py b/src/deepali/core/flow.py index 9af12f4..cea7054 100644 --- a/src/deepali/core/flow.py +++ b/src/deepali/core/flow.py @@ -357,6 +357,7 @@ def jacobian_det( sigma: Optional[float] = None, spacing: Optional[Union[Scalar, Array]] = None, stride: Optional[ScalarOrTuple[int]] = None, + add_identity: bool = True, ) -> Tensor: r"""Evaluate Jacobian determinant of spatial deformation. @@ -366,6 +367,9 @@ def jacobian_det( sigma: Standard deviation of Gaussian used for computing spatial derivatives. spacing: Physical size of image voxels used to compute spatial derivatives. stride: Number of output grid points between control points plus one for ``mode='bspline'``. + add_identity: Whether to calculate derivatives of :math:`u(x)` (False) or the spatial + deformation given by :math:`x + u(x)` (True), where :math:`u` is the flow field, + by adding the identity matrix to the Jacobian of :math:`u`. Returns: Scalar field of Jacobian determinant values as tensor of shape ``(N, 1, ..., X)``. @@ -380,8 +384,9 @@ def jacobian_det( which = FlowDerivativeKeys.jacobian(spatial_dims=D) deriv = flow_derivatives(flow, which=which, **kwargs) # Add 1 to diagonal elements of Jacobian matrix, because T(x) = x + u(x) - for i in range(D): - deriv[FlowDerivativeKeys.symbol(i, i)].add_(1) + if add_identity: + for i in range(D): + deriv[FlowDerivativeKeys.symbol(i, i)].add_(1) if D == 2: a = deriv["du/dx"] b = deriv["du/dy"] @@ -452,12 +457,13 @@ def jacobian_dict( kwargs = dict(mode=mode, sigma=sigma, spacing=spacing, stride=stride) which = FlowDerivativeKeys.jacobian(spatial_dims=D) deriv = flow_derivatives(flow, which=which, **kwargs) + # Optionally, add 1 to diagonal elements of Jacobian matrix, because T(x) = x + u(x). + if add_identity: + for i in range(D): + deriv[FlowDerivativeKeys.symbol(i, i)].add_(1) jac = {} for i, j in combinations_with_replacement(range(D), 2): - dij = deriv[FlowDerivativeKeys.symbol(i, j)] - if add_identity and i == j: - dij = dij.add_(1) # T(x) = x + u(x) - jac[(i, j)] = dij + jac[(i, j)] = deriv[FlowDerivativeKeys.symbol(i, j)] return {(i, j): jac[(i, j) if i < j else (j, i)] for i, j in product(range(D), repeat=2)}