Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interactive viewer #31

Merged
merged 3 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ dependencies = [
"zarr",
"lmdb",
"kornia",
"mpl_interactions",
]

[[project.source]]
Expand Down
49 changes: 48 additions & 1 deletion src/nimbus_inference/example_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from typing import Union
import datasets
from alpineer.misc_utils import verify_in_list
import zipfile
import os
import requests

EXAMPLE_DATASET_REVISION: str = "main"

Expand Down Expand Up @@ -214,4 +217,48 @@ def get_example_dataset(dataset: str, save_dir: Union[str, pathlib.Path],
example_dataset.download_example_dataset()

# Move the dataset over to the save_dir from the user.
example_dataset.move_example_dataset(move_dir=save_dir)
example_dataset.move_example_dataset(move_dir=save_dir)


def download_and_unpack_gold_standard(save_dir: Union[str, pathlib.Path], overwrite_existing: bool = True):
"""
Downloads 'gold_standard_labelled.zip' from the Hugging Face dataset and unpacks it in the given folder
if the dataset is not already present there.

Args:
save_dir (Union[str, Path]): The path to save the dataset files in.
overwrite_existing (bool): The option to overwrite existing files. Defaults to True.
"""
url = "https://huggingface.co/datasets/JLrumberger/Pan-Multiplex-Gold-Standard/resolve/main/gold_standard_labelled.zip"
save_dir = pathlib.Path(save_dir)
zip_path = save_dir / "gold_standard_labelled.zip"

# Create the save directory if it doesn't exist
save_dir.mkdir(parents=True, exist_ok=True)

# Check if the dataset is already present
if zip_path.exists() and not overwrite_existing:
print(f"{zip_path} already exists. Skipping download.")
return

# Download the zip file
print(f"Downloading {url} to {zip_path}...")
response = requests.get(url, stream=True)
response.raise_for_status()

with open(zip_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)

print(f"Downloaded {zip_path}")

# Unpack the zip file
print(f"Unpacking {zip_path}...")
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(save_dir)

print(f"Unpacked to {save_dir}")

# Optionally, remove the zip file after unpacking
os.remove(zip_path)
print(f"Removed {zip_path}")
15 changes: 12 additions & 3 deletions src/nimbus_inference/nimbus.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,18 @@ def segmentation_naming_convention(fov_path):
Returns:
str: paths to segmentation fovs
"""
fov_name = os.path.basename(fov_path).replace(".ome.tiff", "")
return os.path.join(deepcell_output_dir, fov_name + "_whole_cell.tiff")

fov_name = os.path.basename(fov_path)
# remove suffix
fov_name = Path(fov_name).stem
# find all fnames which contain a superset of the fov_name
fnames = os.listdir(deepcell_output_dir)
# use re instead of glob
fnames = [os.path.join(deepcell_output_dir, f) for f in fnames if fov_name in f]
if len(fnames) == 0:
raise ValueError(f"No segmentation data found for fov {fov_name}")
if len(fnames) > 1:
raise ValueError(f"Multiple segmentation data found for fov {fov_name}")
return fnames[0]
return segmentation_naming_convention


Expand Down
57 changes: 56 additions & 1 deletion src/nimbus_inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,4 +717,59 @@ def __getitem__(self, idx):
input_data = sample[:2]
groundtruth = sample[2:3]
inst_mask = sample[3:]
return input_data, groundtruth, inst_mask, self.keys[idx]
return input_data, groundtruth, inst_mask, self.keys[idx]


class InteractiveDataset(object):
"""Dataset for the InteractiveViewer class. This dataset class stores multiple objects of type
MultiplexedDataset, and allows to select a dataset and use its method for reading fovs and
channels from it.

Args:
datasets (dict): dictionary with dataset names as keys and dataset objects as values
"""
def __init__(self, datasets: dict):
self.datasets = datasets
self.dataset_names = list(datasets.keys())
self.dataset = None

def set_dataset(self, dataset_name: str):
"""Set the active dataset

Args:
dataset_name (str): name of the dataset
"""
self.dataset = self.datasets[dataset_name]
return self.dataset

def get_channel(self, fov: str, channel: str):
"""Get a channel from a fov

Args:
fov (str): name of a fov
channel (str): channel name
Returns:
np.array: channel image
"""
return self.dataset.get_channel(fov, channel)

def get_segmentation(self, fov: str):
"""Get the instance mask for a fov

Args:
fov (str): name of a fov
Returns:
np.array: instance mask
"""
return self.dataset.get_segmentation(fov)

def get_groundtruth(self, fov: str, channel: str):
"""Get the groundtruth for a fov / channel combination

Args:
fov (str): name of a fov
channel (str): channel name
Returns:
np.array: groundtruth activity mask (0: negative, 1: positive, 2: ambiguous)
"""
return self.dataset.get_groundtruth(fov, channel)
224 changes: 223 additions & 1 deletion src/nimbus_inference/viewer_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from natsort import natsorted
from skimage.segmentation import find_boundaries
from skimage.transform import rescale
from nimbus_inference.utils import MultiplexDataset
from nimbus_inference.utils import MultiplexDataset, InteractiveDataset
from mpl_interactions import panhandler
import matplotlib.pyplot as plt

class NimbusViewer(object):
"""Viewer for Nimbus application.
Expand Down Expand Up @@ -277,3 +279,223 @@ def display(self):
self.select_fov(None)
self.layout()
self.update_composite()


class InteractiveImageDuo(widgets.Image):
"""Interactive image viewer for Nimbus application.

Args:
figsize (tuple): Size of figure.
title_left (str): Title of left image.
title_right (str): Title of right image.
"""
def __init__(self, figsize=(10, 5), title_left='Multiplexed image', title_right='Groundtruth'):
super().__init__()
self.title_left = title_left
self.title_right = title_right

# Initialize matplotlib figure
with plt.ioff():
self.fig, self.ax = plt.subplots(1, 2, sharex=True, sharey=True, figsize=figsize)

# uncomment the following lines to enable zooming via scroll wheel
# self.zoom_handler = self.custom_zoom_factory(self.ax[0])
# self.pan_handler = panhandler(self.fig)

# Display the figure canvas
display(self.fig.canvas)

def custom_zoom_factory(self, ax, base_scale=1.1):
"""Enable zooming via scroll wheel on matplotlib axes.

Args:
ax (matplotlib ax): ax to enable zooming on.
base_scale (float): Scale factor for zooming.
"""
def zoom(event):
cur_xlim = ax.get_xlim()
cur_ylim = ax.get_ylim()
xdata = event.xdata # get event x location
ydata = event.ydata # get event y location

if event.button == 'up':
scale_factor = 1 / base_scale
elif event.button == 'down':
scale_factor = base_scale
else:
scale_factor = 1
print(event.button)

new_width = (cur_xlim[1] - cur_xlim[0]) * scale_factor
new_height = (cur_ylim[1] - cur_ylim[0]) * scale_factor

relx = (cur_xlim[1] - xdata) / (cur_xlim[1] - cur_xlim[0])
rely = (cur_ylim[1] - ydata) / (cur_ylim[1] - cur_ylim[0])

ax.set_xlim([xdata - new_width * (1 - relx), xdata + new_width * (relx)])
ax.set_ylim([ydata - new_height * (1 - rely), ydata + new_height * (rely)])
ax.figure.canvas.draw_idle()

fig = ax.get_figure() # get the figure of interest
fig.canvas.mpl_connect('scroll_event', zoom)

return zoom

def update_left_image(self, image):
"""Update the left image displayed in the viewer.

Args:
image (np.array): Image to display.
"""
self.ax[0].imshow(image)
self.ax[0].title.set_text(self.title_left)
self.ax[0].set_xticks([])
self.ax[0].set_yticks([])
self.fig.canvas.draw_idle()

def update_right_image(self, image):
"""Update the right image displayed in the viewer.

Args:
image (np.array): Image to display.
"""
self.ax[1].imshow(image, vmin=0, vmax=255)
self.ax[1].title.set_text(self.title_right)
self.ax[1].set_xticks([])
self.ax[1].set_yticks([])
self.fig.canvas.draw_idle()


class NimbusInteractiveGTViewer(NimbusViewer):
"""Interactive viewer for Nimbus application that shows input data and ground truth
side by side.

Args:
dataset (MultiplexDataset): dataset object
output_dir (str): Path to directory containing output of Nimbus application.
figsize (tuple): Size of figure.
"""
def __init__(
self, datasets: InteractiveDataset, output_dir, figsize=(20, 10)
):
super().__init__(
datasets.datasets[datasets.dataset_names[0]], output_dir
)
self.image = InteractiveImageDuo(figsize=figsize)
self.dataset = datasets.datasets[datasets.dataset_names[0]]
self.datasets = datasets
self.dataset_select = widgets.Select(
options=datasets.dataset_names,
description='Dataset:',
disabled=False
)
self.dataset_select.observe(self.select_dataset, names='value')

def layout(self):
"""Creates layout for viewer."""
channel_selectors = widgets.HBox([
self.red_select,
self.green_select,
self.blue_select
])
layout = widgets.HBox([
# widgets.HBox([
self.dataset_select,
self.fov_select,
channel_selectors,
self.overlay_checkbox,
self.update_button
# ]),
])
display(layout)

def select_dataset(self, change):
"""Selects dataset to display.

Args:
change (dict): Change dictionary from ipywidgets.
"""
self.dataset = self.datasets.set_dataset(change['new'])
self.fov_names = natsorted(copy(self.dataset.fovs))
self.fov_select.options = self.fov_names
self.select_fov(None)


def update_img(self, image_fn, composite_image):
"""Updates image in viewer by saving it as png and loading it with the viewer widget.

Args:
ax (matplotlib ax): ax to update.
composite_image (np.array): Composite image to display.
"""
if composite_image.shape[0] > self.max_resolution[0] or composite_image.shape[1] > self.max_resolution[1]:
scale = float(np.max(self.max_resolution)/np.max(composite_image.shape))
composite_image = rescale(composite_image, (scale, scale, 1), preserve_range=True)
composite_image = composite_image.astype(np.uint8)
image_fn(composite_image)

def update_composite(self):
"""Updates composite image in viewer."""
path_dict = {
"red": None,
"green": None,
"blue": None
}
in_path_dict = copy(path_dict)
if self.red_select.value:
path_dict["red"] = os.path.join(
self.output_dir, self.fov_select.value, self.red_select.value + self.suffix
)
in_path_dict["red"] = {"fov": self.fov_select.value, "channel": self.red_select.value}
if self.green_select.value:
path_dict["green"] = os.path.join(
self.output_dir, self.fov_select.value, self.green_select.value + self.suffix
)
in_path_dict["green"] = {
"fov": self.fov_select.value, "channel": self.green_select.value
}
if self.blue_select.value:
path_dict["blue"] = os.path.join(
self.output_dir, self.fov_select.value, self.blue_select.value + self.suffix
)
in_path_dict["blue"] = {
"fov": self.fov_select.value, "channel": self.blue_select.value
}
non_none = [p for p in path_dict.values() if p]
if not non_none:
return

in_composite_image = self.create_composite_from_dataset(in_path_dict)
in_composite_image, seg_boundaries = self.overlay(
in_composite_image, add_boundaries=self.overlay_checkbox.value
)
in_composite_image = in_composite_image / np.quantile(
in_composite_image, 0.999, axis=(0,1)
)
in_composite_image = np.clip(in_composite_image*255, 0, 255).astype(np.uint8)
if seg_boundaries is not None:
in_composite_image[seg_boundaries] = [127, 127, 127]

img = in_composite_image[...,0].astype(np.float32) * 0
right_images = []
for c, s in {'red': self.red_select.value,
'green': self.green_select.value,
'blue': self.blue_select.value}.items():
if s:
composite_image = self.dataset.get_groundtruth(
self.fov_select.value, s
)
else:
composite_image = img
composite_image = np.squeeze(composite_image).astype(np.float32)
right_images.append(composite_image)
right_images = np.stack(right_images, axis=-1)
right_images = np.clip(right_images, 0, 2)
right_images[right_images == 2] = 0.3
right_images[seg_boundaries] = 0.0
right_images *= 255.0
right_images = right_images.astype(np.uint8)

# update image viewers
self.update_img(self.image.update_left_image, in_composite_image)
self.update_img(self.image.update_right_image, right_images)
Loading
Loading