Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport PR #2734 on branch 1.9.x (Make _validate_palette work with arrays) #2735

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/release-notes/1.9.7.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@

```{rubric} Bug fixes
```
- Fix handling of numpy array palettes (e.g. after write-read cycle) {pr}`2734` {smaller}`P Angerer`
7 changes: 4 additions & 3 deletions scanpy/plotting/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def default_palette(
return palette


def _validate_palette(adata, key):
def _validate_palette(adata: anndata.AnnData, key: str) -> None:
"""
checks if the list of colors in adata.uns[f'{key}_colors'] is valid
and updates the color list in adata.uns[f'{key}_colors'] if needed.
Expand Down Expand Up @@ -354,8 +354,9 @@ def _validate_palette(adata, key):
break
_palette.append(color)
# Don't modify if nothing changed
if _palette is not None and list(_palette) != list(adata.uns[color_key]):
adata.uns[color_key] = _palette
if _palette is None or np.equal(_palette, adata.uns[color_key]).all():
return
adata.uns[color_key] = _palette


def _set_colors_for_categorical_obs(
Expand Down
28 changes: 28 additions & 0 deletions scanpy/tests/test_plotting_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import cast
import numpy as np
import pytest

from anndata import AnnData
from matplotlib import colormaps
from matplotlib.colors import ListedColormap

from scanpy.plotting._utils import _validate_palette


viridis = cast(ListedColormap, colormaps["viridis"])


@pytest.mark.parametrize(
"palette",
[
pytest.param(viridis.colors, id="viridis"),
pytest.param(["b", "#cccccc", "r", "yellow", "lightblue"], id="named"),
pytest.param([(1, 0, 0, 1), (0, 0, 1, 1)], id="rgba"),
],
)
@pytest.mark.parametrize("typ", [np.asarray, list])
def test_validate_palette_no_mod(palette, typ):
palette = typ(palette)
adata = AnnData(uns=dict(test_colors=palette))
_validate_palette(adata, "test")
assert palette is adata.uns["test_colors"], "Palette should not be modified"
Loading