forked from sicara/easy-few-shot-learning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
prototypical_networks.py
81 lines (65 loc) · 2.72 KB
/
prototypical_networks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""
See original implementation (quite far from this one)
at https://github.com/jakesnell/prototypical-networks
"""
import torch
from easyfsl.methods import AbstractMetaLearner
from easyfsl.utils import compute_prototypes
class PrototypicalNetworks(AbstractMetaLearner):
"""
Jake Snell, Kevin Swersky, and Richard S. Zemel.
"Prototypical networks for few-shot learning." (2017)
https://arxiv.org/abs/1703.05175
Prototypical networks extract feature vectors for both support and query images. Then it
computes the mean of support features for each class (called prototypes), and predict
classification scores for query images based on their euclidean distance to the prototypes.
"""
def __init__(self, *args):
"""
Build Prototypical Networks by calling the constructor of AbstractMetaLearner.
Raises:
ValueError: if the backbone is not a feature extractor,
i.e. if its output for a given image is not a 1-dim tensor.
"""
super().__init__(*args)
if len(self.backbone_output_shape) != 1:
raise ValueError(
"Illegal backbone for Prototypical Networks. "
"Expected output for an image is a 1-dim tensor."
)
# Here we create the field so that the model can store the prototypes for a support set
self.prototypes = None
def process_support_set(
self,
support_images: torch.Tensor,
support_labels: torch.Tensor,
):
"""
Overrides process_support_set of AbstractMetaLearner.
Extract feature vectors from the support set and store class prototypes.
Args:
support_images: images of the support set
support_labels: labels of support set images
"""
support_features = self.backbone.forward(support_images)
self.prototypes = compute_prototypes(support_features, support_labels)
def forward(
self,
query_images: torch.Tensor,
) -> torch.Tensor:
"""
Overrides forward method of AbstractMetaLearner.
Predict query labels based on their distance to class prototypes in the feature space.
Classification scores are the negative of euclidean distances.
Args:
query_images: images of the query set
Returns:
a prediction of classification scores for query images
"""
# Extract the features of support and query images
z_query = self.backbone.forward(query_images)
# Compute the euclidean distance from queries to prototypes
dists = torch.cdist(z_query, self.prototypes)
# Use it to compute classification scores
scores = -dists
return scores