-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
125 lines (111 loc) · 4.71 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import io
import torch
import torch.nn as nn
import numpy as np
import chess.pgn
from chess import Board
import torch.nn.functional as F
PIECE_CHARS = "♔♕♖♗♘♙⭘♟♞♝♜♛♚"
def encode_board(board: Board) -> np.array:
step = 1 - 2 * board.turn
unicode = board.unicode().replace(' ','').replace('\n','')[::step]
return np.array([PIECE_CHARS[::step].index(c) for c in unicode], dtype=int).reshape(8,8)
class SEBlock(nn.Module):
"""Squeeze-and-Excitation Block"""
def __init__(self, channels, reduction=16):
super().__init__()
self.squeeze = nn.AdaptiveAvgPool2d(1)
self.excitation = nn.Sequential(
nn.Linear(channels, channels // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channels // reduction, channels, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.squeeze(x).view(b, c)
y = self.excitation(y).view(b, c, 1, 1)
return x * y.expand_as(x)
class SelfAttention(nn.Module):
"""Self-Attention Module"""
def __init__(self, channels):
super().__init__()
self.query = nn.Conv2d(channels, channels//8, 1)
self.key = nn.Conv2d(channels, channels//8, 1)
self.value = nn.Conv2d(channels, channels, 1)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
batch_size, C, H, W = x.size()
proj_query = self.query(x).view(batch_size, -1, H*W).permute(0, 2, 1)
proj_key = self.key(x).view(batch_size, -1, H*W)
energy = torch.bmm(proj_query, proj_key)
attention = F.softmax(energy, dim=-1)
proj_value = self.value(x).view(batch_size, -1, H*W)
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
out = out.view(batch_size, C, H, W)
return self.gamma * out + x
class Residual(nn.Module):
def __init__(self, outer_channels, inner_channels, use_1x1conv, dropout):
super().__init__()
self.conv1 = nn.Conv2d(outer_channels, inner_channels, kernel_size=3, padding=1, stride=1)
self.conv2 = nn.Conv2d(inner_channels, outer_channels, kernel_size=3, padding=1, stride=1)
self.se = SEBlock(outer_channels)
self.attention = SelfAttention(outer_channels)
if use_1x1conv:
self.conv3 = nn.Conv2d(outer_channels, outer_channels, kernel_size=1, stride=1)
else:
self.conv3 = None
self.bn1 = nn.BatchNorm2d(inner_channels)
self.bn2 = nn.BatchNorm2d(outer_channels)
self.dropout = nn.Dropout(p=dropout)
def forward(self, X):
identity = X
Y = F.relu(self.bn1(self.conv1(X)))
Y = self.dropout(self.bn2(self.conv2(Y)))
Y = self.se(Y)
Y = self.attention(Y)
if self.conv3:
identity = self.conv3(identity)
Y += identity
return F.relu(Y)
class Model(nn.Module):
def __init__(self, nlayers, embed_dim, inner_dim, use_1x1conv, dropout, device='cpu'):
super().__init__()
self.vocab = PIECE_CHARS
self.embed_dim = embed_dim
self.inner_dim = inner_dim
self.use_1x1conv = use_1x1conv
self.dropout = dropout
self.embedder = nn.Embedding(len(self.vocab), self.embed_dim)
self.convnet = nn.Sequential(*[Residual(self.embed_dim, self.inner_dim, self.use_1x1conv, self.dropout) for _ in range(nlayers)])
self.accumulator = nn.Conv2d(self.embed_dim, self.embed_dim, kernel_size=8, padding=0, stride=1)
self.decoder = nn.Sequential(
nn.Linear(self.embed_dim, self.embed_dim // 2),
nn.ReLU(),
nn.Linear(self.embed_dim // 2, 1)
)
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, (nn.Linear, nn.Conv2d)):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, inputs):
inputs = self.embedder(inputs)
inputs = torch.permute(inputs, (0, 3, 1, 2)).contiguous()
inputs = self.convnet(inputs)
inputs = F.relu(self.accumulator(inputs).squeeze())
scores = self.decoder(inputs).flatten()
return scores
def score(self, pgn, move):
game = chess.pgn.read_game(io.StringIO(pgn))
board = Board()
for past_move in list(game.mainline_moves()):
board.push(past_move)
board.push_san(move)
board_tensor = torch.tensor(encode_board(board)).unsqueeze(0)
return self.forward(board_tensor).item()