-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathutil.py
96 lines (72 loc) · 2.27 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import json
import torch
import os
import logging
def save_config(obj, path):
f = open(path, 'w')
json.dump(obj.args, f, indent=' ')
f.write('\n')
f.close()
def load_config(Model, path):
f = open(path, 'r')
return Model(json.load(f))
def save_snapshot(model, ws, id):
filename = os.path.join(ws, 'snapshots', 'model.%s' % str(id))
f = open(filename, 'wb')
torch.save(model.state_dict(), f)
f.close()
def load_snapshot(model, ws, id):
filename = os.path.join(ws, 'snapshots', 'model.%s' % str(id))
f = open(filename, 'rb')
model.load_state_dict(torch.load(f, map_location=lambda s, loc: s))
f.close()
def load_last_snapshot(model, ws):
last = 0
for file in os.listdir(os.path.join(ws, 'snapshots')):
if 'model.' in file:
epoch = int(file.split('.')[1])
if epoch > last:
last = epoch
if last > 0:
load_snapshot(model, ws, last)
return last
def open_result(ws, name, id):
return open(os.path.join(ws, 'results', '%s.%s' %
(name, str(id))), 'w')
use_cuda = torch.cuda.is_available()
def Variable(*args, **kwargs):
v = torch.autograd.Variable(*args, **kwargs)
if use_cuda:
v = v.cuda()
return v
class bcolors:
HEADER = '\033[95m'
OKBLUE = '\033[94m'
OKGREEN = '\033[92m'
WARNING = '\033[93m'
FAIL = '\033[91m'
ENDC = '\033[0m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'
def colored(text, color, bold=False):
if bold:
return bcolors.BOLD + color + text + bcolors.ENDC
else:
return color + text + bcolors.ENDC
LOG_COLORS = {
'WARNING': bcolors.WARNING,
'INFO': bcolors.OKGREEN,
'DEBUG': bcolors.OKBLUE,
'CRITICAL': bcolors.WARNING,
'ERROR': bcolors.FAIL
}
class ColoredFormatter(logging.Formatter):
def __init__(self, msg, datefmt, use_color=True):
logging.Formatter.__init__(self, msg, datefmt=datefmt)
self.use_color = use_color
def format(self, record):
levelname = record.levelname
if self.use_color and levelname in LOG_COLORS:
record.levelname = colored(record.levelname[0],
LOG_COLORS[record.levelname])
return logging.Formatter.format(self, record)