diff --git a/sleap/info/write_tracking_h5.py b/sleap/info/write_tracking_h5.py index 2b714eeb5..4fb7b6e1e 100644 --- a/sleap/info/write_tracking_h5.py +++ b/sleap/info/write_tracking_h5.py @@ -177,10 +177,10 @@ def get_occupancy_and_points_matrices( occupancy_matrix[track_i, frame_i] = 1 locations_matrix[frame_i, ..., track_i] = inst.numpy() + tracking_scores[frame_i, ..., track_i] = inst.tracking_score if type(inst) == PredictedInstance: point_scores[frame_i, ..., track_i] = inst.scores instance_scores[frame_i, ..., track_i] = inst.score - tracking_scores[frame_i, ..., track_i] = inst.tracking_score return ( occupancy_matrix, diff --git a/sleap/instance.py b/sleap/instance.py index 382ececf2..6b2595882 100644 --- a/sleap/instance.py +++ b/sleap/instance.py @@ -357,6 +357,7 @@ class Instance: frame: A back reference to the :class:`LabeledFrame` that this :class:`Instance` belongs to. This field is set when instances are added to :class:`LabeledFrame` objects. + tracking_score: The instance-level track matching score. """ skeleton: Skeleton = attr.ib() @@ -369,6 +370,8 @@ class Instance: # The underlying Point array type that this instances point array should be. _point_array_type = PointArray + tracking_score: float = attr.ib(default=0.0, converter=float) + @from_predicted.validator def _validate_from_predicted_( self, attribute, from_predicted: Optional["PredictedInstance"] @@ -662,7 +665,8 @@ def __repr__(self) -> str: f"video={self.video}, " f"frame_idx={self.frame_idx}, " f"points=[{pts}], " - f"track={self.track}" + f"track={self.track}, " + f"tracking_score={self.tracking_score:.2f}" ")" ) @@ -998,11 +1002,9 @@ class PredictedInstance(Instance): Args: score: The instance-level grouping prediction score. - tracking_score: The instance-level track matching score. """ score: float = attr.ib(default=0.0, converter=float) - tracking_score: float = attr.ib(default=0.0, converter=float) # The underlying Point array type that this instances point array should be. _point_array_type = PredictedPointArray @@ -1453,7 +1455,7 @@ def __len__(self) -> int: """Return number of instances associated with frame.""" return len(self.instances) - def __getitem__(self, index) -> Instance: + def __getitem__(self, index) -> Union[Instance, PredictedInstance]: """Return instance (retrieved by index).""" return self.instances.__getitem__(index) diff --git a/sleap/io/format/hdf5.py b/sleap/io/format/hdf5.py index 55a30d74f..e17941762 100644 --- a/sleap/io/format/hdf5.py +++ b/sleap/io/format/hdf5.py @@ -28,11 +28,12 @@ class LabelsV1Adaptor(format.adaptor.Adaptor): - FORMAT_ID = 1.2 + FORMAT_ID = 1.3 # 1.0 points with gridline coordinates, top left corner at (0, 0) # 1.1 points with midpixel coordinates, top left corner at (-0.5, -0.5) - # 1.2 adds track score to read and write functions + # 1.2 adds tracking score for PredictedInstance to read and write functions + # 1.3 adds tracking score for Instance to read and write functions @property def handles(self): @@ -222,6 +223,9 @@ def cast_as_compound(arr, dtype): skeleton=skeleton, track=track, points=points[i["point_id_start"] : i["point_id_end"]], + tracking_score=i["tracking_score"] + if (format_id is not None and format_id >= 1.3) + else 0.0, ) else: # PredictedInstance instance = PredictedInstance( @@ -481,11 +485,9 @@ def append_unique(old, new): if instance_type is PredictedInstance: score = instance.score pid = pred_point_id + pred_point_id_offset - tracking_score = instance.tracking_score else: score = np.nan pid = point_id + point_id_offset - tracking_score = np.nan # Keep track of any from_predicted instance links, we will # insert the correct instance_id in the dataset after we are @@ -494,6 +496,8 @@ def append_unique(old, new): instances_with_from_predicted.append(instance_id) instances_from_predicted.append(instance.from_predicted) + tracking_score = instance.tracking_score + # Copy all the data instances[instance_id] = ( instance_id + instance_id_offset, diff --git a/sleap/io/format/nix.py b/sleap/io/format/nix.py index 4c39ec8b6..fa2830464 100644 --- a/sleap/io/format/nix.py +++ b/sleap/io/format/nix.py @@ -239,13 +239,12 @@ def chunked_write( positions[index, :, node_map[m]] = np.array([np.nan, np.nan]) centroids[index, :] = inst.centroid + trackscore[index] = inst.tracking_score if hasattr(inst, "score"): instscore[index] = inst.score - trackscore[index] = inst.tracking_score pointscore[index, :] = inst.scores else: instscore[index] = 0.0 - trackscore[index] = 0.0 pointscore[index, :] = dflt_pointscore frameid_array[start:end] = indices[: end - start] diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index 231b004f5..73a18316a 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -1563,17 +1563,28 @@ def run_tracker(frames: List[LabeledFrame], tracker: BaseTracker) -> List[Labele for inst in lf.instances: inst.track = None - track_args = dict(untracked_instances=lf.instances) + # Prefer user instances over predicted instances + instances = [] + if lf.has_user_instances: + instances_to_track = lf.user_instances + if lf.has_predicted_instances: + instances = lf.predicted_instances + else: + instances_to_track = lf.predicted_instances + + track_args = {"untracked_instances": instances_to_track} + if tracker.uses_image: track_args["img"] = lf.video[lf.frame_idx] else: track_args["img"] = None track_args["img_hw"] = lf.image.shape[-3:-1] + instances.extend(tracker.track(**track_args)) new_lf = LabeledFrame( frame_idx=lf.frame_idx, video=lf.video, - instances=tracker.track(**track_args), + instances=instances, ) new_lfs.append(new_lf) diff --git a/tests/io/test_dataset.py b/tests/io/test_dataset.py index d71d4cc83..0cb6e0fee 100644 --- a/tests/io/test_dataset.py +++ b/tests/io/test_dataset.py @@ -10,7 +10,6 @@ from sleap.instance import Instance, Point, LabeledFrame, PredictedInstance, Track from sleap.io.video import Video, MediaVideo from sleap.io.dataset import Labels, load_file -from sleap.io.legacy import load_labels_json_old from sleap.io.format.ndx_pose import NDXPoseAdaptor from sleap.io.format import filehandle from sleap.gui.suggestions import VideoFrameSuggestions, SuggestionFrame @@ -748,6 +747,36 @@ def test_dont_unify_skeletons(): labels.to_dict() +def test_instance_cattr(centered_pair_predictions: Labels, tmpdir: str): + labels = centered_pair_predictions + lf = labels.labeled_frames[0] + pred_inst: PredictedInstance = lf[0] + skeleton = pred_inst.skeleton + track = pred_inst.track + + # Initialize Instance + instance = Instance.from_pointsarray( + points=pred_inst.numpy(), skeleton=skeleton, track=track + ) + instance.from_predicted = pred_inst + assert instance.tracking_score == 0.0 + labels.add_instance(lf, instance) + + instance.tracking_score = 0.5 + pred_inst.tracking_score = 0.7 + + filename = str(PurePath(tmpdir, "labels.slp")) + labels.save(filename) + + labels_loaded = sleap.load_file(filename) + lf_loaded = labels_loaded.labeled_frames[0] + pred_inst_loaded = lf_loaded.predicted_instances[0] + instance_loaded = lf_loaded.user_instances[0] + + assert round(pred_inst_loaded.tracking_score, 1) == pred_inst.tracking_score + assert round(instance_loaded.tracking_score, 1) == instance.tracking_score + + def test_instance_access(): labels = Labels()