forked from microsoft/TransformerCompression
-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_zero_shot_tasks.py
167 lines (139 loc) · 6.16 KB
/
run_zero_shot_tasks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import argparse
import json
import logging
import os
import lm_eval
import torch
import wandb
from lm_eval import tasks
from lm_eval import utils as lm_eval_utils
from lm_eval.api.registry import ALL_TASKS
from lm_eval.models.huggingface import HFLM
from lm_eval.tasks import initialize_tasks
from slicegpt import gpu_utils, hf_utils, utils
from slicegpt.config import config
utils.configure_logging()
os.environ["WANDB__SERVICE_WAIT"] = "300"
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
type=str,
default="facebook/opt-125m",
help="Model to load",
)
path_group = parser.add_mutually_exclusive_group()
path_group.add_argument(
"--model-path",
type=str,
default=None,
help="Path to load the model and tokenizer from (required for local models, not required for HF models)",
)
path_group.add_argument(
"--sliced-model-path",
type=str,
help="Path to load the model to fine-tune (sliced) and tokenizer from",
default=None,
)
parser.add_argument(
"--sparsity", type=float, default=0.0, help="A measure of how much slicing is applied (in the range [0, 1))"
)
parser.add_argument(
"--round-interval",
type=int,
default=8,
help="Interval for rounding the weights (the best value may depend on your hardware)",
)
parser.add_argument('--hf-token', type=str, default=os.getenv('HF_TOKEN', None))
parser.add_argument("--batch-size", type=int, default=1, help="Batch size for evaluating with lm eval harness.")
parser.add_argument(
"--distribute-model",
action="store_true",
help="Use accelerate to put the model on multiple GPUs for evaluation. It is recommended to use it for models with 30B parameters and above.",
)
parser.add_argument('--wandb-project', type=str, default="slicegpt-lm-eval", help="wandb project name.")
parser.add_argument('--no-wandb', action="store_true", help="Disable wandb.")
parser.add_argument('--wandb-project', type=str, default="slicegpt-zeroshot")
parser.add_argument(
'--tasks',
nargs='+',
default=["piqa", "hellaswag", "arc_easy", "arc_challenge", "winogrande"],
choices=lm_eval_utils.MultiChoice(tasks.ALL_TASKS),
)
parser.add_argument('--num-fewshot', type=int, default=0, help="Number of fewshots for all tasks.")
return parser.parse_args()
def main() -> None:
logging.info("Running SliceGPT zeroshot tasks experiment.")
initialize_tasks()
args = parse_args()
logging.info(f"PyTorch device: {config.device}")
logging.info(f"Number of available cuda devices: {torch.cuda.device_count()}")
try:
wandb.init(project=args.wandb_project, config=args, mode='disabled' if args.no_wandb else None)
except wandb.UsageError as e:
# wandb.init will throw an error if the user is not logged in and the process is running in a non-shell
# environment, e.g. notebook, IDE, no-shell process, etc. In this case, we want to continue without wandb.
logging.info(f'Failed to initialize wandb: {e}, continuing without wandb')
wandb.init(project=args.wandb_project, mode='disabled')
if args.sliced_model_path:
# load the sliced model
logging.info(f"Loading sliced {args.model} model from {args.sliced_model_path} with sparsity {args.sparsity}")
model_adapter, tokenizer = hf_utils.load_sliced_model(
args.model,
args.sliced_model_path,
sparsity=args.sparsity,
token=args.hf_token,
round_interval=args.round_interval,
)
else:
# load the original model
logging.info(f"Loading {args.model} model")
model_adapter, tokenizer = hf_utils.get_model_and_tokenizer(args.model, args.model_path, token=args.hf_token)
# the lm eval harness ties the weights, but this should not be done for sliced models unless the lm_head was sliced
model_adapter.model.tie_weights = lambda: None
if args.distribute_model:
# distribute model across available GPUs
gpu_utils.distribute_model(model_adapter)
else:
model_adapter.model.to(config.device)
### LM Eval Harness ###
hflm = HFLM(pretrained=model_adapter.model, tokenizer=tokenizer, batch_size=args.batch_size)
if args.tasks is None:
task_names = tasks.ALL_TASKS
else:
task_names = lm_eval_utils.pattern_match(args.tasks, ALL_TASKS)
logging.info(f"Selected Tasks: {task_names}")
results = lm_eval.simple_evaluate(hflm, tasks=task_names, num_fewshot=args.num_fewshot, batch_size=args.batch_size)[
'results'
]
wandb.log(results)
metric_vals = {task: round(result.get('acc_norm,none', result['acc,none']), 4) for task, result in results.items()}
logging.info(json.dumps(metric_vals, indent=4))
def calculate_avg_accuracy(task_names, results):
n_tasks = len(task_names)
acc_cumul = sum(
result.get('acc_norm,none', result['acc,none']) for task, result in results.items() if 'mmlu' not in task
)
questions_per_mmlu_task = {
task_name: lm_eval.tasks.get_task_dict([task_name])[task_name].dataset["test"].num_rows
for task_name in task_names
if 'mmlu' in task_name
}
if not questions_per_mmlu_task:
return acc_cumul / n_tasks
# Calculate average accuracy for mmlu tasks, weighted by number of questions in each task
acc_mmlu = sum(
result.get('acc_norm,none', result['acc,none']) * questions_per_mmlu_task[task]
for task, result in results.items()
if 'mmlu' in task
)
acc_mmlu_avg = acc_mmlu / sum(questions_per_mmlu_task.values())
wandb.log({'acc_mmlu_avg': acc_mmlu_avg})
return (acc_cumul + acc_mmlu_avg) / (n_tasks - len(questions_per_mmlu_task) + 1)
acc_avg = calculate_avg_accuracy(task_names, results)
wandb.log({'acc_avg': acc_avg})
logging.info(f"Average accuracy across tasks: {acc_avg}")
if __name__ == "__main__":
main()