-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathretrieve.py
executable file
·156 lines (132 loc) · 4.61 KB
/
retrieve.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import os
import argparse
import json
import sys
from tqdm import tqdm
from densephrases import (
DensePhrases,
) # note that DensePhrases is installed with editable mode
# fixed setting
R_UNIT = "sentence"
TOP_K = 200
DUMP_DIR = "DensePhrases/outputs/densephrases-multi_wiki-20181220/dump"
RUNFILE_DIR = "runs"
os.makedirs(RUNFILE_DIR, exist_ok=True)
class Retriever:
def __init__(self, args):
self.args = args
self.initialize_retriever()
def initialize_retriever(self):
# load model
self.model = DensePhrases(
load_dir=self.args.query_encoder_name_or_dir, # change query encoder after re-training
dump_dir=DUMP_DIR,
index_name=self.args.index_name,
)
def retrieve(self, single_query_or_queries_dict):
queries_batch = []
if isinstance(single_query_or_queries_dict, dict): # batch search
queries, qids = (
single_query_or_queries_dict["queries"],
single_query_or_queries_dict["qids"],
)
# batchify
N = self.args.batch_size
for i in range(0, len(queries), N):
batch = queries[i : i + N]
queries_batch.append(batch)
with open(f"{RUNFILE_DIR}/{self.args.runfile_name}", "w") as fw:
# generate runfile
print(f"generating runfile: {RUNFILE_DIR}/{self.args.runfile_name}")
# iterate through batch
idx = 0
for batch_query in tqdm(queries_batch):
# retrieve
result = self.model.search(
batch_query,
retrieval_unit=R_UNIT,
top_k=TOP_K,
truecase=self.args.truecase,
)
# write to runfile
for i in range(len(result)):
fw.write(f"{qids[idx]}\t{result[i]}\n")
idx += 1
return None
elif isinstance(single_query_or_queries_dict, str): # online search
result = self.model.search(
single_query_or_queries_dict, retrieval_unit=R_UNIT, top_k=TOP_K
)
return result
else:
raise NotImplementedError
if __name__ == "__main__":
# parse arguments
parser = argparse.ArgumentParser(
description="Retrieve query-relevant collection with varying topK."
)
parser.add_argument(
"--query_encoder_name_or_dir",
type=str,
default="princeton-nlp/densephrases-multi-query-multi",
help="query encoder name registered in huggingface model hub OR custom query encoder checkpoint directory",
)
parser.add_argument(
"--index_name",
type=str,
default="start/1048576_flat_OPQ96",
help="index name appended to index directory prefix",
)
parser.add_argument(
"--query_list_path",
type=str,
default="DensePhrases/densephrases-data/open-qa/nq-open/test_preprocessed.json",
help="use batch search by default",
)
parser.add_argument(
"--single_query",
type=str,
default=None,
help="if presented do online search instead of batch search",
)
parser.add_argument(
"--runfile_name",
type=str,
default="run.tsv",
help="output runfile name which indluces query id and retrieved collection",
)
parser.add_argument(
"--batch_size",
type=int,
default=1,
help="#query to process with parallel processing",
)
parser.add_argument(
"--truecase",
action="store_true",
help="set True when we use case-sentive language model",
)
args = parser.parse_args()
# to prevent collision with DensePhrase native argparser
sys.argv = [sys.argv[0]]
# define input for retriever: batch or online search
if args.single_query is None:
with open(args.query_list_path, "r") as fr:
qa_data = json.load(fr)
# get all query list
queries, qids = [], []
for sample in qa_data["data"]:
queries.append(sample["question"])
qids.append(sample["id"])
inputs = {
"queries": queries,
"qids": qids,
}
else:
inputs = args.single_query
# initialize retriever
retriever = Retriever(args)
# run
result = retriever.retrieve(single_query_or_queries_dict=inputs)
if args.single_query is not None:
print(f"query: {args.single_query}, result: {result}")