This repository has been archived by the owner on Sep 24, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 124
/
config.py
167 lines (148 loc) · 4.23 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import argparse
arg_lists = []
parser = argparse.ArgumentParser(description="RAM")
def str2bool(v):
return v.lower() in ("true", "1")
def add_argument_group(name):
arg = parser.add_argument_group(name)
arg_lists.append(arg)
return arg
# glimpse network params
glimpse_arg = add_argument_group("Glimpse Network Params")
glimpse_arg.add_argument(
"--patch_size", type=int, default=8, help="size of extracted patch at highest res"
)
glimpse_arg.add_argument(
"--glimpse_scale", type=int, default=1, help="scale of successive patches"
)
glimpse_arg.add_argument(
"--num_patches", type=int, default=1, help="# of downscaled patches per glimpse"
)
glimpse_arg.add_argument(
"--loc_hidden", type=int, default=128, help="hidden size of loc fc"
)
glimpse_arg.add_argument(
"--glimpse_hidden", type=int, default=128, help="hidden size of glimpse fc"
)
# core network params
core_arg = add_argument_group("Core Network Params")
core_arg.add_argument(
"--num_glimpses", type=int, default=6, help="# of glimpses, i.e. BPTT iterations"
)
core_arg.add_argument("--hidden_size", type=int, default=256, help="hidden size of rnn")
# reinforce params
reinforce_arg = add_argument_group("Reinforce Params")
reinforce_arg.add_argument(
"--std", type=float, default=0.05, help="gaussian policy standard deviation"
)
reinforce_arg.add_argument(
"--M", type=int, default=1, help="Monte Carlo sampling for valid and test sets"
)
# data params
data_arg = add_argument_group("Data Params")
data_arg.add_argument(
"--valid_size",
type=float,
default=0.1,
help="Proportion of training set used for validation",
)
data_arg.add_argument(
"--batch_size", type=int, default=128, help="# of images in each batch of data"
)
data_arg.add_argument(
"--num_workers",
type=int,
default=4,
help="# of subprocesses to use for data loading",
)
data_arg.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="Whether to shuffle the train and valid indices",
)
data_arg.add_argument(
"--show_sample",
type=str2bool,
default=False,
help="Whether to visualize a sample grid of the data",
)
# training params
train_arg = add_argument_group("Training Params")
train_arg.add_argument(
"--is_train", type=str2bool, default=True, help="Whether to train or test the model"
)
train_arg.add_argument(
"--momentum", type=float, default=0.5, help="Nesterov momentum value"
)
train_arg.add_argument(
"--epochs", type=int, default=200, help="# of epochs to train for"
)
train_arg.add_argument(
"--init_lr", type=float, default=3e-4, help="Initial learning rate value"
)
train_arg.add_argument(
"--lr_patience",
type=int,
default=20,
help="Number of epochs to wait before reducing lr",
)
train_arg.add_argument(
"--train_patience",
type=int,
default=50,
help="Number of epochs to wait before stopping train",
)
# other params
misc_arg = add_argument_group("Misc.")
misc_arg.add_argument(
"--use_gpu", type=str2bool, default=True, help="Whether to run on the GPU"
)
misc_arg.add_argument(
"--best",
type=str2bool,
default=True,
help="Load best model or most recent for testing",
)
misc_arg.add_argument(
"--random_seed", type=int, default=1, help="Seed to ensure reproducibility"
)
misc_arg.add_argument(
"--data_dir", type=str, default="./data", help="Directory in which data is stored"
)
misc_arg.add_argument(
"--ckpt_dir",
type=str,
default="./ckpt",
help="Directory in which to save model checkpoints",
)
misc_arg.add_argument(
"--logs_dir",
type=str,
default="./logs/",
help="Directory in which Tensorboard logs wil be stored",
)
misc_arg.add_argument(
"--use_tensorboard",
type=str2bool,
default=False,
help="Whether to use tensorboard for visualization",
)
misc_arg.add_argument(
"--resume",
type=str2bool,
default=False,
help="Whether to resume training from checkpoint",
)
misc_arg.add_argument(
"--print_freq",
type=int,
default=10,
help="How frequently to print training details",
)
misc_arg.add_argument(
"--plot_freq", type=int, default=1, help="How frequently to plot glimpses"
)
def get_config():
config, unparsed = parser.parse_known_args()
return config, unparsed