From 40a9fb566477b697af6f16d90f9dad24f4f101a8 Mon Sep 17 00:00:00 2001 From: jules-vanaret Date: Sat, 10 Aug 2024 05:35:47 +0200 Subject: [PATCH] blacked and ruffed --- .../_nice_colormap.py | 2 +- .../_widget.py | 422 +++++++++++------- 2 files changed, 261 insertions(+), 163 deletions(-) diff --git a/src/napari_spatial_correlation_plotter/_nice_colormap.py b/src/napari_spatial_correlation_plotter/_nice_colormap.py index 76928c4..3c4e2e7 100644 --- a/src/napari_spatial_correlation_plotter/_nice_colormap.py +++ b/src/napari_spatial_correlation_plotter/_nice_colormap.py @@ -257,4 +257,4 @@ def get_nice_colormap(): "#bcbcbc", ] - return colours_w_old_colors \ No newline at end of file + return colours_w_old_colors diff --git a/src/napari_spatial_correlation_plotter/_widget.py b/src/napari_spatial_correlation_plotter/_widget.py index 6b8cf7c..f7e7078 100644 --- a/src/napari_spatial_correlation_plotter/_widget.py +++ b/src/napari_spatial_correlation_plotter/_widget.py @@ -6,44 +6,45 @@ import napari import numpy as np from magicgui.widgets import Container, EmptyWidget, create_widget -from matplotlib.backends.backend_qt5agg import \ - FigureCanvasQTAgg as FigureCanvas -from matplotlib.backends.backend_qt5agg import \ - NavigationToolbar2QT as NavigationToolbar +from matplotlib.backends.backend_qt5agg import ( + FigureCanvasQTAgg as FigureCanvas, +) +from matplotlib.backends.backend_qt5agg import ( + NavigationToolbar2QT as NavigationToolbar, +) from matplotlib.figure import Figure from matplotlib.path import Path from matplotlib.widgets import LassoSelector, RectangleSelector from napari.layers import Image, Labels, Layer from napari.utils import DirectLabelColormap -from napari_spatial_correlation_plotter._nice_colormap import get_nice_colormap -from tapenade.analysis.spatial_correlation import \ - SpatialCorrelationPlotter -from tapenade.preprocessing import \ - masked_gaussian_smooth_dense_two_arrays_gpu from qtpy.QtCore import Qt from qtpy.QtGui import QGuiApplication, QIcon from skimage.measure import regionprops +from tapenade.analysis.spatial_correlation import SpatialCorrelationPlotter +from tapenade.preprocessing import masked_gaussian_smooth_dense_two_arrays_gpu from vispy.color import Color -ICON_ROOT = PathL(__file__).parent / "icons" +from napari_spatial_correlation_plotter._nice_colormap import get_nice_colormap +ICON_ROOT = PathL(__file__).parent / "icons" # TODO: # - add log scale to heatmap colors - colors = get_nice_colormap() cmap = [Color(hex_name).RGBA.astype("float") / 255 for hex_name in colors] + def in_bbox(min_x, max_x, min_y, max_y, xys): mins = np.array([min_x, min_y]).reshape(1, 2) maxs = np.array([max_x, max_y]).reshape(1, 2) foo = np.logical_and(xys >= mins, xys <= maxs) - return np.logical_and(foo[:,0], foo[:,1]) + return np.logical_and(foo[:, 0], foo[:, 1]) + # Class below was based upon matplotlib lasso selection example: # https://matplotlib.org/stable/gallery/widgets/lasso_selector_demo_sgskip.html @@ -95,7 +96,6 @@ def onselect(self, verts): ind_mask[ind_mask] = path.contains_points(self.xys[ind_mask]) self.ind_mask = ind_mask - self.canvas.draw_idle() # self.selected_coordinates = self.xys[self.ind].data @@ -108,8 +108,15 @@ def disconnect(self): class MplCanvas(FigureCanvas): - def __init__(self, xys, parent=None, width=7, height=4, - manual_clustering_method=None, create_selectors=False): + def __init__( + self, + xys, + parent=None, + width=7, + height=4, + manual_clustering_method=None, + create_selectors=False, + ): self.xys = xys @@ -130,7 +137,6 @@ def __init__(self, xys, parent=None, width=7, height=4, self.reset_params(create_selectors=create_selectors, xys=xys) - def reset_params(self, create_selectors, xys): self.axes = self.fig.axes[0] @@ -170,11 +176,11 @@ def reset_params(self, create_selectors, xys): # set colorbar tick color cb.ax.yaxis.set_tick_params(color="white") - # set colorbar edgecolor + # set colorbar edgecolor cb.outline.set_edgecolor("white") # set colorbar ticklabels - plt.setp(plt.getp(cb.ax.axes, 'yticklabels'), color="white") + plt.setp(plt.getp(cb.ax.axes, "yticklabels"), color="white") if create_selectors: self.selector = SelectFromCollection(self, self.axes, xys) @@ -183,14 +189,13 @@ def reset_params(self, create_selectors, xys): self.axes, self.draw_rectangle, useblit=True, - props=dict(edgecolor='#1f77b4', fill=False), + props=dict(edgecolor="#1f77b4", fill=False), button=3, # right button minspanx=5, minspany=5, spancoords="pixels", interactive=False, ) - def draw_rectangle(self, eclick, erelease): """eclick and erelease are the press and release events""" @@ -210,6 +215,7 @@ def reset(self): self.axes.clear() self.is_pressed = None + class FigureToolbar(NavigationToolbar): def __init__(self, canvas): super().__init__(canvas, None) @@ -224,7 +230,9 @@ def _update_buttons_checked(self): QIcon(os.path.join(ICON_ROOT, "Pan_checked.png")) ) else: - self._actions["pan"].setIcon(QIcon(os.path.join(ICON_ROOT, "Pan.png"))) + self._actions["pan"].setIcon( + QIcon(os.path.join(ICON_ROOT, "Pan.png")) + ) if "zoom" in self._actions: if self._actions["zoom"].isChecked(): self._actions["zoom"].setIcon( @@ -264,11 +272,11 @@ def save_figure(self): # set colorbar tick color cb.ax.yaxis.set_tick_params(color="black") - # set colorbar edgecolor + # set colorbar edgecolor cb.outline.set_edgecolor("black") # set colorbar ticklabels - plt.setp(plt.getp(cb.ax.axes, 'yticklabels'), color="black") + plt.setp(plt.getp(cb.ax.axes, "yticklabels"), color="black") super().save_figure() @@ -300,16 +308,15 @@ def save_figure(self): # set colorbar tick color cb.ax.yaxis.set_tick_params(color="white") - # set colorbar edgecolor + # set colorbar edgecolor cb.outline.set_edgecolor("white") # set colorbar ticklabels - plt.setp(plt.getp(cb.ax.axes, 'yticklabels'), color="white") + plt.setp(plt.getp(cb.ax.axes, "yticklabels"), color="white") self.canvas.draw() - class PlotterWidget(Container): def __init__(self, napari_viewer): super().__init__() @@ -328,15 +335,14 @@ def __init__(self, napari_viewer): self.figure = None - self.labels_method_choices = ['cellular density', 'volume fraction'] + self.labels_method_choices = ["cellular density", "volume fraction"] self._hidden_features = {} - # Canvas Widget that displays the 'figure', it takes the 'figure' instance if True: self.graphics_widget = MplCanvas( manual_clustering_method=self.manual_clustering_method, - xys=None + xys=None, ) self.toolbar = FigureToolbar(self.graphics_widget) @@ -353,12 +359,13 @@ def __init__(self, napari_viewer): self.toolbar, self.graphics_widget, ], - labels=False + labels=False, ) self.quantityX_layer_combo = create_widget( - annotation=Layer, label="Quantity X", - options={'choices': self._image_labels_layers_filter} + annotation=Layer, + label="Quantity X", + options={"choices": self._image_labels_layers_filter}, ) self.quantityX_layer_combo.changed.connect( @@ -367,17 +374,18 @@ def __init__(self, napari_viewer): self.quantityX_labels_choices_combo = create_widget( widget_type="ComboBox", - options={'choices': self.labels_method_choices} + options={"choices": self.labels_method_choices}, ) self.quantityX_labels_choices_container = Container( widgets=[self.quantityX_labels_choices_combo], labels=False, - layout='horizontal' + layout="horizontal", ) self.quantityY_layer_combo = create_widget( - annotation=Layer, label="Quantity Y", - options={'choices': self._image_labels_layers_filter} + annotation=Layer, + label="Quantity Y", + options={"choices": self._image_labels_layers_filter}, ) self.quantityY_layer_combo.changed.connect( @@ -386,47 +394,60 @@ def __init__(self, napari_viewer): self.quantityY_labels_choices_combo = create_widget( widget_type="ComboBox", - options={'choices': self.labels_method_choices} + options={"choices": self.labels_method_choices}, ) self.quantityY_labels_choices_container = Container( widgets=[self.quantityY_labels_choices_combo], labels=False, - layout='horizontal' + layout="horizontal", ) self.mask_layer_combo = create_widget( - annotation=Image, label="Mask layer", - options={'nullable': True, 'choices': self._bool_layers_filter} + annotation=Image, + label="Mask layer", + options={ + "nullable": True, + "choices": self._bool_layers_filter, + }, ) self.labels_layer_combo = create_widget( - annotation=Labels, label="Labels layer", - options={'nullable': True} + annotation=Labels, + label="Labels layer", + options={"nullable": True}, ) self.blur_sigma_slider = create_widget( - widget_type="IntSlider", label="Blur sigma", - options={'min':0, 'max':50, 'value':1} + widget_type="IntSlider", + label="Blur sigma", + options={"min": 0, "max": 50, "value": 1}, ) # self.blur_sigma_slider.changed.connect(self.sigma_changed) self.run_button = create_widget( - widget_type="PushButton", label="Compute correlation heatmap", + widget_type="PushButton", + label="Compute correlation heatmap", ) self.run_button.clicked.connect(self.run) self.show_individual_cells_checkbox = create_widget( - annotation=bool, label="Show individual cells", + annotation=bool, + label="Show individual cells", ) - self.show_individual_cells_checkbox.changed.connect(self.parameters_changed) + self.show_individual_cells_checkbox.changed.connect( + self.parameters_changed + ) self.show_linear_fit_checkbox = create_widget( - annotation=bool, label="Show linear fit", + annotation=bool, + label="Show linear fit", ) - self.show_linear_fit_checkbox.changed.connect(self.parameters_changed) + self.show_linear_fit_checkbox.changed.connect( + self.parameters_changed + ) #! normalize is currently broken with manual selection # self.normalize_quantities_checkbox = create_widget( @@ -436,8 +457,9 @@ def __init__(self, napari_viewer): # self.normalize_quantities_checkbox.changed.connect(self.parameters_changed) self.display_quadrants = create_widget( - annotation=bool, label="Display quadrants", - ) + annotation=bool, + label="Display quadrants", + ) self.display_quadrants.changed.connect(self.parameters_changed) @@ -447,7 +469,7 @@ def __init__(self, napari_viewer): self.show_linear_fit_checkbox, ], labels=False, - layout='horizontal' + layout="horizontal", ) self.options_container2 = Container( widgets=[ @@ -455,19 +477,23 @@ def __init__(self, napari_viewer): self.display_quadrants, ], labels=False, - layout='horizontal' + layout="horizontal", ) self.heatmap_binsX = create_widget( - widget_type="IntSlider", label="X", - value=20, options={'min':2, 'max':100, 'tracking': True} + widget_type="IntSlider", + label="X", + value=20, + options={"min": 2, "max": 100, "tracking": True}, ) self.heatmap_binsX.changed.connect(self.parameters_changed) self.heatmap_binsY = create_widget( - widget_type="IntSlider", label="Y", - value=20, options={'min':2, 'max':100, 'tracking': True} + widget_type="IntSlider", + label="Y", + value=20, + options={"min": 2, "max": 100, "tracking": True}, ) self.heatmap_binsY.changed.connect(self.parameters_changed) @@ -478,20 +504,32 @@ def __init__(self, napari_viewer): self.heatmap_binsY, ], labels=True, - label='Heatmap bins', - layout='horizontal' + label="Heatmap bins", + layout="horizontal", ) self.percentilesX = create_widget( - widget_type="FloatRangeSlider", label="X", - options={'min':0, 'max':100, 'value':[0,100], 'tracking': True} + widget_type="FloatRangeSlider", + label="X", + options={ + "min": 0, + "max": 100, + "value": [0, 100], + "tracking": True, + }, ) self.percentilesX.changed.connect(self.parameters_changed) self.percentilesY = create_widget( - widget_type="FloatRangeSlider", label="Y", - options={'min':0, 'max':100, 'value':[0,100], 'tracking': True} + widget_type="FloatRangeSlider", + label="Y", + options={ + "min": 0, + "max": 100, + "value": [0, 100], + "tracking": True, + }, ) self.percentilesY.changed.connect(self.parameters_changed) @@ -502,13 +540,15 @@ def __init__(self, napari_viewer): self.percentilesY, ], labels=True, - label='Percentiles', - layout='horizontal' + label="Percentiles", + layout="horizontal", ) - parameters_text = EmptyWidget(label='Parameters:') + parameters_text = EmptyWidget(label="Parameters:") - display_parameters_text = EmptyWidget(label='Display Parameters:') + display_parameters_text = EmptyWidget( + label="Display Parameters:" + ) self.extend( [ @@ -531,23 +571,24 @@ def __init__(self, napari_viewer): # takes care of case where this isn't set yet directly after init self.plot_cluster_name = None - self.id=0 + self.id = 0 def manual_clustering_method(self, inside): inside = np.array(inside) # leads to errors sometimes otherwise if len(inside) == 0: return # if nothing was plotted yet, leave - + clustering_ID = "MANUAL_CLUSTER_ID" modifiers = QGuiApplication.keyboardModifiers() - if modifiers == Qt.ShiftModifier and clustering_ID in self._hidden_features.keys(): + if ( + modifiers == Qt.ShiftModifier + and clustering_ID in self._hidden_features.keys() + ): former_clusters = self._hidden_features[clustering_ID] former_clusters[inside] = np.max(former_clusters) + 1 - self._hidden_features.update( - {clustering_ID: former_clusters} - ) + self._hidden_features.update({clustering_ID: former_clusters}) else: self._hidden_features[clustering_ID] = inside.astype(int) @@ -557,34 +598,49 @@ def manual_clustering_method(self, inside): plot_cluster_name=clustering_ID, ) - def run(self): # Check if all necessary layers are specified if self.quantityX_layer_combo.value is None: - napari.utils.notifications.show_warning("Please specify quantityX_layer") + napari.utils.notifications.show_warning( + "Please specify quantityX_layer" + ) return else: self.quantityX = self.quantityX_layer_combo.value.data self.quantityX_label = self.quantityX_layer_combo.value.name if isinstance(self.quantityX_layer_combo.value, Labels): - self.quantityX_colormap = 'inferno' + self.quantityX_colormap = "inferno" else: - self.quantityX_colormap = self.quantityX_layer_combo.value.colormap - self.quantityX_is_labels = isinstance(self.quantityX_layer_combo.value, Labels) - self.quantityX_labels_choice = self.quantityX_labels_choices_combo.value + self.quantityX_colormap = ( + self.quantityX_layer_combo.value.colormap + ) + self.quantityX_is_labels = isinstance( + self.quantityX_layer_combo.value, Labels + ) + self.quantityX_labels_choice = ( + self.quantityX_labels_choices_combo.value + ) if self.quantityY_layer_combo.value is None: - napari.utils.notifications.show_warning("Please specify quantityY_layer") + napari.utils.notifications.show_warning( + "Please specify quantityY_layer" + ) return else: self.quantityY = self.quantityY_layer_combo.value.data self.quantityY_label = self.quantityY_layer_combo.value.name if isinstance(self.quantityY_layer_combo.value, Labels): - self.quantityY_colormap = 'inferno' + self.quantityY_colormap = "inferno" else: - self.quantityY_colormap = self.quantityY_layer_combo.value.colormap - self.quantityY_is_labels = isinstance(self.quantityY_layer_combo.value, Labels) - self.quantityY_labels_choice = self.quantityY_labels_choices_combo.value + self.quantityY_colormap = ( + self.quantityY_layer_combo.value.colormap + ) + self.quantityY_is_labels = isinstance( + self.quantityY_layer_combo.value, Labels + ) + self.quantityY_labels_choice = ( + self.quantityY_labels_choices_combo.value + ) if self.mask_layer_combo.value is not None: self.mask = self.mask_layer_combo.value.data @@ -600,50 +656,69 @@ def run(self): else: shape = self.quantityX.shape # use np.mgrid - self.argwheres = np.mgrid[0:shape[0], 0:shape[1], 0:shape[2]].reshape(3, -1).T + self.argwheres = ( + np.mgrid[0 : shape[0], 0 : shape[1], 0 : shape[2]] + .reshape(3, -1) + .T + ) # Blur the layers smoothedX, smoothedY = self._smooth_quantities( - self.quantityX, self.quantityX_is_labels, self.quantityX_labels_choice, - self.quantityY, self.quantityY_is_labels, self.quantityY_labels_choice, - self.mask + self.quantityX, + self.quantityX_is_labels, + self.quantityX_labels_choice, + self.quantityY, + self.quantityY_is_labels, + self.quantityY_labels_choice, + self.mask, ) - self._update_smoothed_layers(smoothedX, self.quantityX_colormap, - smoothedY, self.quantityY_colormap) + self._update_smoothed_layers( + smoothedX, + self.quantityX_colormap, + smoothedY, + self.quantityY_colormap, + ) self.plot_from_smoothed( - smoothedX, self.quantityX_is_labels, self.quantityX_label, self.quantityX_labels_choice, - smoothedY, self.quantityY_is_labels, self.quantityY_label, self.quantityY_labels_choice, - self.mask, self.labels_image + smoothedX, + self.quantityX_is_labels, + self.quantityX_label, + self.quantityX_labels_choice, + smoothedY, + self.quantityY_is_labels, + self.quantityY_label, + self.quantityY_labels_choice, + self.mask, + self.labels_image, ) # Set a parameter "self.histogram_displayed" to True self.histogram_displayed = True if self.cluster_labels_layer is not None: - self.cluster_labels_layer.data = np.zeros_like(self.cluster_labels_layer.data) + self.cluster_labels_layer.data = np.zeros_like( + self.cluster_labels_layer.data + ) - def _update_smoothed_layers(self, - blurredX, X_colormap, - blurredY, Y_colormap): + def _update_smoothed_layers( + self, blurredX, X_colormap, blurredY, Y_colormap + ): if ( - self.quantityX_smoothed_layer is None or \ - self.quantityX_smoothed_layer not in self._viewer.layers + self.quantityX_smoothed_layer is None + or self.quantityX_smoothed_layer not in self._viewer.layers ): self.quantityX_smoothed_layer = self._viewer.add_image( - blurredX, - colormap=X_colormap + blurredX, colormap=X_colormap ) else: self.quantityX_smoothed_layer.data = blurredX if ( - self.quantityY_smoothed_layer is None or \ - self.quantityY_smoothed_layer not in self._viewer.layers + self.quantityY_smoothed_layer is None + or self.quantityY_smoothed_layer not in self._viewer.layers ): self.quantityY_smoothed_layer = self._viewer.add_image( - blurredY, - colormap=Y_colormap + blurredY, colormap=Y_colormap ) else: self.quantityY_smoothed_layer.data = blurredY @@ -654,12 +729,12 @@ def _update_quantities_labels_choices(self, event): if not self.quantityX_labels_choices_displayed: self.insert( - self.index(self.quantityX_layer_combo) + 1, - self.quantityX_labels_choices_container + self.index(self.quantityX_layer_combo) + 1, + self.quantityX_labels_choices_container, ) self.quantityX_labels_choices_displayed = True - + else: if self.quantityX_labels_choices_displayed: @@ -670,54 +745,55 @@ def _update_quantities_labels_choices(self, event): if not self.quantityY_labels_choices_displayed: self.insert( - self.index(self.quantityY_layer_combo) + 1, - self.quantityY_labels_choices_container + self.index(self.quantityY_layer_combo) + 1, + self.quantityY_labels_choices_container, ) - + self.quantityY_labels_choices_displayed = True - + else: if self.quantityY_labels_choices_displayed: self.remove(self.quantityY_labels_choices_container) self.quantityY_labels_choices_displayed = False - - + def _transform_labels_to_density(self, labels, method): self.test_value = True if method == self.labels_method_choices[0]: props = regionprops(labels) centroids = np.array([prop.centroid for prop in props]).astype(int) - + labels = np.zeros(labels.shape, dtype=bool) labels[centroids[:, 0], centroids[:, 1], centroids[:, 2]] = True return labels - + elif method == self.labels_method_choices[1]: return labels.astype(bool) - - - def _smooth_quantities(self, - quantityX, quantityX_is_labels, quantityX_labels_choice, - quantityY, quantityY_is_labels, quantityY_labels_choice, - mask): + def _smooth_quantities( + self, + quantityX, + quantityX_is_labels, + quantityX_labels_choice, + quantityY, + quantityY_is_labels, + quantityY_labels_choice, + mask, + ): masks_volume = [] if quantityX_is_labels: quantityX = self._transform_labels_to_density( - quantityX, - quantityX_labels_choice + quantityX, quantityX_labels_choice ) masks_volume.append(None) if quantityY_is_labels: quantityY = self._transform_labels_to_density( - quantityY, - quantityY_labels_choice + quantityY, quantityY_labels_choice ) masks_volume.append(None) @@ -737,10 +813,17 @@ def _smooth_quantities(self, return smoothedX, smoothedY def plot_from_smoothed( - self, - smoothedX, quantityX_is_labels, quantityX_label, quantityX_labels_choice, - smoothedY, quantityY_is_labels, quantityY_label, quantityY_labels_choice, - mask, labels + self, + smoothedX, + quantityX_is_labels, + quantityX_label, + quantityX_labels_choice, + smoothedY, + quantityY_is_labels, + quantityY_label, + quantityY_labels_choice, + mask, + labels, ): # Construct HeatmapPlotter self.heatmap_plotter = SpatialCorrelationPlotter( @@ -754,7 +837,7 @@ def plot_from_smoothed( labelX = quantityX_labels_choice else: labelX = quantityX_label - + if quantityY_is_labels: labelY = quantityY_labels_choice else: @@ -766,7 +849,7 @@ def plot_from_smoothed( show_individual_cells=self.show_individual_cells_checkbox.value, show_linear_fit=self.show_linear_fit_checkbox.value, # normalize_quantities=self.normalize_quantities_checkbox.value, - normalize_quantities=False, #! normalize is currently broken with manual selection + normalize_quantities=False, #! normalize is currently broken with manual selection percentiles_X=self.percentilesX.value, percentiles_Y=self.percentilesY.value, figsize=self.graphics_widget.figure.get_size_inches(), @@ -788,9 +871,10 @@ def plot_heatmap(self, figure): xys = self.heatmap_plotter.xys self.graphics_widget = MplCanvas( - parent=figure, manual_clustering_method=self.manual_clustering_method, - create_selectors=True,#labels_layer_exists, - xys=xys + parent=figure, + manual_clustering_method=self.manual_clustering_method, + create_selectors=True, # labels_layer_exists, + xys=xys, ) self.toolbar = FigureToolbar(self.graphics_widget) @@ -807,7 +891,7 @@ def plot_heatmap(self, figure): self.toolbar, self.graphics_widget, ], - labels=False + labels=False, ) widget_index = self.index(self.graph_container) @@ -816,13 +900,20 @@ def plot_heatmap(self, figure): self.graph_container = new_graph_container self.graphics_widget.draw() - def parameters_changed(self): if self.histogram_displayed: - labelX = self.quantityX_labels_choice if self.quantityX_is_labels else self.quantityX_label - labelY = self.quantityY_labels_choice if self.quantityY_is_labels else self.quantityY_label - + labelX = ( + self.quantityX_labels_choice + if self.quantityX_is_labels + else self.quantityX_label + ) + labelY = ( + self.quantityY_labels_choice + if self.quantityY_is_labels + else self.quantityY_label + ) + t0 = time() # Get figure from HeatmapPlotter figure, _ = self.heatmap_plotter.get_heatmap_figure( @@ -836,9 +927,9 @@ def parameters_changed(self): figsize=self.graphics_widget.figure.get_size_inches(), label_X=labelX, label_Y=labelY, - display_quadrants=self.display_quadrants.value + display_quadrants=self.display_quadrants.value, ) - print("Time to get figure:", time()-t0) + print("Time to get figure:", time() - t0) # Display figure in graphics_widget -> Create a method "self.plot" self.plot_heatmap(figure) @@ -857,7 +948,6 @@ def _bool_layers_filter(self, wdg): if (isinstance(layer, Image) and layer.data.dtype == bool) ] - def draw_cluster_labels( self, features, @@ -871,9 +961,9 @@ def draw_cluster_labels( # labels_layer = self.labels_layer_combo.value # mask_layer = self.mask_layer_combo.value # self.graphics_widget.reset() - + # fill all prediction nan values with -1 - self.cluster_ids = features[plot_cluster_name]#.fillna(-1) + self.cluster_ids = features[plot_cluster_name] # .fillna(-1) self.graphics_widget.selector.disconnect() self.graphics_widget.selector = SelectFromCollection( @@ -899,23 +989,26 @@ def draw_cluster_labels( keep_selection = list(self._viewer.layers.selection) - if self.labels_image is not None: cluster_image = self.generate_cluster_image_from_labels( self.labels_image, self.cluster_ids ) - + elif self.mask is not None: cluster_image = self.generate_cluster_image_from_points( - self.argwheres, self.cluster_ids, shape=self.quantityX.shape, + self.argwheres, + self.cluster_ids, + shape=self.quantityX.shape, ) else: cluster_image = self.generate_cluster_image_from_points( - self.argwheres, self.cluster_ids, shape=self.quantityX_layer_combo.value.data.shape, + self.argwheres, + self.cluster_ids, + shape=self.quantityX_layer_combo.value.data.shape, ) # if the cluster image layer doesn't yet exist make it - # otherwise just update it + # otherwise just update it if ( self.cluster_labels_layer is None or self.cluster_labels_layer not in self._viewer.layers @@ -925,7 +1018,7 @@ def draw_cluster_labels( cluster_image, # self.analysed_layer.data colormap=napari_cmap, # cluster_id_dict name="clustered labels", - opacity=1 + opacity=1, ) else: # updating data @@ -936,26 +1029,31 @@ def draw_cluster_labels( for s in keep_selection: self._viewer.layers.selection.add(s) - def generate_cluster_image_from_labels(self, label_image, predictionlist): props = regionprops(label_image) - cluster_image = np.zeros(label_image.shape, dtype='uint8') + cluster_image = np.zeros(label_image.shape, dtype="uint8") - argwheres = np.argwhere(predictionlist>0).flatten() + argwheres = np.argwhere(predictionlist > 0).flatten() for index in argwheres: prop = props[index] roi_data = label_image[prop.slice] - cluster_image[prop.slice][roi_data==prop.label] = predictionlist[index]+1 + cluster_image[prop.slice][roi_data == prop.label] = ( + predictionlist[index] + 1 + ) return cluster_image - - def generate_cluster_image_from_points(self, argwheres, predictionlist, shape): - cluster_image = np.zeros(shape, dtype='uint8') - points_to_display = argwheres[predictionlist>0] + def generate_cluster_image_from_points( + self, argwheres, predictionlist, shape + ): + + cluster_image = np.zeros(shape, dtype="uint8") + points_to_display = argwheres[predictionlist > 0] - cluster_image[tuple(points_to_display.T)] = predictionlist[predictionlist>0] + 1 - - return cluster_image \ No newline at end of file + cluster_image[tuple(points_to_display.T)] = ( + predictionlist[predictionlist > 0] + 1 + ) + + return cluster_image