Skip to content

Commit

Permalink
Merge pull request #408 from KevinMusgrave/dev
Browse files Browse the repository at this point in the history
v1.1
  • Loading branch information
Kevin Musgrave authored Dec 28, 2021
2 parents b895277 + 37b177e commit 429a309
Show file tree
Hide file tree
Showing 18 changed files with 953 additions and 99 deletions.
12 changes: 12 additions & 0 deletions docs/accuracy_calculation.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator
AccuracyCalculator(include=(),
exclude=(),
avg_of_avgs=False,
return_per_class=False,
k=None,
label_comparison_fn=None,
device=None,
Expand All @@ -18,6 +19,7 @@ AccuracyCalculator(include=(),
* **include**: Optional. A list or tuple of strings, which are the names of metrics you want to calculate. If left empty, all default metrics will be calculated.
* **exclude**: Optional. A list or tuple of strings, which are the names of metrics you **do not** want to calculate.
* **avg_of_avgs**: If True, the average accuracy per class is computed, and then the average of those averages is returned. This can be useful if your dataset has unbalanced classes. If False, the global average will be returned.
* **return_per_class**: If True, the average accuracy per class is computed and returned.
* **k**: The number of nearest neighbors that will be retrieved for metrics that require k-nearest neighbors. The allowed values are:
* ```None```. This means k will be set to the total number of reference embeddings.
* An integer greater than 0. This means k will be set to the input integer.
Expand Down Expand Up @@ -72,6 +74,12 @@ def get_accuracy(self,
Note that labels can be 2D if a [custom label comparison function](#using-a-custom-label-comparison-function) is used.


### Lone query labels
If some query labels don't appear in the reference set, then it's impossible for those labels to have non-zero k-nn accuracy. Zero accuracy for these labels doesn't indicate anything about the quality of the embedding space. So these lone query labels are excluded from k-nn based accuracy calculations.

For example, if the input ```query_labels``` is ```[0,0,1,1]``` and ```reference_labels``` is ```[1,1,1,2,2]```, then 0 is considered a lone query label.


### CPU/GPU usage

* If you installed ```faiss-cpu``` then the CPU will always be used.
Expand Down Expand Up @@ -100,6 +108,10 @@ If your dataset is large, you might find the k-nn search is very slow. This is b

- [See section 3.2 of A Metric Learning Reality Check](https://arxiv.org/pdf/2003.08505.pdf)

- **mean_reciprocal_rank**:

- [Slides from Stanford](https://web.stanford.edu/class/cs276/handouts/EvaluationNew-handout-1-per.pdf)

- **precision_at_1**:

- Fancy way of saying "is the 1st nearest neighbor correct?"
Expand Down
4 changes: 2 additions & 2 deletions docs/extend/losses.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ from pytorch_metric_learning.losses import BaseMetricLossFunction
import torch

class BarebonesLoss(BaseMetricLossFunction):
def compute_loss(self, embeddings, labels, indices_tuple):
def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
# perform some calculation #
some_loss = torch.mean(embeddings)

Expand All @@ -33,7 +33,7 @@ from pytorch_metric_learning.utils import loss_and_miner_utils as lmu
import torch

class FullFeaturedLoss(BaseMetricLossFunction):
def compute_loss(self, embeddings, labels, indices_tuple):
def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
indices_tuple = lmu.convert_to_triplets(indices_tuple, labels)
anchors, positives, negatives = indices_tuple
if len(anchors) == 0:
Expand Down
Binary file added docs/imgs/vicreg_covariance.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/imgs/vicreg_invariance.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/imgs/vicreg_total.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/imgs/vicreg_variance.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/imgs/vicreg_variance_detail.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
83 changes: 82 additions & 1 deletion docs/losses.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,31 @@ def compute_loss(self, embeddings, labels, indices_tuple=None):
```


## CentroidTripletLoss
[On the Unreasonable Effectiveness of Centroids in Image Retrieval](https://arxiv.org/pdf/2104.13643.pdf){target=_blank}

This is like [TripletMarginLoss](losses.md#tripletmarginloss), except the positives and negatives are class centroids.

```python
losses.CentroidTripletLoss(margin=0.05,
swap=False,
smooth_loss=False,
triplets_per_anchor="all",
**kwargs)
```
**Parameters**:

See [TripletMarginLoss](losses.md#tripletmarginloss)

**Default distance**:

See [TripletMarginLoss](losses.md#tripletmarginloss)

**Default reducer**:

- [AvgNonZeroReducer](reducers.md#avgnonzeroreducer)


## CircleLoss
[Circle Loss: A Unified Perspective of Pair Similarity Optimization](https://arxiv.org/pdf/2002.10857.pdf){target=_blank}

Expand Down Expand Up @@ -1024,4 +1049,60 @@ Extended by:
* [ProxyAnchorLoss](losses.md#proxyanchorloss)
* [ProxyNCALoss](losses.md#proxyncaloss)
* [SoftTripleLoss](losses.md#softtripleloss)
* [SphereFaceLoss](losses.md#spherefaceloss)
* [SphereFaceLoss](losses.md#spherefaceloss)


## VICRegLoss
[VICReg: Variance-Invariance-Covariance Regularization for Self-Supervised Learning](https://arxiv.org/pdf/2105.04906.pdf){target=_blank}
```python
losses.VICRegLoss(invariance_lambda=25,
variance_mu=25,
covariance_v=1,
eps=1e-4,
**kwargs)
```

**Usage**:

Unlike other loss functions, ```VICRegLoss``` does not accept ```labels``` or ```indices_tuple```:

```python
loss_fn = VICRegLoss()
loss = loss_fn(embeddings, ref_emb)
```

**Equations**:

![vicreg_total](imgs/vicreg_total.png){: style="height:40px"}

where

![vicreg_total](imgs/vicreg_invariance.png){: style="height:70px"}

![vicreg_total](imgs/vicreg_variance.png){: style="height:90px"}

![vicreg_total](imgs/vicreg_variance_detail.png){: style="height:40px"}

![vicreg_total](imgs/vicreg_covariance.png){: style="height:70px"}

**Parameters**:

* **invariance_lambda**: The weight of the invariance term.
* **variance_mu**: The weight of the variance term.
* **covariance_v**: The weight of the covariance term.
* **eps**: Small scalar to prevent numerical instability.

**Default distance**:

- Not applicable. You cannot pass in a distance function.

**Default reducer**:

- [MeanReducer](reducers.md#meanreducer)

**Reducer input**:

* **invariance_loss**: The MSE loss between ```embeddings[i]``` and ```ref_emb[i]```. Reduction type is ```"element"```.
* **variance_loss1**: The variance loss for ```embeddings```. Reduction type is ```"element"```.
* **variance_loss2**: The variance loss for ```ref_emb```. Reduction type is ```"element"```.
* **covariance_loss**: The covariance loss. This loss is already reduced to a single value.
2 changes: 1 addition & 1 deletion src/pytorch_metric_learning/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.0.0"
__version__ = "1.1.0"
2 changes: 2 additions & 0 deletions src/pytorch_metric_learning/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .angular_loss import AngularLoss
from .arcface_loss import ArcFaceLoss
from .base_metric_loss_function import BaseMetricLossFunction, MultipleLosses
from .centroid_triplet_loss import CentroidTripletLoss
from .circle_loss import CircleLoss
from .contrastive_loss import ContrastiveLoss
from .cosface_loss import CosFaceLoss
Expand All @@ -25,3 +26,4 @@
from .supcon_loss import SupConLoss
from .triplet_margin_loss import TripletMarginLoss
from .tuplet_margin_loss import TupletMarginLoss
from .vicreg_loss import VICRegLoss
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
class BaseMetricLossFunction(
EmbeddingRegularizerMixin, ModuleWithRecordsReducerAndDistance
):
def compute_loss(self, embeddings, labels, indices_tuple=None):
def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
"""
This has to be implemented and is what actually computes the loss.
"""
Expand Down
180 changes: 180 additions & 0 deletions src/pytorch_metric_learning/losses/centroid_triplet_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
from collections import defaultdict

import numpy as np
import torch

from ..reducers import AvgNonZeroReducer
from ..utils import common_functions as c_f
from ..utils import loss_and_miner_utils as lmu
from .base_metric_loss_function import BaseMetricLossFunction
from .triplet_margin_loss import TripletMarginLoss


def concat_indices_tuple(x):
return [torch.cat(y) for y in zip(*x)]


class CentroidTripletLoss(BaseMetricLossFunction):
def __init__(
self,
margin=0.05,
swap=False,
smooth_loss=False,
triplets_per_anchor="all",
**kwargs
):
super().__init__(**kwargs)
self.triplet_loss = TripletMarginLoss(
margin=margin,
swap=swap,
smooth_loss=smooth_loss,
triplets_per_anchor=triplets_per_anchor,
**kwargs
)

def compute_loss(
self, embeddings, labels, indices_tuple=None, ref_emb=None, ref_labels=None
):
c_f.indices_tuple_not_supported(indices_tuple)
c_f.ref_not_supported(embeddings, labels, ref_emb, ref_labels)
"""
"During training stage each mini-batch contains 𝑃 distinct item
classes with 𝑀 samples per class, resulting in batch size of 𝑃 × 𝑀."
"""
masks, class_masks, labels_list, query_indices = self.create_masks_train(labels)

P = len(labels_list)
M = max([len(instances) for instances in labels_list])
DIM = embeddings.size(-1)

"""
"...each sample from S𝑘 is used as a query 𝑞𝑘 and the rest
𝑀 −1 samples are used to build a prototype centroid"
i.e. for each class k of M items, we make M pairs of (query, centroid),
making a total of P*M total pairs.
masks = (M*P x len(embeddings)) matrix
labels_list[i] = indicies of embeddings belonging to ith class
centroids_emd.shape == (M*P, DIM)
i.e. centroids_emb[0] == centroid vector for 0th class, where the first embedding is the query vector
centroids_emb[1] == centroid vector for 0th class, where the second embedding is the query vector
centroids_emb[M+1] == centroid vector for 1th class, where the first embedding is the query vector
"""

masks_float = masks.type(embeddings.type()).to(embeddings.device)
class_masks_float = class_masks.type(embeddings.type()).to(embeddings.device)
inst_counts = masks_float.sum(-1)
class_inst_counts = class_masks_float.sum(-1)

valid_mask = inst_counts > 0
padded = masks_float.unsqueeze(-1) * embeddings.unsqueeze(0)
class_padded = class_masks_float.unsqueeze(-1) * embeddings.unsqueeze(0)

positive_centroids_emb = padded.sum(-2) / inst_counts.masked_fill(
inst_counts == 0, 1
).unsqueeze(-1)

negative_centroids_emb = class_padded.sum(-2) / class_inst_counts.masked_fill(
class_inst_counts == 0, 1
).unsqueeze(-1)

query_indices = torch.tensor(query_indices).to(embeddings.device)
query_embeddings = embeddings.index_select(0, query_indices)
query_labels = labels.index_select(0, query_indices)
assert positive_centroids_emb.size() == (M * P, DIM)
assert negative_centroids_emb.size() == (P, DIM)
assert query_embeddings.size() == (M * P, DIM)

query_indices = query_indices.view((P, M)).transpose(0, 1)
query_embeddings = query_embeddings.view((P, M, -1)).transpose(0, 1)
query_labels = query_labels.view((P, M)).transpose(0, 1)
positive_centroids_emb = positive_centroids_emb.view((P, M, -1)).transpose(0, 1)
valid_mask = valid_mask.view((P, M)).transpose(0, 1)

labels_collect = []
embeddings_collect = []
tuple_indices_collect = []
starting_idx = 0
for inst_idx in range(M):
one_mask = valid_mask[inst_idx]
if torch.sum(one_mask) > 1:
anchors = query_embeddings[inst_idx][one_mask]
pos_centroids = positive_centroids_emb[inst_idx][one_mask]
one_labels = query_labels[inst_idx][one_mask]

embeddings_concat = torch.cat(
(anchors, pos_centroids, negative_centroids_emb)
)
labels_concat = torch.cat(
(one_labels, one_labels, query_labels[inst_idx])
)
indices_tuple = lmu.get_all_triplets_indices(labels_concat)

"""
Right now indices tuple considers all embeddings in
embeddings_concat as anchors, pos_example, neg_examples.
1. make only query vectors be anchor vectors
2. make pos_centroids be only used as a positive example
3. negative as so
"""
# make only query vectors be anchor vectors
indices_tuple = [x[: len(x) // 3] + starting_idx for x in indices_tuple]

# make only pos_centroids be postive examples
indices_tuple = [x.view(len(one_labels), -1) for x in indices_tuple]
indices_tuple = [x.chunk(2, dim=1)[0] for x in indices_tuple]

# make only neg_centroids be negative examples
indices_tuple = [
x.chunk(len(one_labels), dim=1)[-1].flatten() for x in indices_tuple
]

tuple_indices_collect.append(indices_tuple)
embeddings_collect.append(embeddings_concat)
labels_collect.append(labels_concat)
starting_idx += len(labels_concat)

indices_tuple = concat_indices_tuple(tuple_indices_collect)

if len(indices_tuple) == 0:
return self.zero_losses()

final_embeddings = torch.cat(embeddings_collect)
final_labels = torch.cat(labels_collect)

loss = self.triplet_loss.compute_loss(
final_embeddings, final_labels, indices_tuple, ref_emb=None, ref_labels=None
)
return loss

def create_masks_train(self, class_labels):
labels_dict = defaultdict(list)
class_labels = class_labels.detach().cpu().numpy()
for idx, pid in enumerate(class_labels):
labels_dict[pid].append(idx)

unique_classes = list(labels_dict.keys())
labels_list = list(labels_dict.values())
lens_list = [len(item) for item in labels_list]
lens_list_cs = np.cumsum(lens_list)

M = max(len(instances) for instances in labels_list)
P = len(unique_classes)

query_indices = []
class_masks = torch.zeros((P, len(class_labels)), dtype=bool)
masks = torch.zeros((M * P, len(class_labels)), dtype=bool)
for class_idx, class_insts in enumerate(labels_list):
class_masks[class_idx, class_insts] = 1
for instance_idx in range(M):
matrix_idx = class_idx * M + instance_idx
if instance_idx < len(class_insts):
query_indices.append(class_insts[instance_idx])
ones = class_insts[:instance_idx] + class_insts[instance_idx + 1 :]
masks[matrix_idx, ones] = 1
else:
query_indices.append(0)
return masks, class_masks, labels_list, query_indices

def get_default_reducer(self):
return AvgNonZeroReducer()
Loading

0 comments on commit 429a309

Please sign in to comment.