-
Notifications
You must be signed in to change notification settings - Fork 3
/
arguments.py
117 lines (95 loc) · 4.06 KB
/
arguments.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import argparse
import os
import torch
import numpy as np
import torch
import random
import re
import yaml
import shutil
import warnings
from datetime import datetime
class Namespace(object):
def __init__(self, somedict):
for key, value in somedict.items():
assert isinstance(key, str) and re.match("[A-Za-z_-]", key)
if isinstance(value, dict):
self.__dict__[key] = Namespace(value)
else:
self.__dict__[key] = value
def __getattr__(self, attribute):
raise AttributeError(f"Can not find {attribute} in namespace. Please write {attribute} in your config file(xxx.yaml)!")
def set_deterministic(seed):
# seed by default is None
if seed is not None:
print(f"Deterministic with seed = {seed}")
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config-file', required=True, type=str, help="xxx.yaml")
parser.add_argument('--debug', action='store_true')
parser.add_argument('--debug_subset_size', type=int, default=8)
parser.add_argument('--download', action='store_true', help="if can't find dataset, download from web")
parser.add_argument('--data_dir', type=str, default=os.getenv('DATA'))
parser.add_argument('--log_dir', type=str, default=os.getenv('LOG'))
parser.add_argument('--ckpt_dir', type=str, default=os.getenv('CHECKPOINT'))
parser.add_argument('--ckpt_dir_1', type=str, default=os.getenv('CHECKPOINT'))
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--eval_from', type=str, default=None)
parser.add_argument('--hide_progress', action='store_true')
parser.add_argument('--cl_default', action='store_true')
parser.add_argument('--server', action='store_true')
parser.add_argument('--hcl', action='store_true')
parser.add_argument('--buffer_qdi', action='store_true')
parser.add_argument('--validation', action='store_true',
help='Test on the validation set')
parser.add_argument('--ood_eval', action='store_true',
help='Test on the OOD set')
parser.add_argument('--alpha', type=float, default=0.3)
args = parser.parse_args()
with open(args.config_file, 'r') as f:
for key, value in Namespace(yaml.load(f, Loader=yaml.FullLoader)).__dict__.items():
vars(args)[key] = value
if args.debug:
if args.train:
args.train.batch_size = 2
args.train.num_epochs = 1
args.train.stop_at_epoch = 1
if args.eval:
args.eval.batch_size = 2
args.eval.num_epochs = 1 # train only one epoch
args.dataset.num_workers = 0
assert not None in [args.log_dir, args.data_dir, args.ckpt_dir, args.name]
args.log_dir = os.path.join(args.log_dir, 'in-progress_'+datetime.now().strftime('%m%d%H%M%S_')+args.name)
os.makedirs(args.log_dir, exist_ok=False)
print(f'creating file {args.log_dir}')
os.makedirs(args.ckpt_dir, exist_ok=True)
shutil.copy2(args.config_file, args.log_dir)
set_deterministic(args.seed)
vars(args)['aug_kwargs'] = {
'name':args.model.name,
'image_size': args.dataset.image_size,
'cl_default': args.cl_default
}
vars(args)['dataset_kwargs'] = {
# 'name':args.model.name,
# 'image_size': args.dataset.image_size,
'dataset':args.dataset.name,
'data_dir': args.data_dir,
'download':args.download,
'debug_subset_size': args.debug_subset_size if args.debug else None,
# 'drop_last': True,
# 'pin_memory': True,
# 'num_workers': args.dataset.num_workers,
}
vars(args)['dataloader_kwargs'] = {
'drop_last': True,
'pin_memory': True,
'num_workers': args.dataset.num_workers,
}
return args