Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Visium hd #211

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 88 additions & 6 deletions src/spatialdata_io/readers/visium_hd.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
from geopandas import GeoDataFrame
from imageio import imread as imread2
from multiscale_spatial_image import MultiscaleSpatialImage
from numpy.random import default_rng
from skimage.transform import estimate_transform
from spatial_image import SpatialImage
from spatialdata import SpatialData
from spatialdata.models import Image2DModel, ShapesModel, TableModel
from spatialdata.models import Image2DModel, Labels2DModel, ShapesModel, TableModel
from spatialdata.transformations import (
Affine,
Identity,
Expand All @@ -40,6 +42,7 @@ def visium_hd(
filtered_counts_file: bool = True,
bin_size: int | list[int] | None = None,
bins_as_squares: bool = True,
annotate_table_by_labels: bool = False,
fullres_image_file: str | Path | None = None,
load_all_images: bool = False,
imread_kwargs: Mapping[str, Any] = MappingProxyType({}),
Expand Down Expand Up @@ -68,6 +71,8 @@ def visium_hd(
bins_as_squares
If `True`, the bins are represented as squares. If `False`, the bins are represented as circles. For a correct
visualization one should use squares.
annotate_table_by_labels
If `True` will annotate the table with corresponding labels layer representing the bins, if `False`, table will be annotated by a shapes layer.
fullres_image_file
Path to the full-resolution image. By default the image is searched in the ``{vx.MICROSCOPE_IMAGE!r}``
directory.
Expand All @@ -90,6 +95,7 @@ def visium_hd(
tables = {}
shapes = {}
images: dict[str, Any] = {}
labels: dict[str, Any] = {}

if dataset_id is None:
dataset_id = _infer_dataset_id(path)
Expand Down Expand Up @@ -191,7 +197,17 @@ def _get_bins(path: Path) -> list[str]:
VisiumHDKeys.LOCATIONS_X,
]
)
# let instance key range from 1 to coords.index.stop+1
assert isinstance(coords.index, pd.RangeIndex)
assert coords.index.start == 0
coords.index = coords.index + 1
dtype = _get_uint_dtype(coords.index.stop)

coords = coords.reset_index().rename(columns={"index": VisiumHDKeys.INSTANCE_KEY})
coords[VisiumHDKeys.INSTANCE_KEY] = coords[VisiumHDKeys.INSTANCE_KEY].astype(dtype)

coords.set_index(VisiumHDKeys.BARCODE, inplace=True, drop=True)

coords_filtered = coords.loc[adata.obs.index]
adata.obs = pd.merge(adata.obs, coords_filtered, how="left", left_index=True, right_index=True)
# compatibility to legacy squidpy
Expand All @@ -204,7 +220,6 @@ def _get_bins(path: Path) -> list[str]:
],
inplace=True,
)
adata.obs[VisiumHDKeys.INSTANCE_KEY] = np.arange(len(adata))

# scaling
transform_original = Identity()
Expand Down Expand Up @@ -249,13 +264,64 @@ def _get_bins(path: Path) -> list[str]:
GeoDataFrame(geometry=squares_series), transformations=transformations
)

# parse table
adata.obs[VisiumHDKeys.REGION_KEY] = shapes_name
# add labels layer (rasterized bins).
labels_name = f"{dataset_id}_{bin_size_str}_labels"

min_row, min_col = adata.obs[VisiumHDKeys.ARRAY_ROW].min(), adata.obs[VisiumHDKeys.ARRAY_COL].min()
n_rows, n_cols = (
adata.obs[VisiumHDKeys.ARRAY_ROW].max() - min_row + 1,
adata.obs[VisiumHDKeys.ARRAY_COL].max() - min_col + 1,
)
y = (adata.obs[VisiumHDKeys.ARRAY_ROW] - min_row).values
x = (adata.obs[VisiumHDKeys.ARRAY_COL] - min_col).values

labels_element = np.zeros((n_rows, n_cols), dtype=dtype)

# make image that can visualy represent the cells
labels_element[y, x] = adata.obs[VisiumHDKeys.INSTANCE_KEY].values.T

# estimate the transformation to go from this raster to original dimension (i.e. in pixel coordinates)
RNG = default_rng(0)

# get the transformation
if adata.n_obs < 6:
raise ValueError("At least 6 bins are needed to estimate the transformation.")

random_indices = RNG.choice(adata.n_obs, min(100, adata.n_obs), replace=True)

sub_adata_for_transform = adata[random_indices]

src = np.stack(
[
sub_adata_for_transform.obs[VisiumHDKeys.ARRAY_COL] - min_col,
sub_adata_for_transform.obs[VisiumHDKeys.ARRAY_ROW] - min_row,
],
axis=1,
)
dst = sub_adata_for_transform.obsm["spatial"] # this is x, y

to_bins = Sequence(
[
Affine(
estimate_transform(ttype="affine", src=src, dst=dst).params,
input_axes=("x", "y"),
output_axes=("x", "y"),
)
]
)

labels_transformations = {cs: to_bins.compose_with(t) for cs, t in transformations.items()}
labels_element = Labels2DModel.parse(
data=labels_element, dims=("y", "x"), transformations=labels_transformations
)
labels[labels_name] = labels_element

adata.obs[VisiumHDKeys.REGION_KEY] = labels_name if annotate_table_by_labels else shapes_name
adata.obs[VisiumHDKeys.REGION_KEY] = adata.obs[VisiumHDKeys.REGION_KEY].astype("category")

tables[bin_size_str] = TableModel.parse(
adata,
region=shapes_name,
region=labels_name if annotate_table_by_labels else shapes_name,
region_key=str(VisiumHDKeys.REGION_KEY),
instance_key=str(VisiumHDKeys.INSTANCE_KEY),
)
Expand Down Expand Up @@ -349,7 +415,7 @@ def _get_bins(path: Path) -> list[str]:
affine1 = transform_matrices["spot_colrow_to_microscope_colrow"]
set_transformation(image, Sequence([affine0, affine1]), "global")

return SpatialData(tables=tables, images=images, shapes=shapes)
return SpatialData(tables=tables, images=images, shapes=shapes, labels=labels)


def _infer_dataset_id(path: Path) -> str:
Expand Down Expand Up @@ -424,3 +490,19 @@ def _get_transform_matrices(metadata: dict[str, Any], hd_layout: dict[str, Any])
transform_matrices[key.value] = _get_affine(data)

return transform_matrices


def _get_uint_dtype(value: int) -> str:
max_uint64 = np.iinfo(np.uint64).max
max_uint32 = np.iinfo(np.uint32).max
max_uint16 = np.iinfo(np.uint16).max

if max_uint16 >= value:
dtype = "uint16"
elif max_uint32 >= value:
dtype = "uint32"
elif max_uint64 >= value:
dtype = "uint64"
else:
raise ValueError(f"Maximum cell number is {value}. Values higher than {max_uint64} are not supported.")
return dtype
Loading