Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Display tract profiles grouped by var_interest #136

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 85 additions & 1 deletion afqinsight/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
from tqdm.auto import tqdm

from .utils import BUNDLE_MAT_2_PYTHON
from .tools import tractutils as tu
from .datasets import AFQDataset

import logging
logger = logging.getLogger(__name__)

POSITIONS = OrderedDict(
{
Expand Down Expand Up @@ -331,3 +333,85 @@ def plot_tract_profiles(
figs[metric] = fig

return figs


def plot_profiles_by_group(
tracto_df:pd.DataFrame,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type annotations! We haven't done any of those yet, but it's a good step into the future!

var_df:pd.DataFrame,
center_cut=True,
test_plot=False,
scalar_to_plot='dki_fa',
pierre-nedelec marked this conversation as resolved.
Show resolved Hide resolved
var_to_plot:str=None,
palette='viridis',
):
"""Plot tractometry profiles grouped by variable of interest

Parameters
----------
tracto_df : pd.DataFrame
tractometry data
var_df : pd.DataFrame
dataframe with categorical or pseudocontinuous variable of interest
center_cut : bool, optional
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can allow center_cut to be a number between 0 and 1 that tells how much of the center to cut?

whether to cut the extremes of the tract out, by default True
test_plot : bool, optional
whether to subsample the data to get quick plot, by default False
scalar_to_plot : str, optional
scalar to plot (y axis), by default 'dki_fa'
var_to_plot : str, optional
variable of interest name if `var_df` contains several variables, by default None

Returns
-------
sns.FacetGrid
FacetGrid
"""

if var_to_plot is None:
available = list(set(var_df.columns) - set(['subjectID','sessionID']))
if len(available)>1:
raise ValueError(f"Please specify `var_to_plot` as there are "
f"several options available in `var_df`.")
var_to_plot = available[0]
else:
if var_to_plot not in var_df.columns:
raise ValueError(f"{var_to_plot=} not available in `var_df` columns.")
pierre-nedelec marked this conversation as resolved.
Show resolved Hide resolved

tracto_df = tu.optimize(tracto_df)
tracto_df = tracto_df[['subjectID','sessionID','tractID','nodeID',scalar_to_plot]]

if test_plot:
tracto_df = tu.tracto_subsample(tracto_df)
dpi = 150
else:
dpi = 300

if center_cut:
tracto_df = tu.center_cut(tracto_df)

df = tu.merge_profiles_varinterest(tracto_df, var_df)
df = tu.beautify(df)

available_tracts = df.tractID.unique()
tract_order = [v for k,v in tu.BUNDLE_DICT.items() if k in available_tracts]
col_wrap = min(5, len(available_tracts))
g = sns.FacetGrid(
data=df,
col_wrap=col_wrap,
col='tractID_b', col_order=tract_order,
hue=var_to_plot, #hue_order=list(reversed(df[var_to_plot].cat.categories)),
margin_titles=True, palette=palette,
sharey='row',
height=3.5, aspect=0.9,
legend_out=True,
# subplot_kws=dict(figsize=(3.5,10))
)
g.map(sns.lineplot, 'nodeID', scalar_to_plot, estimator=np.nanmean, errorbar=('se',1))
g.set_titles(col_template='{col_name}', row_template='{row_name}')
g.add_legend(title=var_to_plot)
g.fig.suptitle(f'{scalar_to_plot.upper()} values along select WM tracts for {var_to_plot}', y=1)
g.tight_layout()
g.fig.set_dpi(dpi)

return g

Empty file added afqinsight/tools/__init__.py
Empty file.
152 changes: 152 additions & 0 deletions afqinsight/tools/tractutils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
"""Utility functions to manipulate tractometry type dataframe
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that these functions could be added to the existing "utils.py" file.

"""

import numpy as np
import pandas as pd
import logging
logger = logging.getLogger(__name__)

BUNDLE_DICT = {
'CST_R': 'Right Corticospinal',
'CST_L': 'Left Corticospinal',
'UNC_R': 'Right Uncinate',
'UNC_L': 'Left Uncinate',
'IFO_R': 'Right IFOF',
'IFO_L': 'Left IFOF',
'ARC_R': 'Right Arcuate',
'ARC_L': 'Left Arcuate',
'ATR_R': 'Right Thalamic Radiation',
'ATR_L': 'Left Thalamic Radiation',
'CGC_R': 'Right Cingulum Cingulate',
'CGC_L': 'Left Cingulum Cingulate',
'HCC_R': 'Right Cingulum Hippocampus',
'HCC_L': 'Left Cingulum Hippocampus',
'FP': 'Callosum Forceps Major',
'FA': 'Callosum Forceps Minor',
'ILF_R': 'Right ILF',
'ILF_L': 'Left ILF',
'SLF_R': 'Right SLF',
'SLF_L': 'Left SLF',
'VOF_R': 'Right Vertical Occipital',
'VOF_L': 'Left Vertical Occipital',
'pARC_R': 'Right Posterior Arcuate',
'pARC_L': 'Left Posterior Arcuate',
'AntFrontal': 'Callosum: AntFrontal',
'Motor': 'Callosum: Motor',
'Occipital': 'Callosum: Occipital',
'Orbital': 'Callosum: Orbital',
'PostParietal': 'Callosum: PostParietal',
'SupFrontal': 'Callosum: SupFrontal',
'SupParietal': 'Callosum: SupParietal',
'Temporal': 'Callosum: Temporal',
}


def merge_profiles_varinterest(
tracto_df:pd.DataFrame,
var_df:pd.DataFrame,
show_info=True,
) -> pd.DataFrame:
"""Merge tractometry and variable of interest dataframes

Parameters
----------
tracto_df : pd.DataFrame
tractometry dataframe
var_df : pd.DataFrame
dataframe containing variables of interest
show_info : bool, optional
Whether to show what's beeing dropped with merge, by default True

Returns
-------
pd.DataFrame
Merged dataframe
"""

var_df = optimize(var_df)
tracto_df = optimize(tracto_df)

df = pd.merge(
left=tracto_df,
right=var_df,
on=['subjectID','sessionID'],
how="outer")
#TODO: show how many subjects we lose here
if show_info:
logger.info("Not implemented yet")
df = df.dropna().reset_index(drop=True)
return df



def center_cut(df: pd.DataFrame, cut: tuple=(25,75)) -> pd.DataFrame:
""" Returns dataframe where the nodeID is cut between two indicated values.

Args:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use the numpy docstring standard (used in other functions you wrote) here as well.

df (pd.DataFrame): dataframe with column `nodeID` to be filtered on.
cut (tuple, optional): extreme values to filter nodeID on. Defaults to (25,75).

Returns:
pd.DataFrame: filtered dataframe
"""

df = df[(df.nodeID >= cut[0]) & (df.nodeID <= cut[1])]
return df.reset_index(drop=True)



def optimize(df:pd.DataFrame) -> pd.DataFrame:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function name is rather generic. What does "optimize" mean here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And please add a docstring

cols = list(set(['subjectID','sessionID','tractID']) & set(df.columns))
for col in cols:
df[col] = df[col].astype('category').cat.remove_unused_categories()
if 'nodeID' in df.columns:
df['nodeID'] = df['nodeID'].astype(np.int16)
return df


def beautify(tracto: pd.DataFrame) -> pd.DataFrame:
"""Makes tracto_df ready for plotting

Args:
tracto (pd.DataFrame): tractometry dataframe.

Returns:
pd.DataFrame: tractometry dataframe, with embelishments for
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sentence seems to be curtailed.

"""

for col in tracto.columns:
if tracto[col].dtype == 'category':
tracto[col] = tracto[col].cat.remove_unused_categories()
if "tractID_b" not in tracto.columns:
tracto = tracto.assign(tractID_b=tracto.tractID.map(BUNDLE_DICT))
return tracto


def tracto_subsample(
tracto_df:pd.DataFrame,
subjects=5,
sessions=1,
tracts=2,
) -> pd.DataFrame:

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs a docstring.

if isinstance(subjects, int):
subjects = np.random.choice(
tracto_df.subjectID.unique(), subjects, replace=False)
if isinstance(sessions, int):
sessions = np.random.choice(
tracto_df.sessionID.unique(), sessions, replace=False)
if isinstance(tracts, int):
tracts = np.random.choice(
tracto_df.tractID.unique(), tracts, replace=False)

df = tracto_df[
(tracto_df.subjectID.isin(subjects)) &
(tracto_df.sessionID.isin(sessions)) &
(tracto_df.tractID.isin(tracts))
]
for col in df.columns:
if df[col].dtype == 'category':
df[col] = df[col].cat.remove_unused_categories()
return df