-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdemo.py
93 lines (77 loc) · 3.49 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from core import *
from torch_backend import *
def Computational(model):
import torch
from ptflops import get_model_complexity_info
with torch.cuda.device(0):
# choice is added
flops, params = get_model_complexity_info(model, (3, 32, 32), as_strings=True, print_per_layer_stat=True)
print('{:<30} {:<8}'.format('Computational complexity: ', flops))
print('{:<30} {:<8}'.format('Number of parameters: ', params))
return
def conv_bn(c_in, c_out):
return {
'conv': nn.Conv2d(c_in, c_out, kernel_size=3, stride=1, padding=1, bias=False),
'bn': BatchNorm(c_out),
'relu': nn.ReLU(True)
}
def residual(c):
return {
'in': Identity(),
'res1': conv_bn(c, c),
'res2': conv_bn(c, c),
'add': (Add(), ['in', 'res2/relu']),
}
def net(channels=None, weight=0.125, pool=nn.MaxPool2d(2), extra_layers=(), res_layers=('layer1', 'layer3')):
channels = channels or {'prep': 64, 'layer1': 128, 'layer2': 256, 'layer3': 512}
n = {
'input': (None, []),
'prep': conv_bn(3, channels['prep']),
'layer1': dict(conv_bn(channels['prep'], channels['layer1']), pool=pool),
'layer2': dict(conv_bn(channels['layer1'], channels['layer2']), pool=pool),
'layer3': dict(conv_bn(channels['layer2'], channels['layer3']), pool=pool),
'pool': nn.MaxPool2d(4),
'flatten': Flatten(),
'linear': nn.Linear(channels['layer3'], 10, bias=False),
'logits': Mul(weight),
}
for layer in res_layers:
n[layer]['residual'] = residual(channels[layer])
for layer in extra_layers:
n[layer]['extra'] = conv_bn(channels[layer], channels[layer])
return n
DATA_DIR = './data'
dataset = cifar10(root=DATA_DIR)
timer = Timer()
print('Preprocessing training data')
transforms = [
partial(normalise, mean=np.array(cifar10_mean, dtype=np.float32), std=np.array(cifar10_std, dtype=np.float32)),
partial(transpose, source='NHWC', target='NCHW'),
]
train_set = list(zip(*preprocess(dataset['train'], [partial(pad, border=4)] + transforms).values()))
print(f'Finished in {timer():.2} seconds')
print('Preprocessing test data')
valid_set = list(zip(*preprocess(dataset['valid'], transforms).values()))
print(f'Finished in {timer():.2} seconds')
# colors = ColorMap()
# draw = lambda graph: DotGraph({p: ({'fillcolor': colors[type(v)], 'tooltip': repr(v)}, inputs) for p, (v, inputs) in graph.items() if v is not None})
# draw(build_graph(net()))
epochs=24
lr_schedule = PiecewiseLinear([0, 5, epochs], [0, 0.4, 0])
batch_size = 512
train_transforms = [Crop(32, 32), FlipLR(), Cutout(8, 8)]
N_runs = 3
train_batches = DataLoader(Transform(train_set, train_transforms), batch_size, shuffle=True, set_random_choices=True, drop_last=True)
valid_batches = DataLoader(valid_set, batch_size, shuffle=False, drop_last=False)
lr = lambda step: lr_schedule(step/len(train_batches))/batch_size
#
# Computational(Network(net()))
summaries = []
for i in range(N_runs):
print(f'Starting Run {i} at {localtime()}')
model = Network(net()).to(device).half()
opts = [SGD(trainable_params(model).values(), {'lr': lr, 'weight_decay': Const(5e-4*batch_size), 'momentum': Const(0.9)})]
logs, state = Table(), {MODEL: model, LOSS: x_ent_loss, OPTS: opts}
for epoch in range(epochs):
logs.append(union({'epoch': epoch+1}, train_epoch(state, Timer(torch.cuda.synchronize), train_batches, valid_batches)))
logs.df().query(f'epoch=={epochs}')[['train_acc', 'valid_acc']].describe()