-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsampler.py
113 lines (95 loc) · 4.1 KB
/
sampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""
Copyright [2023] [Poutaraud]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Adapted from
https://github.com/sicara/easy-few-shot-learning/tree/master
"""
import random
from typing import List, Tuple, Iterator
import torch
from torch import Tensor
from torch.utils.data import Sampler
class TaskSampler(Sampler):
"""
Samples batches in the shape of few-shot image classification tasks. At each iteration, it will sample
n_way classes, and then sample support and query images from these classes.
"""
def __init__(
self,
dataset,
n_way: int,
n_shot: int,
n_query: int,
n_tasks: int,
):
"""
Args:
dataset: dataset from which to sample classification tasks. Must have a field 'label': a
list of length len(dataset) containing containing the labels of all images.
n_way: number of classes in one task
n_shot: number of support images for each class in one task
n_query: number of query images for each class in one task
n_tasks: number of tasks to sample
"""
super().__init__(data_source=None)
self.n_way = n_way
self.n_shot = n_shot
self.n_query = n_query
self.n_tasks = n_tasks
self.items_per_label = {}
for item, label in enumerate(dataset.__getlabel__()):
if label in self.items_per_label.keys():
self.items_per_label[label].append(item)
else:
self.items_per_label[label] = [item]
def __len__(self) -> int:
return self.n_tasks
def __iter__(self) -> Iterator[List[int]]:
for _ in range(self.n_tasks):
yield torch.cat(
[
# pylint: disable=not-callable
torch.tensor(
random.sample(
self.items_per_label[label], self.n_shot + self.n_query
)
)
# pylint: enable=not-callable
for label in random.sample(self.items_per_label.keys(), self.n_way)
]
).tolist()
def episode(
self, input_data: List[Tuple[Tensor, int]]
) -> Tuple[Tensor, Tensor, Tensor, Tensor, List[int]]:
"""
Collate function to be used as argument for the collate_fn parameter of episodic data loaders.
Args:
input_data: each element is a tuple containing:
- an image as a torch Tensor
- the label of this image
Returns:
tuple(Tensor, Tensor, Tensor, Tensor, list[int]): respectively:
- support images,
- their labels,
- query images,
- their labels,
- the dataset class ids of the class sampled in the episode
"""
true_class_ids = list({x[1] for x in input_data})
all_images = torch.cat([x[0].unsqueeze(0) for x in input_data])
all_images = all_images.reshape((self.n_way, self.n_shot + self.n_query, *all_images.shape[1:]))
support_images = all_images[:, :self.n_shot].reshape((-1, *all_images.shape[2:]))
query_images = all_images[:, self.n_shot:].reshape((-1, *all_images.shape[2:]))
all_labels = torch.tensor([true_class_ids.index(x[1]) for x in input_data])
all_labels = all_labels.reshape((self.n_way, self.n_shot + self.n_query))
support_labels = all_labels[:, :self.n_shot].flatten()
query_labels = all_labels[:, self.n_shot:].flatten()
return (support_images, support_labels, query_images, query_labels, true_class_ids)