diff --git a/sdcflows/interfaces/bspline.py b/sdcflows/interfaces/bspline.py index 9ea745c64c..31e65bd157 100644 --- a/sdcflows/interfaces/bspline.py +++ b/sdcflows/interfaces/bspline.py @@ -40,14 +40,13 @@ OutputMultiObject, ) -from sdcflows.transform import grid_bspline_weights as gbsw +from sdcflows.transform import grid_bspline_weights LOW_MEM_BLOCK_SIZE = 1000 DEFAULT_ZOOMS_MM = (40.0, 40.0, 20.0) # For human adults (mid-frequency), in mm DEFAULT_LF_ZOOMS_MM = (100.0, 100.0, 40.0) # For human adults (low-frequency), in mm DEFAULT_HF_ZOOMS_MM = (16.0, 16.0, 10.0) # For human adults (high-frequency), in mm -BSPLINE_SUPPORT = 2 - 1.82e-3 # Disallows weights < 1e-9 LOGGER = logging.getLogger("nipype.interface") @@ -76,6 +75,13 @@ class _BSplineApproxInputSpec(BaseInterfaceInputSpec): usedefault=True, desc="generate a field, extrapolated outside the brain mask", ) + zooms_min = traits.Union( + traits.Float, + traits.Tuple(traits.Float, traits.Float, traits.Float), + default_value=1.95, + usedefault=True, + desc="limit minimum image zooms, set 0.0 to use the original image", + ) debug = traits.Bool(False, usedefault=True, desc="generate extra assets for debugging") @@ -124,27 +130,48 @@ class BSplineApprox(SimpleInterface): def _run_interface(self, runtime): from sklearn import linear_model as lm - from scipy.sparse import vstack as sparse_vstack + + # Output name baseline + out_name = fname_presuffix( + self.inputs.in_data, suffix="_field", newpath=runtime.cwd + ) # Load in the fieldmap fmapnii = nb.load(self.inputs.in_data) - data = fmapnii.get_fdata(dtype="float32") + zooms = fmapnii.header.get_zooms() - # Generate the output naming base - out_name = fname_presuffix( - self.inputs.in_data, suffix="_field", newpath=runtime.cwd + # Get a mask (or define on the spot to cover the full extent) + masknii = ( + nb.load(self.inputs.in_mask) + if isdefined(self.inputs.in_mask) + else None ) - # Create a copy of the header for use below - hdr = fmapnii.header.copy() - hdr.set_data_dtype("float32") + need_resize = np.any(np.array(zooms) < self.inputs.zooms_min) + if need_resize: + from sdcflows.utils.tools import resample_to_zooms + + zooms_min = np.maximum(zooms, self.inputs.zooms_min) + + LOGGER.info( + "Resampling image with resolution exceeding 'zooms_min' " + f"({'x'.join(str(s) for s in zooms)} → " + f"{'x'.join(str(s) for s in zooms_min)})." + ) + fmapnii = resample_to_zooms(fmapnii, zooms_min) + + if masknii is not None: + masknii = resample_to_zooms(masknii, zooms_min) + + data = fmapnii.get_fdata(dtype="float32") + + # Generate a numpy array with the mask mask = ( - nb.load(self.inputs.in_mask).get_fdata() > 0 - if isdefined(self.inputs.in_mask) - else np.ones_like(data, dtype=bool) + np.ones(fmapnii.shape, dtype=bool) if masknii is None + else np.asanyarray(masknii.dataobj) > 1e-4 ) - # Massage bs_spacing input + # Convert spacings to numpy arrays bs_spacing = [np.array(sp, dtype="float32") for sp in self.inputs.bs_spacing] # Recenter the fieldmap @@ -157,45 +184,29 @@ def _run_interface(self, runtime): elif self.inputs.recenter == "mean": data -= np.mean(data[mask]) - # Calculate the spatial location of control points - bs_levels = [] - ncoeff = [] - weights = None - for sp in bs_spacing: - level = bspline_grid(fmapnii, control_zooms_mm=sp) - bs_levels.append(level) - ncoeff.append(level.dataobj.size) - - weights = ( - gbsw(fmapnii, level) - if weights is None - else sparse_vstack((weights, gbsw(fmapnii, level))) - ) + # Calculate collocation matrix & the spatial location of control points + colmat, bs_levels = _collocation_matrix(fmapnii, bs_spacing) - regressors = weights.T.tocsr()[mask.reshape(-1), :] + bs_levels_str = ['x'.join(str(s) for s in level.shape) for level in bs_levels] + bs_levels_str[-1] = f"and {bs_levels_str[-1]}" + LOGGER.info( + f"Approximating B-Splines grids ({', '.join(bs_levels_str)} [knots]) on a grid of " + f"{'x'.join(str(s) for s in fmapnii.shape)} ({np.prod(fmapnii.shape)}) voxels," + f" of which {mask.sum()} fall within the mask." + ) # Fit the model model = lm.Ridge(alpha=self.inputs.ridge_alpha, fit_intercept=False) - model.fit(regressors, data[mask]) - - interp_data = np.zeros_like(data) - interp_data[mask] = np.array(model.coef_) @ regressors.T # Interpolation - - # Store outputs - out_name = fname_presuffix( - self.inputs.in_data, suffix="_field", newpath=runtime.cwd - ) - hdr = fmapnii.header.copy() - hdr.set_data_dtype("float32") - fmapnii.__class__(interp_data, fmapnii.affine, hdr).to_filename(out_name) - self._results["out_field"] = out_name + model.fit(colmat[mask.reshape(-1), :], data[mask]) + # Store coefficients index = 0 self._results["out_coeff"] = [] - for i, (n, bsl) in enumerate(zip(ncoeff, bs_levels)): + for i, bsl in enumerate(bs_levels): + n = bsl.dataobj.size out_level = out_name.replace("_field.", f"_coeff{i:03}.") bsl.__class__( - np.array(model.coef_, dtype="float32")[index : index + n].reshape( + np.array(model.coef_, dtype="float32")[index:index + n].reshape( bsl.shape ), bsl.affine, @@ -204,6 +215,27 @@ def _run_interface(self, runtime): index += n self._results["out_coeff"].append(out_level) + # Interpolating in the original grid will require a new collocation matrix + if need_resize: + fmapnii = nb.load(self.inputs.in_data) + data = fmapnii.get_fdata(dtype="float32") + mask = ( + np.ones_like(fmapnii.dataobj, dtype=bool) if masknii is None + else np.asanyarray(nb.load(self.inputs.in_mask).dataobj) > 1e-4 + ) + colmat, _ = _collocation_matrix(fmapnii, bs_spacing) + + regressors = colmat[mask.reshape(-1), :] + interp_data = np.zeros_like(data) + # Interpolate the field from the coefficients just calculated + interp_data[mask] = regressors @ model.coef_ + + # Store interpolated field + hdr = fmapnii.header.copy() + hdr.set_data_dtype("float32") + fmapnii.__class__(interp_data, fmapnii.affine, hdr).to_filename(out_name) + self._results["out_field"] = out_name + # Write out fitting-error map self._results["out_error"] = out_name.replace("_field.", "_error.") fmapnii.__class__( @@ -217,8 +249,8 @@ def _run_interface(self, runtime): self._results["out_extrapolated"] = self._results["out_field"] return runtime - extrapolators = weights.tocsc()[:, ~mask.reshape(-1)] - interp_data[~mask] = np.array(model.coef_) @ extrapolators # Extrapolation + extrapolators = colmat[~mask.reshape(-1), :] + interp_data[~mask] = extrapolators @ model.coef_ # Extrapolation self._results["out_extrapolated"] = out_name.replace("_field.", "_extra.") fmapnii.__class__(interp_data, fmapnii.affine, hdr).to_filename( self._results["out_extrapolated"] @@ -474,6 +506,24 @@ def bspline_grid(img, control_zooms_mm=DEFAULT_ZOOMS_MM): return img.__class__(np.zeros(bs_shape, dtype="float32"), bs_affine) +def _collocation_matrix(image, knot_spacing): + from scipy.sparse import vstack as sparse_vstack + + bs_levels = [] + weights = None + for sp in knot_spacing: + level = bspline_grid(image, control_zooms_mm=sp) + bs_levels.append(level) + + weights = ( + grid_bspline_weights(image, level) + if weights is None + else sparse_vstack((weights, grid_bspline_weights(image, level))) + ) + + return weights.T.tocsr(), bs_levels + + def _fix_topup_fieldcoeff(in_coeff, fmap_ref, pe_dir, out_file=None): """Read in a coefficients file generated by TOPUP and fix x-form headers.""" from pathlib import Path diff --git a/sdcflows/utils/tools.py b/sdcflows/utils/tools.py index 4942f6a05e..5fe23051eb 100644 --- a/sdcflows/utils/tools.py +++ b/sdcflows/utils/tools.py @@ -23,6 +23,42 @@ """Image processing tools.""" +def resample_to_zooms(in_file, zooms, order=3, prefilter=True): + """Resample the input data to a new grid with the requested zooms.""" + from pathlib import Path + import numpy as np + import nibabel as nb + from nibabel.affines import rescale_affine + from nitransforms.linear import Affine + + if isinstance(in_file, (str, Path)): + in_file = nb.load(in_file) + + # Prepare output x-forms + sform, scode = in_file.get_sform(coded=True) + qform, qcode = in_file.get_qform(coded=True) + + hdr = in_file.header.copy() + zooms = np.array(zooms) + + pre_zooms = np.array(in_file.header.get_zooms()[:3]) + # Could use `np.ceil` if we prefer + new_shape = np.rint(np.array(in_file.shape[:3]) * pre_zooms / zooms) + affine = rescale_affine(in_file.affine, in_file.shape[:3], zooms, new_shape) + + # Generate new reference + hdr.set_sform(affine, scode) + hdr.set_qform(affine, qcode) + newref = in_file.__class__( + np.zeros(new_shape.astype(int), dtype=hdr.get_data_dtype()), + affine, + hdr, + ) + + # Resample via identity transform + return Affine(reference=newref).apply(in_file, order=order, prefilter=prefilter) + + def ensure_positive_cosines(img): """ Reorient axes polarity to have all positive direction cosines.