diff --git a/docs/accuracy_calculation.md b/docs/accuracy_calculation.md index 38a2a6c2..2bd2a995 100644 --- a/docs/accuracy_calculation.md +++ b/docs/accuracy_calculation.md @@ -7,6 +7,7 @@ from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator AccuracyCalculator(include=(), exclude=(), avg_of_avgs=False, + return_per_class=False, k=None, label_comparison_fn=None, device=None, @@ -18,6 +19,7 @@ AccuracyCalculator(include=(), * **include**: Optional. A list or tuple of strings, which are the names of metrics you want to calculate. If left empty, all default metrics will be calculated. * **exclude**: Optional. A list or tuple of strings, which are the names of metrics you **do not** want to calculate. * **avg_of_avgs**: If True, the average accuracy per class is computed, and then the average of those averages is returned. This can be useful if your dataset has unbalanced classes. If False, the global average will be returned. +* **return_per_class**: If True, the average accuracy per class is computed and returned. * **k**: The number of nearest neighbors that will be retrieved for metrics that require k-nearest neighbors. The allowed values are: * ```None```. This means k will be set to the total number of reference embeddings. * An integer greater than 0. This means k will be set to the input integer. @@ -72,6 +74,12 @@ def get_accuracy(self, Note that labels can be 2D if a [custom label comparison function](#using-a-custom-label-comparison-function) is used. +### Lone query labels +If some query labels don't appear in the reference set, then it's impossible for those labels to have non-zero k-nn accuracy. Zero accuracy for these labels doesn't indicate anything about the quality of the embedding space. So these lone query labels are excluded from k-nn based accuracy calculations. + +For example, if the input ```query_labels``` is ```[0,0,1,1]``` and ```reference_labels``` is ```[1,1,1,2,2]```, then 0 is considered a lone query label. + + ### CPU/GPU usage * If you installed ```faiss-cpu``` then the CPU will always be used. @@ -100,6 +108,10 @@ If your dataset is large, you might find the k-nn search is very slow. This is b - [See section 3.2 of A Metric Learning Reality Check](https://arxiv.org/pdf/2003.08505.pdf) +- **mean_reciprocal_rank**: + + - [Slides from Stanford](https://web.stanford.edu/class/cs276/handouts/EvaluationNew-handout-1-per.pdf) + - **precision_at_1**: - Fancy way of saying "is the 1st nearest neighbor correct?" diff --git a/docs/extend/losses.md b/docs/extend/losses.md index fb7808ba..4add13d0 100644 --- a/docs/extend/losses.md +++ b/docs/extend/losses.md @@ -7,7 +7,7 @@ from pytorch_metric_learning.losses import BaseMetricLossFunction import torch class BarebonesLoss(BaseMetricLossFunction): - def compute_loss(self, embeddings, labels, indices_tuple): + def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels): # perform some calculation # some_loss = torch.mean(embeddings) @@ -33,7 +33,7 @@ from pytorch_metric_learning.utils import loss_and_miner_utils as lmu import torch class FullFeaturedLoss(BaseMetricLossFunction): - def compute_loss(self, embeddings, labels, indices_tuple): + def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels): indices_tuple = lmu.convert_to_triplets(indices_tuple, labels) anchors, positives, negatives = indices_tuple if len(anchors) == 0: diff --git a/docs/imgs/vicreg_covariance.png b/docs/imgs/vicreg_covariance.png new file mode 100644 index 00000000..15fc2835 Binary files /dev/null and b/docs/imgs/vicreg_covariance.png differ diff --git a/docs/imgs/vicreg_invariance.png b/docs/imgs/vicreg_invariance.png new file mode 100644 index 00000000..f2b43219 Binary files /dev/null and b/docs/imgs/vicreg_invariance.png differ diff --git a/docs/imgs/vicreg_total.png b/docs/imgs/vicreg_total.png new file mode 100644 index 00000000..10b09c7c Binary files /dev/null and b/docs/imgs/vicreg_total.png differ diff --git a/docs/imgs/vicreg_variance.png b/docs/imgs/vicreg_variance.png new file mode 100644 index 00000000..01150bb2 Binary files /dev/null and b/docs/imgs/vicreg_variance.png differ diff --git a/docs/imgs/vicreg_variance_detail.png b/docs/imgs/vicreg_variance_detail.png new file mode 100644 index 00000000..a89aeca6 Binary files /dev/null and b/docs/imgs/vicreg_variance_detail.png differ diff --git a/docs/losses.md b/docs/losses.md index 97a366d0..0f88597e 100644 --- a/docs/losses.md +++ b/docs/losses.md @@ -152,6 +152,31 @@ def compute_loss(self, embeddings, labels, indices_tuple=None): ``` +## CentroidTripletLoss +[On the Unreasonable Effectiveness of Centroids in Image Retrieval](https://arxiv.org/pdf/2104.13643.pdf){target=_blank} + +This is like [TripletMarginLoss](losses.md#tripletmarginloss), except the positives and negatives are class centroids. + +```python +losses.CentroidTripletLoss(margin=0.05, + swap=False, + smooth_loss=False, + triplets_per_anchor="all", + **kwargs) +``` +**Parameters**: + +See [TripletMarginLoss](losses.md#tripletmarginloss) + +**Default distance**: + +See [TripletMarginLoss](losses.md#tripletmarginloss) + +**Default reducer**: + + - [AvgNonZeroReducer](reducers.md#avgnonzeroreducer) + + ## CircleLoss [Circle Loss: A Unified Perspective of Pair Similarity Optimization](https://arxiv.org/pdf/2002.10857.pdf){target=_blank} @@ -1024,4 +1049,60 @@ Extended by: * [ProxyAnchorLoss](losses.md#proxyanchorloss) * [ProxyNCALoss](losses.md#proxyncaloss) * [SoftTripleLoss](losses.md#softtripleloss) -* [SphereFaceLoss](losses.md#spherefaceloss) \ No newline at end of file +* [SphereFaceLoss](losses.md#spherefaceloss) + + +## VICRegLoss +[VICReg: Variance-Invariance-Covariance Regularization for Self-Supervised Learning](https://arxiv.org/pdf/2105.04906.pdf){target=_blank} +```python +losses.VICRegLoss(invariance_lambda=25, + variance_mu=25, + covariance_v=1, + eps=1e-4, + **kwargs) +``` + +**Usage**: + +Unlike other loss functions, ```VICRegLoss``` does not accept ```labels``` or ```indices_tuple```: + +```python +loss_fn = VICRegLoss() +loss = loss_fn(embeddings, ref_emb) +``` + +**Equations**: + +![vicreg_total](imgs/vicreg_total.png){: style="height:40px"} + +where + +![vicreg_total](imgs/vicreg_invariance.png){: style="height:70px"} + +![vicreg_total](imgs/vicreg_variance.png){: style="height:90px"} + +![vicreg_total](imgs/vicreg_variance_detail.png){: style="height:40px"} + +![vicreg_total](imgs/vicreg_covariance.png){: style="height:70px"} + +**Parameters**: + +* **invariance_lambda**: The weight of the invariance term. +* **variance_mu**: The weight of the variance term. +* **covariance_v**: The weight of the covariance term. +* **eps**: Small scalar to prevent numerical instability. + +**Default distance**: + + - Not applicable. You cannot pass in a distance function. + +**Default reducer**: + + - [MeanReducer](reducers.md#meanreducer) + +**Reducer input**: + +* **invariance_loss**: The MSE loss between ```embeddings[i]``` and ```ref_emb[i]```. Reduction type is ```"element"```. +* **variance_loss1**: The variance loss for ```embeddings```. Reduction type is ```"element"```. +* **variance_loss2**: The variance loss for ```ref_emb```. Reduction type is ```"element"```. +* **covariance_loss**: The covariance loss. This loss is already reduced to a single value. diff --git a/src/pytorch_metric_learning/__init__.py b/src/pytorch_metric_learning/__init__.py index 5becc17c..6849410a 100644 --- a/src/pytorch_metric_learning/__init__.py +++ b/src/pytorch_metric_learning/__init__.py @@ -1 +1 @@ -__version__ = "1.0.0" +__version__ = "1.1.0" diff --git a/src/pytorch_metric_learning/losses/__init__.py b/src/pytorch_metric_learning/losses/__init__.py index 54911ae2..ffadf54c 100644 --- a/src/pytorch_metric_learning/losses/__init__.py +++ b/src/pytorch_metric_learning/losses/__init__.py @@ -1,6 +1,7 @@ from .angular_loss import AngularLoss from .arcface_loss import ArcFaceLoss from .base_metric_loss_function import BaseMetricLossFunction, MultipleLosses +from .centroid_triplet_loss import CentroidTripletLoss from .circle_loss import CircleLoss from .contrastive_loss import ContrastiveLoss from .cosface_loss import CosFaceLoss @@ -25,3 +26,4 @@ from .supcon_loss import SupConLoss from .triplet_margin_loss import TripletMarginLoss from .tuplet_margin_loss import TupletMarginLoss +from .vicreg_loss import VICRegLoss diff --git a/src/pytorch_metric_learning/losses/base_metric_loss_function.py b/src/pytorch_metric_learning/losses/base_metric_loss_function.py index b69b4333..9aa1c91e 100644 --- a/src/pytorch_metric_learning/losses/base_metric_loss_function.py +++ b/src/pytorch_metric_learning/losses/base_metric_loss_function.py @@ -12,7 +12,7 @@ class BaseMetricLossFunction( EmbeddingRegularizerMixin, ModuleWithRecordsReducerAndDistance ): - def compute_loss(self, embeddings, labels, indices_tuple=None): + def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels): """ This has to be implemented and is what actually computes the loss. """ diff --git a/src/pytorch_metric_learning/losses/centroid_triplet_loss.py b/src/pytorch_metric_learning/losses/centroid_triplet_loss.py new file mode 100644 index 00000000..5198a0b4 --- /dev/null +++ b/src/pytorch_metric_learning/losses/centroid_triplet_loss.py @@ -0,0 +1,180 @@ +from collections import defaultdict + +import numpy as np +import torch + +from ..reducers import AvgNonZeroReducer +from ..utils import common_functions as c_f +from ..utils import loss_and_miner_utils as lmu +from .base_metric_loss_function import BaseMetricLossFunction +from .triplet_margin_loss import TripletMarginLoss + + +def concat_indices_tuple(x): + return [torch.cat(y) for y in zip(*x)] + + +class CentroidTripletLoss(BaseMetricLossFunction): + def __init__( + self, + margin=0.05, + swap=False, + smooth_loss=False, + triplets_per_anchor="all", + **kwargs + ): + super().__init__(**kwargs) + self.triplet_loss = TripletMarginLoss( + margin=margin, + swap=swap, + smooth_loss=smooth_loss, + triplets_per_anchor=triplets_per_anchor, + **kwargs + ) + + def compute_loss( + self, embeddings, labels, indices_tuple=None, ref_emb=None, ref_labels=None + ): + c_f.indices_tuple_not_supported(indices_tuple) + c_f.ref_not_supported(embeddings, labels, ref_emb, ref_labels) + """ + "During training stage each mini-batch contains ๐‘ƒ distinct item + classes with ๐‘€ samples per class, resulting in batch size of ๐‘ƒ ร— ๐‘€." + """ + masks, class_masks, labels_list, query_indices = self.create_masks_train(labels) + + P = len(labels_list) + M = max([len(instances) for instances in labels_list]) + DIM = embeddings.size(-1) + + """ + "...each sample from S๐‘˜ is used as a query ๐‘ž๐‘˜ and the rest + ๐‘€ โˆ’1 samples are used to build a prototype centroid" + i.e. for each class k of M items, we make M pairs of (query, centroid), + making a total of P*M total pairs. + masks = (M*P x len(embeddings)) matrix + labels_list[i] = indicies of embeddings belonging to ith class + centroids_emd.shape == (M*P, DIM) + i.e. centroids_emb[0] == centroid vector for 0th class, where the first embedding is the query vector + centroids_emb[1] == centroid vector for 0th class, where the second embedding is the query vector + centroids_emb[M+1] == centroid vector for 1th class, where the first embedding is the query vector + """ + + masks_float = masks.type(embeddings.type()).to(embeddings.device) + class_masks_float = class_masks.type(embeddings.type()).to(embeddings.device) + inst_counts = masks_float.sum(-1) + class_inst_counts = class_masks_float.sum(-1) + + valid_mask = inst_counts > 0 + padded = masks_float.unsqueeze(-1) * embeddings.unsqueeze(0) + class_padded = class_masks_float.unsqueeze(-1) * embeddings.unsqueeze(0) + + positive_centroids_emb = padded.sum(-2) / inst_counts.masked_fill( + inst_counts == 0, 1 + ).unsqueeze(-1) + + negative_centroids_emb = class_padded.sum(-2) / class_inst_counts.masked_fill( + class_inst_counts == 0, 1 + ).unsqueeze(-1) + + query_indices = torch.tensor(query_indices).to(embeddings.device) + query_embeddings = embeddings.index_select(0, query_indices) + query_labels = labels.index_select(0, query_indices) + assert positive_centroids_emb.size() == (M * P, DIM) + assert negative_centroids_emb.size() == (P, DIM) + assert query_embeddings.size() == (M * P, DIM) + + query_indices = query_indices.view((P, M)).transpose(0, 1) + query_embeddings = query_embeddings.view((P, M, -1)).transpose(0, 1) + query_labels = query_labels.view((P, M)).transpose(0, 1) + positive_centroids_emb = positive_centroids_emb.view((P, M, -1)).transpose(0, 1) + valid_mask = valid_mask.view((P, M)).transpose(0, 1) + + labels_collect = [] + embeddings_collect = [] + tuple_indices_collect = [] + starting_idx = 0 + for inst_idx in range(M): + one_mask = valid_mask[inst_idx] + if torch.sum(one_mask) > 1: + anchors = query_embeddings[inst_idx][one_mask] + pos_centroids = positive_centroids_emb[inst_idx][one_mask] + one_labels = query_labels[inst_idx][one_mask] + + embeddings_concat = torch.cat( + (anchors, pos_centroids, negative_centroids_emb) + ) + labels_concat = torch.cat( + (one_labels, one_labels, query_labels[inst_idx]) + ) + indices_tuple = lmu.get_all_triplets_indices(labels_concat) + + """ + Right now indices tuple considers all embeddings in + embeddings_concat as anchors, pos_example, neg_examples. + + 1. make only query vectors be anchor vectors + 2. make pos_centroids be only used as a positive example + 3. negative as so + """ + # make only query vectors be anchor vectors + indices_tuple = [x[: len(x) // 3] + starting_idx for x in indices_tuple] + + # make only pos_centroids be postive examples + indices_tuple = [x.view(len(one_labels), -1) for x in indices_tuple] + indices_tuple = [x.chunk(2, dim=1)[0] for x in indices_tuple] + + # make only neg_centroids be negative examples + indices_tuple = [ + x.chunk(len(one_labels), dim=1)[-1].flatten() for x in indices_tuple + ] + + tuple_indices_collect.append(indices_tuple) + embeddings_collect.append(embeddings_concat) + labels_collect.append(labels_concat) + starting_idx += len(labels_concat) + + indices_tuple = concat_indices_tuple(tuple_indices_collect) + + if len(indices_tuple) == 0: + return self.zero_losses() + + final_embeddings = torch.cat(embeddings_collect) + final_labels = torch.cat(labels_collect) + + loss = self.triplet_loss.compute_loss( + final_embeddings, final_labels, indices_tuple, ref_emb=None, ref_labels=None + ) + return loss + + def create_masks_train(self, class_labels): + labels_dict = defaultdict(list) + class_labels = class_labels.detach().cpu().numpy() + for idx, pid in enumerate(class_labels): + labels_dict[pid].append(idx) + + unique_classes = list(labels_dict.keys()) + labels_list = list(labels_dict.values()) + lens_list = [len(item) for item in labels_list] + lens_list_cs = np.cumsum(lens_list) + + M = max(len(instances) for instances in labels_list) + P = len(unique_classes) + + query_indices = [] + class_masks = torch.zeros((P, len(class_labels)), dtype=bool) + masks = torch.zeros((M * P, len(class_labels)), dtype=bool) + for class_idx, class_insts in enumerate(labels_list): + class_masks[class_idx, class_insts] = 1 + for instance_idx in range(M): + matrix_idx = class_idx * M + instance_idx + if instance_idx < len(class_insts): + query_indices.append(class_insts[instance_idx]) + ones = class_insts[:instance_idx] + class_insts[instance_idx + 1 :] + masks[matrix_idx, ones] = 1 + else: + query_indices.append(0) + return masks, class_masks, labels_list, query_indices + + def get_default_reducer(self): + return AvgNonZeroReducer() diff --git a/src/pytorch_metric_learning/losses/vicreg_loss.py b/src/pytorch_metric_learning/losses/vicreg_loss.py new file mode 100644 index 00000000..0be29b23 --- /dev/null +++ b/src/pytorch_metric_learning/losses/vicreg_loss.py @@ -0,0 +1,93 @@ +import torch +import torch.nn.functional as F + +from ..utils import common_functions as c_f +from .base_metric_loss_function import BaseMetricLossFunction + + +class VICRegLoss(BaseMetricLossFunction): + def __init__( + self, invariance_lambda=25, variance_mu=25, covariance_v=1, eps=1e-4, **kwargs + ): + if "distance" in kwargs: + raise ValueError("VICRegLoss cannot use a distance function") + if "embedding_regularizer" in kwargs: + raise ValueError("VICRegLoss cannot use a regularizer") + super().__init__(**kwargs) + """ + The overall loss function is a weighted average of the invariance, variance and covariance terms: + L(Z, Z') = ฮปs(Z, Z') + ยต[v(Z) + v(Z')] + ฮฝ[c(Z) + c(Z')], + where ฮป, ยต and ฮฝ are hyper-parameters controlling the importance of each term in the loss. + """ + self.invariance_lambda = invariance_lambda + self.variance_mu = variance_mu + self.covariance_v = covariance_v + self.eps = eps + + def forward(self, embeddings, ref_emb): + """ + x should have shape (N, embedding_size) + """ + self.reset_stats() + loss_dict = self.compute_loss(embeddings, ref_emb) + return self.reducer( + loss_dict, embeddings, c_f.torch_arange_from_size(embeddings) + ) + + def compute_loss(self, embeddings, ref_emb): + invariance_loss = self.invariance_lambda * self.invariance_loss( + embeddings, ref_emb + ) + variance_loss1, variance_loss2 = self.variance_loss(embeddings, ref_emb) + covariance_loss = self.covariance_v * self.covariance_loss(embeddings, ref_emb) + var_loss_size = c_f.torch_arange_from_size(variance_loss1) + return { + "invariance_loss": { + "losses": invariance_loss, + "indices": c_f.torch_arange_from_size(invariance_loss), + "reduction_type": "element", + }, + "variance_loss1": { + "losses": self.variance_mu * variance_loss1, + "indices": var_loss_size, + "reduction_type": "element", + }, + "variance_loss2": { + "losses": self.variance_mu * variance_loss2, + "indices": var_loss_size, + "reduction_type": "element", + }, + "covariance_loss": { + "losses": covariance_loss, + "indices": None, + "reduction_type": "already_reduced", + }, + } + + def invariance_loss(self, emb, ref_emb): + return torch.mean((emb - ref_emb) ** 2, dim=1) + + def variance_loss(self, emb, ref_emb): + std_emb = torch.sqrt(emb.var(dim=0) + self.eps) + std_ref_emb = torch.sqrt(ref_emb.var(dim=0) + self.eps) + return F.relu(1 - std_emb), F.relu(1 - std_ref_emb) + + def covariance_loss(self, emb, ref_emb): + _, D = emb.size() + cov_emb = torch.cov(emb.T) + cov_ref_emb = torch.cov(ref_emb.T) + + diag = torch.eye(D, device=cov_emb.device) + cov_loss = ( + cov_emb[~diag.bool()].pow_(2).sum() / D + + cov_ref_emb[~diag.bool()].pow_(2).sum() / D + ) + return cov_loss + + def _sub_loss_names(self): + return [ + "invariance_loss", + "variance_loss1", + "variance_loss2", + "covariance_loss", + ] diff --git a/src/pytorch_metric_learning/utils/accuracy_calculator.py b/src/pytorch_metric_learning/utils/accuracy_calculator.py index a4858f08..81b236bf 100644 --- a/src/pytorch_metric_learning/utils/accuracy_calculator.py +++ b/src/pytorch_metric_learning/utils/accuracy_calculator.py @@ -10,9 +10,15 @@ EQUALITY = torch.eq -def maybe_get_avg_of_avgs(accuracy_per_sample, sample_labels, avg_of_avgs): - if avg_of_avgs: - unique_labels = torch.unique(sample_labels, dim=0) +def get_unique_labels(labels): + return torch.unique(labels, dim=0) + + +def maybe_get_avg_of_avgs( + accuracy_per_sample, sample_labels, avg_of_avgs, return_per_class +): + if avg_of_avgs or return_per_class: + unique_labels = get_unique_labels(sample_labels) mask = c_f.torch_all_from_dim_to_end( sample_labels == unique_labels.unsqueeze(1), 2 ) @@ -20,6 +26,8 @@ def maybe_get_avg_of_avgs(accuracy_per_sample, sample_labels, avg_of_avgs): acc_sum_per_class = torch.sum(accuracy_per_sample.unsqueeze(1) * mask, dim=0) mask_sum_per_class = torch.sum(mask, dim=0) average_per_class = acc_sum_per_class / mask_sum_per_class + if return_per_class: + return average_per_class.cpu().tolist() return torch.mean(average_per_class).item() return torch.mean(accuracy_per_sample).item() @@ -48,6 +56,7 @@ def r_precision( embeddings_come_from_same_source, label_counts, avg_of_avgs, + return_per_class, label_comparison_fn, ): relevance_mask = get_relevance_mask( @@ -64,7 +73,9 @@ def r_precision( c_f.to_dtype(matches_per_row, dtype=torch.float64) / max_possible_matches_per_row ) - return maybe_get_avg_of_avgs(accuracy_per_sample, gt_labels, avg_of_avgs) + return maybe_get_avg_of_avgs( + accuracy_per_sample, gt_labels, avg_of_avgs, return_per_class + ) def mean_average_precision( @@ -72,6 +83,7 @@ def mean_average_precision( gt_labels, embeddings_come_from_same_source, avg_of_avgs, + return_per_class, label_comparison_fn, relevance_mask=None, at_r=False, @@ -97,7 +109,34 @@ def mean_average_precision( max_possible_matches_per_row = torch.sum(equality, dim=1) max_possible_matches_per_row[max_possible_matches_per_row == 0] = 1 accuracy_per_sample = summed_precision_per_row / max_possible_matches_per_row - return maybe_get_avg_of_avgs(accuracy_per_sample, gt_labels, avg_of_avgs) + return maybe_get_avg_of_avgs( + accuracy_per_sample, gt_labels, avg_of_avgs, return_per_class + ) + + +def mean_reciprocal_rank( + knn_labels, + gt_labels, + avg_of_avgs, + return_per_class, + label_comparison_fn, +): + device = gt_labels.device + is_same_label = label_comparison_fn(gt_labels, knn_labels) + + # find & remove caeses where it has 0 correct results + sum_per_row = is_same_label.sum(-1) + zero_remove_mask = sum_per_row > 0 + indices = torch.arange(is_same_label.shape[1], 0, -1, device=device) + tmp = is_same_label * indices + indices = torch.argmax(tmp, 1, keepdim=True) + 1.0 + + indices[zero_remove_mask] = 1.0 / indices[zero_remove_mask] + indices[~zero_remove_mask] = 0.0 + + indices = indices.flatten() + + return maybe_get_avg_of_avgs(indices, gt_labels, avg_of_avgs, return_per_class) def mean_average_precision_at_r( @@ -106,6 +145,7 @@ def mean_average_precision_at_r( embeddings_come_from_same_source, label_counts, avg_of_avgs, + return_per_class, label_comparison_fn, ): relevance_mask = get_relevance_mask( @@ -120,23 +160,28 @@ def mean_average_precision_at_r( gt_labels, embeddings_come_from_same_source, avg_of_avgs, + return_per_class, label_comparison_fn, relevance_mask=relevance_mask, at_r=True, ) -def precision_at_k(knn_labels, gt_labels, k, avg_of_avgs, label_comparison_fn): +def precision_at_k( + knn_labels, gt_labels, k, avg_of_avgs, return_per_class, label_comparison_fn +): curr_knn_labels = knn_labels[:, :k] same_label = label_comparison_fn(gt_labels, curr_knn_labels) accuracy_per_sample = ( c_f.to_dtype(torch.sum(same_label, dim=1), dtype=torch.float64) / k ) - return maybe_get_avg_of_avgs(accuracy_per_sample, gt_labels, avg_of_avgs) + return maybe_get_avg_of_avgs( + accuracy_per_sample, gt_labels, avg_of_avgs, return_per_class + ) def get_label_match_counts(query_labels, reference_labels, label_comparison_fn): - unique_query_labels = torch.unique(query_labels, dim=0) + unique_query_labels = get_unique_labels(query_labels) if label_comparison_fn is EQUALITY: comparison = unique_query_labels[:, None] == reference_labels match_counts = torch.sum(c_f.torch_all_from_dim_to_end(comparison, 2), dim=1) @@ -191,12 +236,19 @@ def try_getting_not_lone_labels(knn_labels, query_labels, not_lone_query_mask): ) +def zero_accuracy(unique_labels, return_per_class): + if return_per_class: + return [0 for _ in range(len(unique_labels))] + return 0 + + class AccuracyCalculator: def __init__( self, include=(), exclude=(), avg_of_avgs=False, + return_per_class=False, k=None, label_comparison_fn=None, device=None, @@ -212,7 +264,12 @@ def __init__( self.check_primary_metrics(include, exclude) self.original_function_dict = self.get_function_dict(include, exclude) self.curr_function_dict = self.get_function_dict() + + if avg_of_avgs and return_per_class: + raise ValueError("avg_of_avgs and return_per_class are mutually exclusive") self.avg_of_avgs = avg_of_avgs + self.return_per_class = return_per_class + self.device = c_f.use_cuda_if_available() if device is None else device self.knn_func = FaissKNN() if knn_func is None else knn_func self.kmeans_func = ( @@ -277,18 +334,19 @@ def calculate_AMI(self, query_labels, cluster_labels, **kwargs): return adjusted_mutual_info_score(query_labels, cluster_labels) def calculate_precision_at_1( - self, knn_labels, query_labels, not_lone_query_mask, **kwargs + self, knn_labels, query_labels, not_lone_query_mask, label_counts, **kwargs ): knn_labels, query_labels = try_getting_not_lone_labels( knn_labels, query_labels, not_lone_query_mask ) if knn_labels is None: - return 0 + return zero_accuracy(label_counts[0], self.return_per_class) return precision_at_k( knn_labels, query_labels[:, None], 1, self.avg_of_avgs, + self.return_per_class, self.label_comparison_fn, ) @@ -305,13 +363,14 @@ def calculate_mean_average_precision_at_r( knn_labels, query_labels, not_lone_query_mask ) if knn_labels is None: - return 0 + return zero_accuracy(label_counts[0], self.return_per_class) return mean_average_precision_at_r( knn_labels, query_labels[:, None], embeddings_come_from_same_source, label_counts, self.avg_of_avgs, + self.return_per_class, self.label_comparison_fn, ) @@ -321,19 +380,43 @@ def calculate_mean_average_precision( query_labels, not_lone_query_mask, embeddings_come_from_same_source, + label_counts, **kwargs, ): knn_labels, query_labels = try_getting_not_lone_labels( knn_labels, query_labels, not_lone_query_mask ) if knn_labels is None: - return 0 + return zero_accuracy(label_counts[0], self.return_per_class) return mean_average_precision( knn_labels, query_labels[:, None], embeddings_come_from_same_source, self.avg_of_avgs, + self.return_per_class, + self.label_comparison_fn, + ) + + def calculate_mean_reciprocal_rank( + self, + knn_labels, + query_labels, + not_lone_query_mask, + label_counts, + **kwargs, + ): + knn_labels, query_labels = try_getting_not_lone_labels( + knn_labels, query_labels, not_lone_query_mask + ) + if knn_labels is None: + return zero_accuracy(label_counts[0], self.return_per_class) + + return mean_reciprocal_rank( + knn_labels, + query_labels[:, None], + self.avg_of_avgs, + self.return_per_class, self.label_comparison_fn, ) @@ -350,13 +433,14 @@ def calculate_r_precision( knn_labels, query_labels, not_lone_query_mask ) if knn_labels is None: - return 0 + return zero_accuracy(label_counts[0], self.return_per_class) return r_precision( knn_labels, query_labels[:, None], embeddings_come_from_same_source, label_counts, self.avg_of_avgs, + self.return_per_class, self.label_comparison_fn, ) diff --git a/src/pytorch_metric_learning/utils/common_functions.py b/src/pytorch_metric_learning/utils/common_functions.py index eb95b474..ecf825c7 100644 --- a/src/pytorch_metric_learning/utils/common_functions.py +++ b/src/pytorch_metric_learning/utils/common_functions.py @@ -507,6 +507,11 @@ def ref_not_supported(embeddings, labels, ref_emb, ref_labels): raise ValueError("ref_emb is not supported for this loss function") +def indices_tuple_not_supported(indices_tuple): + if indices_tuple is not None: + raise ValueError("indices_tuple is not supported for this loss function") + + def concatenate_indices_tuples(it1, it2): return tuple([torch.cat([x, to_device(y, x)], dim=0) for x, y in zip(it1, it2)]) diff --git a/tests/losses/test_centroid_triplet_loss.py b/tests/losses/test_centroid_triplet_loss.py new file mode 100644 index 00000000..0028b858 --- /dev/null +++ b/tests/losses/test_centroid_triplet_loss.py @@ -0,0 +1,246 @@ +import unittest +from itertools import chain + +import torch + +from pytorch_metric_learning.distances import CosineSimilarity +from pytorch_metric_learning.losses import CentroidTripletLoss +from pytorch_metric_learning.reducers import MeanReducer + +from .. import TEST_DEVICE, TEST_DTYPES +from ..zzz_testing_utils.testing_utils import angle_to_coord + + +def normalize(embeddings): + embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=0) + return embeddings + + +class TestCentroidTripletLoss(unittest.TestCase): + def test_centroid_triplet_loss(self): + for dtype in TEST_DTYPES: + embedding_angles = [0, 10, 20, 30, 40, 50] + centroid_makers = [ + [[10, 20], [30, 40, 50]], + [[0, 20], [30, 40, 50]], + [[0, 10], [30, 40, 50]], + [[40, 50], [0, 10, 20]], + [[30, 50], [0, 10, 20]], + [[30, 40], [0, 10, 20]], + ] + triplets = [ + (0, (0, 0), (0, 1)), + (1, (1, 0), (1, 1)), + (2, (2, 0), (2, 1)), + (3, (3, 0), (3, 1)), + (4, (4, 0), (4, 1)), + (5, (5, 0), (5, 1)), + ] + + labels = torch.LongTensor([0, 0, 0, 1, 1, 1]) + + self.helper(embedding_angles, centroid_makers, labels, triplets, dtype) + + def test_sorting_invariance(self): + for dtype in TEST_DTYPES: + centroid_makers = [ + [[10, 20], [30, 40, 50]], + [[40, 50], [0, 10, 20]], + [[30, 50], [0, 10, 20]], + [[0, 20], [30, 40, 50]], + [[0, 10], [30, 40, 50]], + [[30, 40], [0, 10, 20]], + ] + + embedding_angles = [0, 30, 40, 10, 20, 50] + labels = torch.LongTensor([0, 1, 1, 0, 0, 1]) + + triplets = [ + (0, (0, 0), (0, 1)), + (1, (1, 0), (1, 1)), + (2, (2, 0), (2, 1)), + (3, (3, 0), (3, 1)), + (4, (4, 0), (4, 1)), + (5, (5, 0), (5, 1)), + ] + + self.helper(embedding_angles, centroid_makers, labels, triplets, dtype) + + def test_imbalanced(self): + for dtype in TEST_DTYPES: + per_class_angles = [ + [0, 10, 20], # class A + [30, 40, 50, 55], # class B + [60, 70, 80, 90], # class C + ] + embedding_angles = chain(*per_class_angles) + labels = torch.LongTensor([0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]) + + centroid_makers = [ + [[10, 20], [30, 40, 50, 55]], + [[10, 20], [60, 70, 80, 90]], + [[0, 20], [30, 40, 50, 55]], + [[0, 20], [60, 70, 80, 90]], + [[0, 10], [30, 40, 50, 55]], + [[0, 10], [60, 70, 80, 90]], + [[40, 50, 55], [0, 10, 20]], + [[40, 50, 55], [60, 70, 80, 90]], + [[30, 50, 55], [0, 10, 20]], + [[30, 50, 55], [60, 70, 80, 90]], + [[30, 40, 55], [0, 10, 20]], + [[30, 40, 55], [60, 70, 80, 90]], + [[30, 40, 50], [0, 10, 20]], + [[30, 40, 50], [60, 70, 80, 90]], + [[70, 80, 90], [0, 10, 20]], + [[70, 80, 90], [30, 40, 50, 55]], + [[60, 80, 90], [0, 10, 20]], + [[60, 80, 90], [30, 40, 50, 55]], + [[60, 70, 90], [0, 10, 20]], + [[60, 70, 90], [30, 40, 50, 55]], + [[60, 70, 80], [0, 10, 20]], + [[60, 70, 80], [30, 40, 50, 55]], + ] + + triplets = [ + (0, (0, 0), (0, 1)), + (0, (1, 0), (1, 1)), + (1, (2, 0), (2, 1)), + (1, (3, 0), (3, 1)), + (2, (4, 0), (4, 1)), + (2, (5, 0), (5, 1)), + (3, (6, 0), (6, 1)), + (3, (7, 0), (7, 1)), + (4, (8, 0), (8, 1)), + (4, (9, 0), (9, 1)), + (5, (10, 0), (10, 1)), + (5, (11, 0), (11, 1)), + (6, (12, 0), (12, 1)), + (6, (13, 0), (13, 1)), + (7, (14, 0), (14, 1)), + (7, (15, 0), (15, 1)), + (8, (16, 0), (16, 1)), + (8, (17, 0), (17, 1)), + (9, (18, 0), (18, 1)), + (9, (19, 0), (19, 1)), + (10, (20, 0), (20, 1)), + (10, (21, 0), (21, 1)), + ] + + self.helper(embedding_angles, centroid_makers, labels, triplets, dtype) + + def helper( + self, + embedding_angles, + centroid_makers, + labels, + triplets, + dtype, + ref_emb=None, + ref_labels=None, + ): + embeddings = torch.tensor( + [angle_to_coord(a) for a in embedding_angles], + requires_grad=True, + dtype=dtype, + ).to( + TEST_DEVICE + ) # 2D embeddings + + centroids = [ + [ + torch.stack( + [ + torch.tensor( + angle_to_coord(a), requires_grad=True, dtype=dtype + ).to(TEST_DEVICE) + for a in coords + ] + ).mean(-2) + for coords in one_maker + ] + for one_maker in centroid_makers + ] + + margin = 0.2 + loss_funcA = CentroidTripletLoss(margin=margin) + loss_funcB = CentroidTripletLoss(margin=margin, reducer=MeanReducer()) + loss_funcC = CentroidTripletLoss(margin=margin, distance=CosineSimilarity()) + loss_funcD = CentroidTripletLoss( + margin=margin, reducer=MeanReducer(), distance=CosineSimilarity() + ) + loss_funcE = CentroidTripletLoss(margin=margin, smooth_loss=True) + + [lossA, lossB, lossC, lossD, lossE] = [ + x(embeddings, labels, ref_emb=ref_emb, ref_labels=ref_labels) + for x in [loss_funcA, loss_funcB, loss_funcC, loss_funcD, loss_funcE] + ] + + correct_loss = 0 + correct_loss_cosine = 0 + correct_smooth_loss = 0 + num_non_zero_triplets = 0 + num_non_zero_triplets_cosine = 0 + + for a, pc, nc in triplets: + anchor = embeddings[a] + + positive = centroids[pc[0]][pc[1]] + negative = centroids[nc[0]][nc[1]] + + anchor = normalize(anchor) + positive = normalize(positive) + negative = normalize(negative) + + ap_dist = torch.sqrt(torch.sum((anchor - positive) ** 2)) + an_dist = torch.sqrt(torch.sum((anchor - negative) ** 2)) + curr_loss = torch.relu(ap_dist - an_dist + margin) + + curr_loss_cosine = torch.relu( + torch.sum(anchor * negative) - torch.sum(anchor * positive) + margin + ) + correct_smooth_loss += torch.nn.functional.softplus( + ap_dist - an_dist + margin + ) + if curr_loss > 0: + num_non_zero_triplets += 1 + if curr_loss_cosine > 0: + num_non_zero_triplets_cosine += 1 + correct_loss += curr_loss + correct_loss_cosine += curr_loss_cosine + + rtol = 1e-2 if dtype == torch.float16 else 1e-5 + self.assertTrue( + torch.isclose(lossA, correct_loss / num_non_zero_triplets, rtol=rtol) + ) + self.assertTrue(torch.isclose(lossB, correct_loss / len(triplets), rtol=rtol)) + self.assertTrue( + torch.isclose( + lossC, correct_loss_cosine / num_non_zero_triplets_cosine, rtol=rtol + ) + ) + self.assertTrue( + torch.isclose(lossD, correct_loss_cosine / len(triplets), rtol=rtol) + ) + self.assertTrue( + torch.isclose(lossE, correct_smooth_loss / len(triplets), rtol=rtol) + ) + + def test_backward(self): + margin = 0.2 + loss_funcA = CentroidTripletLoss(margin=margin) + loss_funcB = CentroidTripletLoss(margin=margin, reducer=MeanReducer()) + loss_funcC = CentroidTripletLoss(smooth_loss=True) + for dtype in TEST_DTYPES: + for loss_func in [loss_funcA, loss_funcB, loss_funcC]: + embedding_angles = [0, 20, 40, 60, 80, 85] + embeddings = torch.tensor( + [angle_to_coord(a) for a in embedding_angles], + requires_grad=True, + dtype=dtype, + ).to( + TEST_DEVICE + ) # 2D embeddings + labels = torch.LongTensor([0, 0, 1, 1, 2, 2]) + + loss = loss_func(embeddings, labels) + loss.backward() diff --git a/tests/losses/test_vicreg_loss.py b/tests/losses/test_vicreg_loss.py new file mode 100644 index 00000000..102ae7ae --- /dev/null +++ b/tests/losses/test_vicreg_loss.py @@ -0,0 +1,66 @@ +import unittest + +import torch +import torch.nn.functional as F + +from pytorch_metric_learning.losses import VICRegLoss + +from .. import TEST_DEVICE, TEST_DTYPES + +HYPERPARAMETERS = [[25, 25, 1, 1e-4], [10, 10, 2, 1e-5], [5, 5, 5, 1e-6]] + + +class TestVICRegLoss(unittest.TestCase): + def test_vicreg_loss(self): + torch.manual_seed(3459) + for dtype in TEST_DTYPES: + for hyp in HYPERPARAMETERS: + loss_func = VICRegLoss( + invariance_lambda=hyp[0], + variance_mu=hyp[1], + covariance_v=hyp[2], + eps=hyp[3], + ) + ref_emb_ = torch.randn( + 32, 64, device=TEST_DEVICE, dtype=dtype, requires_grad=True + ) + augmentation_noise = torch.normal( + 0, 0.1, size=(32, 64), device=TEST_DEVICE, dtype=dtype + ) + emb_ = ref_emb_ + augmentation_noise + + for emb, ref_emb in [(emb_, ref_emb_), (ref_emb_, emb_)]: + loss = loss_func(ref_emb, emb) + loss.backward() + + # invariance_loss + invariance_loss = F.mse_loss(emb, ref_emb) + + # variance_loss + std_emb = torch.sqrt(emb.var(dim=0) + hyp[3]) + std_ref_emb = torch.sqrt(ref_emb.var(dim=0) + hyp[3]) + variance_loss = torch.mean(F.relu(1 - std_emb)) + torch.mean( + F.relu(1 - std_ref_emb) + ) + + # covariance loss, a more manual version + N, D = emb.size() + emb = emb - emb.mean(dim=0) + ref_emb = ref_emb - ref_emb.mean(dim=0) + cov_emb = (emb.T @ emb) / (N - 1) + cov_ref_emb = (ref_emb.T @ ref_emb) / (N - 1) + diag = torch.eye(D, device=emb.device) + covariance_loss = ( + cov_emb[~diag.bool()].pow_(2).sum() / D + + cov_ref_emb[~diag.bool()].pow_(2).sum() / D + ) + + correct_loss = ( + hyp[0] * invariance_loss + + hyp[1] * variance_loss + + hyp[2] * covariance_loss + ) + + rtol = 1e-2 if dtype == torch.float16 else 1e-5 + + self.assertTrue(torch.isclose(loss, correct_loss, rtol=rtol)) diff --git a/tests/utils/test_calculate_accuracies.py b/tests/utils/test_calculate_accuracies.py index f55127fd..a4e4744b 100644 --- a/tests/utils/test_calculate_accuracies.py +++ b/tests/utils/test_calculate_accuracies.py @@ -11,18 +11,17 @@ from .. import TEST_DEVICE -def isclose(x, y): - rtol = 0 - if TEST_DEVICE == torch.device("cpu"): - atol = 1e-15 - else: - atol = 1e-7 +def isclose(x, y, many=False): + rtol = 1e-6 + atol = 0 + if many: + return np.allclose(x, y, atol=atol, rtol=rtol) return np.isclose(x, y, atol=atol, rtol=rtol) class TestCalculateAccuracies(unittest.TestCase): def test_accuracy_calculator(self): - query_labels = torch.tensor([1, 1, 2, 3, 4], device=TEST_DEVICE) + query_labels = torch.tensor([1, 1, 2, 3, 4, 0], device=TEST_DEVICE) knn_labels1 = torch.tensor( [ @@ -31,103 +30,153 @@ def test_accuracy_calculator(self): [4, 4, 4, 4, 2], [3, 1, 3, 1, 3], [0, 0, 4, 2, 2], + [1, 2, 3, 4, 5], ], device=TEST_DEVICE, ) - label_counts1 = ([1, 2, 3, 4], [3, 5, 4, 5]) + label_counts1 = ([0, 1, 2, 3, 4], [2, 3, 5, 4, 5]) knn_labels2 = knn_labels1 + 5 - label_counts2 = ([6, 7, 8, 9], [3, 5, 4, 5]) + label_counts2 = ([5, 6, 7, 8, 9], [2, 3, 5, 4, 5]) for avg_of_avgs in [False, True]: - for i, (knn_labels, label_counts) in enumerate( - [(knn_labels1, label_counts1), (knn_labels2, label_counts2)] - ): - - AC = accuracy_calculator.AccuracyCalculator( - exclude=("NMI", "AMI"), avg_of_avgs=avg_of_avgs, device=TEST_DEVICE - ) - kwargs = { - "query_labels": query_labels, - "label_counts": label_counts, - "knn_labels": knn_labels, - "not_lone_query_mask": torch.ones(5, dtype=torch.bool) - if i == 0 - else torch.zeros(5, dtype=torch.bool), - } - - function_dict = AC.get_function_dict() - - for ecfss in [False, True]: - if ecfss: - kwargs["knn_labels"] = kwargs["knn_labels"][:, 1:] - kwargs["embeddings_come_from_same_source"] = ecfss - acc = AC._get_accuracy(function_dict, **kwargs) - if i == 1: - self.assertTrue(acc["precision_at_1"] == 0) - self.assertTrue(acc["r_precision"] == 0) - self.assertTrue(acc["mean_average_precision_at_r"] == 0) - self.assertTrue(acc["mean_average_precision"] == 0) - else: - self.assertTrue( - isclose( - acc["precision_at_1"], - self.correct_precision_at_1(ecfss, avg_of_avgs), - ) + for return_per_class in [False, True]: + for i, (knn_labels, label_counts) in enumerate( + [(knn_labels1, label_counts1), (knn_labels2, label_counts2)] + ): + init_kwargs = { + "exclude": ("NMI", "AMI"), + "avg_of_avgs": avg_of_avgs, + "return_per_class": return_per_class, + "device": TEST_DEVICE, + } + + if avg_of_avgs and return_per_class: + self.assertRaises( + ValueError, + lambda: accuracy_calculator.AccuracyCalculator( + **init_kwargs + ), ) - self.assertTrue( - isclose( - acc["r_precision"], - self.correct_r_precision(ecfss, avg_of_avgs), + continue + + AC = accuracy_calculator.AccuracyCalculator(**init_kwargs) + kwargs = { + "query_labels": query_labels, + "label_counts": label_counts, + "knn_labels": knn_labels, + "not_lone_query_mask": torch.ones(6, dtype=torch.bool) + if i == 0 + else torch.zeros(6, dtype=torch.bool), + } + + function_dict = AC.get_function_dict() + + for ecfss in [False, True]: + if ecfss: + kwargs["knn_labels"] = kwargs["knn_labels"][:, 1:] + kwargs["embeddings_come_from_same_source"] = ecfss + acc = AC._get_accuracy(function_dict, **kwargs) + if i == 1: + zero_acc = 0 if not return_per_class else [0, 0, 0, 0, 0] + self.assertTrue(acc["precision_at_1"] == zero_acc) + self.assertTrue(acc["r_precision"] == zero_acc) + self.assertTrue( + acc["mean_average_precision_at_r"] == zero_acc ) - ) - self.assertTrue( - isclose( - acc["mean_average_precision_at_r"], - self.correct_mean_average_precision_at_r( - ecfss, avg_of_avgs - ), + self.assertTrue(acc["mean_average_precision"] == zero_acc) + self.assertTrue(acc["mean_reciprocal_rank"] == zero_acc) + else: + self.assertTrue( + isclose( + acc["precision_at_1"], + self.correct_precision_at_1( + ecfss, avg_of_avgs, return_per_class + ), + many=return_per_class, + ) ) - ) - self.assertTrue( - isclose( - acc["mean_average_precision"], - self.correct_mean_average_precision(ecfss, avg_of_avgs), + self.assertTrue( + isclose( + acc["r_precision"], + self.correct_r_precision( + ecfss, avg_of_avgs, return_per_class + ), + many=return_per_class, + ) + ) + self.assertTrue( + isclose( + acc["mean_average_precision_at_r"], + self.correct_mean_average_precision_at_r( + ecfss, avg_of_avgs, return_per_class + ), + many=return_per_class, + ) + ) + self.assertTrue( + isclose( + acc["mean_average_precision"], + self.correct_mean_average_precision( + ecfss, avg_of_avgs, return_per_class + ), + many=return_per_class, + ) + ) + self.assertTrue( + isclose( + acc["mean_reciprocal_rank"], + self.correct_mean_reciprocal_rank( + ecfss, avg_of_avgs, return_per_class + ), + many=return_per_class, + ) ) - ) - def correct_precision_at_1(self, embeddings_come_from_same_source, avg_of_avgs): + def correct_precision_at_1( + self, embeddings_come_from_same_source, avg_of_avgs, return_per_class + ): if not embeddings_come_from_same_source: - if not avg_of_avgs: - return 0.4 - else: - return (0.5 + 0 + 1 + 0) / 4 + accs = [0, 0.5, 0, 1, 0] + if not (avg_of_avgs or return_per_class): + return 2.0 / 6 else: - if not avg_of_avgs: - return 1.0 / 5 - else: - return (0.5 + 0 + 0 + 0) / 4 + accs = [0, 0.5, 0, 0, 0] + if not (avg_of_avgs or return_per_class): + return 1.0 / 6 + + if avg_of_avgs: + return np.mean(accs) + if return_per_class: + return accs - def correct_r_precision(self, embeddings_come_from_same_source, avg_of_avgs): + def correct_r_precision( + self, embeddings_come_from_same_source, avg_of_avgs, return_per_class + ): if not embeddings_come_from_same_source: acc0 = 2.0 / 3 acc1 = 2.0 / 3 acc2 = 1.0 / 5 acc3 = 2.0 / 4 acc4 = 1.0 / 5 + acc5 = 0 else: acc0 = 1.0 / 1 acc1 = 1.0 / 2 acc2 = 1.0 / 4 acc3 = 1.0 / 3 acc4 = 1.0 / 4 - if not avg_of_avgs: - return np.mean([acc0, acc1, acc2, acc3, acc4]) + acc5 = 0 + accs = [acc5, (acc0 + acc1) / 2, acc2, acc3, acc4] + if avg_of_avgs: + return np.mean(accs) + elif return_per_class: + return accs else: - return np.mean([(acc0 + acc1) / 2, acc2, acc3, acc4]) + return np.mean([acc0, acc1, acc2, acc3, acc4, acc5]) def correct_mean_average_precision_at_r( - self, embeddings_come_from_same_source, avg_of_avgs + self, embeddings_come_from_same_source, avg_of_avgs, return_per_class ): if not embeddings_come_from_same_source: acc0 = (1.0 / 2 + 2.0 / 3) / 3 @@ -135,19 +184,24 @@ def correct_mean_average_precision_at_r( acc2 = (1.0 / 5) / 5 acc3 = (1 + 2.0 / 3) / 4 acc4 = (1.0 / 3) / 5 + acc5 = 0 else: acc0 = 1 acc1 = (1.0 / 2) / 2 acc2 = (1.0 / 4) / 4 acc3 = (1.0 / 2) / 3 acc4 = (1.0 / 2) / 4 - if not avg_of_avgs: - return np.mean([acc0, acc1, acc2, acc3, acc4]) + acc5 = 0 + accs = [acc5, (acc0 + acc1) / 2, acc2, acc3, acc4] + if avg_of_avgs: + return np.mean(accs) + elif return_per_class: + return accs else: - return np.mean([(acc0 + acc1) / 2, acc2, acc3, acc4]) + return np.mean([acc0, acc1, acc2, acc3, acc4, acc5]) def correct_mean_average_precision( - self, embeddings_come_from_same_source, avg_of_avgs + self, embeddings_come_from_same_source, avg_of_avgs, return_per_class ): if not embeddings_come_from_same_source: acc0 = (1.0 / 2 + 2.0 / 3) / 2 @@ -155,16 +209,47 @@ def correct_mean_average_precision( acc2 = (1.0 / 5) / 1 acc3 = (1 + 2.0 / 3 + 3.0 / 5) / 3 acc4 = (1.0 / 3) / 1 + acc5 = 0 else: acc0 = 1 acc1 = (1.0 / 2 + 2.0 / 3) / 2 acc2 = 1.0 / 4 acc3 = (1.0 / 2 + 2.0 / 4) / 2 acc4 = 1.0 / 2 - if not avg_of_avgs: - return np.mean([acc0, acc1, acc2, acc3, acc4]) + acc5 = 0 + accs = [acc5, (acc0 + acc1) / 2, acc2, acc3, acc4] + if avg_of_avgs: + return np.mean(accs) + elif return_per_class: + return accs + else: + return np.mean([acc0, acc1, acc2, acc3, acc4, acc5]) + + def correct_mean_reciprocal_rank( + self, embeddings_come_from_same_source, avg_of_avgs, return_per_class + ): + if not embeddings_come_from_same_source: + acc0 = 1 / 2 + acc1 = 1 + acc2 = 1 / 5 + acc3 = 1 + acc4 = 1 / 3 + acc5 = 0 + else: + acc0 = 1 + acc1 = 1 / 2 + acc2 = 1 / 4 + acc3 = 1 / 2 + acc4 = 1 / 2 + acc5 = 0 + + accs = [acc5, (acc0 + acc1) / 2, acc2, acc3, acc4] + if avg_of_avgs: + return np.mean(accs) + elif return_per_class: + return accs else: - return np.mean([(acc0 + acc1) / 2, acc2, acc3, acc4]) + return np.mean([acc0, acc1, acc2, acc3, acc4, acc5]) def test_get_lone_query_labels_custom(self): def fn1(x, y):