Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve docs docstrings #134

Merged
merged 6 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 57 additions & 3 deletions src/scportrait/pipeline/_utils/sdata_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import xarray
from alphabase.io import tempmmap
from spatialdata import SpatialData
from spatialdata.models import PointsModel, TableModel
from spatialdata.models import Image2DModel, PointsModel, TableModel
from spatialdata.transformations.transformations import Identity

from scportrait.pipeline._base import Logable
Expand All @@ -19,7 +19,8 @@
get_chunk_size,
)

ChunkSize: TypeAlias = tuple[int, int]
ChunkSize2D: TypeAlias = tuple[int, int]
ChunkSize3D: TypeAlias = tuple[int, int, int]
ObjectType: TypeAlias = Literal["images", "labels", "points", "tables"]


Expand Down Expand Up @@ -143,6 +144,58 @@ def _get_input_image(self, sdata: SpatialData) -> xarray.DataArray:
return input_image

## write elements to sdata object
def _write_image_sdata(
self,
image,
image_name: str,
channel_names: list[str] = None,
scale_factors: list[int] = None,
chunks: ChunkSize3D = (1, 1000, 1000),
overwrite=False,
):
"""
Write the supplied image to the spatialdata object.

Args:
image (dask.array): Image to be written to the spatialdata object.
image_name (str): Name of the image to be written to the spatialdata object.
channel_names list[str]: List of channel names for the image. Default is None.
scale_factors list[int]: List of scale factors for the image. Default is [2, 4, 8]. This will load the image at 4 different resolutions to allow for fluid visualization.
chunks (tuple): Chunk size for the image. Default is (1, 1000, 1000).
overwrite (bool): Whether to overwrite existing data. Default is False.
"""

if scale_factors is None:
scale_factors = [2, 4, 8]
if scale_factors is None:
scale_factors = [2, 4, 8]

_sdata = self._read_sdata()

if channel_names is None:
channel_names = [f"channel_{i}" for i in range(image.shape[0])]

# transform to spatialdata image model
transform_original = Identity()
image = Image2DModel.parse(
image,
dims=["c", "y", "x"],
chunks=chunks,
c_coords=channel_names,
scale_factors=scale_factors,
transformations={"global": transform_original},
rgb=False,
)

if overwrite:
self._force_delete_object(_sdata, image_name, "images")

_sdata.images[image_name] = image
_sdata.write_element(image_name, overwrite=True)

self.log(f"Image {image_name} written to sdata object.")
self._check_sdata_status()

def _write_segmentation_object_sdata(
self,
segmentation_object: spLabels2DModel,
Expand Down Expand Up @@ -177,7 +230,7 @@ def _write_segmentation_sdata(
segmentation: xarray.DataArray | np.ndarray,
segmentation_label: str,
classes: set[str] | None = None,
chunks: ChunkSize = (1000, 1000),
chunks: ChunkSize2D = (1000, 1000),
overwrite: bool = False,
) -> None:
"""Write segmentation data to SpatialData.
Expand Down Expand Up @@ -268,6 +321,7 @@ def _add_centers(self, segmentation_label: str, overwrite: bool = False) -> None
centroids_object = self._get_centers(_sdata, segmentation_label)
self._write_points_object_sdata(centroids_object, self.centers_name, overwrite=overwrite)

## load elements from sdata to a memory mapped array
def _load_input_image_to_memmap(
self, tmp_dir_abs_path: str | Path, image: np.typing.NDArray[Any] | None = None
) -> str:
Expand Down
6 changes: 4 additions & 2 deletions src/scportrait/pipeline/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,8 +842,10 @@ def process(self, partial=False, n_cells=None, seed=42):
self.log("Loading input images to memory mapped arrays...")
start_data_transfer = timeit.default_timer()

self.path_seg_masks = self.project._load_seg_to_memmap(seg_name=self.masks, tmp_dir_abs_path=self._tmp_dir_path)
self.path_image_data = self.project._load_input_image_to_memmap(tmp_dir_abs_path=self._tmp_dir_path)
self.path_seg_masks = self.filehandler._load_seg_to_memmap(
seg_name=self.masks, tmp_dir_abs_path=self._tmp_dir_path
)
self.path_image_data = self.filehandler._load_input_image_to_memmap(tmp_dir_abs_path=self._tmp_dir_path)

stop_data_transfer = timeit.default_timer()
time_data_transfer = stop_data_transfer - start_data_transfer
Expand Down
Loading
Loading