-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun_emotrans.py
96 lines (83 loc) · 3.35 KB
/
run_emotrans.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
## cuda environment
import warnings, os, wandb, yaml, sys
warnings.filterwarnings("ignore")
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['TOKENIZERS_PARALLELISM']='false'
## import packages
from global_var import *
sys.path.append(utils_dir)
from config import config
from writer import JsonFile
from processor import Processor
from utils_processor import set_rng_seed
def run(args):
if args.train['wandb']:
wandb.init(
project=f"project: {'-'.join(args.train['tasks'])}",
name=f"{'-'.join(args.train['tasks'])}-seed-{args.train['seed']}",
)
set_rng_seed(args.train['seed']) # 固定随机种子
# import model and dataset
from Model_EmoTrans import import_model
model, dataset = import_model(args)
# train or eval the model
processor = Processor(args, model, dataset)
if args.train['inference']:
processor.loadState()
result = processor._evaluate(stage='test')
else: result = processor._train()
if args.train['wandb']: wandb.finish()
## 2. output results
record = {
'params': {
'e': args.train['epochs'],
'es': args.train['early_stop'],
'lr': args.train['learning_rate'],
'lr_pre': args.train['learning_rate_pre'],
'bz': args.train['batch_size'],
'dr': args.model['drop_rate'],
'seed': args.train['seed'],
# 'ekl': args.model['ekl'],
},
'metric': {
'stop': result['valid']['epoch'],
'tr_mf1': result['train']['f1'],
'tv_mf1': result['valid']['f1'],
'te_mf1': result['test']['f1'],
},
}
return record
if __name__ == '__main__':
args = config(task='', dataset='meld', framework=None, model='emotrans')
## 导入配置文件
with open(f"./configs/{args.model['name']}.yaml", 'r') as f:
run_config = yaml.safe_load(f)
args.train.update(run_config['train'])
args.model.update(run_config['model'])
args.logger['display'].extend(['arch', 'scale', 'weight'])
seeds = [2024,2025,2026]
if seeds or args.train['inference']: # 按指定 seed 执行
if not seeds: seeds = [args.train['seed']]
recoed_path = f"{args.file['record']}{args.model['name']}_best.jsonl"
record_show = JsonFile(recoed_path, mode_w='a', delete=True)
for seed in seeds:
args.train['seed'] = seed
record = run(args)
record_show.write(record, space=False)
# seeds = []
# if seeds or args.train['inference']: # 按指定 seed 执行
# if not seeds: seeds = [args.train['seed']]
# recoed_path = f"{args.file['record']}{args.model['name']}_best.jsonl"
# record_show = JsonFile(recoed_path, mode_w='a', delete=True)
# for seed in seeds:
# args.train['seed'] = seed
# record = run(args)
# record_show.write(record, space=False)
# else: # 随机 seed 执行
# recoed_path = f"{args.file['record']}{args.model['name']}_search.jsonl"
# record_show = JsonFile(recoed_path, mode_w='a', delete=True)
# for c in range(100):
# args.train['seed'] = random.randint(1000,9999)+c
# record = run(args)
# record_show.write(record, space=False)