-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathargs.py
159 lines (155 loc) · 4.16 KB
/
args.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import argparse
import os
import torch
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
"--data-location",
type=str,
default=os.path.expanduser("/mnt/data/dataset"),
help="The root directory for the datasets.",
)
parser.add_argument(
"--eval-datasets",
default=None,
type=lambda x: x.split(","),
help="Which datasets to use for evaluation. Split by comma, e.g. MNIST,EuroSAT. ",
)
parser.add_argument(
"--train-dataset",
default=None,
type=lambda x: x.split(","),
help="Which dataset(s) to patch on.",
)
parser.add_argument(
"--exp_name",
type=str,
default=None,
help="Name of the experiment, for organization purposes only.",
)
parser.add_argument(
"--results-db",
type=str,
default=None,
help="Where to store the results, else does not store",
)
parser.add_argument(
"--model",
type=str,
default="ViT-B-32",
help="The type of model (e.g. RN50, ViT-B-32).",
)
parser.add_argument(
"--task_to_orth",
type=str,
default="DTD",
)
parser.add_argument(
"--penalty",
type=float,
default=.1,
)
parser.add_argument(
"--penalty_iter",
type=int,
default=-1,
)
parser.add_argument(
"--batch_size",
type=int,
default=64,
)
parser.add_argument(
"--orth_batch_size",
type=int,
default=64,
)
parser.add_argument(
"--num-grad-accumulation",
type=int,
default=1,
help="Number of gradient accumulation steps.",
)
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate.")
parser.add_argument("--wd", type=float, default=0.1, help="Weight decay")
parser.add_argument("--ls", type=float, default=0.0, help="Label smoothing.")
parser.add_argument(
"--warmup_length",
type=int,
default=500,
)
parser.add_argument(
"--epochs",
type=int,
default=10,
)
parser.add_argument(
"--load",
type=lambda x: x.split(","),
default=None,
help="Optionally load _classifiers_, e.g. a zero shot classifier or probe or ensemble both.", # noqa: E501
)
parser.add_argument(
"--save",
type=str,
default=None,
help="Optionally save a _classifier_, e.g. a zero shot classifier or probe.",
)
parser.add_argument(
"--cache-dir",
type=str,
default=None,
help="Directory for caching features and encoder",
)
parser.add_argument(
"--openclip-cachedir",
type=str,
default=os.path.expanduser("~/openclip-cachedir/open_clip"),
help="Directory for caching models from OpenCLIP",
)
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Number of processes for distributed training.",
)
parser.add_argument(
"--checkpoint_every",
type=int,
default=-1,
help="How often to checkpoint the model.",
)
parser.add_argument(
"--port",
type=int,
default=12355,
help="Port for distributed training.",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="Random seed.",
)
parser.add_argument(
"--device_number",
type=int,
default=0,
help="device_number",
)
parser.add_argument(
"--finetuning-mode",
choices=["standard", "linear", "posthoc", "none"],
help="Whether to use linearized models or not.",
)
parser.add_argument(
"--n-eval-points",
type=int,
default=21,
help="Number of evaluation points used to find optimal coefficient in task arithmetic.",
)
parsed_args = parser.parse_args()
parsed_args.device = "cuda" if torch.cuda.is_available() else "cpu"
if parsed_args.load is not None and len(parsed_args.load) == 1:
parsed_args.load = parsed_args.load[0]
return parsed_args