Skip to content

Commit

Permalink
precompute argwheres
Browse files Browse the repository at this point in the history
  • Loading branch information
jules-vanaret committed Jun 7, 2024
1 parent 0c43e8e commit 5331993
Showing 1 changed file with 17 additions and 24 deletions.
41 changes: 17 additions & 24 deletions src/napari_spatial_correlation_plotter/_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,12 @@ def run(self):
self.labels_image = self.labels_layer_combo.value.data
else:
self.labels_image = None
if self.mask is not None:
self.argwheres = np.argwhere(self.mask)
else:
shape = self.quantityX.shape
# use np.mgrid
self.argwheres = np.mgrid[0:shape[0], 0:shape[1], 0:shape[2]].reshape(3, -1).T

# Blur the layers
smoothedX, smoothedY = self._smooth_quantities(
Expand Down Expand Up @@ -837,7 +843,7 @@ def plot_from_smoothed(
labelY = quantityY_label

# Get figure from HeatmapPlotter
figure, _, sampling_indices = self.heatmap_plotter.get_heatmap_figure(
figure, _ = self.heatmap_plotter.get_heatmap_figure(
bins=(self.heatmap_binsX.value, self.heatmap_binsY.value),
show_individual_cells=self.show_individual_cells_checkbox.value,
show_linear_fit=self.show_linear_fit_checkbox.value,
Expand All @@ -849,8 +855,6 @@ def plot_from_smoothed(
label_Y=labelY,
)

self.sampling_indices = sampling_indices

# Display figure in graphics_widget
self.plot_heatmap(figure)

Expand Down Expand Up @@ -901,7 +905,7 @@ def parameters_changed(self):
labelY = self.quantityY_labels_choice if self.quantityY_is_labels else self.quantityY_label

# Get figure from HeatmapPlotter
figure, _, sampling_indices = self.heatmap_plotter.get_heatmap_figure(
figure, _ = self.heatmap_plotter.get_heatmap_figure(
bins=(self.heatmap_binsX.value, self.heatmap_binsY.value),
show_individual_cells=self.show_individual_cells_checkbox.value,
show_linear_fit=self.show_linear_fit_checkbox.value,
Expand All @@ -913,8 +917,6 @@ def parameters_changed(self):
label_Y=labelY,
)

self.sampling_indices = sampling_indices

# Display figure in graphics_widget -> Create a method "self.plot"
self.plot_heatmap(figure)

Expand Down Expand Up @@ -943,8 +945,8 @@ def draw_cluster_labels(
"""

# self.analysed_layer = self.labels_select.value
labels_layer = self.labels_layer_combo.value
mask_layer = self.mask_layer_combo.value
# labels_layer = self.labels_layer_combo.value
# mask_layer = self.mask_layer_combo.value
# self.graphics_widget.reset()

# fill all prediction nan values with -1
Expand Down Expand Up @@ -981,20 +983,18 @@ def draw_cluster_labels(
keep_selection = list(self._viewer.layers.selection)


if labels_layer is not None:
if self.labels_image is not None:
cluster_image = self.generate_cluster_image_from_labels(
labels_layer.data, self.cluster_ids
self.labels_image, self.cluster_ids
)

elif mask_layer is not None:
elif self.mask is not None:
cluster_image = self.generate_cluster_image_from_points(
mask_layer.data, self.cluster_ids, shape=mask_layer.data.shape,
sampling_indices=self.sampling_indices
self.argwheres, self.cluster_ids, shape=self.quantityX.shape,
)
else:
cluster_image = self.generate_cluster_image_from_points(
None, self.cluster_ids, shape=self.quantityX_layer_combo.value.data.shape,
sampling_indices=self.sampling_indices
self.argwheres, self.cluster_ids, shape=self.quantityX_layer_combo.value.data.shape,
)

# if the cluster image layer doesn't yet exist make it
Expand Down Expand Up @@ -1035,19 +1035,12 @@ def generate_cluster_image_from_labels(self, label_image, predictionlist):

return cluster_image

def generate_cluster_image_from_points(self, mask, predictionlist, shape, sampling_indices):

print(len(sampling_indices))

def generate_cluster_image_from_points(self, argwheres, predictionlist, shape):

cluster_image = np.zeros(shape, dtype='uint8')

t0 = time()
if mask is not None:
argwheres = np.argwhere(mask)
else:
# use np.mgrid
argwheres = np.mgrid[0:shape[0], 0:shape[1], 0:shape[2]].reshape(3, -1).T
argwheres = argwheres
print('argwhere time:', time()-t0)
# if sampling_indices is not None:
# argwheres = argwheres[sampling_indices]
Expand Down

0 comments on commit 5331993

Please sign in to comment.