-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtrain.py
30 lines (28 loc) · 1.69 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import argparse
import os
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--quantization_bit', default=4, type=int)
parser.add_argument('--do_train', action='store_true')
parser.add_argument('--train_file', default='我加入的群聊.txt.json', type=str)
parser.add_argument('--prompt_column', default='prompt', type=str)
parser.add_argument('--response_column', default='response', type=str)
parser.add_argument('--history_column', default='history', type=str)
parser.add_argument('--overwrite_cache', action='store_true')
parser.add_argument('--model_name_or_path', default='..\\chatglm-6b', type=str)
parser.add_argument('--output_dir', default='chatglm_qq', type=str)
parser.add_argument('--overwrite_output_dir', action='store_true')
parser.add_argument('--max_source_length', default=128, type=int)
parser.add_argument('--max_target_length', default=128, type=int)
parser.add_argument('--per_device_train_batch_size', default=4, type=int)
parser.add_argument('--per_device_eval_batch_size', default=1, type=int)
parser.add_argument('--gradient_accumulation_steps', default=2, type=int)
parser.add_argument('--predict_with_generate', action='store_true')
parser.add_argument('--max_steps', default=3000, type=int)
parser.add_argument('--logging_steps', default=10, type=int)
parser.add_argument('--save_steps', default=200, type=int)
parser.add_argument('--learning_rate', default=0.002, type=float)
parser.add_argument('--pre_seq_len', default=128, type=int)
args = parser.parse_args()
args = " ".join([f"--{k} {v}" for k, v in vars(args).items()])
os.system(f"python main.py {args}")