-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathfield_test.py
121 lines (92 loc) · 3.49 KB
/
field_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import timeit
import time
from itertools import compress
from rapidfuzz.fuzz import (
ratio,
token_ratio,
token_set_ratio,
token_sort_ratio,
partial_token_set_ratio,
partial_token_sort_ratio,
partial_ratio_alignment,
partial_ratio,
WRatio,
QRatio
)
import pandas as pd
import numpy as np
import pickle
import boto3
from fuzzup.fuzz import fuzzy_cluster, compute_prominence, match_whitelist
from fuzzup.whitelists import get_politicians, get_cities
def load_preds_from_s3(file="ner_preds_v1.pickle"):
s3 = boto3.resource('s3')
preds = pickle.loads(s3.Bucket("nerbonanza").Object(file).get()['Body'].read())
return preds
# load ner predictions
id = None
s3 = boto3.resource('s3')
entities = pd.read_csv(s3.Bucket("nerbonanza").Object('entities.csv').get()['Body'])
articles = pd.read_csv(s3.Bucket("nerbonanza").Object('test_articles.csv').get()['Body'])
# danish company names (for white list)
def load_danish_companies(file="companies-name-municipality.json"):
s3 = boto3.resource('s3')
companies = pd.read_json(s3.Bucket("nerbonanza").Object(file).get()['Body'])
return companies
#### WHITELIST EXPERIMENTS
# companies = load_danish_companies()
# whitelist = companies.name.tolist()
# whitelist = list(get_danish_politicians().keys())
whitelist = list(get_cities().keys())
# run random article
def run_random(articles,
entitites,
id=None,
scorer=partial_token_set_ratio,
cutoff=75
):
if id is None:
id = np.random.choice(articles.content_id.tolist())
article = articles[articles.content_id == id]
article = article[['content_id', 'title', 'lead', 'body']]
article_ents = entities[entities.content_id == id]
article_ents = article_ents[article_ents.placement == "body"]
preds = article_ents.to_dict(orient="records")
t1 = time.time()
clusters = fuzzy_cluster(preds,
scorer=scorer,
workers=4,
cutoff=cutoff,
merge_output=True)
#pd.DataFrame.from_dict(clu ters)
clusters = compute_prominence(clusters,
merge_output=True,
weight_position=.5)
# subset location entities (for matching with cities)
locations = [x["entity_group"] == "LOC" for x in clusters]
locations = list(compress(clusters, locations))
clusters = locations
clusters = match_whitelist(clusters,
whitelist=whitelist,
scorer=ratio,
score_cutoff=95,
merge_output=True,
aggregate_cluster=True,
workers=1)
t2 = time.time()
if len(clusters) > 0:
clusters = pd.DataFrame.from_dict(clusters).sort_values(by ="prominence_rank")
print(id)
#print(article.title.tolist()[0])
#print(article.lead.tolist()[0])
print(article.body.tolist()[0])
print(clusters)
return t2-t1
run_random(articles,
entities,
scorer=partial_token_set_ratio,
cutoff=75)
#n_trials = 500
#timings = [run_random(articles, entities) for x in range(n_trials)]
#print(f"Avg. time for {n_trials} trials: {np.round(np.nanmean(timings), 4)}s")
#print(f"Median time for {n_trials} trials: {np.round(np.nanmedian(timings), 4)}s")