Skip to content

Commit

Permalink
Port napari widgets from magicgui to qt (#9)
Browse files Browse the repository at this point in the history
* create simple dropdown with qt

* connected qt mask widget

* finished first draft of mask qt widget

* made a chest of widgets

* Update collapsible widget API to latest brainglobe-utils (#8)

Co-authored-by: IgorTatarnikov <[email protected]>

* finish porting all existing widgets to qt

* merged utils and preproc modules

---------

Co-authored-by: IgorTatarnikov <[email protected]>
  • Loading branch information
niksirbi and IgorTatarnikov authored Jan 10, 2024
1 parent 5016bbf commit 27c03c4
Show file tree
Hide file tree
Showing 8 changed files with 480 additions and 376 deletions.
10 changes: 2 additions & 8 deletions brainglobe_template_builder/napari/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,9 @@
__version__ = "unknown"

from brainglobe_template_builder.napari._reader import napari_get_reader
from brainglobe_template_builder.napari._widget import (
mask_widget,
points_widget,
transform_widget,
)
from brainglobe_template_builder.napari._widget import PreprocWidgets

__all__ = (
"napari_get_reader",
"mask_widget",
"points_widget",
"transform_widget",
"PreprocWidgets",
)
239 changes: 16 additions & 223 deletions brainglobe_template_builder/napari/_widget.py
Original file line number Diff line number Diff line change
@@ -1,229 +1,22 @@
"""
This module is an example of a barebones QWidget plugin for napari
from brainglobe_utils.qtpy.collapsible_widget import CollapsibleWidgetContainer
from napari.viewer import Viewer

It implements the Widget specification.
see: https://napari.org/stable/plugins/guides.html?#widgets
from brainglobe_template_builder.napari.mask_widget import CreateMask
from brainglobe_template_builder.napari.midline_widget import FindMidline

Replace code below according to your needs.
"""
from typing import Literal

import numpy as np
from magicgui import magic_factory
from napari.layers import Image, Labels, Points
from napari_plugin_engine import napari_hook_implementation
class PreprocWidgets(CollapsibleWidgetContainer):
def __init__(self, napari_viewer: Viewer, parent=None):
super().__init__()

from brainglobe_template_builder.utils import (
extract_largest_object,
fit_plane_to_points,
get_midline_points,
threshold_image,
)

# 9 colors taken from ColorBrewer2.org Set3 palette
POINTS_COLOR_CYCLE = [
"#8dd3c7",
"#ffffb3",
"#bebada",
"#fb8072",
"#80b1d3",
"#fdb462",
"#b3de69",
"#fccde5",
"#d9d9d9",
]


@magic_factory(
call_button="generate mask",
gauss_sigma={"widget_type": "SpinBox", "max": 20, "min": 0},
threshold_method={"choices": ["triangle", "otsu", "isodata"]},
erosion_size={"widget_type": "SpinBox", "max": 20, "min": 0},
)
def mask_widget(
image: Image,
gauss_sigma: float = 3,
threshold_method: Literal["triangle", "otsu", "isodata"] = "triangle",
erosion_size: int = 5,
) -> Labels:
"""Threshold image and create a mask for the largest object.
The mask is generated by applying a Gaussian filter to the image,
thresholding the smoothed image, keeping only the largest object, and
eroding the resulting mask.
Parameters
----------
image : Image
A napari image layer to threshold.
gauss_sigma : float
Standard deviation for Gaussian kernel (in pixels) to smooth image
before thresholding. Set to 0 to skip smoothing.
threshold_method : str
Thresholding method to use. One of 'triangle', 'otsu', and 'isodata'
(corresponding to methods from the skimage.filters module).
Defaults to 'triangle'.
erosion_size : int
Size of the erosion footprint (in pixels) to apply to the mask.
Set to 0 to skip erosion.
Returns
-------
napari.layers.Labels
A napari labels layer containing the mask.
"""

if image is not None:
assert isinstance(image, Image), "image must be a napari Image layer"
else:
print("Please select an image layer")
return None

from skimage import filters, morphology

# Apply gaussian filter to image
if gauss_sigma > 0:
data_smoothed = filters.gaussian(image.data, sigma=gauss_sigma)
else:
data_smoothed = image.data

# Threshold the (smoothed) image
binary = threshold_image(data_smoothed, method=threshold_method)

# Keep only the largest object in the binary image
mask = extract_largest_object(binary)

# Erode the mask
if erosion_size > 0:
mask = morphology.binary_erosion(
mask, footprint=np.ones((erosion_size,) * image.ndim)
self.add_widget(
CreateMask(napari_viewer, parent=self),
collapsible=True,
widget_title="Create mask",
)
return Labels(mask, opacity=0.5, name="mask")


@magic_factory(
call_button="Estimate midline points",
)
def points_widget(
mask: Labels,
) -> Points:
"""Create a points layer with 9 midline points.
Parameters
----------
mask : Labels
A napari labels layer to use as a reference for the points.
Returns
-------
napari.layers.Points
A napari points layer containing the midline points.
"""

# Estimate 9 midline points
points = get_midline_points(mask.data)

point_labels = np.arange(1, points.shape[0] + 1)

point_attrs = {
"properties": {"label": point_labels},
"face_color": "label",
"face_color_cycle": POINTS_COLOR_CYCLE,
"symbol": "cross",
"edge_width": 0,
"opacity": 0.6,
"size": 6,
"ndim": mask.ndim,
"name": "midline points",
}

# Make mask layer invisible
mask.visible = False

return Points(points, **point_attrs)


@magic_factory(
call_button="Align midline",
image={"label": "Image"},
points={"label": "Midline points"},
axis={"label": "Axis", "choices": ["x", "y", "z"]},
)
def transform_widget(
image: Image,
points: Points,
axis: Literal["x", "y", "z"] = "x",
) -> Image:
"""Transform image to align points with midline of the specified axis.
It first fits a plane to the points, then rigidly transforms the image
such that the fitted plane is aligned with the axis midline.
Parameters
----------
image : Image
A napari image layer to align.
points : Points
A napari points layer containing points.
axis : str
Axis to align the midline with. One of 'x', 'y', and 'z'.
Defaults to 'x'.
Returns
-------
napari.layers.Image
A napari image layer containing the transformed image.
"""

from scipy.ndimage import affine_transform
from scipy.spatial.transform import Rotation

points_data = points.data
normal_vector = fit_plane_to_points(points_data)
assert normal_vector.shape == (3,)

# Compute centroid of the midline points
centroid = np.mean(points_data, axis=0)

# Translation of the centroid to the origin
translation_to_origin = np.eye(4)
translation_to_origin[:3, 3] = -centroid

# Rotation to align normal vector with unit vector along the specified axis
axis_vec = np.zeros(3)
axis_index = {"z": 0, "y": 1, "x": 2}[axis] # axis order is zyx in napari
axis_vec[axis_index] = 1
rotation_to_axis = Rotation.align_vectors(
axis_vec.reshape(1, 3),
normal_vector.reshape(1, 3),
)[0].as_matrix()
assert rotation_to_axis.shape == (3, 3)
rotation_4x4 = np.eye(4)
rotation_4x4[:3, :3] = rotation_to_axis

# Translation back, so that the plane is in the middle of axis
translation_to_mid_axis = np.eye(4)
translation_to_mid_axis[axis_index, 3] = (
image.data.shape[axis_index] // 2 - centroid[axis_index]
)

# Combine transformations into a single 4x4 matrix
transformation_matrix = (
np.linalg.inv(translation_to_origin)
@ rotation_4x4
@ translation_to_origin
@ translation_to_mid_axis
)

# Apply the transformation to the image
transformed_image = affine_transform(
image,
transformation_matrix[:3, :3],
offset=transformation_matrix[:3, 3],
)
return Image(transformed_image, name="aligned image")


@napari_hook_implementation
def napari_experimental_provide_dock_widget():
return [mask_widget, points_widget, transform_widget]
self.add_widget(
FindMidline(napari_viewer, parent=self),
collapsible=True,
widget_title="Find midline",
)
63 changes: 63 additions & 0 deletions brainglobe_template_builder/napari/mask_widget.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from napari.layers import Image
from napari.utils.notifications import show_info
from napari.viewer import Viewer
from qtpy.QtWidgets import (
QComboBox,
QFormLayout,
QPushButton,
QSpinBox,
QWidget,
)

from brainglobe_template_builder.preproc import create_mask


class CreateMask(QWidget):
"""Widget to create a mask from a selected image layer."""

def __init__(self, napari_viewer: Viewer, parent=None):
super().__init__(parent=parent)
self.viewer = napari_viewer
self.setLayout(QFormLayout())

self.gauss_sigma = QSpinBox(parent=self)
self.gauss_sigma.setRange(0, 20)
self.gauss_sigma.setValue(3)
self.layout().addRow("gauss sigma:", self.gauss_sigma)

self.threshold_method = QComboBox(parent=self)
self.threshold_method.addItems(["triangle", "otsu", "isodata"])
self.layout().addRow("threshold method:", self.threshold_method)

self.erosion_size = QSpinBox(parent=self)
self.erosion_size.setRange(0, 20)
self.erosion_size.setValue(5)
self.layout().addRow("erosion size:", self.erosion_size)

self.generate_mask_button = QPushButton("Create mask", parent=self)
self.layout().addRow(self.generate_mask_button)
self.generate_mask_button.clicked.connect(self._on_button_click)

def _on_button_click(self):
"""Create a mask from the selected image layer, using the parameters
specified in the widget, and add it to the napari viewer.
"""

if len(self.viewer.layers.selection) != 1:
show_info("Please select exactly one Image layer")
return None

image = list(self.viewer.layers.selection)[0]

if not isinstance(image, Image):
show_info("The selected layer is not an Image layer")
return None

mask_data = create_mask(
image.data,
gauss_sigma=self.gauss_sigma.value(),
threshold_method=self.threshold_method.currentText(),
erosion_size=self.erosion_size.value(),
)

self.viewer.add_labels(mask_data, name="mask", opacity=0.5)
Loading

0 comments on commit 27c03c4

Please sign in to comment.