diff --git a/smriprep/workflows/surfaces.py b/smriprep/workflows/surfaces.py index ec44155b7f..48895377a3 100644 --- a/smriprep/workflows/surfaces.py +++ b/smriprep/workflows/surfaces.py @@ -1170,11 +1170,11 @@ def init_hcp_morphometrics_wf( def init_segs_to_native_wf( *, image_type: ty.Literal['T1w', 'T2w'] = 'T1w', - segmentation: ty.Literal['aseg', 'aparc_aseg', 'wmparc'] = 'aseg', + segmentation: ty.Literal['aseg', 'aparc_aseg', 'aparc_a2009s', 'aparc_dkt'] | str = 'aseg', name: str = 'segs_to_native_wf', ) -> Workflow: """ - Get a segmentation from FreeSurfer conformed space into native T1w space. + Get a segmentation from FreeSurfer conformed space into native anatomical space. Workflow Graph .. workflow:: @@ -1219,30 +1219,15 @@ def init_segs_to_native_wf( lta = pe.Node(ConcatenateXFMs(out_fmt='fs'), name='lta', run_without_submitting=True) - # Resample from T1.mgz to T1w.nii.gz, applying any offset in fsnative2anat_xfm, + # Resample from Freesurfer anat to native anat, applying any offset in fsnative2anat_xfm, # and convert to NIfTI while we're at it resample = pe.Node( fs.ApplyVolTransform(transformed_file='seg.nii.gz', interp='nearest'), name='resample', ) - if segmentation.startswith('aparc'): - if segmentation == 'aparc_aseg': - - def _sel(x): - return [parc for parc in x if 'aparc+' in parc][0] # noqa - - elif segmentation == 'aparc_a2009s': - - def _sel(x): - return [parc for parc in x if 'a2009s+' in parc][0] # noqa - - elif segmentation == 'aparc_dkt': - - def _sel(x): - return [parc for parc in x if 'DKTatlas+' in parc][0] # noqa - - segmentation = (segmentation, _sel) + select_seg = pe.Node(niu.Function(function=_select_seg), name='select_seg') + select_seg.inputs.segmentation = segmentation anat = 'T2' if image_type == 'T2w' else 'T1' @@ -1254,7 +1239,8 @@ def _sel(x): ('fsnative2anat_xfm', 'in_xfms')]), (fssource, lta, [(anat, 'moving')]), (inputnode, resample, [('in_file', 'target_file')]), - (fssource, resample, [(segmentation, 'source_file')]), + (fssource, select_seg, [(segmentation, 'in_files')]), + (select_seg, resample, [('out', 'source_file')]), (lta, resample, [('out_xfm', 'lta_file')]), (resample, outputnode, [('transformed_file', 'out_file')]), ]) # fmt:skip @@ -1678,3 +1664,17 @@ def _get_surfaces(subjects_dir: str, subject_id: str, surfaces: list[str]) -> tu ret = tuple(all_surfs[surface] for surface in surfaces) return ret if len(ret) > 1 else ret[0] + + +def _select_seg(in_files, segmentation): + if isinstance(in_files, str): + return in_files + + seg_mapping = {'aparc_aseg': 'aparc+', 'aparc_a2009s': 'a2009s+', 'aparc_dkt': 'DKTatlas+'} + if segmentation in seg_mapping: + segmentation = seg_mapping[segmentation] + + for fl in in_files: + if segmentation in fl: + return fl + raise FileNotFoundError(f'No segmentation containing "{segmentation}" was found.') diff --git a/smriprep/workflows/tests/test_surfaces.py b/smriprep/workflows/tests/test_surfaces.py index 95076da759..da28ba1914 100644 --- a/smriprep/workflows/tests/test_surfaces.py +++ b/smriprep/workflows/tests/test_surfaces.py @@ -9,7 +9,7 @@ from smriprep.interfaces.tests.data import load as load_test_data -from ..surfaces import init_anat_ribbon_wf, init_gifti_surfaces_wf +from ..surfaces import _select_seg, init_anat_ribbon_wf, init_gifti_surfaces_wf def test_ribbon_workflow(tmp_path: Path): @@ -53,3 +53,16 @@ def test_ribbon_workflow(tmp_path: Path): assert np.allclose(ribbon.affine, expected.affine) # Mask data is binary, so we can use np.array_equal assert np.array_equal(ribbon.dataobj, expected.dataobj) + + +@pytest.mark.parametrize( + ('in_files', 'segmentation', 'expected'), + [ + ('aparc+aseg.mgz', 'aparc_aseg', 'aparc+aseg.mgz'), + (['a2009s+aseg.mgz', 'aparc+aseg.mgz'], 'aparc_aseg', 'aparc+aseg.mgz'), + (['a2009s+aseg.mgz', 'aparc+aseg.mgz'], 'aparc_a2009s', 'a2009s+aseg.mgz'), + ('wmparc.mgz', 'wmparc.mgz', 'wmparc.mgz'), + ], +) +def test_select_seg(in_files, segmentation, expected): + assert _select_seg(in_files, segmentation) == expected