Skip to content

Commit

Permalink
Color circles and polygons with annotations (#73)
Browse files Browse the repository at this point in the history
* added support for circles

* added affine transformations

* minor fix

* cleanup

* added color circles by obs

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* added polygons

* pre-commit fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* cleanup

* cleanup

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
rahulbshrestha and pre-commit-ci[bot] authored May 23, 2023
1 parent 51eccdf commit 306e9a2
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 5 deletions.
12 changes: 12 additions & 0 deletions examples/spatialdata_visium.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# In this script, we are reading the Visium dataset into a SpatialData object and visualising
# it with the Interactive class from napari_spatialdata.

# The dataset can be downloaded from https://spatialdata.scverse.org/en/latest/tutorials/notebooks/datasets/README.html

from napari_spatialdata import Interactive
from spatialdata import SpatialData

if __name__ == "__main__":
sdata = SpatialData.read("../data/visium/data.zarr") # Change this path!
i = Interactive(sdata)
i.run()
4 changes: 4 additions & 0 deletions src/napari_spatialdata/_interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,15 @@ def _add_circles(self, key: str) -> None:
self._sdata.table.obs[self._sdata.table.uns["spatialdata_attrs"]["region_key"]] == key
],
"shapes_key": self._sdata.table.uns["spatialdata_attrs"]["region_key"],
"shapes_type": "circles",
},
)

def _add_polygons(self, key: str) -> None:
polygons = []
df = self._sdata.shapes[key]
affine = _get_transform(self._sdata.shapes[key], self.coordinate_system_widget._system)

# when mulitpolygons are present, we select the largest ones
if "MultiPolygon" in np.unique(df.geometry.type):
logger.info("Multipolygons are present in the data. Only the largest polygon per cell is retained.")
Expand Down Expand Up @@ -133,6 +135,7 @@ def _add_polygons(self, key: str) -> None:
self._sdata.table.obs[self._sdata.table.uns["spatialdata_attrs"]["region_key"]] == key
],
"shapes_key": self._sdata.table.uns["spatialdata_attrs"]["region_key"],
"shapes_type": "polygons",
},
)

Expand Down Expand Up @@ -185,6 +188,7 @@ def _add_points(self, key: str) -> None:
logger.info("Subsampling points because the number of points exceeds the currently supported 100 000.")
gen = np.random.default_rng()
subsample = gen.choice(len(points), size=100000, replace=False)

self._viewer.add_points(
points[["y", "x"]].values[subsample],
name=key,
Expand Down
12 changes: 11 additions & 1 deletion src/napari_spatialdata/_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,14 +244,24 @@ def _select_layer(self) -> None:
# if layer is not None and "adata" in layer.metadata:
self.model.adata = layer.metadata["adata"]

self.model.coordinates = np.insert(self.model.adata.obsm[Key.obsm.spatial][:, ::-1][:, :2], 0, values=0, axis=1)
if self.model.adata.shape == (0, 0):
return

if "spatial" in self.model.adata.obsm:
self.model.coordinates = np.insert(
self.model.adata.obsm[Key.obsm.spatial][:, ::-1][:, :2], 0, values=0, axis=1
)

if "points" in layer.metadata:
# TODO: Check if this can be removed
self.model.points_coordinates = layer.metadata["points"].X
self.model.points_var = layer.metadata["points"].obs["gene"]
self.model.point_diameter = np.array([0.0] + [layer.metadata["point_diameter"]] * 2) * self.model.scale

self.model.spot_diameter = np.array([0.0, 10.0, 10.0])
self.model.labels_key = layer.metadata["labels_key"] if isinstance(layer, Labels) else None
self.model.system_name = self.model.layer.name

if "colormap" in layer.metadata:
self.model.cmap = layer.metadata["colormap"]
if hasattr(
Expand Down
29 changes: 25 additions & 4 deletions src/napari_spatialdata/_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np
import pandas as pd
from loguru import logger
from napari.layers import Labels, Layer, Points
from napari.layers import Labels, Layer, Points, Shapes
from napari.viewer import Viewer
from qtpy import QtCore, QtWidgets
from qtpy.QtCore import Qt, Signal
Expand Down Expand Up @@ -108,8 +108,8 @@ def __init__(self, viewer: Viewer | None, model: ImageModel, attr: str, **kwargs
self._model = model

self._attr = attr
self._getter = getattr(self.model, f"get_{attr}")

self._getter = getattr(self.model, f"get_{attr}")
self.layerChanged.connect(self._onChange)
self._onChange()

Expand All @@ -125,6 +125,7 @@ def _onAction(self, items: Iterable[str]) -> None:
except Exception as e: # noqa: BLE001
logger.error(e)
continue

if vec.ndim == 2:
self.viewer.add_points(
vec,
Expand All @@ -136,7 +137,8 @@ def _onAction(self, items: Iterable[str]) -> None:
)
else:
properties = self._get_points_properties(vec, key=item, layer=self.model.layer)
if isinstance(self.model.layer, Points):

if isinstance(self.model.layer, (Points, Shapes)):
self.model.layer.name = (
"" if self.model.system_name is None else self.model.system_name + ":"
) + item
Expand Down Expand Up @@ -201,11 +203,26 @@ def _(self, vec: NDArrayA, **kwargs: Any) -> dict[str, Any]:
norm_vec = _min_max_norm(vec)
color_vec = cmap(norm_vec)
if layer is not None and isinstance(layer, Labels):
cmap = plt.get_cmap(self.model.cmap)
norm_vec = _min_max_norm(vec)
color_vec = cmap(norm_vec)

return {
"color": dict(zip(self.model.adata.obs[self.model.labels_key].values, color_vec)),
"properties": {"value": vec},
"text": None,
}

if layer is not None and isinstance(layer, Shapes):
cmap = plt.get_cmap(self.model.cmap)
norm_vec = _min_max_norm(vec)
color_vec = cmap(norm_vec)

return {
"text": None,
"face_color": color_vec,
}

return {
"text": None,
"face_color": color_vec,
Expand All @@ -217,8 +234,12 @@ def _(self, vec: pd.Series, key: str, layer: Layer) -> dict[str, Any]:
face_color = _get_categorical(
self.model.adata, key=key, palette=self.model.palette, colordict=colortypes, vec=vec
)

if layer is not None and isinstance(layer, Labels):
return {"color": dict(zip(self.model.adata.obs[self.model.labels_key].values, face_color))}
return {"color": dict(zip(self.model.adata.obs[self.model.labels_key].values, face_color)), "text": None}

if layer is not None and isinstance(layer, Shapes):
return {"face_color": face_color, "metadata": None, "text": None}

cluster_labels = _position_cluster_labels(self.model.coordinates, vec)
return {
Expand Down

0 comments on commit 306e9a2

Please sign in to comment.