-
Notifications
You must be signed in to change notification settings - Fork 135
/
Copy pathnq-bm25-fid.py
151 lines (114 loc) · 4.3 KB
/
nq-bm25-fid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import logging
import kilt.eval_retrieval as retrieval_metrics
import pandas as pd
import torch
from datasets import load_dataset
from haystack import Pipeline
from haystack.document_stores import ElasticsearchDocumentStore
from haystack.nodes import BM25Retriever, PromptModel, SentenceTransformersRanker
from haystack.nodes.prompt import AnswerParser, PromptNode
from haystack.nodes.prompt.prompt_template import PromptTemplate
from kilt.eval_downstream import _calculate_metrics, validate_input
from tqdm import tqdm
from fastrag.prompters.invocation_layers import fid
from fastrag.utils import get_timing_from_pipeline
def evaluate(gold_records, guess_records):
# 0. validate input
gold_records, guess_records = validate_input(gold_records, guess_records)
# 1. downstream + kilt
result = _calculate_metrics(gold_records, guess_records)
# 2. retrieval performance
retrieval_results = retrieval_metrics.compute(
gold_records, guess_records, ks=[1, 5], rank_keys=["wikipedia_id"]
)
result["retrieval"] = {
"Rprec": retrieval_results["Rprec"],
"recall@5": retrieval_results["recall@5"],
}
return result
def create_json_entry(jid, input_text, answer, documents):
return {
"id": jid,
"input": input_text,
"output": [{"answer": answer, "provenance": [{"wikipedia_id": d.id} for d in documents]}],
}
def create_records(test_dataset, result_collection):
guess_records = []
for i in range(len(test_dataset)):
example = test_dataset[i]
results = result_collection[i]
guess_records.append(
create_json_entry(
example["id"], example["input"], results["answers"][0].answer, results["documents"]
)
)
return guess_records
def evaluate_from_answers(gold_records, result_collection):
guess_records = create_records(gold_records, result_collection)
return evaluate(gold_records, guess_records)
logging.getLogger().setLevel(logging.INFO)
logging.basicConfig(
format="%(asctime)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=logging.INFO,
)
# Create Components
document_store = ElasticsearchDocumentStore(
host="localhost", index="index_name", port=80, search_fields=["content", "title"]
)
retriever = BM25Retriever(document_store=document_store)
reranker = SentenceTransformersRanker(model_name_or_path="cross-encoder/ms-marco-MiniLM-L-12-v2")
PrompterModel = PromptModel(
model_name_or_path="Intel/fid_flan_t5_base_nq",
use_gpu=True,
invocation_layer_class=fid.FiDHFLocalInvocationLayer,
model_kwargs=dict(
model_kwargs=dict(device_map={"": 0}, torch_dtype=torch.bfloat16, do_sample=False),
generation_kwargs=dict(max_length=10),
),
)
reader = PromptNode(
model_name_or_path=PrompterModel,
default_prompt_template=PromptTemplate("{query}", output_parser=AnswerParser()),
)
# Build Pipeline
p = Pipeline()
p.add_node(component=retriever, name="Retriever", inputs=["Query"])
p.add_node(component=reranker, name="Reranker", inputs=["Retriever"])
p.add_node(component=reader, name="Reader", inputs=["Reranker"])
# Load Dataset
data = load_dataset("kilt_tasks", "nq")
validation_data = data["validation"]
# Run Pipeline
retriever_top_k = 100
reranker_top_k = 50
all_results = []
efficiency_metrics = []
for example in tqdm(validation_data):
results = p.run(
query=example["input"],
params={"Retriever": {"top_k": retriever_top_k}, "Reranker": {"top_k": reranker_top_k}},
)
pipeline_latency_report = get_timing_from_pipeline(p)
efficiency_metrics.append(
{
component_name: component_time[1]
for component_name, component_time in pipeline_latency_report.items()
}
)
all_results.append(results)
kilt_metrics = evaluate_from_answers(validation_data, all_results)
# Show Results
# Show Results
efficiency_metrics_df = pd.DataFrame(efficiency_metrics)
efficiency_metrics_df_mean = efficiency_metrics_df.mean()
for metric in efficiency_metrics_df.columns:
logging.info(f"Mean Latency for {metric} examples: {efficiency_metrics_df_mean[metric]} sec")
logging.info(
f"""
Accuracy: {kilt_metrics['downstream']['accuracy']}
EM: {kilt_metrics['downstream']['em']}
F1: {kilt_metrics['downstream']['f1']}
ROUGE-L: {kilt_metrics['downstream']['rougel']}
"""
)