Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: Calculate bspline grids separately from colocation matrices #308

Merged
merged 2 commits into from
Dec 6, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 18 additions & 32 deletions sdcflows/interfaces/bspline.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ class BSplineApprox(SimpleInterface):

def _run_interface(self, runtime):
from sklearn import linear_model as lm
from scipy.sparse import vstack as sparse_vstack

# Output name baseline
out_name = fname_presuffix(
Expand All @@ -147,21 +148,25 @@ def _run_interface(self, runtime):
else None
)

# Determine the shape of bspline coefficients
# This should not change with resizing, so do it first
bs_grids = [bspline_grid(fmapnii, control_zooms_mm=sp) for sp in self.inputs.bs_spacing]

need_resize = np.any(np.array(zooms) < self.inputs.zooms_min)
if need_resize:
from sdcflows.utils.tools import resample_to_zooms

zooms_min = np.maximum(zooms, self.inputs.zooms_min)
target_zooms = np.maximum(zooms, self.inputs.zooms_min)

LOGGER.info(
"Resampling image with resolution exceeding 'zooms_min' "
f"({'x'.join(str(s) for s in zooms)} → "
f"{'x'.join(str(s) for s in zooms_min)})."
f"{'x'.join(str(s) for s in target_zooms)})."
)
fmapnii = resample_to_zooms(fmapnii, zooms_min)
fmapnii = resample_to_zooms(fmapnii, target_zooms)

if masknii is not None:
masknii = resample_to_zooms(masknii, zooms_min)
masknii = resample_to_zooms(masknii, target_zooms)

data = fmapnii.get_fdata(dtype="float32")

Expand All @@ -171,9 +176,6 @@ def _run_interface(self, runtime):
else np.asanyarray(masknii.dataobj) > 1e-4
)

# Convert spacings to numpy arrays
bs_spacing = [np.array(sp, dtype="float32") for sp in self.inputs.bs_spacing]

# Recenter the fieldmap
if self.inputs.recenter == "mode":
from scipy.stats import mode
Expand All @@ -187,13 +189,13 @@ def _run_interface(self, runtime):
elif self.inputs.recenter == "mean":
data -= np.mean(data[mask])

# Calculate collocation matrix & the spatial location of control points
colmat, bs_levels = _collocation_matrix(fmapnii, bs_spacing)
# Calculate collocation matrix from (possibly resized) image and knot grids
colmat = sparse_vstack(grid_bspline_weights(fmapnii, grid) for grid in bs_grids).T.tocsr()

bs_levels_str = ['x'.join(str(s) for s in level.shape) for level in bs_levels]
bs_levels_str[-1] = f"and {bs_levels_str[-1]}"
bs_grids_str = ['x'.join(str(s) for s in grid.shape) for grid in bs_grids]
bs_grids_str[-1] = f"and {bs_grids_str[-1]}"
LOGGER.info(
f"Approximating B-Splines grids ({', '.join(bs_levels_str)} [knots]) on a grid of "
f"Approximating B-Splines grids ({', '.join(bs_grids_str)} [knots]) on a grid of "
f"{'x'.join(str(s) for s in fmapnii.shape)} ({np.prod(fmapnii.shape)}) voxels,"
f" of which {mask.sum()} fall within the mask."
)
Expand All @@ -205,7 +207,7 @@ def _run_interface(self, runtime):
# Store coefficients
index = 0
self._results["out_coeff"] = []
for i, bsl in enumerate(bs_levels):
for i, bsl in enumerate(bs_grids):
n = bsl.dataobj.size
out_level = out_name.replace("_field.", f"_coeff{i:03}.")
bsl.__class__(
Expand All @@ -226,7 +228,9 @@ def _run_interface(self, runtime):
np.ones_like(fmapnii.dataobj, dtype=bool) if masknii is None
else np.asanyarray(nb.load(self.inputs.in_mask).dataobj) > 1e-4
)
colmat, _ = _collocation_matrix(fmapnii, bs_spacing)
colmat = sparse_vstack(
grid_bspline_weights(fmapnii, grid) for grid in bs_grids
).T.tocsr()

regressors = colmat[mask.reshape(-1), :]
interp_data = np.zeros_like(data)
Expand Down Expand Up @@ -509,24 +513,6 @@ def bspline_grid(img, control_zooms_mm=DEFAULT_ZOOMS_MM):
return img.__class__(np.zeros(bs_shape, dtype="float32"), bs_affine)


def _collocation_matrix(image, knot_spacing):
from scipy.sparse import vstack as sparse_vstack

bs_levels = []
weights = None
for sp in knot_spacing:
level = bspline_grid(image, control_zooms_mm=sp)
bs_levels.append(level)

weights = (
grid_bspline_weights(image, level)
if weights is None
else sparse_vstack((weights, grid_bspline_weights(image, level)))
)

return weights.T.tocsr(), bs_levels


def _fix_topup_fieldcoeff(in_coeff, fmap_ref, pe_dir, out_file=None):
"""Read in a coefficients file generated by TOPUP and fix x-form headers."""
from pathlib import Path
Expand Down