Skip to content

Commit

Permalink
Merge pull request #450 from nipreps/fix/aparc-select
Browse files Browse the repository at this point in the history
FIX: Select function in segmentation resampling workflow
  • Loading branch information
mgxd authored Aug 21, 2024
2 parents f368a28 + 4738420 commit 60f26d2
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 22 deletions.
42 changes: 21 additions & 21 deletions smriprep/workflows/surfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -1170,11 +1170,11 @@ def init_hcp_morphometrics_wf(
def init_segs_to_native_wf(
*,
image_type: ty.Literal['T1w', 'T2w'] = 'T1w',
segmentation: ty.Literal['aseg', 'aparc_aseg', 'wmparc'] = 'aseg',
segmentation: ty.Literal['aseg', 'aparc_aseg', 'aparc_a2009s', 'aparc_dkt'] | str = 'aseg',
name: str = 'segs_to_native_wf',
) -> Workflow:
"""
Get a segmentation from FreeSurfer conformed space into native T1w space.
Get a segmentation from FreeSurfer conformed space into native anatomical space.
Workflow Graph
.. workflow::
Expand Down Expand Up @@ -1219,30 +1219,15 @@ def init_segs_to_native_wf(

lta = pe.Node(ConcatenateXFMs(out_fmt='fs'), name='lta', run_without_submitting=True)

# Resample from T1.mgz to T1w.nii.gz, applying any offset in fsnative2anat_xfm,
# Resample from Freesurfer anat to native anat, applying any offset in fsnative2anat_xfm,
# and convert to NIfTI while we're at it
resample = pe.Node(
fs.ApplyVolTransform(transformed_file='seg.nii.gz', interp='nearest'),
name='resample',
)

if segmentation.startswith('aparc'):
if segmentation == 'aparc_aseg':

def _sel(x):
return [parc for parc in x if 'aparc+' in parc][0] # noqa

elif segmentation == 'aparc_a2009s':

def _sel(x):
return [parc for parc in x if 'a2009s+' in parc][0] # noqa

elif segmentation == 'aparc_dkt':

def _sel(x):
return [parc for parc in x if 'DKTatlas+' in parc][0] # noqa

segmentation = (segmentation, _sel)
select_seg = pe.Node(niu.Function(function=_select_seg), name='select_seg')
select_seg.inputs.segmentation = segmentation

anat = 'T2' if image_type == 'T2w' else 'T1'

Expand All @@ -1254,7 +1239,8 @@ def _sel(x):
('fsnative2anat_xfm', 'in_xfms')]),
(fssource, lta, [(anat, 'moving')]),
(inputnode, resample, [('in_file', 'target_file')]),
(fssource, resample, [(segmentation, 'source_file')]),
(fssource, select_seg, [(segmentation, 'in_files')]),
(select_seg, resample, [('out', 'source_file')]),
(lta, resample, [('out_xfm', 'lta_file')]),
(resample, outputnode, [('transformed_file', 'out_file')]),
]) # fmt:skip
Expand Down Expand Up @@ -1678,3 +1664,17 @@ def _get_surfaces(subjects_dir: str, subject_id: str, surfaces: list[str]) -> tu

ret = tuple(all_surfs[surface] for surface in surfaces)
return ret if len(ret) > 1 else ret[0]


def _select_seg(in_files, segmentation):
if isinstance(in_files, str):
return in_files

seg_mapping = {'aparc_aseg': 'aparc+', 'aparc_a2009s': 'a2009s+', 'aparc_dkt': 'DKTatlas+'}
if segmentation in seg_mapping:
segmentation = seg_mapping[segmentation]

for fl in in_files:
if segmentation in fl:
return fl
raise FileNotFoundError(f'No segmentation containing "{segmentation}" was found.')
15 changes: 14 additions & 1 deletion smriprep/workflows/tests/test_surfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from smriprep.interfaces.tests.data import load as load_test_data

from ..surfaces import init_anat_ribbon_wf, init_gifti_surfaces_wf
from ..surfaces import _select_seg, init_anat_ribbon_wf, init_gifti_surfaces_wf


def test_ribbon_workflow(tmp_path: Path):
Expand Down Expand Up @@ -53,3 +53,16 @@ def test_ribbon_workflow(tmp_path: Path):
assert np.allclose(ribbon.affine, expected.affine)
# Mask data is binary, so we can use np.array_equal
assert np.array_equal(ribbon.dataobj, expected.dataobj)


@pytest.mark.parametrize(
('in_files', 'segmentation', 'expected'),
[
('aparc+aseg.mgz', 'aparc_aseg', 'aparc+aseg.mgz'),
(['a2009s+aseg.mgz', 'aparc+aseg.mgz'], 'aparc_aseg', 'aparc+aseg.mgz'),
(['a2009s+aseg.mgz', 'aparc+aseg.mgz'], 'aparc_a2009s', 'a2009s+aseg.mgz'),
('wmparc.mgz', 'wmparc.mgz', 'wmparc.mgz'),
],
)
def test_select_seg(in_files, segmentation, expected):
assert _select_seg(in_files, segmentation) == expected

0 comments on commit 60f26d2

Please sign in to comment.