forked from misads/AliProducts
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgrid_search.py
99 lines (65 loc) · 2.08 KB
/
grid_search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# encoding = utf-8
"""
借鉴自NVIDIA的深度学习实验管理工具runx(https://github.com/NVIDIA/runx)。自己重写了一下。
!!该程序会将yml配置文件中的所有命令执行一遍。!!
用法:
python grid_search.py --sweep sweep.yml --show # 默认
python grid_search.py --sweep sweep.yml --run
"""
import argparse
import misc_utils as utils
import random
import string
import yaml
import os
def load_yml(file='sweep.yml', op=None):
if not os.path.isfile(file):
raise FileNotFoundError('File "%s" not found' % file)
with open(file, 'r') as f:
try:
cfg = yaml.safe_load(f.read())
except yaml.YAMLError:
raise Exception('Error parsing YAML file: ' + file)
if op:
return cfg[op]
else:
return cfg
def hash(n):
choices = '0123456789abcdef'
ans = ''
for _ in range(n):
ans += choices[random.randint(0, 15)]
return ans
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--sweep', type=str, default='sweep.yml', help='configure file, default is "sweep.yml"')
parser.add_argument('--show', action='store_true', default=True)
parser.add_argument('--run', action='store_true')
return parser.parse_args()
opt = parse_args()
if __name__ == '__main__':
cfg = load_yml(opt.sweep)
cmd = cfg['cmd']
hparams = cfg['hparams'].items()
hparams = list(hparams)
n = len(hparams)
temp = [''] * n
ans = []
def dfs(i):
if i >= n:
ans.append(temp.copy())
# print(temp)
return
hparam, choices = hparams[i]
for choice in choices:
temp[i] = choice
dfs(i+1)
dfs(0)
# for hparam in hparams:
for i, one_run in enumerate(ans):
command = cmd + ' --tag %s' % hash(8)
for (hparam, _), choice in zip(hparams, one_run):
command += ' --%s %s' % (hparam, choice)
utils.color_print(('%d/%d: ' % ((i+1), len(ans)) + command), 4)
if opt.run:
os.system(command)