From 524300a7e78ae647bb8f293630dbf2f7bf7cb514 Mon Sep 17 00:00:00 2001 From: Talmo Pereira Date: Thu, 12 Oct 2023 08:57:27 -0700 Subject: [PATCH] Handle case where centroid model doesn't detect any peaks --- sleap/nn/inference.py | 78 ++++++++++++++++++++++++++++++------------- 1 file changed, 54 insertions(+), 24 deletions(-) diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 0cabc91bb..4c1019179 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -408,12 +408,13 @@ def process_batch(ex): ex["frame_ind"] = ex["frame_ind"].numpy().flatten() # Adjust for potential SizeMatcher scaling. - offset_x = ex.get("offset_x", 0) - offset_y = ex.get("offset_y", 0) - ex["instance_peaks"] -= np.reshape([offset_x, offset_y], [-1, 1, 1, 2]) - ex["instance_peaks"] /= np.expand_dims( - np.expand_dims(ex["scale"], axis=1), axis=1 - ) + if ex["instance_peaks"].size > 0: + offset_x = ex.get("offset_x", 0) + offset_y = ex.get("offset_y", 0) + ex["instance_peaks"] -= np.reshape([offset_x, offset_y], [-1, 1, 1, 2]) + ex["instance_peaks"] /= np.expand_dims( + np.expand_dims(ex["scale"], axis=1), axis=1 + ) return ex @@ -795,6 +796,7 @@ def call(self, example_gt: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]: crop_offsets=crop_offsets, centroids=example_gt["centroids"], centroid_vals=centroid_vals, + n_peaks=n_peaks, ) @@ -1907,7 +1909,9 @@ def call(self, inputs): centroid_vals, crop_sample_inds, nrows=samples ) - outputs = dict(centroids=centroids, centroid_vals=centroid_vals) + outputs = dict( + centroids=centroids, centroid_vals=centroid_vals, n_peaks=n_peaks + ) if self.return_confmaps: # Return confidence maps with outputs. cms = tf.RaggedTensor.from_value_rowids( @@ -2081,6 +2085,15 @@ def call( samples = tf.shape(crops)[0] crop_sample_inds = tf.range(samples, dtype=tf.int32) + outputs = {} + + if "centroids" in inputs: + outputs["centroids"] = inputs["centroids"] + if "centroid_vals" in inputs: + outputs["centroid_vals"] = inputs["centroid_vals"] + if "centroid_confmaps" in inputs: + outputs["centroid_confmaps"] = inputs["centroid_confmaps"] + # Preprocess inputs (scaling, padding, colorspace, int to float). crops = self.preprocess(crops) @@ -2140,13 +2153,8 @@ def call( ) # Build outputs. - outputs = {"instance_peaks": peaks, "instance_peak_vals": peak_vals} - if "centroids" in inputs: - outputs["centroids"] = inputs["centroids"] - if "centroid_vals" in inputs: - outputs["centroid_vals"] = inputs["centroid_vals"] - if "centroid_confmaps" in inputs: - outputs["centroid_confmaps"] = inputs["centroid_confmaps"] + outputs["instance_peaks"] = peaks + outputs["instance_peak_vals"] = peak_vals if self.return_confmaps: cms = tf.RaggedTensor.from_value_rowids( cms, crop_sample_inds, nrows=samples @@ -2253,17 +2261,39 @@ def call( crop_output = self.centroid_crop(example) - if isinstance(self.instance_peaks, FindInstancePeaksGroundTruth): - if "instances" in example: - peaks_output = self.instance_peaks(example, crop_output) - else: - raise ValueError( - "Ground truth data was not detected... " - "Please load both models when predicting on non-ground-truth data." - ) + if crop_output["n_peaks"] == 0: + samples = tf.shape(example["image"])[0] + output = { + "centroids": crop_output["centroids"], + "centroid_vals": crop_output["centroid_vals"], + "instance_peak_vals": tf.RaggedTensor.from_value_rowids( + tf.zeros(shape=(0,), dtype=tf.float32), + tf.zeros(shape=(0,), dtype=tf.int32), + nrows=samples, + ), + "instance_peaks": tf.RaggedTensor.from_value_rowids( + tf.zeros(shape=(0, 2), dtype=tf.float32), + tf.zeros(shape=(0,), dtype=tf.int32), + nrows=samples, + ), + } + + if self.instance_peaks.return_confmaps: + output["instance_confmaps"] = tf.zeros((0, 0, 0, 0), dtype=tf.float32) + + return output else: - peaks_output = self.instance_peaks(crop_output) - return peaks_output + if isinstance(self.instance_peaks, FindInstancePeaksGroundTruth): + if "instances" in example: + peaks_output = self.instance_peaks(example, crop_output) + else: + raise ValueError( + "Ground truth data was not detected... " + "Please load both models when predicting on non-ground-truth data." + ) + else: + peaks_output = self.instance_peaks(crop_output) + return peaks_output @attr.s(auto_attribs=True)