From e7654783608e2f73b63970dbe468c4fce95132f4 Mon Sep 17 00:00:00 2001 From: Taylor Salo Date: Wed, 24 Jul 2024 11:40:31 -0400 Subject: [PATCH 01/11] Try using surfplot for parcellated surfaces. --- pyproject.toml | 1 + xcp_d/interfaces/plotting.py | 206 ++++++++++++++------------------ xcp_d/workflows/connectivity.py | 1 + 3 files changed, 94 insertions(+), 114 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3bd874ff0..97232c2f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ dependencies = [ "scipy <= 1.14.0,>= 1.14.0", # nipype needs networkx, which needs scipy > 1.8.0 "seaborn", # for plots "sentry-sdk ~= 2.10.0", # for usage reports + "surfplot ~= 0.2.0", # for surface plots "templateflow ~= 24.2.0", "toml", ] diff --git a/xcp_d/interfaces/plotting.py b/xcp_d/interfaces/plotting.py index 7abf8b8ce..ee645a833 100644 --- a/xcp_d/interfaces/plotting.py +++ b/xcp_d/interfaces/plotting.py @@ -9,6 +9,7 @@ import numpy as np import pandas as pd import seaborn as sns +import svgutils.transform as sg from matplotlib.cm import ScalarMappable from matplotlib.colors import Normalize from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec @@ -27,6 +28,7 @@ traits, ) from nipype.interfaces.fsl.base import FSLCommand, FSLCommandInputSpec +from surfplot import Plot from templateflow.api import get as get_template from xcp_d.utils.confounds import load_motion @@ -886,6 +888,14 @@ class _PlotCiftiParcellationInputSpec(BaseInterfaceInputSpec): mandatory=True, desc="Labels for the CIFTI files.", ) + atlas_files = traits.List( + traits.Str, + mandatory=True, + desc=( + "The atlas files. Same length as 'labels' and to be reduced to match " + "'cortical_atlases'." + ), + ) out_file = File( exists=False, mandatory=False, @@ -963,158 +973,126 @@ def _run_interface(self, runtime): rh = self.inputs.rh_underlay lh = self.inputs.lh_underlay - # Create Figure and GridSpec. - # One subplot for each file. Each file will then have four subplots, arranged in a square. - cortical_files = [ + data_files = [ self.inputs.in_files[i] for i, atlas in enumerate(self.inputs.labels) if atlas in self.inputs.cortical_atlases ] - cortical_atlases = [ + atlas_names = [ atlas for atlas in self.inputs.labels if atlas in self.inputs.cortical_atlases ] - n_files = len(cortical_files) - fig = plt.figure(constrained_layout=False) - - if n_files == 1: - fig.set_size_inches(6.5, 6) - # Add an additional column for the colorbar - gs = GridSpec(1, 2, figure=fig, width_ratios=[1, 0.05]) - gs_list = [gs[0, 0]] - subplots = [fig.add_subplot(gs) for gs in gs_list] - cbar_gs_list = [gs[0, 1]] - else: - nrows = np.ceil(n_files / 2).astype(int) - fig.set_size_inches(12.5, 6 * nrows) - # Add an additional column for the colorbar - gs = GridSpec(nrows, 3, figure=fig, width_ratios=[1, 1, 0.05]) - gs_list = [gs[i, j] for i in range(nrows) for j in range(2)] - subplots = [fig.add_subplot(gs) for gs in gs_list] - cbar_gs_list = [gs[i, 2] for i in range(nrows)] - - for subplot in subplots: - subplot.set_axis_off() + atlas_files = [ + atlas for atlas in self.inputs.atlas_files if atlas in self.inputs.cortical_atlases + ] vmin, vmax = self.inputs.vmin, self.inputs.vmax - threshold = 0.01 if vmin == vmax: - threshold = None - # Define vmin and vmax based on all of the files vmin, vmax = np.inf, -np.inf - for cortical_file in cortical_files: - img_data = nb.load(cortical_file).get_fdata() + for data_file in data_files: + img_data = nb.load(data_file).get_fdata() vmin = np.min([np.nanmin(img_data), vmin]) vmax = np.max([np.nanmax(img_data), vmax]) vmin = 0 - for i_file in range(n_files): - subplot = subplots[i_file] - subplot.set_title(cortical_atlases[i_file]) - subplot_gridspec = gs_list[i_file] - - # Create 4 Axes (2 rows, 2 columns) from the subplot - gs_inner = GridSpecFromSubplotSpec(2, 2, subplot_spec=subplot_gridspec) - inner_subplots = [ - fig.add_subplot(gs_inner[i, j], projection="3d") - for i in range(2) - for j in range(2) - ] + figure_files = [] + for i_file, atlas_name in enumerate(atlas_names): + data_file = data_files[i_file] + atlas_file = atlas_files[i_file] + temp_file = fname_presuffix( + f"{atlas_name}.svg", + newpath=runtime.cwd, + ) - img = nb.load(cortical_files[i_file]) - img_data = img.get_fdata() - img_axes = [img.header.get_axis(i) for i in range(img.ndim)] - lh_surf_data = surf_data_from_cifti( + plot_obj = Plot(lh, rh) + + # add schaefer parcellation (no color bar needed) + data_img = nb.load(data_file) + img_data = data_img.get_fdata() + img_axes = [data_img.header.get_axis(i) for i in range(data_img.ndim)] + lh_data = surf_data_from_cifti( img_data, img_axes[1], "CIFTI_STRUCTURE_CORTEX_LEFT", ) - rh_surf_data = surf_data_from_cifti( + rh_data = surf_data_from_cifti( img_data, img_axes[1], "CIFTI_STRUCTURE_CORTEX_RIGHT", ) - - plot_surf_stat_map( - lh, - lh_surf_data, - threshold=threshold, - vmin=vmin, - vmax=vmax, - hemi="left", - view="lateral", - engine="matplotlib", + plot_obj.add_layer( + {"left": lh_data, "right": rh_data}, cmap="cool", - colorbar=False, - axes=inner_subplots[0], - figure=fig, + color_range=(vmin, vmax), + cbar=True, ) - plot_surf_stat_map( - rh, - rh_surf_data, - threshold=threshold, - vmin=vmin, - vmax=vmax, - hemi="right", - view="lateral", - engine="matplotlib", - cmap="cool", - colorbar=False, - axes=inner_subplots[1], - figure=fig, + + # Add parcel boundaries + atlas_img = nb.load(atlas_file) + atlas_data = atlas_img.get_fdata() + atlas_axes = [atlas_img.header.get_axis(i) for i in range(atlas_img.ndim)] + lh_atlas = surf_data_from_cifti( + atlas_data, + atlas_axes[1], + "CIFTI_STRUCTURE_CORTEX_LEFT", ) - plot_surf_stat_map( - lh, - lh_surf_data, - threshold=threshold, - vmin=vmin, - vmax=vmax, - hemi="left", - view="medial", - engine="matplotlib", - cmap="cool", - colorbar=False, - axes=inner_subplots[2], - figure=fig, + rh_atlas = surf_data_from_cifti( + atlas_data, + atlas_axes[1], + "CIFTI_STRUCTURE_CORTEX_RIGHT", ) - plot_surf_stat_map( - rh, - rh_surf_data, - threshold=threshold, - vmin=vmin, - vmax=vmax, - hemi="right", - view="medial", - engine="matplotlib", - cmap="cool", - colorbar=False, - axes=inner_subplots[3], - figure=fig, + plot_obj.add_layer( + {"left": np.squeeze(lh_atlas), "right": np.squeeze(rh_atlas)}, + cmap="gray", + as_outline=True, + cbar=False, ) + fig = plot_obj.build() + fig.savefig(temp_file) + figure_files.append(temp_file) + plt.close(fig) + + # Now build the combined figure + # Load SVG files and get their sizes + direction = "vertical" + svg_objects = [sg.fromfile(svg_path) for svg_path in figure_files] + figures = [svg_obj.getroot() for svg_obj in svg_objects] + widths, heights = [], [] + for fig in figures: + widths.append(fig.width) + heights.append(fig.height) + + # Calculate total width and height for the new SVG + if direction == "vertical": + total_width = max(widths) + total_height = sum(heights) + y_offset = 0 + else: + total_width = sum(widths) + total_height = max(heights) + x_offset = 0 - for ax in inner_subplots: - ax.set_rasterized(True) + # Create new SVG figure + new_svg = sg.SVGFigure(total_width, total_height) - # Create a ScalarMappable with the "cool" colormap and the specified vmin and vmax - sm = ScalarMappable(cmap="cool", norm=Normalize(vmin=vmin, vmax=vmax)) + # Add each SVG to the new figure + for fig in figures: + if direction == "vertical": + fig.moveto(0, y_offset) + y_offset += fig.height + else: + fig.moveto(x_offset, 0) + x_offset += fig.width - for colorbar_gridspec in cbar_gs_list: - colorbar_ax = fig.add_subplot(colorbar_gridspec) - # Add a colorbar to colorbar_ax using the ScalarMappable - fig.colorbar(sm, cax=colorbar_ax) + new_svg.append(figures) self._results["out_file"] = fname_presuffix( - cortical_files[0], + data_files[0], suffix="_file.svg", newpath=runtime.cwd, use_ext=False, ) - fig.savefig( - self._results["out_file"], - bbox_inches="tight", - pad_inches=None, - format="svg", - ) + new_svg.savefig(self._results["out_file"]) plt.close(fig) return runtime diff --git a/xcp_d/workflows/connectivity.py b/xcp_d/workflows/connectivity.py index 1cf912f31..73d80a379 100644 --- a/xcp_d/workflows/connectivity.py +++ b/xcp_d/workflows/connectivity.py @@ -674,6 +674,7 @@ def init_functional_connectivity_cifti_wf(mem_gb, exact_scans, name="connectivit workflow.connect([ (inputnode, plot_coverage, [ ("atlases", "labels"), + ("atlas_files", "atlas_files"), ("lh_midthickness", "lh_underlay"), ("rh_midthickness", "rh_underlay"), ]), From be7556c67faf39458586befbf29a7e160ffd4b56 Mon Sep 17 00:00:00 2001 From: Taylor Salo Date: Wed, 24 Jul 2024 12:08:18 -0400 Subject: [PATCH 02/11] Fix my mistake. --- xcp_d/interfaces/plotting.py | 11 ++++++----- xcp_d/workflows/connectivity.py | 2 ++ 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/xcp_d/interfaces/plotting.py b/xcp_d/interfaces/plotting.py index ee645a833..5d74e4819 100644 --- a/xcp_d/interfaces/plotting.py +++ b/xcp_d/interfaces/plotting.py @@ -978,12 +978,13 @@ def _run_interface(self, runtime): for i, atlas in enumerate(self.inputs.labels) if atlas in self.inputs.cortical_atlases ] - atlas_names = [ - atlas for atlas in self.inputs.labels if atlas in self.inputs.cortical_atlases - ] - atlas_files = [ - atlas for atlas in self.inputs.atlas_files if atlas in self.inputs.cortical_atlases + keep_idx = [ + i + for i, atlas in enumerate(self.inputs.labels) + if atlas in self.inputs.cortical_atlases ] + atlas_names = [self.inputs.labels[i] for i in keep_idx] + atlas_files = [self.inputs.atlas_files[i] for i in keep_idx] vmin, vmax = self.inputs.vmin, self.inputs.vmax if vmin == vmax: diff --git a/xcp_d/workflows/connectivity.py b/xcp_d/workflows/connectivity.py index 73d80a379..d5f02f012 100644 --- a/xcp_d/workflows/connectivity.py +++ b/xcp_d/workflows/connectivity.py @@ -858,6 +858,7 @@ def init_functional_connectivity_cifti_wf(mem_gb, exact_scans, name="connectivit workflow.connect([ (inputnode, plot_parcellated_reho, [ ("atlases", "labels"), + ("atlas_files", "atlas_files"), ("lh_midthickness", "lh_underlay"), ("rh_midthickness", "rh_underlay"), ]), @@ -913,6 +914,7 @@ def init_functional_connectivity_cifti_wf(mem_gb, exact_scans, name="connectivit workflow.connect([ (inputnode, plot_parcellated_alff, [ ("atlases", "labels"), + ("atlas_files", "atlas_files"), ("lh_midthickness", "lh_underlay"), ("rh_midthickness", "rh_underlay"), ]), From 4f5300e4d418efb96c4a93c9a7fab51f090d5eb1 Mon Sep 17 00:00:00 2001 From: Taylor Salo Date: Wed, 24 Jul 2024 12:36:07 -0400 Subject: [PATCH 03/11] Update plotting.py --- xcp_d/interfaces/plotting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xcp_d/interfaces/plotting.py b/xcp_d/interfaces/plotting.py index 5d74e4819..5cbe6b397 100644 --- a/xcp_d/interfaces/plotting.py +++ b/xcp_d/interfaces/plotting.py @@ -1012,12 +1012,12 @@ def _run_interface(self, runtime): img_data = data_img.get_fdata() img_axes = [data_img.header.get_axis(i) for i in range(data_img.ndim)] lh_data = surf_data_from_cifti( - img_data, + np.squeeze(img_data), img_axes[1], "CIFTI_STRUCTURE_CORTEX_LEFT", ) rh_data = surf_data_from_cifti( - img_data, + np.squeeze(img_data), img_axes[1], "CIFTI_STRUCTURE_CORTEX_RIGHT", ) From 4aaa4023c12f2ad5f9af6de257bc8c896a07d1d3 Mon Sep 17 00:00:00 2001 From: Taylor Salo Date: Wed, 24 Jul 2024 14:10:43 -0400 Subject: [PATCH 04/11] Get it running but it cuts stuff off. --- xcp_d/interfaces/plotting.py | 37 ++++++++++++++++++++++-------------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/xcp_d/interfaces/plotting.py b/xcp_d/interfaces/plotting.py index 5cbe6b397..508dbb48b 100644 --- a/xcp_d/interfaces/plotting.py +++ b/xcp_d/interfaces/plotting.py @@ -3,6 +3,7 @@ """Plotting interfaces.""" import json import os +import re import matplotlib.pyplot as plt import nibabel as nb @@ -1012,17 +1013,17 @@ def _run_interface(self, runtime): img_data = data_img.get_fdata() img_axes = [data_img.header.get_axis(i) for i in range(data_img.ndim)] lh_data = surf_data_from_cifti( - np.squeeze(img_data), + img_data, img_axes[1], "CIFTI_STRUCTURE_CORTEX_LEFT", ) rh_data = surf_data_from_cifti( - np.squeeze(img_data), + img_data, img_axes[1], "CIFTI_STRUCTURE_CORTEX_RIGHT", ) plot_obj.add_layer( - {"left": lh_data, "right": rh_data}, + {"left": np.squeeze(lh_data), "right": np.squeeze(rh_data)}, cmap="cool", color_range=(vmin, vmax), cbar=True, @@ -1049,6 +1050,8 @@ def _run_interface(self, runtime): cbar=False, ) fig = plot_obj.build() + fig.suptitle(atlas_name, fontsize=16) + fig.tight_layout() fig.savefig(temp_file) figure_files.append(temp_file) plt.close(fig) @@ -1056,12 +1059,16 @@ def _run_interface(self, runtime): # Now build the combined figure # Load SVG files and get their sizes direction = "vertical" - svg_objects = [sg.fromfile(svg_path) for svg_path in figure_files] - figures = [svg_obj.getroot() for svg_obj in svg_objects] widths, heights = [], [] - for fig in figures: - widths.append(fig.width) - heights.append(fig.height) + for figure_file in figure_files: + svg_obj = sg.fromfile(figure_file) + fig = svg_obj.getroot() + + # Original size is represented as string (example: '600px'); convert to float + width = float(re.sub("[^0-9]", "", svg_obj.width)) + height = float(re.sub("[^0-9]", "", svg_obj.height)) + widths.append(width) + heights.append(height) # Calculate total width and height for the new SVG if direction == "vertical": @@ -1077,15 +1084,18 @@ def _run_interface(self, runtime): new_svg = sg.SVGFigure(total_width, total_height) # Add each SVG to the new figure - for fig in figures: + for i_fig, figure_file in enumerate(figure_files): + svg_obj = sg.fromfile(figure_file) + fig = svg_obj.getroot() + if direction == "vertical": fig.moveto(0, y_offset) - y_offset += fig.height + y_offset += heights[i_fig] else: fig.moveto(x_offset, 0) - x_offset += fig.width + x_offset += widths[i_fig] - new_svg.append(figures) + new_svg.append(fig) self._results["out_file"] = fname_presuffix( data_files[0], @@ -1093,8 +1103,7 @@ def _run_interface(self, runtime): newpath=runtime.cwd, use_ext=False, ) - new_svg.savefig(self._results["out_file"]) - plt.close(fig) + new_svg.save(self._results["out_file"]) return runtime From c46b8c7806160408ef8576cf3384d880f0e7b383 Mon Sep 17 00:00:00 2001 From: Taylor Salo Date: Wed, 24 Jul 2024 14:23:00 -0400 Subject: [PATCH 05/11] Fix things up. --- xcp_d/interfaces/plotting.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/xcp_d/interfaces/plotting.py b/xcp_d/interfaces/plotting.py index 508dbb48b..74af666f0 100644 --- a/xcp_d/interfaces/plotting.py +++ b/xcp_d/interfaces/plotting.py @@ -1081,7 +1081,9 @@ def _run_interface(self, runtime): x_offset = 0 # Create new SVG figure - new_svg = sg.SVGFigure(total_width, total_height) + new_svg = sg.SVGFigure(width=f"{total_width}px", height=f"{total_height}px") + # for some reason, the width and height params aren't retained, so set them again + new_svg.set_size((f"{total_width}px", f"{total_height}px")) # Add each SVG to the new figure for i_fig, figure_file in enumerate(figure_files): From 0f2e36c3c630ce948f3a655d538f8da77dfbc73f Mon Sep 17 00:00:00 2001 From: Taylor Salo Date: Wed, 24 Jul 2024 14:46:06 -0400 Subject: [PATCH 06/11] Update plotting.py --- xcp_d/interfaces/plotting.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/xcp_d/interfaces/plotting.py b/xcp_d/interfaces/plotting.py index 74af666f0..0a017cff9 100644 --- a/xcp_d/interfaces/plotting.py +++ b/xcp_d/interfaces/plotting.py @@ -1008,7 +1008,7 @@ def _run_interface(self, runtime): plot_obj = Plot(lh, rh) - # add schaefer parcellation (no color bar needed) + LOGGER.info(f"Adding {atlas_name} to the plot.") data_img = nb.load(data_file) img_data = data_img.get_fdata() img_axes = [data_img.header.get_axis(i) for i in range(data_img.ndim)] @@ -1022,6 +1022,7 @@ def _run_interface(self, runtime): img_axes[1], "CIFTI_STRUCTURE_CORTEX_RIGHT", ) + LOGGER.info(f"Data sizes: {lh_data.shape} {rh_data.shape}") plot_obj.add_layer( {"left": np.squeeze(lh_data), "right": np.squeeze(rh_data)}, cmap="cool", @@ -1043,6 +1044,8 @@ def _run_interface(self, runtime): atlas_axes[1], "CIFTI_STRUCTURE_CORTEX_RIGHT", ) + + LOGGER.info(f"Atlas sizes: {lh_atlas.shape} {rh_atlas.shape}") plot_obj.add_layer( {"left": np.squeeze(lh_atlas), "right": np.squeeze(rh_atlas)}, cmap="gray", From 11df13966948d6b40d7858bae75a503b118dafee Mon Sep 17 00:00:00 2001 From: Taylor Salo Date: Wed, 24 Jul 2024 15:08:26 -0400 Subject: [PATCH 07/11] Log stuff. --- xcp_d/interfaces/plotting.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/xcp_d/interfaces/plotting.py b/xcp_d/interfaces/plotting.py index 0a017cff9..f27373d79 100644 --- a/xcp_d/interfaces/plotting.py +++ b/xcp_d/interfaces/plotting.py @@ -1005,6 +1005,10 @@ def _run_interface(self, runtime): f"{atlas_name}.svg", newpath=runtime.cwd, ) + lh_img = nb.load(lh) + rh_img = nb.load(rh) + LOGGER.info(f"Underlay files: {lh} {rh}") + LOGGER.info(f"Underlay sizes: {lh_img.shape} {rh_img.shape}") plot_obj = Plot(lh, rh) From 79bd148c80058d43936a07361673a19ef7a6008f Mon Sep 17 00:00:00 2001 From: Taylor Salo Date: Wed, 24 Jul 2024 15:14:59 -0400 Subject: [PATCH 08/11] Update plotting.py --- xcp_d/interfaces/plotting.py | 39 +++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/xcp_d/interfaces/plotting.py b/xcp_d/interfaces/plotting.py index f27373d79..07ca3be04 100644 --- a/xcp_d/interfaces/plotting.py +++ b/xcp_d/interfaces/plotting.py @@ -1008,7 +1008,10 @@ def _run_interface(self, runtime): lh_img = nb.load(lh) rh_img = nb.load(rh) LOGGER.info(f"Underlay files: {lh} {rh}") - LOGGER.info(f"Underlay sizes: {lh_img.shape} {rh_img.shape}") + LOGGER.info( + "Underlay sizes: " + f"{lh_img.agg_data()[0].shape[0]} {rh_img.agg_data()[0].shape[0]}" + ) plot_obj = Plot(lh, rh) @@ -1065,7 +1068,6 @@ def _run_interface(self, runtime): # Now build the combined figure # Load SVG files and get their sizes - direction = "vertical" widths, heights = [], [] for figure_file in figure_files: svg_obj = sg.fromfile(figure_file) @@ -1077,15 +1079,23 @@ def _run_interface(self, runtime): widths.append(width) heights.append(height) - # Calculate total width and height for the new SVG - if direction == "vertical": - total_width = max(widths) - total_height = sum(heights) - y_offset = 0 + cell_width, cell_height = max(widths), max(heights) + max_columns = 2 + if len(figure_files) == 1: + total_width = cell_width + total_height = cell_height + loc_idx = [(0, 0)] else: - total_width = sum(widths) - total_height = max(heights) - x_offset = 0 + n_rows = int(np.ceil(len(figure_files) / max_columns)) + n_columns = max_columns + total_width = cell_width * n_columns + total_height = cell_height * n_rows + + loc_idx = [] + for i_row in range(n_rows): + row_idx = [j for j in range(len(figure_files)) if (j // max_columns) == i_row] + for j_col, idx in enumerate(row_idx): + loc_idx.append((i_row, j_col)) # Create new SVG figure new_svg = sg.SVGFigure(width=f"{total_width}px", height=f"{total_height}px") @@ -1096,14 +1106,7 @@ def _run_interface(self, runtime): for i_fig, figure_file in enumerate(figure_files): svg_obj = sg.fromfile(figure_file) fig = svg_obj.getroot() - - if direction == "vertical": - fig.moveto(0, y_offset) - y_offset += heights[i_fig] - else: - fig.moveto(x_offset, 0) - x_offset += widths[i_fig] - + fig.moveto(loc_idx[i_fig][0] * cell_height, loc_idx[i_fig][1] * cell_width) new_svg.append(fig) self._results["out_file"] = fname_presuffix( From 104e2d07a37e85ba399dbbcd7170b7fa21242885 Mon Sep 17 00:00:00 2001 From: Taylor Salo Date: Wed, 24 Jul 2024 15:26:36 -0400 Subject: [PATCH 09/11] Update plotting.py --- xcp_d/interfaces/plotting.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/xcp_d/interfaces/plotting.py b/xcp_d/interfaces/plotting.py index 07ca3be04..5d08c526d 100644 --- a/xcp_d/interfaces/plotting.py +++ b/xcp_d/interfaces/plotting.py @@ -1094,7 +1094,7 @@ def _run_interface(self, runtime): loc_idx = [] for i_row in range(n_rows): row_idx = [j for j in range(len(figure_files)) if (j // max_columns) == i_row] - for j_col, idx in enumerate(row_idx): + for j_col in range(len(row_idx)): loc_idx.append((i_row, j_col)) # Create new SVG figure @@ -1106,7 +1106,9 @@ def _run_interface(self, runtime): for i_fig, figure_file in enumerate(figure_files): svg_obj = sg.fromfile(figure_file) fig = svg_obj.getroot() - fig.moveto(loc_idx[i_fig][0] * cell_height, loc_idx[i_fig][1] * cell_width) + offset0 = loc_idx[i_fig][1] * cell_width + offset1 = loc_idx[i_fig][0] * cell_height + fig.moveto(offset0, offset1) new_svg.append(fig) self._results["out_file"] = fname_presuffix( From 14e39d066d1850d5b2a08398f423d4da1b80b3aa Mon Sep 17 00:00:00 2001 From: Taylor Salo Date: Wed, 24 Jul 2024 15:58:14 -0400 Subject: [PATCH 10/11] Update plotting.py --- xcp_d/interfaces/plotting.py | 114 ++++++----------------------------- 1 file changed, 18 insertions(+), 96 deletions(-) diff --git a/xcp_d/interfaces/plotting.py b/xcp_d/interfaces/plotting.py index 5d08c526d..86f288d14 100644 --- a/xcp_d/interfaces/plotting.py +++ b/xcp_d/interfaces/plotting.py @@ -11,10 +11,7 @@ import pandas as pd import seaborn as sns import svgutils.transform as sg -from matplotlib.cm import ScalarMappable -from matplotlib.colors import Normalize -from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec -from nilearn.plotting import plot_anat, plot_stat_map, plot_surf_stat_map +from nilearn.plotting import plot_anat, plot_stat_map from nipype import logging from nipype.interfaces.base import ( BaseInterfaceInputSpec, @@ -1183,106 +1180,31 @@ def _run_interface(self, runtime): rh = self.inputs.rh_underlay lh = self.inputs.lh_underlay - cifti = nb.load(self.inputs.in_file) - cifti_data = cifti.get_fdata() - cifti_axes = [cifti.header.get_axis(i) for i in range(cifti.ndim)] - - # Create Figure and GridSpec. - fig = plt.figure(constrained_layout=False) - fig.set_size_inches(6.5, 6) - # Add an additional column for the colorbar - gs = GridSpec(1, 2, figure=fig, width_ratios=[1, 0.05]) - subplot_gridspec = gs[0, 0] - subplot = fig.add_subplot(subplot_gridspec) - colorbar_gridspec = gs[0, 1] - - subplot.set_axis_off() - - # Create 4 Axes (2 rows, 2 columns) from the subplot - gs_inner = GridSpecFromSubplotSpec(2, 2, subplot_spec=subplot_gridspec) - inner_subplots = [ - fig.add_subplot(gs_inner[i, j], projection="3d") for i in range(2) for j in range(2) - ] + data_img = nb.load(self.inputs.in_file) + img_data = data_img.get_fdata() + data_axes = [data_img.header.get_axis(i) for i in range(data_img.ndim)] - lh_surf_data = surf_data_from_cifti( - cifti_data, - cifti_axes[1], + plot_obj = Plot(lh, rh) + lh_data = surf_data_from_cifti( + img_data, + data_axes[1], "CIFTI_STRUCTURE_CORTEX_LEFT", ) - rh_surf_data = surf_data_from_cifti( - cifti_data, - cifti_axes[1], + rh_data = surf_data_from_cifti( + img_data, + data_axes[1], "CIFTI_STRUCTURE_CORTEX_RIGHT", ) + vmax = np.nanmax([np.nanmax(lh_data), np.nanmax(rh_data)]) + vmin = np.nanmin([np.nanmin(lh_data), np.nanmin(rh_data)]) - vmax = np.nanmax([np.nanmax(lh_surf_data), np.nanmax(rh_surf_data)]) - vmin = np.nanmin([np.nanmin(lh_surf_data), np.nanmin(rh_surf_data)]) - - plot_surf_stat_map( - lh, - lh_surf_data, - vmin=vmin, - vmax=vmax, - hemi="left", - view="lateral", - engine="matplotlib", - cmap="cool", - colorbar=False, - axes=inner_subplots[0], - figure=fig, - ) - plot_surf_stat_map( - rh, - rh_surf_data, - vmin=vmin, - vmax=vmax, - hemi="right", - view="lateral", - engine="matplotlib", + plot_obj.add_layer( + {"left": np.squeeze(lh_data), "right": np.squeeze(rh_data)}, cmap="cool", - colorbar=False, - axes=inner_subplots[1], - figure=fig, - ) - plot_surf_stat_map( - lh, - lh_surf_data, - vmin=vmin, - vmax=vmax, - hemi="left", - view="medial", - engine="matplotlib", - cmap="cool", - colorbar=False, - axes=inner_subplots[2], - figure=fig, + color_range=(vmin, vmax), + cbar=True, ) - plot_surf_stat_map( - rh, - rh_surf_data, - vmin=vmin, - vmax=vmax, - hemi="right", - view="medial", - engine="matplotlib", - cmap="cool", - colorbar=False, - axes=inner_subplots[3], - figure=fig, - ) - - inner_subplots[0].set_title("Left Hemisphere", fontsize=10) - inner_subplots[1].set_title("Right Hemisphere", fontsize=10) - - for ax in inner_subplots: - ax.set_rasterized(True) - - # Create a ScalarMappable with the "cool" colormap and the specified vmin and vmax - sm = ScalarMappable(cmap="cool", norm=Normalize(vmin=vmin, vmax=vmax)) - - colorbar_ax = fig.add_subplot(colorbar_gridspec) - # Add a colorbar to colorbar_ax using the ScalarMappable - fig.colorbar(sm, cax=colorbar_ax) + fig = plot_obj.build() self._results["out_file"] = fname_presuffix( self.inputs.in_file, From ed8913bd3bdff0d5d54107e2baa3aff6fa2f8a3c Mon Sep 17 00:00:00 2001 From: Taylor Salo Date: Wed, 24 Jul 2024 16:11:19 -0400 Subject: [PATCH 11/11] Update pyproject.toml --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 97232c2f8..f021b0834 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dependencies = [ "surfplot ~= 0.2.0", # for surface plots "templateflow ~= 24.2.0", "toml", + "vtk ~= 9.2.6", ] dynamic = ["version"]