-
Notifications
You must be signed in to change notification settings - Fork 15
/
main.py
97 lines (83 loc) · 3.1 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from hydra_moe.utils import AttributeDict
import argparse
import os
import subprocess
import yaml
class Config:
def __init__(self, dictionary):
for k, v in dictionary.items():
setattr(self, k, v)
def load_config(config_file):
with open(config_file, "r") as stream:
try:
config_dict = yaml.safe_load(stream)
config = Config(config_dict)
return config
except yaml.YAMLError as exc:
print(exc)
def inference_runner(config_file):
config = load_config(config_file)
model_name = config.model_name_or_path.split("/")[1]
if "/" in config.dataset:
dataset_name = config.dataset.split("/")[1]
else:
dataset_name = config.dataset
config.output_dir = f"{config.output_dir}_{model_name}_{dataset_name}"
command = "python inference.py "
for key, value in vars(config).items():
command += f"--{key} {value} "
print(f"Command:\n{command.split(' ')}")
subprocess.run(command, shell=True)
def finetuner_runner(config_file):
config = load_config(config_file)
model_name = config.model_name_or_path.split("/")[1]
if "/" in config.dataset:
dataset_name = config.dataset.split("/")[1]
else:
dataset_name = config.dataset
config.output_dir = f"{config.output_dir}_{model_name}_{dataset_name}"
config.hub_model_id = f"{config.hub_model_id}/{model_name}_{dataset_name}"
command = "python finetuner.py "
for key, value in vars(config).items():
command += f"--{key} {value} "
print(f"Command:\n{command.split(' ')}")
subprocess.run(command, shell=True)
def webui_runner(config_file):
config = load_config(config_file)
model_name = config.model_name_or_path.split("/")[1]
if "/" in config.dataset:
dataset_name = config.dataset.split("/")[1]
else:
dataset_name = config.dataset
config.output_dir = f"{config.output_dir}_{model_name}_{dataset_name}"
command = "python server.py "
for key, value in vars(config).items():
command += f"--{key} {value} "
print(f"Command:\n{command.split(' ')}")
subprocess.run(command, shell=True)
def main():
parser = argparse.ArgumentParser(description="MoE")
parser.add_argument("--finetune", action="store_true", help="Finetune? T/F")
parser.add_argument("--inference", action="store_true", help="Inference? T/F")
parser.add_argument("--webui", action="store_true", help="Webui? T/F")
parser.add_argument(
"--config", type=str, required=False, help="Path to YAML config file"
)
args = parser.parse_args()
if not args.config:
if args.finetune:
config_file = "configs/default_ft_config.yaml"
elif args.inference:
config_file = "configs/inference_config.yaml"
elif args.webui:
config_file = "configs/inference_config.yaml"
else:
config_file = args.config
if args.finetune:
finetuner_runner(config_file)
if args.inference:
inference_runner(config_file)
if args.webui:
webui_runner(config_file)
if __name__ == "__main__":
main()