-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathtest_part_c_students.py
172 lines (148 loc) · 8.13 KB
/
test_part_c_students.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
if __name__ == '__main__':
import os
import sys
import re
from datetime import datetime
import pandas as pd
import pyarrow.parquet as pq
import time
import timeit
import importlib
import logging
logging.basicConfig(filename='part_c_tests.log', level=logging.DEBUG,
filemode='w', format='%(levelname)s %(asctime)s: %(message)s')
import metrics
def test_file_exists(fn):
if os.path.exists(fn):
return True
logging.error(f'{fn} does not exist.')
return False
tid_ptrn = re.compile('\d+')
def invalid_tweet_id(tid):
if not isinstance(tid, str):
tid = str(tid)
if tid_ptrn.fullmatch(tid) is None:
return True
return False
bench_data_path = os.path.join('data', 'benchmark_data_train.snappy.parquet')
bench_lbls_path = os.path.join('data', 'benchmark_lbls_train.csv')
queries_path = os.path.join('data', 'queries_train.tsv')
model_dir = os.path.join('.', 'model')
start = datetime.now()
try:
# is the report there?
test_file_exists('report_part_c.docx')
# is benchmark data under 'data' folder?
bench_lbls = None
if not test_file_exists(bench_data_path) or \
not test_file_exists(bench_lbls_path):
logging.error("Benchmark data does exist under the 'data' folder.")
sys.exit(-1)
else:
bench_lbls = pd.read_csv(bench_lbls_path,
dtype={'query': int, 'tweet': str, 'y_true': int})
logging.info("Successfully loaded benchmark labels data.")
# is queries file under data?
queries = None
if not test_file_exists(queries_path):
logging.error("Queries data not found ~> skipping some tests.")
else:
queries = pd.read_csv(os.path.join('data', 'queries_train.tsv'), sep='\t')
logging.info("Successfully loaded queries data.")
import configuration
config = configuration.ConfigClass()
# do we need to download a pretrained model?
model_url = config.get_model_url()
if model_url is not None and config.get_download_model():
import utils
dest_path = 'model.zip'
utils.download_file_from_google_drive(model_url, dest_path)
if not os.path.exists(model_dir):
os.mkdir(model_dir)
if os.path.exists(dest_path):
utils.unzip_file(dest_path, model_dir)
logging.info(f'Successfully downloaded and extracted pretrained model into {model_dir}.')
else:
logging.error('model.zip file does not exists.')
# test for each search engine module
engine_modules = ['search_engine_' + name for name in ['1', '2', 'best']]
for engine_module in engine_modules:
try:
# does the module file exist?
if not test_file_exists(engine_module + '.py'):
continue
# try importing the module
se = importlib.import_module(engine_module)
logging.info(f"Successfully imported module {engine_module}.")
engine = se.SearchEngine(config=config)
# test building an index and doing so in <1 minute
build_idx_time = timeit.timeit(
"engine.build_index_from_parquet(bench_data_path)",
globals=globals(), number=1
)
logging.debug(f"Building the index in {engine_module} for benchmark data took {build_idx_time} seconds.")
if build_idx_time > 60:
logging.error('Parsing and index our *small* benchmark dataset took over a minute!')
# test loading precomputed model
engine.load_precomputed_model(model_dir)
# test that we can run one query and get results in the format we expect
n_res, res = engine.search('bioweapon')
if n_res is None or res is None or n_res < 1 or len(res) < 1:
logging.error('basic query for the word bioweapon returned no results')
else:
logging.debug(f"{engine_module} successfully returned {n_res} results for the query 'bioweapon'.")
invalid_tweet_ids = [doc_id for doc_id in res if invalid_tweet_id(doc_id)]
if len(invalid_tweet_ids) > 0:
logging.error("the query 'bioweapon' returned results that are not valid tweet ids: " + str(
invalid_tweet_ids[:10]))
# run multiple queries and test that no query takes > 10 seconds
queries_results = []
if queries is not None:
for i, row in queries.iterrows():
q_id = row['query_id']
q_keywords = row['keywords']
start_time = time.time()
q_n_res, q_res = engine.search(q_keywords)
end_time = time.time()
q_time = end_time - start_time
if q_n_res is None or q_res is None or q_n_res < 1 or len(q_res) < 1:
logging.error(f"Query {q_id} with keywords '{q_keywords}' returned no results.")
else:
logging.debug(f"{engine_module} successfully returned {q_n_res} results for query number {q_id}.")
invalid_tweet_ids = [doc_id for doc_id in q_res if invalid_tweet_id(doc_id)]
if len(invalid_tweet_ids) > 0:
logging.error(f"Query {q_id} returned results that are not valid tweet ids: " + str(
invalid_tweet_ids[:10]))
queries_results.extend(
[(q_id, str(doc_id)) for doc_id in q_res if not invalid_tweet_id(doc_id)])
if q_time > 10:
logging.error(f"Query {q_id} with keywords '{q_keywords}' took more than 10 seconds.")
queries_results = pd.DataFrame(queries_results, columns=['query', 'tweet'])
# merge query results with labels benchmark
q_results_labeled = None
if bench_lbls is not None and len(queries_results) > 0:
q_results_labeled = pd.merge(queries_results, bench_lbls,
on=['query', 'tweet'], how='inner', suffixes=('_result', '_bench'))
# q_results_labeled.rename(columns={'y_true': 'label'})
# test that MAP > 0
if q_results_labeled is not None:
# logging.debug(q_results_labeled.head())
results_map = metrics.map(q_results_labeled)
logging.debug(f"{engine_module} results have MAP value of {results_map}.")
if results_map <= 0 or results_map > 1:
logging.error(f'{engine_module} results MAP value is out of range (0,1).')
# test that the average across queries of precision,
# precision@5, precision@10, precision@50, and recall
# is in (0,1).
if engine_module == 'search_engine_best' and \
test_file_exists('idx_bench.pkl'):
logging.info('idx_bench.pkl found!')
engine.load_index('idx_bench.pkl')
logging.info('Successfully loaded idx_bench.pkl using search_engine_best.')
except Exception as e:
logging.error(f'The following error occured while testing the module {engine_module}.')
logging.error(e, exc_info=True)
except Exception as e:
logging.error(e, exc_info=True)
run_time = datetime.now() - start
logging.debug(f'Total runtime was: {run_time}')