From a87593da6b570cb20fca36b9c920fb4d97a22ba4 Mon Sep 17 00:00:00 2001 From: Pierre Nedelec <65797646+pierre-nedelec@users.noreply.github.com> Date: Mon, 6 Feb 2023 20:38:24 -0800 Subject: [PATCH] display tract profiles grouped by var_interest --- afqinsight/plot.py | 86 ++++++++++++++++++- afqinsight/tools/__init__.py | 0 afqinsight/tools/tractutils.py | 152 +++++++++++++++++++++++++++++++++ 3 files changed, 237 insertions(+), 1 deletion(-) create mode 100644 afqinsight/tools/__init__.py create mode 100644 afqinsight/tools/tractutils.py diff --git a/afqinsight/plot.py b/afqinsight/plot.py index aa746cde..f526c659 100644 --- a/afqinsight/plot.py +++ b/afqinsight/plot.py @@ -10,8 +10,10 @@ from tqdm.auto import tqdm from .utils import BUNDLE_MAT_2_PYTHON +from .tools import tractutils as tu from .datasets import AFQDataset - +import logging +logger = logging.getLogger(__name__) POSITIONS = OrderedDict( { @@ -331,3 +333,85 @@ def plot_tract_profiles( figs[metric] = fig return figs + + +def plot_profiles_by_group( + tracto_df:pd.DataFrame, + var_df:pd.DataFrame, + center_cut=True, + test_plot=False, + scalar_to_plot='dki_fa', + var_to_plot:str=None, + palette='viridis', + ): + """Plot tractometry profiles grouped by variable of interest + + Parameters + ---------- + tracto_df : pd.DataFrame + tractometry data + var_df : pd.DataFrame + dataframe with categorical or pseudocontinuous variable of interest + center_cut : bool, optional + whether to cut the extremes of the tract out, by default True + test_plot : bool, optional + whether to subsample the data to get quick plot, by default False + scalar_to_plot : str, optional + scalar to plot (y axis), by default 'dki_fa' + var_to_plot : str, optional + variable of interest name if `var_df` contains several variables, by default None + + Returns + ------- + sns.FacetGrid + FacetGrid + """ + + if var_to_plot is None: + available = list(set(var_df.columns) - set(['subjectID','sessionID'])) + if len(available)>1: + raise ValueError(f"Please specify `var_to_plot` as there are " + f"several options available in `var_df`.") + var_to_plot = available[0] + else: + if var_to_plot not in var_df.columns: + raise ValueError(f"{var_to_plot=} not available in `var_df` columns.") + + tracto_df = tu.optimize(tracto_df) + tracto_df = tracto_df[['subjectID','sessionID','tractID','nodeID',scalar_to_plot]] + + if test_plot: + tracto_df = tu.tracto_subsample(tracto_df) + dpi = 150 + else: + dpi = 300 + + if center_cut: + tracto_df = tu.center_cut(tracto_df) + + df = tu.merge_profiles_varinterest(tracto_df, var_df) + df = tu.beautify(df) + + available_tracts = df.tractID.unique() + tract_order = [v for k,v in tu.BUNDLE_DICT.items() if k in available_tracts] + col_wrap = min(5, len(available_tracts)) + g = sns.FacetGrid( + data=df, + col_wrap=col_wrap, + col='tractID_b', col_order=tract_order, + hue=var_to_plot, #hue_order=list(reversed(df[var_to_plot].cat.categories)), + margin_titles=True, palette=palette, + sharey='row', + height=3.5, aspect=0.9, + legend_out=True, + # subplot_kws=dict(figsize=(3.5,10)) + ) + g.map(sns.lineplot, 'nodeID', scalar_to_plot, estimator=np.nanmean, errorbar=('se',1)) + g.set_titles(col_template='{col_name}', row_template='{row_name}') + g.add_legend(title=var_to_plot) + g.fig.suptitle(f'{scalar_to_plot.upper()} values along select WM tracts for {var_to_plot}', y=1) + g.tight_layout() + g.fig.set_dpi(dpi) + + return g + \ No newline at end of file diff --git a/afqinsight/tools/__init__.py b/afqinsight/tools/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/afqinsight/tools/tractutils.py b/afqinsight/tools/tractutils.py new file mode 100644 index 00000000..791563fa --- /dev/null +++ b/afqinsight/tools/tractutils.py @@ -0,0 +1,152 @@ +"""Utility functions to manipulate tractometry type dataframe +""" + +import numpy as np +import pandas as pd +import logging +logger = logging.getLogger(__name__) + +BUNDLE_DICT = { + 'CST_R': 'Right Corticospinal', + 'CST_L': 'Left Corticospinal', + 'UNC_R': 'Right Uncinate', + 'UNC_L': 'Left Uncinate', + 'IFO_R': 'Right IFOF', + 'IFO_L': 'Left IFOF', + 'ARC_R': 'Right Arcuate', + 'ARC_L': 'Left Arcuate', + 'ATR_R': 'Right Thalamic Radiation', + 'ATR_L': 'Left Thalamic Radiation', + 'CGC_R': 'Right Cingulum Cingulate', + 'CGC_L': 'Left Cingulum Cingulate', + 'HCC_R': 'Right Cingulum Hippocampus', + 'HCC_L': 'Left Cingulum Hippocampus', + 'FP': 'Callosum Forceps Major', + 'FA': 'Callosum Forceps Minor', + 'ILF_R': 'Right ILF', + 'ILF_L': 'Left ILF', + 'SLF_R': 'Right SLF', + 'SLF_L': 'Left SLF', + 'VOF_R': 'Right Vertical Occipital', + 'VOF_L': 'Left Vertical Occipital', + 'pARC_R': 'Right Posterior Arcuate', + 'pARC_L': 'Left Posterior Arcuate', + 'AntFrontal': 'Callosum: AntFrontal', + 'Motor': 'Callosum: Motor', + 'Occipital': 'Callosum: Occipital', + 'Orbital': 'Callosum: Orbital', + 'PostParietal': 'Callosum: PostParietal', + 'SupFrontal': 'Callosum: SupFrontal', + 'SupParietal': 'Callosum: SupParietal', + 'Temporal': 'Callosum: Temporal', +} + + +def merge_profiles_varinterest( + tracto_df:pd.DataFrame, + var_df:pd.DataFrame, + show_info=True, +) -> pd.DataFrame: + """Merge tractometry and variable of interest dataframes + + Parameters + ---------- + tracto_df : pd.DataFrame + tractometry dataframe + var_df : pd.DataFrame + dataframe containing variables of interest + show_info : bool, optional + Whether to show what's beeing dropped with merge, by default True + + Returns + ------- + pd.DataFrame + Merged dataframe + """ + + var_df = optimize(var_df) + tracto_df = optimize(tracto_df) + + df = pd.merge( + left=tracto_df, + right=var_df, + on=['subjectID','sessionID'], + how="outer") + #TODO: show how many subjects we lose here + if show_info: + logger.info("Not implemented yet") + df = df.dropna().reset_index(drop=True) + return df + + + +def center_cut(df: pd.DataFrame, cut: tuple=(25,75)) -> pd.DataFrame: + """ Returns dataframe where the nodeID is cut between two indicated values. + + Args: + df (pd.DataFrame): dataframe with column `nodeID` to be filtered on. + cut (tuple, optional): extreme values to filter nodeID on. Defaults to (25,75). + + Returns: + pd.DataFrame: filtered dataframe + """ + + df = df[(df.nodeID >= cut[0]) & (df.nodeID <= cut[1])] + return df.reset_index(drop=True) + + + +def optimize(df:pd.DataFrame) -> pd.DataFrame: + cols = list(set(['subjectID','sessionID','tractID']) & set(df.columns)) + for col in cols: + df[col] = df[col].astype('category').cat.remove_unused_categories() + if 'nodeID' in df.columns: + df['nodeID'] = df['nodeID'].astype(np.int16) + return df + + +def beautify(tracto: pd.DataFrame) -> pd.DataFrame: + """Makes tracto_df ready for plotting + + Args: + tracto (pd.DataFrame): tractometry dataframe. + + Returns: + pd.DataFrame: tractometry dataframe, with embelishments for + """ + + for col in tracto.columns: + if tracto[col].dtype == 'category': + tracto[col] = tracto[col].cat.remove_unused_categories() + if "tractID_b" not in tracto.columns: + tracto = tracto.assign(tractID_b=tracto.tractID.map(BUNDLE_DICT)) + return tracto + + +def tracto_subsample( + tracto_df:pd.DataFrame, + subjects=5, + sessions=1, + tracts=2, + ) -> pd.DataFrame: + + if isinstance(subjects, int): + subjects = np.random.choice( + tracto_df.subjectID.unique(), subjects, replace=False) + if isinstance(sessions, int): + sessions = np.random.choice( + tracto_df.sessionID.unique(), sessions, replace=False) + if isinstance(tracts, int): + tracts = np.random.choice( + tracto_df.tractID.unique(), tracts, replace=False) + + df = tracto_df[ + (tracto_df.subjectID.isin(subjects)) & + (tracto_df.sessionID.isin(sessions)) & + (tracto_df.tractID.isin(tracts)) + ] + for col in df.columns: + if df[col].dtype == 'category': + df[col] = df[col].cat.remove_unused_categories() + return df +