Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(fix): resolve data ordering to match axis for stacked violin plots #3196

Merged
merged 12 commits into from
Aug 6, 2024
1 change: 1 addition & 0 deletions docs/release-notes/1.10.3.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@
* Add compatibility with {mod}`numpy` 2.0 {pr}`3065` and {pr}`3115` {smaller}`P Angerer`
* Fix `legend_loc` argument in {func}`scanpy.pl.embedding` not accepting matplotlib parameters {pr}`3163` {smaller}`P Angerer`
* Fix dispersion cutoff in {func}`~scanpy.pp.highly_variable_genes` in presence of `NaN`s {pr}`3176` {smaller}`P Angerer`
* Fix axis labeling for swapped axes in {func}`~scanpy.pl.rank_genes_groups_stacked_violin` {pr}`3196` {smaller}`Ilan Gold`

#### Performance
39 changes: 30 additions & 9 deletions src/scanpy/plotting/_stacked_violin.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,8 +410,27 @@ def _mainplot(self, ax: Axes):
colormap_array = cmap(normalize(_color_df.values))
x_spacer_size = self.plot_x_padding
y_spacer_size = self.plot_y_padding

# All columns should have a unique name, yet, frequently
# gene names are repeated in self.var_names, otherwise the
# violin plot will not distinguish those genes
_matrix.columns = [f"{x}_{idx}" for idx, x in enumerate(_matrix.columns)]

# Ensure the categories axis is always ordered identically.
# If the axes are not swapped, the above _matrix.columns is used in the actual violin plot (i.e., unique names).
# If they are swapped, then use the same as the labels used below.
# Without this, `_make_rows_of_violinplots` does not know about the order of the categories in labels.
labels = _color_df.columns
x_axis_order = labels if self.are_axes_swapped else _matrix.columns

self._make_rows_of_violinplots(
ax, _matrix, colormap_array, _color_df, x_spacer_size, y_spacer_size
ax,
_matrix,
colormap_array,
_color_df,
x_spacer_size,
y_spacer_size,
x_axis_order,
)

# turn on axis for `ax` as this is turned off
Expand All @@ -434,7 +453,6 @@ def _mainplot(self, ax: Axes):
# 0.5 to position the ticks on the center of the violins
x_ticks = np.arange(_color_df.shape[1]) + 0.5
ax.set_xticks(x_ticks)
labels = _color_df.columns
ax.set_xticklabels(labels, minor=False, ha="center")
# rotate x tick labels if they are longer than 2 characters
if max([len(x) for x in labels]) > 2:
Expand All @@ -445,7 +463,14 @@ def _mainplot(self, ax: Axes):
return normalize

def _make_rows_of_violinplots(
self, ax, _matrix, colormap_array, _color_df, x_spacer_size, y_spacer_size
self,
ax,
_matrix,
colormap_array,
_color_df,
x_spacer_size,
y_spacer_size,
x_axis_order,
):
import seaborn as sns # Slow import, only import if called

Expand All @@ -460,11 +485,6 @@ def _make_rows_of_violinplots(
else:
row_colors = [None] * _color_df.shape[0]

# All columns should have a unique name, yet, frequently
# gene names are repeated in self.var_names, otherwise the
# violin plot will not distinguish those genes
_matrix.columns = [f"{x}_{idx}" for idx, x in enumerate(_matrix.columns)]

# transform the dataframe into a dataframe having three columns:
# the categories name (from groupby),
# the gene name
Expand Down Expand Up @@ -543,9 +563,10 @@ def _make_rows_of_violinplots(
hue=None if palette_colors is None else x,
palette=palette_colors,
color=row_colors[idx],
order=x_axis_order,
hue_order=x_axis_order,
**self.kwds,
)

if self.stripplot:
row_ax = sns.stripplot(
x=x,
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
24 changes: 24 additions & 0 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,30 @@ def test_stacked_violin_obj(image_comparer, plt):
save_and_compare_images("stacked_violin_return_fig")


# checking for https://github.com/scverse/scanpy/issues/3152
def test_stacked_violin_swap_axes_match(image_comparer):
save_and_compare_images = partial(image_comparer, ROOT, tol=10)
pbmc = pbmc68k_reduced()
sc.tl.rank_genes_groups(
pbmc,
"bulk_labels",
method="wilcoxon",
tie_correct=True,
pts=True,
key_added="wilcoxon",
)
swapped_ax = sc.pl.rank_genes_groups_stacked_violin(
pbmc,
n_genes=2,
key="wilcoxon",
groupby="bulk_labels",
swap_axes=True,
return_fig=True,
)
swapped_ax.show()
save_and_compare_images("stacked_violin_swap_axes_pbmc68k_reduced")


def test_tracksplot(image_comparer):
save_and_compare_images = partial(image_comparer, ROOT, tol=15)

Expand Down
Loading