v1.0.0
Reference embeddings for tuple losses
You can separate the source of anchors and positive/negatives. In the example below, anchors will be selected from embeddings
and positives/negatives will be selected from ref_emb
.
loss_fn = TripletMarginLoss()
loss = loss_fn(embeddings, labels, ref_emb=ref_emb, ref_labels=ref_labels)
Efficient mode for DistributedLossWrapper
efficient=True
: each process uses its own embeddings for anchors, and the gathered embeddings for positives/negatives. Gradients will not be equal to those in non-distributed code, but the benefit is reduced memory and faster training.efficient=False
: each process uses gathered embeddings for both anchors and positives/negatives. Gradients will be equal to those in non-distributed code, but at the cost of doing unnecessary operations (i.e. doing computations where both anchors and positives/negatives have no gradient).
The default is False
. You can set it to True
like this:
from pytorch_metric_learning import losses
from pytorch_metric_learning.utils import distributed as pml_dist
loss_func = losses.ContrastiveLoss()
loss_func = pml_dist.DistributedLossWrapper(loss_func, efficient=True)
Documentation: https://kevinmusgrave.github.io/pytorch-metric-learning/distributed/
Customizing k-nearest-neighbors for AccuracyCalculator
You can use a different type of faiss index:
import faiss
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator
from pytorch_metric_learning.utils.inference import FaissKNN
knn_func = FaissKNN(index_init_fn=faiss.IndexFlatIP, gpus=[0,1,2])
ac = AccuracyCalculator(knn_func=knn_func)
You can also use a custom distance function:
from pytorch_metric_learning.distances import SNRDistance
from pytorch_metric_learning.utils.inference import CustomKNN
knn_func = CustomKNN(SNRDistance())
ac = AccuracyCalculator(knn_func=knn_func)
Relevant docs:
Issues resolved
#204
#251
#256
#292
#330
#337
#345
#347
#349
#353
#359
#361
#362
#363
#368
#376
#380
Contributors
Thanks to @yutanakamura-tky and @KinglittleQ for pull requests, and @mensaochun for providing helpful code in #380