diff --git a/src/bioclip/predict.py b/src/bioclip/predict.py index 6f71dbc..354f7bc 100644 --- a/src/bioclip/predict.py +++ b/src/bioclip/predict.py @@ -346,16 +346,13 @@ def format_species_probs(self, image_path: str, probs: torch.Tensor, k: int = 5) result.append(item) return result - def create_name(self, classification_dict: dict[str, str]) -> str: - return " ".join(classification_dict.values()) - def format_grouped_probs(self, image_path: str, probs: torch.Tensor, rank: Rank, min_prob: float = 1e-9, k: int = 5) -> List[dict[str, float]]: output = collections.defaultdict(float) class_dict_lookup = {} name_to_class_dict = {} for i in torch.nonzero(probs > min_prob).squeeze(): classification_dict = create_classification_dict(self.txt_names[i], rank) - name = self.create_name(classification_dict) + name = join_names(classification_dict) class_dict_lookup[name] = classification_dict output[name] += probs[i] name_to_class_dict[name] = classification_dict