-
Notifications
You must be signed in to change notification settings - Fork 0
/
ner.py
99 lines (80 loc) · 3.15 KB
/
ner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""Arabic named entity recognition"""
from pyarabic.araby import strip_tashkeel
import data
def find_names(sentences: list[list[str]]):
for sentence in sentences:
yield [strip_tashkeel(word) in data.known_names for word in sentence]
# from contextlib import redirect_stderr
# from os import devnull
# import torch
# from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline
# # Load the model
# torch.device("cuda" if torch.cuda.is_available() else "cpu")
# tokenizer = AutoTokenizer.from_pretrained("hatmimoha/arabic-ner")
# model = AutoModelForTokenClassification.from_pretrained("hatmimoha/arabic-ner")
# nlp = pipeline("ner", model=model, tokenizer=tokenizer)
# def _tag_sentences(sentences: list[str]):
# with open(devnull, "w") as f, redirect_stderr(f):
# annotations = nlp(sentences)
# for sentence in annotations:
# entities = []
# tags = []
# for item in sentence:
# if item["word"].startswith("##"):
# entities[-1] = entities[-1] + item["word"].replace("##", "")
# else:
# entities.append(item["word"])
# tags.append(item["entity"])
# yield entities, tags
# def find_names(sentences: list[list[str]]):
# """
# Given a list of lists of unvocalized tokens,
# returns a generator of lists of whether each token is part of a name.
# """
# for sentence, entities in zip(
# sentences,
# _tag_sentences(
# [" ".join(token for token in sentence) for sentence in sentences]
# ),
# ):
# is_name_data = [False] * len(sentence)
# current_token = 0
# for entity, tag in zip(*entities):
# while current_token < len(sentence):
# # Possibly use fuzz matching here
# if sentence[current_token] == entity:
# break
# current_token += 1
# else:
# break
# is_name_data[current_token] = True
# print(f"{sentence[current_token]}, {tag}")
# yield is_name_data
# Tests
if __name__ == "__main__":
from data import known_names
wiki_input = "data/ner-gold-standard/wiki.txt"
correct_positive = 0
correct_negative = 0
incorrect_positive = 0
incorrect_negative = 0
with open(wiki_input, encoding="utf-8") as i:
for line in i.readlines():
if not line.strip():
continue
word, tag = line.split()
word = word.removeprefix("")
is_name = tag != "O"
test_is_name = word in known_names
# the best way to do it
correct_positive += test_is_name * is_name
incorrect_positive += test_is_name * (not is_name)
incorrect_negative += (not test_is_name) * is_name
correct_negative += (not test_is_name) * (not is_name)
s = correct_positive + correct_negative + incorrect_positive + incorrect_negative
print(f"{correct_positive=}")
print(f"{correct_negative=}")
print(f"{incorrect_positive=}")
print(f"{incorrect_negative=}")
print("-" * 30)
print(f"Sum: {s}")