diff --git a/sdcflows/interfaces/bspline.py b/sdcflows/interfaces/bspline.py index 8a880113ca..bc03b24758 100644 --- a/sdcflows/interfaces/bspline.py +++ b/sdcflows/interfaces/bspline.py @@ -130,6 +130,7 @@ class BSplineApprox(SimpleInterface): def _run_interface(self, runtime): from sklearn import linear_model as lm + from scipy.sparse import vstack as sparse_vstack # Output name baseline out_name = fname_presuffix( @@ -147,21 +148,25 @@ def _run_interface(self, runtime): else None ) + # Determine the shape of bspline coefficients + # This should not change with resizing, so do it first + bs_grids = [bspline_grid(fmapnii, control_zooms_mm=sp) for sp in self.inputs.bs_spacing] + need_resize = np.any(np.array(zooms) < self.inputs.zooms_min) if need_resize: from sdcflows.utils.tools import resample_to_zooms - zooms_min = np.maximum(zooms, self.inputs.zooms_min) + target_zooms = np.maximum(zooms, self.inputs.zooms_min) LOGGER.info( "Resampling image with resolution exceeding 'zooms_min' " f"({'x'.join(str(s) for s in zooms)} → " - f"{'x'.join(str(s) for s in zooms_min)})." + f"{'x'.join(str(s) for s in target_zooms)})." ) - fmapnii = resample_to_zooms(fmapnii, zooms_min) + fmapnii = resample_to_zooms(fmapnii, target_zooms) if masknii is not None: - masknii = resample_to_zooms(masknii, zooms_min) + masknii = resample_to_zooms(masknii, target_zooms) data = fmapnii.get_fdata(dtype="float32") @@ -171,9 +176,6 @@ def _run_interface(self, runtime): else np.asanyarray(masknii.dataobj) > 1e-4 ) - # Convert spacings to numpy arrays - bs_spacing = [np.array(sp, dtype="float32") for sp in self.inputs.bs_spacing] - # Recenter the fieldmap if self.inputs.recenter == "mode": from scipy.stats import mode @@ -187,13 +189,13 @@ def _run_interface(self, runtime): elif self.inputs.recenter == "mean": data -= np.mean(data[mask]) - # Calculate collocation matrix & the spatial location of control points - colmat, bs_levels = _collocation_matrix(fmapnii, bs_spacing) + # Calculate collocation matrix from (possibly resized) image and knot grids + colmat = sparse_vstack(grid_bspline_weights(fmapnii, grid) for grid in bs_grids).T.tocsr() - bs_levels_str = ['x'.join(str(s) for s in level.shape) for level in bs_levels] - bs_levels_str[-1] = f"and {bs_levels_str[-1]}" + bs_grids_str = ['x'.join(str(s) for s in grid.shape) for grid in bs_grids] + bs_grids_str[-1] = f"and {bs_grids_str[-1]}" LOGGER.info( - f"Approximating B-Splines grids ({', '.join(bs_levels_str)} [knots]) on a grid of " + f"Approximating B-Splines grids ({', '.join(bs_grids_str)} [knots]) on a grid of " f"{'x'.join(str(s) for s in fmapnii.shape)} ({np.prod(fmapnii.shape)}) voxels," f" of which {mask.sum()} fall within the mask." ) @@ -205,7 +207,7 @@ def _run_interface(self, runtime): # Store coefficients index = 0 self._results["out_coeff"] = [] - for i, bsl in enumerate(bs_levels): + for i, bsl in enumerate(bs_grids): n = bsl.dataobj.size out_level = out_name.replace("_field.", f"_coeff{i:03}.") bsl.__class__( @@ -226,7 +228,9 @@ def _run_interface(self, runtime): np.ones_like(fmapnii.dataobj, dtype=bool) if masknii is None else np.asanyarray(nb.load(self.inputs.in_mask).dataobj) > 1e-4 ) - colmat, _ = _collocation_matrix(fmapnii, bs_spacing) + colmat = sparse_vstack( + grid_bspline_weights(fmapnii, grid) for grid in bs_grids + ).T.tocsr() regressors = colmat[mask.reshape(-1), :] interp_data = np.zeros_like(data) @@ -509,24 +513,6 @@ def bspline_grid(img, control_zooms_mm=DEFAULT_ZOOMS_MM): return img.__class__(np.zeros(bs_shape, dtype="float32"), bs_affine) -def _collocation_matrix(image, knot_spacing): - from scipy.sparse import vstack as sparse_vstack - - bs_levels = [] - weights = None - for sp in knot_spacing: - level = bspline_grid(image, control_zooms_mm=sp) - bs_levels.append(level) - - weights = ( - grid_bspline_weights(image, level) - if weights is None - else sparse_vstack((weights, grid_bspline_weights(image, level))) - ) - - return weights.T.tocsr(), bs_levels - - def _fix_topup_fieldcoeff(in_coeff, fmap_ref, pe_dir, out_file=None): """Read in a coefficients file generated by TOPUP and fix x-form headers.""" from pathlib import Path