-
Notifications
You must be signed in to change notification settings - Fork 7
/
test.py
77 lines (61 loc) · 2.68 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import argparse
from cfg.default_config import get_cfg_defaults
import numpy as np
from utils.comm import setup_seed
from tools.infer import infer_with_model_data_build
import pandas as pd
from sklearn.metrics import accuracy_score,classification_report
def get_args():
arg = argparse.ArgumentParser()
arg.add_argument('--data_dir', type = str,
default = '/workspace/nCoV_sentence_simi/data/')
arg.add_argument('--test_file', type=str,
default='/workspace/nCoV_sentence_simi/data/val_fold_0.csv')
arg.add_argument('--model_pths', type = list,
default= [
'/workspace/wkdir/ernie_2/fold_0/best.pth',
'/workspace/wkdir/ernie_2/fold_1/best.pth',
'/workspace/wkdir/ernie_2/fold_2/best.pth',
'/workspace/wkdir/ernie_2/fold_3/best.pth',
'/workspace/wkdir/ernie_2/fold_4/best.pth',
'/workspace/wkdir/ernie_2/fold_5/best.pth'
])
arg.add_argument('--cfg_files' , type = str,
default = [
'/workspace/nCoV_sentence_simi/cfgs/ernie.yml',
'/workspace/nCoV_sentence_simi/cfgs/ernie.yml',
'/workspace/nCoV_sentence_simi/cfgs/ernie.yml',
'/workspace/nCoV_sentence_simi/cfgs/ernie.yml',
'/workspace/nCoV_sentence_simi/cfgs/ernie.yml',
'/workspace/nCoV_sentence_simi/cfgs/ernie.yml',
])
arg.add_argument('--save_path', type = str, default = '/test.pred.csv')
return arg.parse_args()
if __name__ == '__main__':
import logging
logging.basicConfig(level=logging.WARNING)
setup_seed(1029)
args = get_args()
cfg = get_cfg_defaults()
print(args.cfg_files[0])
cfg.merge_from_file(args.cfg_files[0])
cfg.DATA.data_dir = args.data_dir
cfg.DATA.test_file = args.test_file
labels = None
all_pred_loggits = []
for i,path in enumerate(args.model_pths):
cfg.merge_from_file(args.cfg_files[i])
_, labels,pred_loggits = infer_with_model_data_build(cfg, path,args.test_file)
all_pred_loggits.append(pred_loggits)
all_pred_loggits = np.array(all_pred_loggits)
all_pred_loggits = all_pred_loggits.sum(axis=0)
print(all_pred_loggits.shape)
preds = np.argmax(all_pred_loggits, axis=1)
# if True:
# acc = accuracy_score(y_pred=preds, y_true=labels)
# print('acc: ',acc)
df = pd.DataFrame({
'id': range(len(preds)),
'label': preds
})
df.to_csv(args.save_path,index=None)