From d708d8efced97954f75684c1be790bf03c23da29 Mon Sep 17 00:00:00 2001 From: bosung Date: Fri, 16 Feb 2024 20:46:36 -0800 Subject: [PATCH] bug fix --- wrapper.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/wrapper.py b/wrapper.py index 75cb303..448d6e5 100644 --- a/wrapper.py +++ b/wrapper.py @@ -385,7 +385,8 @@ def predict(self, head_entity, tail_entity = first, second else: head_entity, tail_entity = second, first - new_sent = RelationSentence.from_spans(context, head=head_entity, tail=tail_entity, label=cur_label) + new_sent = RelationSentence(tokens=context.split(), head=[], tail=[], raw=batch["outputs"][jj], + head_text=head_entity, tail_text=tail_entity, label=cur_label) new_sent.head_text = head_entity new_sent.tail_text = tail_entity new_sent.raw = batch["outputs"][jj] @@ -458,7 +459,9 @@ def score(path_pred: str, path_gold: str) -> dict: for p in pred.sents[i].triplets: results_by_label[p.label]["n_pred"] += 1 for g in gold.sents[i].triplets: - if (p.head, p.tail, p.label) == (g.head, g.tail, g.label): + gold_head = " ".join([g.tokens[i] for i in g.head]) + gold_tail = " ".join([g.tokens[i] for i in g.tail]) + if (p.head_text, p.tail_text, p.label) == (gold_head, gold_tail, g.label): num_correct += 1 results_by_label[g.label]["tp"] += 1 if p.label == g.label: