-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluate.py
155 lines (127 loc) · 5.57 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import logging
import os
import numpy as np
import torch
from sklearn.metrics import classification_report, confusion_matrix
from torch.utils.data import DataLoader
from tqdm import tqdm
from feature_embeddings import BertPrep
from train import BertForNegationCueClassification, NegCueDataset
import config
def evaluate(
ckpt: str,
dataset_file: str,
error_analysis_fname: str,
classification_metrics_fname: str,
) -> None:
"""
Evaluate model checkpoint on given dataset. Perform postprocessig of predictions to match initial tokenization
:param ckpt: path to model checkpoint
:param dataset_file: path to dataset in CoNLL format
:param error_analysis_fname: path to file where error analysis data will be stored
:param classification_metrics_fname: path to file where classification metrics data will be stored
:return: None
"""
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
lexicals = (
["pos", "possible_prefix", "possible_suffix"]
if "_lex" in os.path.basename(ckpt)
else []
)
# TODO this is stupid
train_prep = BertPrep(config.TRAIN_FEATURES, lexicals)
n_lexicals = train_prep.lexicals_vec_size
logging.info(f"Using lexicals: {lexicals}.")
model = BertForNegationCueClassification.from_pretrained(
ckpt, num_labels=len(train_prep.tag2idx), n_lexicals=n_lexicals
)
logging.info(f"Running model on device: {device}")
model.to(device)
dataset_prep = BertPrep(dataset_file, lexicals)
dataset = NegCueDataset(dataset_prep.preprocess_dataset(), n_lexicals=n_lexicals)
dataset_loader = DataLoader(dataset, batch_size=1, num_workers=1)
true_tags = []
pred_tags = []
pred_tags_parsed = []
true_tags_parsed = []
tokenizer = train_prep.tokenizer
inv_tag_enc = {v: k for k, v in train_prep.tag2idx.items()}
logging.info("Running prediction...")
with torch.no_grad(), open(error_analysis_fname, "w") as ea_fd:
for i, data in tqdm(
enumerate(dataset_loader), total=len(dataset_prep.sentences)
):
sentence = [el[0] for el in dataset_prep.sentences[i]]
# detect the index of the end of the sentence
tokenized_sentence = data["input_ids"].numpy().reshape(-1)
end = [
i
for i, tok in enumerate(tokenized_sentence)
if tok == tokenizer.vocab["[SEP]"]
][0]
# read TRUE labels for this sentence
true_sent_tags = data["labels"].numpy().reshape(-1)[:end][1:]
true_tags.extend(true_sent_tags)
token_ids = data["token_ids"].numpy().reshape(-1)[:end][1:]
# Glue tags that were extended because of tokenization.
true_tag_parsed = []
prev_id = None
for label, tok_id in zip(true_sent_tags, token_ids):
if prev_id == tok_id:
prev_id = tok_id
continue
prev_id = tok_id
true_tag_parsed.append(label)
true_tags_parsed.extend(true_tag_parsed)
input_ids = data["input_ids"].to(device)
attention_mask = data["attention_mask"].to(device)
labels = data["labels"].to(device)
lexicals = None if n_lexicals == 0 else data["lexicals"]
# Query the model for tags predictions
outputs = model(
input_ids,
lexicals,
attention_mask=attention_mask,
labels=labels,
device=device,
)
# Get actual tags from logits and strip them to sentence length
pred_sent_tags = (
outputs.logits.argmax(2).cpu().numpy().reshape(-1)[:end][1:]
)
pred_tags.extend(pred_sent_tags)
# Glue tags that were extended because of tokenization.
pred_tag_parsed = []
prev_id = None
for label, tok_id in zip(pred_sent_tags, token_ids):
if prev_id == tok_id:
prev_id = tok_id
continue
prev_id = tok_id
pred_tag_parsed.append(label)
pred_tags_parsed.extend(pred_tag_parsed)
# Print
if not np.array_equal(true_tag_parsed, pred_tag_parsed):
print(f"---------------------{i}--------------------------", file=ea_fd)
print(sentence, file=ea_fd)
mask = np.array(true_tag_parsed) != np.array(pred_tag_parsed)
print(np.ma.array(sentence, mask=~mask), file=ea_fd)
# print('UNK:', [sentence[i] for i in unk_i])
print("TRUE |", [inv_tag_enc[k] for k in true_tag_parsed], file=ea_fd)
print("PRED |", [inv_tag_enc[k] for k in pred_tag_parsed], file=ea_fd)
logging.info(f"Saved data for error analysis to: {error_analysis_fname}")
with open(classification_metrics_fname, "w") as cr_fd:
print("--------- Post-processed Classification Report --------- ", file=cr_fd)
print(
classification_report(
true_tags_parsed,
pred_tags_parsed,
target_names=list(train_prep.tag2idx.keys()),
digits=4,
),
file=cr_fd,
)
print("--------- Post-processed Confustion Matrix --------- ", file=cr_fd)
print(inv_tag_enc, file=cr_fd)
print(confusion_matrix(true_tags, pred_tags), file=cr_fd)
logging.info(f"Saved classification metrics to: {classification_metrics_fname}")