Skip to content

Commit

Permalink
add visualization function for overlaying segmentation masks
Browse files Browse the repository at this point in the history
  • Loading branch information
sophiamaedler committed May 27, 2024
1 parent 886083c commit ace345b
Showing 1 changed file with 59 additions and 0 deletions.
59 changes: 59 additions & 0 deletions src/sparcscore/utils/vis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import h5py
import matplotlib.pyplot as plt
import os


def segmentation_mask(
project,
mask_channel=0,
image_channel=0,
selection=None,
cmap_image="Greys_r",
cmap_masks="jet",
alpha=0.5,
):
"""
Visualize the segmentation mask overlayed with a channel of the input image.
Parameters
----------
project : sparcspy.pipeline.project.Project
instance of a sparcspy project.
mask_channel : int, optional
The index of the channel to use for the segmentation mask (default: 0).
image_channel : int, optional
The index of the channel to use for the image (default: 0).
selection : tuple(slice, slice), optional
The selection coordinates for a specific region of interest (default: None).
cmap_image : str, optional
The colormap to use for the input image (default: "Greys_r").
cmap_masks : str, optional
The colormap to use for the segmentation mask (default: "jet").
alpha : float, optional
The transparency level of the segmentation mask (default: 0.5).
Returns
-------
fig : object
The generated figure object.
"""
segmentation_file = os.path.join(
project.seg_directory, project.segmentation_f.DEFAULT_OUTPUT_FILE
)

with h5py.File(segmentation_file, "r") as hf:
segmentation = hf.get(project.segmentation_f.DEFAULT_MASK_NAME)
channels = hf.get(project.segmentation_f.DEFAULT_CHANNELS_NAME)

if selection is None:
segmentation = segmentation[mask_channel, :, :]
image = channels[image_channel, :, :]
else:
segmentation = segmentation[mask_channel, selection[0], selection[1]]
image = channels[image_channel, selection[0], selection[1]]

fig = plt.figure()
plt.imshow(image, cmap=cmap_image)
plt.imshow(segmentation, alpha=alpha, cmap=cmap_masks)
plt.axis("off")
return fig

0 comments on commit ace345b

Please sign in to comment.