-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmodels.py
59 lines (47 loc) · 1.79 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
from layers import ManifoldMixup
class GenericMixupModel(nn.Module):
"""
Generic model consists of an embedding layer, an encoder, a pooler and a classification head
"""
def __init__(self, embed: nn.Module, encoder: nn.Module,
pooler: nn.Module, cls_head: nn.Module, use_mixup: bool = True):
super().__init__()
self.embed = embed
self.encoder = encoder
self.pooler = pooler
self.cls_head = cls_head
if use_mixup:
self.mixup = nn.Sequential(
ManifoldMixup(), nn.ReLU()
)
def forward(self, x1, x1_lens=None, x2=None, x2_lens=None, mixup_factor: float=1.):
"""
- x2: example to mixup with x1
- mixup_factor: 1 no mixup
"""
x1_embed = self.embed(x1)
x1_encoded, _ = self.encoder(x1_embed, x1_lens)
if x2 is not None:
x2_embed = self.embed(x2)
x2_encoded, _ = self.encoder(x2_embed, x2_lens)
x_encoded = self.mixup(x1_encoded, x2_encoded, mixup_factor)
else:
x_encoded = x1_encoded
x_pooled = self.pooler(x_encoded)
logits = self.cls_head(x_pooled)
return logits
class EMAModel:
def __init__(self, model):
self.original = model
self.model = copy.deepcopy(model)
def __call__(self, **kwargs):
return self.model(**kwargs)
def update_parameters(self, alpha, global_step):
alpha = min(1 - 1/(global_step+1), alpha)
for ema_p, p in zip(self.model.parameters(), self.original.parameters()):
# ema * alpha + (1 - alpha) * p
ema_p.data.mul_(alpha).add_(1 - alpha, p.data)