-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpvtv2.py
163 lines (136 loc) · 6.91 KB
/
pvtv2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import DropPath
class DWConv(nn.Module):
def __init__(self, dim):
super(DWConv, self).__init__()
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim)
def forward(self, x, H, W):
B,N,C = x.shape
x = x.transpose(1, 2).view(B, C, H, W)
x = self.dwconv(x)
x = x.flatten(2).transpose(1, 2)
return x
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features):
super().__init__()
self.fc1 = nn.Linear(in_features, hidden_features)
self.dwconv = DWConv(hidden_features)
self.fc2 = nn.Linear(hidden_features, in_features)
def forward(self, x, H, W):
x = self.fc1(x)
x = F.gelu(self.dwconv(x, H, W))
x = self.fc2(x)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads, sr_ratio):
super().__init__()
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
self.num_heads = num_heads
self.scale = (dim//num_heads)**(-0.5)
self.q = nn.Linear(dim, dim)
self.kv = nn.Linear(dim, dim*2)
self.proj = nn.Linear(dim, dim)
self.sr_ratio = sr_ratio
if sr_ratio > 1:
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
self.norm = nn.LayerNorm(dim)
def forward(self, x, H, W):
B, N, C = x.shape
q = self.q(x).reshape(B, N, self.num_heads, C//self.num_heads).permute(0, 2, 1, 3)
if self.sr_ratio > 1:
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
x_ = self.norm(x_)
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
else:
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio, drop_path, sr_ratio):
super().__init__()
self.norm1 = nn.LayerNorm(dim, eps=1e-6)
self.attn = Attention(dim, num_heads=num_heads, sr_ratio=sr_ratio)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = nn.LayerNorm(dim, eps=1e-6)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim*mlp_ratio))
def forward(self, x, H, W):
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
return x
class OverlapPatchEmbed(nn.Module):
def __init__(self, patch_size, stride, in_chans, embed_dim):
super().__init__()
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=(patch_size//2, patch_size//2))
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x):
x = self.proj(x)
B,C,H,W = x.shape
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
return x, H, W
class PVT(nn.Module):
def __init__(self, embed_dims, mlp_ratios, depths, snapshot, sr_ratios=[8, 4, 2, 1]):
super().__init__()
self.depths = depths
self.snapshot = snapshot
# patch_embed
self.patch_embed1 = OverlapPatchEmbed(patch_size=7, stride=4, in_chans=3, embed_dim=embed_dims[0])
self.patch_embed2 = OverlapPatchEmbed(patch_size=3, stride=2, in_chans=embed_dims[0], embed_dim=embed_dims[1])
self.patch_embed3 = OverlapPatchEmbed(patch_size=3, stride=2, in_chans=embed_dims[1], embed_dim=embed_dims[2])
self.patch_embed4 = OverlapPatchEmbed(patch_size=3, stride=2, in_chans=embed_dims[2], embed_dim=embed_dims[3])
# transformer encoder
dpr = [x.item() for x in torch.linspace(0, 0.1, sum(depths))] # stochastic depth decay rule
cur = 0
self.block1 = nn.ModuleList([Block(dim=embed_dims[0], num_heads=1, mlp_ratio=mlp_ratios[0], drop_path=dpr[cur + i], sr_ratio=sr_ratios[0]) for i in range(depths[0])])
self.norm1 = nn.LayerNorm(embed_dims[0], eps=1e-6)
cur += depths[0]
self.block2 = nn.ModuleList([Block(dim=embed_dims[1], num_heads=2, mlp_ratio=mlp_ratios[1], drop_path=dpr[cur + i], sr_ratio=sr_ratios[1]) for i in range(depths[1])])
self.norm2 = nn.LayerNorm(embed_dims[1], eps=1e-6)
cur += depths[1]
self.block3 = nn.ModuleList([Block(dim=embed_dims[2], num_heads=5, mlp_ratio=mlp_ratios[2], drop_path=dpr[cur + i], sr_ratio=sr_ratios[2]) for i in range(depths[2])])
self.norm3 = nn.LayerNorm(embed_dims[2], eps=1e-6)
cur += depths[2]
self.block4 = nn.ModuleList([Block(dim=embed_dims[3], num_heads=8, mlp_ratio=mlp_ratios[3], drop_path=dpr[cur + i], sr_ratio=sr_ratios[3]) for i in range(depths[3])])
self.norm4 = nn.LayerNorm(embed_dims[3], eps=1e-6)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better
def forward(self, x):
B = x.shape[0]
# stage 1
out1, H, W = self.patch_embed1(x)
for i, blk in enumerate(self.block1):
out1 = blk(out1, H, W)
out1 = self.norm1(out1).reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
# stage 2
out2, H, W = self.patch_embed2(out1)
for i, blk in enumerate(self.block2):
out2 = blk(out2, H, W)
out2 = self.norm2(out2).reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
# stage 3
out3, H, W = self.patch_embed3(out2)
for i, blk in enumerate(self.block3):
out3 = blk(out3, H, W)
out3 = self.norm3(out3).reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
# stage 4
out4, H, W = self.patch_embed4(out3)
for i, blk in enumerate(self.block4):
out4 = blk(out4, H, W)
out4 = self.norm4(out4).reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
return out1, out2, out3, out4
def initialize(self):
self.load_state_dict(torch.load(self.snapshot), strict=False)
def pvt_v2_b2():
return PVT(embed_dims=[64, 128, 320, 512], mlp_ratios=[8, 8, 4, 4], depths=[3, 4, 6, 3], snapshot='./pretrained/pvt_v2_b2.pth')
if __name__ == "__main__":
model = pvt_v2_b2()
input = torch.ones((16,3,512,512))
out1, out2, out3, out4 = model(input)
print(out1.shape, out2.shape, out3.shape, out4.shape)