Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Integrate downsampling in BSplineApprox when the input is high-res #301

Merged
merged 7 commits into from
Nov 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 96 additions & 46 deletions sdcflows/interfaces/bspline.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,13 @@
OutputMultiObject,
)

from sdcflows.transform import grid_bspline_weights as gbsw
from sdcflows.transform import grid_bspline_weights


LOW_MEM_BLOCK_SIZE = 1000
DEFAULT_ZOOMS_MM = (40.0, 40.0, 20.0) # For human adults (mid-frequency), in mm
DEFAULT_LF_ZOOMS_MM = (100.0, 100.0, 40.0) # For human adults (low-frequency), in mm
DEFAULT_HF_ZOOMS_MM = (16.0, 16.0, 10.0) # For human adults (high-frequency), in mm
BSPLINE_SUPPORT = 2 - 1.82e-3 # Disallows weights < 1e-9
LOGGER = logging.getLogger("nipype.interface")


Expand Down Expand Up @@ -76,6 +75,13 @@ class _BSplineApproxInputSpec(BaseInterfaceInputSpec):
usedefault=True,
desc="generate a field, extrapolated outside the brain mask",
)
zooms_min = traits.Union(
traits.Float,
traits.Tuple(traits.Float, traits.Float, traits.Float),
default_value=1.95,
usedefault=True,
desc="limit minimum image zooms, set 0.0 to use the original image",
)
debug = traits.Bool(False, usedefault=True, desc="generate extra assets for debugging")


Expand Down Expand Up @@ -124,27 +130,48 @@ class BSplineApprox(SimpleInterface):

def _run_interface(self, runtime):
from sklearn import linear_model as lm
from scipy.sparse import vstack as sparse_vstack

# Output name baseline
out_name = fname_presuffix(
self.inputs.in_data, suffix="_field", newpath=runtime.cwd
)

# Load in the fieldmap
fmapnii = nb.load(self.inputs.in_data)
data = fmapnii.get_fdata(dtype="float32")
zooms = fmapnii.header.get_zooms()

# Generate the output naming base
out_name = fname_presuffix(
self.inputs.in_data, suffix="_field", newpath=runtime.cwd
# Get a mask (or define on the spot to cover the full extent)
masknii = (
nb.load(self.inputs.in_mask)
if isdefined(self.inputs.in_mask)
else None
)
# Create a copy of the header for use below
hdr = fmapnii.header.copy()
hdr.set_data_dtype("float32")

need_resize = np.any(np.array(zooms) < self.inputs.zooms_min)
if need_resize:
from sdcflows.utils.tools import resample_to_zooms

zooms_min = np.maximum(zooms, self.inputs.zooms_min)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for clarity, I think we want to rename from zooms_min, since it now diverges form self.inputs.zooms_min.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what would you suggest? (I have no particular preference)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

target_zooms made sense to me.


LOGGER.info(
"Resampling image with resolution exceeding 'zooms_min' "
f"({'x'.join(str(s) for s in zooms)} → "
f"{'x'.join(str(s) for s in zooms_min)})."
)
fmapnii = resample_to_zooms(fmapnii, zooms_min)

if masknii is not None:
masknii = resample_to_zooms(masknii, zooms_min)

data = fmapnii.get_fdata(dtype="float32")

# Generate a numpy array with the mask
mask = (
nb.load(self.inputs.in_mask).get_fdata() > 0
if isdefined(self.inputs.in_mask)
else np.ones_like(data, dtype=bool)
np.ones(fmapnii.shape, dtype=bool) if masknii is None
else np.asanyarray(masknii.dataobj) > 1e-4
)

# Massage bs_spacing input
# Convert spacings to numpy arrays
bs_spacing = [np.array(sp, dtype="float32") for sp in self.inputs.bs_spacing]

# Recenter the fieldmap
Expand All @@ -157,45 +184,29 @@ def _run_interface(self, runtime):
elif self.inputs.recenter == "mean":
data -= np.mean(data[mask])

# Calculate the spatial location of control points
bs_levels = []
ncoeff = []
weights = None
for sp in bs_spacing:
level = bspline_grid(fmapnii, control_zooms_mm=sp)
bs_levels.append(level)
ncoeff.append(level.dataobj.size)

weights = (
gbsw(fmapnii, level)
if weights is None
else sparse_vstack((weights, gbsw(fmapnii, level)))
)
# Calculate collocation matrix & the spatial location of control points
colmat, bs_levels = _collocation_matrix(fmapnii, bs_spacing)

regressors = weights.T.tocsr()[mask.reshape(-1), :]
bs_levels_str = ['x'.join(str(s) for s in level.shape) for level in bs_levels]
bs_levels_str[-1] = f"and {bs_levels_str[-1]}"
LOGGER.info(
f"Approximating B-Splines grids ({', '.join(bs_levels_str)} [knots]) on a grid of "
f"{'x'.join(str(s) for s in fmapnii.shape)} ({np.prod(fmapnii.shape)}) voxels,"
f" of which {mask.sum()} fall within the mask."
)

# Fit the model
model = lm.Ridge(alpha=self.inputs.ridge_alpha, fit_intercept=False)
model.fit(regressors, data[mask])

interp_data = np.zeros_like(data)
interp_data[mask] = np.array(model.coef_) @ regressors.T # Interpolation

# Store outputs
out_name = fname_presuffix(
self.inputs.in_data, suffix="_field", newpath=runtime.cwd
)
hdr = fmapnii.header.copy()
hdr.set_data_dtype("float32")
fmapnii.__class__(interp_data, fmapnii.affine, hdr).to_filename(out_name)
self._results["out_field"] = out_name
model.fit(colmat[mask.reshape(-1), :], data[mask])

# Store coefficients
index = 0
self._results["out_coeff"] = []
for i, (n, bsl) in enumerate(zip(ncoeff, bs_levels)):
for i, bsl in enumerate(bs_levels):
n = bsl.dataobj.size
out_level = out_name.replace("_field.", f"_coeff{i:03}.")
bsl.__class__(
np.array(model.coef_, dtype="float32")[index : index + n].reshape(
np.array(model.coef_, dtype="float32")[index:index + n].reshape(
bsl.shape
),
bsl.affine,
Expand All @@ -204,6 +215,27 @@ def _run_interface(self, runtime):
index += n
self._results["out_coeff"].append(out_level)

# Interpolating in the original grid will require a new collocation matrix
if need_resize:
fmapnii = nb.load(self.inputs.in_data)
data = fmapnii.get_fdata(dtype="float32")
mask = (
np.ones_like(fmapnii.dataobj, dtype=bool) if masknii is None
else np.asanyarray(nb.load(self.inputs.in_mask).dataobj) > 1e-4
)
colmat, _ = _collocation_matrix(fmapnii, bs_spacing)
oesteban marked this conversation as resolved.
Show resolved Hide resolved

regressors = colmat[mask.reshape(-1), :]
interp_data = np.zeros_like(data)
# Interpolate the field from the coefficients just calculated
interp_data[mask] = regressors @ model.coef_

# Store interpolated field
hdr = fmapnii.header.copy()
hdr.set_data_dtype("float32")
fmapnii.__class__(interp_data, fmapnii.affine, hdr).to_filename(out_name)
self._results["out_field"] = out_name

# Write out fitting-error map
self._results["out_error"] = out_name.replace("_field.", "_error.")
fmapnii.__class__(
Expand All @@ -217,8 +249,8 @@ def _run_interface(self, runtime):
self._results["out_extrapolated"] = self._results["out_field"]
return runtime

extrapolators = weights.tocsc()[:, ~mask.reshape(-1)]
interp_data[~mask] = np.array(model.coef_) @ extrapolators # Extrapolation
extrapolators = colmat[~mask.reshape(-1), :]
interp_data[~mask] = extrapolators @ model.coef_ # Extrapolation
self._results["out_extrapolated"] = out_name.replace("_field.", "_extra.")
fmapnii.__class__(interp_data, fmapnii.affine, hdr).to_filename(
self._results["out_extrapolated"]
Expand Down Expand Up @@ -474,6 +506,24 @@ def bspline_grid(img, control_zooms_mm=DEFAULT_ZOOMS_MM):
return img.__class__(np.zeros(bs_shape, dtype="float32"), bs_affine)


def _collocation_matrix(image, knot_spacing):
from scipy.sparse import vstack as sparse_vstack

bs_levels = []
weights = None
for sp in knot_spacing:
level = bspline_grid(image, control_zooms_mm=sp)
bs_levels.append(level)

weights = (
grid_bspline_weights(image, level)
if weights is None
else sparse_vstack((weights, grid_bspline_weights(image, level)))
)

return weights.T.tocsr(), bs_levels


def _fix_topup_fieldcoeff(in_coeff, fmap_ref, pe_dir, out_file=None):
"""Read in a coefficients file generated by TOPUP and fix x-form headers."""
from pathlib import Path
Expand Down
36 changes: 36 additions & 0 deletions sdcflows/utils/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,42 @@
"""Image processing tools."""


def resample_to_zooms(in_file, zooms, order=3, prefilter=True):
oesteban marked this conversation as resolved.
Show resolved Hide resolved
"""Resample the input data to a new grid with the requested zooms."""
from pathlib import Path
import numpy as np
import nibabel as nb
from nibabel.affines import rescale_affine
from nitransforms.linear import Affine

if isinstance(in_file, (str, Path)):
in_file = nb.load(in_file)

# Prepare output x-forms
sform, scode = in_file.get_sform(coded=True)
qform, qcode = in_file.get_qform(coded=True)

hdr = in_file.header.copy()
zooms = np.array(zooms)

pre_zooms = np.array(in_file.header.get_zooms()[:3])
# Could use `np.ceil` if we prefer
new_shape = np.rint(np.array(in_file.shape[:3]) * pre_zooms / zooms)
affine = rescale_affine(in_file.affine, in_file.shape[:3], zooms, new_shape)

# Generate new reference
hdr.set_sform(affine, scode)
hdr.set_qform(affine, qcode)
newref = in_file.__class__(
np.zeros(new_shape.astype(int), dtype=hdr.get_data_dtype()),
affine,
hdr,
)

# Resample via identity transform
return Affine(reference=newref).apply(in_file, order=order, prefilter=prefilter)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at scipy.ndimage.map_coordinates, I don't think prefilter filters more or less based on the down-sampling factor.

I suspect for the factors we'll be dealing with (mostly <2, almost all <3) it's probably fine, but I don't really know how to evaluate this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at scipy.ndimage.map_coordinates, I don't think prefilter filters more or less based on the down-sampling factor.

No, that's the b-spline filtering, which blurs with the width of the B-Spline basis (cubic, then 4 voxels). I believe that is enough for the purpose of this interpolation.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the case of SyN, ANTs is already giving you a displacements field that has been smoothed to a certain kernel width (the third parameter of -t [Syn]). We are not in a more general case where you have a signal with additive noise and you want to subsample it safely.

So I'm quite convinced no extra smoothing should be added.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(to answer the question, you are right, it is independent of the size ratio) -- but that should be okay, IMHO, for the reasons above.



def ensure_positive_cosines(img):
"""
Reorient axes polarity to have all positive direction cosines.
Expand Down