-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Port napari widgets from magicgui to qt (#9)
* create simple dropdown with qt * connected qt mask widget * finished first draft of mask qt widget * made a chest of widgets * Update collapsible widget API to latest brainglobe-utils (#8) Co-authored-by: IgorTatarnikov <[email protected]> * finish porting all existing widgets to qt * merged utils and preproc modules --------- Co-authored-by: IgorTatarnikov <[email protected]>
- Loading branch information
1 parent
5016bbf
commit 27c03c4
Showing
8 changed files
with
480 additions
and
376 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,229 +1,22 @@ | ||
""" | ||
This module is an example of a barebones QWidget plugin for napari | ||
from brainglobe_utils.qtpy.collapsible_widget import CollapsibleWidgetContainer | ||
from napari.viewer import Viewer | ||
|
||
It implements the Widget specification. | ||
see: https://napari.org/stable/plugins/guides.html?#widgets | ||
from brainglobe_template_builder.napari.mask_widget import CreateMask | ||
from brainglobe_template_builder.napari.midline_widget import FindMidline | ||
|
||
Replace code below according to your needs. | ||
""" | ||
from typing import Literal | ||
|
||
import numpy as np | ||
from magicgui import magic_factory | ||
from napari.layers import Image, Labels, Points | ||
from napari_plugin_engine import napari_hook_implementation | ||
class PreprocWidgets(CollapsibleWidgetContainer): | ||
def __init__(self, napari_viewer: Viewer, parent=None): | ||
super().__init__() | ||
|
||
from brainglobe_template_builder.utils import ( | ||
extract_largest_object, | ||
fit_plane_to_points, | ||
get_midline_points, | ||
threshold_image, | ||
) | ||
|
||
# 9 colors taken from ColorBrewer2.org Set3 palette | ||
POINTS_COLOR_CYCLE = [ | ||
"#8dd3c7", | ||
"#ffffb3", | ||
"#bebada", | ||
"#fb8072", | ||
"#80b1d3", | ||
"#fdb462", | ||
"#b3de69", | ||
"#fccde5", | ||
"#d9d9d9", | ||
] | ||
|
||
|
||
@magic_factory( | ||
call_button="generate mask", | ||
gauss_sigma={"widget_type": "SpinBox", "max": 20, "min": 0}, | ||
threshold_method={"choices": ["triangle", "otsu", "isodata"]}, | ||
erosion_size={"widget_type": "SpinBox", "max": 20, "min": 0}, | ||
) | ||
def mask_widget( | ||
image: Image, | ||
gauss_sigma: float = 3, | ||
threshold_method: Literal["triangle", "otsu", "isodata"] = "triangle", | ||
erosion_size: int = 5, | ||
) -> Labels: | ||
"""Threshold image and create a mask for the largest object. | ||
The mask is generated by applying a Gaussian filter to the image, | ||
thresholding the smoothed image, keeping only the largest object, and | ||
eroding the resulting mask. | ||
Parameters | ||
---------- | ||
image : Image | ||
A napari image layer to threshold. | ||
gauss_sigma : float | ||
Standard deviation for Gaussian kernel (in pixels) to smooth image | ||
before thresholding. Set to 0 to skip smoothing. | ||
threshold_method : str | ||
Thresholding method to use. One of 'triangle', 'otsu', and 'isodata' | ||
(corresponding to methods from the skimage.filters module). | ||
Defaults to 'triangle'. | ||
erosion_size : int | ||
Size of the erosion footprint (in pixels) to apply to the mask. | ||
Set to 0 to skip erosion. | ||
Returns | ||
------- | ||
napari.layers.Labels | ||
A napari labels layer containing the mask. | ||
""" | ||
|
||
if image is not None: | ||
assert isinstance(image, Image), "image must be a napari Image layer" | ||
else: | ||
print("Please select an image layer") | ||
return None | ||
|
||
from skimage import filters, morphology | ||
|
||
# Apply gaussian filter to image | ||
if gauss_sigma > 0: | ||
data_smoothed = filters.gaussian(image.data, sigma=gauss_sigma) | ||
else: | ||
data_smoothed = image.data | ||
|
||
# Threshold the (smoothed) image | ||
binary = threshold_image(data_smoothed, method=threshold_method) | ||
|
||
# Keep only the largest object in the binary image | ||
mask = extract_largest_object(binary) | ||
|
||
# Erode the mask | ||
if erosion_size > 0: | ||
mask = morphology.binary_erosion( | ||
mask, footprint=np.ones((erosion_size,) * image.ndim) | ||
self.add_widget( | ||
CreateMask(napari_viewer, parent=self), | ||
collapsible=True, | ||
widget_title="Create mask", | ||
) | ||
return Labels(mask, opacity=0.5, name="mask") | ||
|
||
|
||
@magic_factory( | ||
call_button="Estimate midline points", | ||
) | ||
def points_widget( | ||
mask: Labels, | ||
) -> Points: | ||
"""Create a points layer with 9 midline points. | ||
Parameters | ||
---------- | ||
mask : Labels | ||
A napari labels layer to use as a reference for the points. | ||
Returns | ||
------- | ||
napari.layers.Points | ||
A napari points layer containing the midline points. | ||
""" | ||
|
||
# Estimate 9 midline points | ||
points = get_midline_points(mask.data) | ||
|
||
point_labels = np.arange(1, points.shape[0] + 1) | ||
|
||
point_attrs = { | ||
"properties": {"label": point_labels}, | ||
"face_color": "label", | ||
"face_color_cycle": POINTS_COLOR_CYCLE, | ||
"symbol": "cross", | ||
"edge_width": 0, | ||
"opacity": 0.6, | ||
"size": 6, | ||
"ndim": mask.ndim, | ||
"name": "midline points", | ||
} | ||
|
||
# Make mask layer invisible | ||
mask.visible = False | ||
|
||
return Points(points, **point_attrs) | ||
|
||
|
||
@magic_factory( | ||
call_button="Align midline", | ||
image={"label": "Image"}, | ||
points={"label": "Midline points"}, | ||
axis={"label": "Axis", "choices": ["x", "y", "z"]}, | ||
) | ||
def transform_widget( | ||
image: Image, | ||
points: Points, | ||
axis: Literal["x", "y", "z"] = "x", | ||
) -> Image: | ||
"""Transform image to align points with midline of the specified axis. | ||
It first fits a plane to the points, then rigidly transforms the image | ||
such that the fitted plane is aligned with the axis midline. | ||
Parameters | ||
---------- | ||
image : Image | ||
A napari image layer to align. | ||
points : Points | ||
A napari points layer containing points. | ||
axis : str | ||
Axis to align the midline with. One of 'x', 'y', and 'z'. | ||
Defaults to 'x'. | ||
Returns | ||
------- | ||
napari.layers.Image | ||
A napari image layer containing the transformed image. | ||
""" | ||
|
||
from scipy.ndimage import affine_transform | ||
from scipy.spatial.transform import Rotation | ||
|
||
points_data = points.data | ||
normal_vector = fit_plane_to_points(points_data) | ||
assert normal_vector.shape == (3,) | ||
|
||
# Compute centroid of the midline points | ||
centroid = np.mean(points_data, axis=0) | ||
|
||
# Translation of the centroid to the origin | ||
translation_to_origin = np.eye(4) | ||
translation_to_origin[:3, 3] = -centroid | ||
|
||
# Rotation to align normal vector with unit vector along the specified axis | ||
axis_vec = np.zeros(3) | ||
axis_index = {"z": 0, "y": 1, "x": 2}[axis] # axis order is zyx in napari | ||
axis_vec[axis_index] = 1 | ||
rotation_to_axis = Rotation.align_vectors( | ||
axis_vec.reshape(1, 3), | ||
normal_vector.reshape(1, 3), | ||
)[0].as_matrix() | ||
assert rotation_to_axis.shape == (3, 3) | ||
rotation_4x4 = np.eye(4) | ||
rotation_4x4[:3, :3] = rotation_to_axis | ||
|
||
# Translation back, so that the plane is in the middle of axis | ||
translation_to_mid_axis = np.eye(4) | ||
translation_to_mid_axis[axis_index, 3] = ( | ||
image.data.shape[axis_index] // 2 - centroid[axis_index] | ||
) | ||
|
||
# Combine transformations into a single 4x4 matrix | ||
transformation_matrix = ( | ||
np.linalg.inv(translation_to_origin) | ||
@ rotation_4x4 | ||
@ translation_to_origin | ||
@ translation_to_mid_axis | ||
) | ||
|
||
# Apply the transformation to the image | ||
transformed_image = affine_transform( | ||
image, | ||
transformation_matrix[:3, :3], | ||
offset=transformation_matrix[:3, 3], | ||
) | ||
return Image(transformed_image, name="aligned image") | ||
|
||
|
||
@napari_hook_implementation | ||
def napari_experimental_provide_dock_widget(): | ||
return [mask_widget, points_widget, transform_widget] | ||
self.add_widget( | ||
FindMidline(napari_viewer, parent=self), | ||
collapsible=True, | ||
widget_title="Find midline", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from napari.layers import Image | ||
from napari.utils.notifications import show_info | ||
from napari.viewer import Viewer | ||
from qtpy.QtWidgets import ( | ||
QComboBox, | ||
QFormLayout, | ||
QPushButton, | ||
QSpinBox, | ||
QWidget, | ||
) | ||
|
||
from brainglobe_template_builder.preproc import create_mask | ||
|
||
|
||
class CreateMask(QWidget): | ||
"""Widget to create a mask from a selected image layer.""" | ||
|
||
def __init__(self, napari_viewer: Viewer, parent=None): | ||
super().__init__(parent=parent) | ||
self.viewer = napari_viewer | ||
self.setLayout(QFormLayout()) | ||
|
||
self.gauss_sigma = QSpinBox(parent=self) | ||
self.gauss_sigma.setRange(0, 20) | ||
self.gauss_sigma.setValue(3) | ||
self.layout().addRow("gauss sigma:", self.gauss_sigma) | ||
|
||
self.threshold_method = QComboBox(parent=self) | ||
self.threshold_method.addItems(["triangle", "otsu", "isodata"]) | ||
self.layout().addRow("threshold method:", self.threshold_method) | ||
|
||
self.erosion_size = QSpinBox(parent=self) | ||
self.erosion_size.setRange(0, 20) | ||
self.erosion_size.setValue(5) | ||
self.layout().addRow("erosion size:", self.erosion_size) | ||
|
||
self.generate_mask_button = QPushButton("Create mask", parent=self) | ||
self.layout().addRow(self.generate_mask_button) | ||
self.generate_mask_button.clicked.connect(self._on_button_click) | ||
|
||
def _on_button_click(self): | ||
"""Create a mask from the selected image layer, using the parameters | ||
specified in the widget, and add it to the napari viewer. | ||
""" | ||
|
||
if len(self.viewer.layers.selection) != 1: | ||
show_info("Please select exactly one Image layer") | ||
return None | ||
|
||
image = list(self.viewer.layers.selection)[0] | ||
|
||
if not isinstance(image, Image): | ||
show_info("The selected layer is not an Image layer") | ||
return None | ||
|
||
mask_data = create_mask( | ||
image.data, | ||
gauss_sigma=self.gauss_sigma.value(), | ||
threshold_method=self.threshold_method.currentText(), | ||
erosion_size=self.erosion_size.value(), | ||
) | ||
|
||
self.viewer.add_labels(mask_data, name="mask", opacity=0.5) |
Oops, something went wrong.