Skip to content

Commit

Permalink
Flatten output dict, remove 'name' as we already know the dataset names
Browse files Browse the repository at this point in the history
  • Loading branch information
tomaarsen committed Oct 17, 2024
1 parent 2cfd817 commit daf25c1
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 71 deletions.
157 changes: 89 additions & 68 deletions sentence_transformers/evaluation/NanoBEIREvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
import os
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING, Callable, Literal

import numpy as np
from torch import Tensor
Expand All @@ -19,8 +19,24 @@

logger = logging.getLogger(__name__)


dataset_paths = {
DatasetNameType = Literal[
"climatefever",
"dbpedia",
"fever",
"fiqa2018",
"hotpotqa",
"msmarco",
"nfcorpus",
"nq",
"quoraretrieval",
"scidocs",
"arguana",
"scifact",
"touche2020",
]


dataset_name_to_id = {
"climatefever": "zeta-alpha-ai/NanoClimateFEVER",
"dbpedia": "zeta-alpha-ai/NanoDBPedia",
"fever": "zeta-alpha-ai/NanoFEVER",
Expand All @@ -36,24 +52,24 @@
"touche2020": "zeta-alpha-ai/NanoTouche2020",
}

clean_names = {
"zeta-alpha-ai/NanoClimateFEVER": "ClimateFEVER",
"zeta-alpha-ai/NanoDBPedia": "DBPedia",
"zeta-alpha-ai/NanoFEVER": "FEVER",
"zeta-alpha-ai/NanoFiQA2018": "FiQA2018",
"zeta-alpha-ai/NanoHotpotQA": "HotpotQA",
"zeta-alpha-ai/NanoMSMARCO": "MSMARCO",
"zeta-alpha-ai/NanoNFCorpus": "NFCorpus",
"zeta-alpha-ai/NanoNQ": "NQ",
"zeta-alpha-ai/NanoQuoraRetrieval": "QuoraRetrieval",
"zeta-alpha-ai/NanoSCIDOCS": "SCIDOCS",
"zeta-alpha-ai/NanoArguAna": "ArguAna",
"zeta-alpha-ai/NanoSciFact": "SciFact",
"zeta-alpha-ai/NanoTouche2020": "Touche2020",
dataset_name_to_human_readable = {
"climatefever": "ClimateFEVER",
"dbpedia": "DBPedia",
"fever": "FEVER",
"fiqa2018": "FiQA2018",
"hotpotqa": "HotpotQA",
"msmarco": "MSMARCO",
"nfcorpus": "NFCorpus",
"nq": "NQ",
"quoraretrieval": "QuoraRetrieval",
"scidocs": "SCIDOCS",
"arguana": "ArguAna",
"scifact": "SciFact",
"touche2020": "Touche2020",
}


class NanoBeIREvaluator(SentenceEvaluator):
class NanoBEIREvaluator(SentenceEvaluator):
"""
This class evaluates the performance of a SentenceTransformer Model on the NanoBEIR collection of datasets.
Expand Down Expand Up @@ -85,7 +101,7 @@ class NanoBeIREvaluator(SentenceEvaluator):
results = evaluator(model)
'''
NanoBEeIR Evaluation of the model on ['QuoraRetrieval', 'MSMARCO'] dataset:
NanoBEIR Evaluation of the model on ['QuoraRetrieval', 'MSMARCO'] dataset:
Evaluating NanoBeIRNanoQuoraRetrieval
Evaluating NanoBeIRNanoMSMARCO
Expand Down Expand Up @@ -131,15 +147,14 @@ class NanoBeIREvaluator(SentenceEvaluator):

def __init__(
self,
dataset_names: list[str] | None = None,
dataset_names: list[DatasetNameType] | None = None,
mrr_at_k: list[int] = [10],
ndcg_at_k: list[int] = [10],
accuracy_at_k: list[int] = [1, 3, 5, 10],
precision_recall_at_k: list[int] = [1, 3, 5, 10],
map_at_k: list[int] = [100],
show_progress_bar: bool = False,
batch_size: int = 32,
name: str = "",
write_csv: bool = True,
truncate_dim: int | None = None,
score_functions: dict[str, Callable[[Tensor, Tensor], Tensor]] = {
Expand Down Expand Up @@ -175,7 +190,7 @@ def __init__(
"""
super().__init__()
if dataset_names is None:
dataset_names = list(dataset_paths.keys())
dataset_names = list(dataset_name_to_id.keys())
self.dataset_names = dataset_names
self.aggregate_fn = aggregate_fn
self.aggregate_key = aggregate_key
Expand All @@ -188,6 +203,9 @@ def __init__(
self.score_function_names = sorted(list(self.score_functions.keys()))
self.main_score_function = main_score_function
self.truncate_dim = truncate_dim
self.name = f"NanoBEIR_{aggregate_key}"
if self.truncate_dim:
self.name += f"_{self.truncate_dim}"

self.mrr_at_k = mrr_at_k
self.ndcg_at_k = ndcg_at_k
Expand All @@ -212,11 +230,9 @@ def __init__(
"main_score_function": main_score_function,
}

self.name = name

self.evaluators = [self._load_dataset(name, **ir_evaluator_kwargs) for name in self.dataset_names]

self.csv_file: str = f"NanoBEIR_evaluation_{aggregate_key}{self.name}_results.csv"
self.csv_file: str = f"NanoBEIR_evaluation_{aggregate_key}_results.csv"
self.csv_headers = ["epoch", "steps"]

for score_name in self.score_function_names:
Expand Down Expand Up @@ -250,21 +266,24 @@ def __call__(
out_txt = ""
if self.truncate_dim is not None:
out_txt += f" (truncated to {self.truncate_dim})"
logger.info(f"NanoBEeIR Evaluation of the model on {self.dataset_names} dataset{out_txt}:")
logger.info(f"NanoBEIR Evaluation of the model on {self.dataset_names} dataset{out_txt}:")
for evaluator in tqdm(self.evaluators, desc="Evaluating datasets", disable=not self.show_progress_bar):
logger.info(f"Evaluating {evaluator.name}")
evaluation = evaluator(model, output_path, epoch, steps)
for k in evaluation:
dataset, metric = k.split("_", maxsplit=1)
if self.truncate_dim:
dataset, _, metric = k.split("_", maxsplit=2)
else:
dataset, metric = k.split("_", maxsplit=1)
if metric not in per_metric_results:
per_metric_results[metric] = []
if dataset not in per_dataset_results:
per_dataset_results[dataset] = {}
per_dataset_results[dataset][metric] = evaluation[k]
per_dataset_results[dataset + "_" + metric] = evaluation[k]
per_metric_results[metric].append(evaluation[k])
per_dataset_results[self.aggregate_key] = {}

agg_results = {}
for metric in per_metric_results:
per_dataset_results[self.aggregate_key][metric] = self.aggregate_fn(per_metric_results[metric])
agg_results[metric] = self.aggregate_fn(per_metric_results[metric])
per_dataset_results[self.aggregate_key + "_" + metric] = agg_results[metric]

if output_path is not None and self.write_csv:
csv_path = os.path.join(output_path, self.csv_file)
Expand Down Expand Up @@ -298,16 +317,17 @@ def __call__(
fOut.write("\n")
fOut.close()

agg_results = per_dataset_results[self.aggregate_key]
if not self.primary_metric:
if self.main_score_function is None:
score_function = max(
[(name, agg_results[f"{name}_ndcg@{max(self.ndcg_at_k)}"]) for name in self.score_function_names],
key=lambda x: x[1],
)[0]
self.primary_metric = f"{score_function}_ndcg@{max(self.ndcg_at_k)}"
self.primary_metric = f"{self.aggregate_key}_{score_function}_ndcg@{max(self.ndcg_at_k)}"
else:
self.primary_metric = f"{self.main_score_function.value}_ndcg@{max(self.ndcg_at_k)}"
self.primary_metric = (
f"{self.aggregate_key}_{self.main_score_function.value}_ndcg@{max(self.ndcg_at_k)}"
)

self.store_metrics_in_model_card_data(model, agg_results)

Expand All @@ -316,32 +336,34 @@ def __call__(
logger.info(f"\nAverage Queries: {avg_queries}")
logger.info(f"Average Corpus: {avg_corpus}\n")

scores = per_dataset_results[self.aggregate_key]
for name in self.score_function_names:
logger.info(f"Aggregated for Score Function: {name}")
for k in self.accuracy_at_k:
logger.info("Accuracy@{}: {:.2f}%".format(k, scores[f"{name}_accuracy@{k}"] * 100))
logger.info("Accuracy@{}: {:.2f}%".format(k, agg_results[f"{name}_accuracy@{k}"] * 100))

for k in self.precision_recall_at_k:
logger.info("Precision@{}: {:.2f}%".format(k, scores[f"{name}_precision@{k}"] * 100))
logger.info("Recall@{}: {:.2f}%".format(k, scores[f"{name}_recall@{k}"] * 100))
logger.info("Precision@{}: {:.2f}%".format(k, agg_results[f"{name}_precision@{k}"] * 100))
logger.info("Recall@{}: {:.2f}%".format(k, agg_results[f"{name}_recall@{k}"] * 100))

for k in self.mrr_at_k:
logger.info("MRR@{}: {:.4f}".format(k, scores[f"{name}_mrr@{k}"]))
logger.info("MRR@{}: {:.4f}".format(k, agg_results[f"{name}_mrr@{k}"]))

for k in self.ndcg_at_k:
logger.info("NDCG@{}: {:.4f}".format(k, scores[f"{name}_ndcg@{k}"]))
logger.info("NDCG@{}: {:.4f}".format(k, agg_results[f"{name}_ndcg@{k}"]))
return per_dataset_results

def __get_clean_name(self, dataset_name: str) -> str:
return f"Nano{clean_names[dataset_paths[dataset_name.lower()]]}"
def _get_human_readable_name(self, dataset_name: DatasetNameType) -> str:
human_readable_name = f"Nano{dataset_name_to_human_readable[dataset_name.lower()]}"
if self.truncate_dim is not None:
human_readable_name += f"_{self.truncate_dim}"
return human_readable_name

def _load_dataset(self, dataset_name: str, **ir_evaluator_kwargs) -> InformationRetrievalEvaluator:
def _load_dataset(self, dataset_name: DatasetNameType, **ir_evaluator_kwargs) -> InformationRetrievalEvaluator:
if not is_datasets_available():
raise ValueError("datasets is not available. Please install it to use the NanoBEIREvaluator.")
from datasets import load_dataset

dataset_path = dataset_paths[dataset_name.lower()]
dataset_path = dataset_name_to_id[dataset_name.lower()]
corpus = load_dataset(dataset_path, "corpus", split="train")
queries = load_dataset(dataset_path, "queries", split="train")
qrels = load_dataset(dataset_path, "qrels", split="train")
Expand All @@ -357,38 +379,37 @@ def _load_dataset(self, dataset_name: str, **ir_evaluator_kwargs) -> Information
ir_evaluator_kwargs["query_prompt"] = self.query_prompts.get(dataset_name, None)
if self.corpus_prompts is not None:
ir_evaluator_kwargs["corpus_prompt"] = self.corpus_prompts.get(dataset_name, None)
clean_name = self.__get_clean_name(dataset_name)
human_readable_name = self._get_human_readable_name(dataset_name)
return InformationRetrievalEvaluator(
queries=queries_dict,
corpus=corpus_dict,
relevant_docs=qrels_dict,
name=f"{self.name}{clean_name}",
name=human_readable_name,
**ir_evaluator_kwargs,
)

def _validate_dataset_names(self):
missing_datasets = []
for dataset_name in self.dataset_names:
if dataset_name.lower() not in dataset_paths:
missing_datasets.append(dataset_name)
if missing_datasets:
if missing_datasets := [
dataset_name for dataset_name in self.dataset_names if dataset_name.lower() not in dataset_name_to_id
]:
raise ValueError(
f"Dataset(s) {missing_datasets} not found in NanoBEIR collection."
f"Valid dataset names are: {dataset_paths.keys()}"
f"Dataset(s) {missing_datasets} not found in the NanoBEIR collection."
f"Valid dataset names are: {list(dataset_name_to_id.keys())}"
)

def _validate_prompts(self):
missing_query_prompts = []
missing_corpus_prompts = []
for dataset_name in self.dataset_names:
if self.query_prompts is not None and dataset_name not in self.query_prompts:
missing_query_prompts.append(dataset_name)
if self.corpus_prompts is not None and dataset_name not in self.corpus_prompts:
missing_corpus_prompts.append(dataset_name)
warning_msg = ""
if missing_query_prompts:
warning_msg += f"The following datasets are missing query prompts: {missing_query_prompts}\n"
if missing_corpus_prompts:
warning_msg += f"The following datasets are missing corpus prompts: {missing_corpus_prompts}\n"
if warning_msg:
raise ValueError(warning_msg)
error_msg = ""
if self.query_prompts is not None:
if missing_query_prompts := [
dataset_name for dataset_name in self.dataset_names if dataset_name not in self.query_prompts
]:
error_msg += f"The following datasets are missing query prompts: {missing_query_prompts}\n"

if self.corpus_prompts is not None:
if missing_corpus_prompts := [
dataset_name for dataset_name in self.dataset_names if dataset_name not in self.corpus_prompts
]:
error_msg += f"The following datasets are missing corpus prompts: {missing_corpus_prompts}\n"

if error_msg:
raise ValueError(error_msg.strip())
4 changes: 2 additions & 2 deletions sentence_transformers/evaluation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .LabelAccuracyEvaluator import LabelAccuracyEvaluator
from .MSEEvaluator import MSEEvaluator
from .MSEEvaluatorFromDataFrame import MSEEvaluatorFromDataFrame
from .NanoBEIREvaluator import NanoBeIREvaluator
from .NanoBEIREvaluator import NanoBEIREvaluator
from .ParaphraseMiningEvaluator import ParaphraseMiningEvaluator
from .RerankingEvaluator import RerankingEvaluator
from .SentenceEvaluator import SentenceEvaluator
Expand All @@ -29,5 +29,5 @@
"TranslationEvaluator",
"TripletEvaluator",
"RerankingEvaluator",
"NanoBeIREvaluator",
"NanoBEIREvaluator",
]
2 changes: 1 addition & 1 deletion sentence_transformers/model_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,7 @@ def try_to_pure_python(value: Any) -> Any:
task_name=description,
task_type=description.lower().replace(" ", "-"),
dataset_type=dataset_name or "unknown",
dataset_name=dataset_name.replace("_", " ").replace("-", " ") or "Unknown",
dataset_name=dataset_name.replace("_", " ").replace("-", " ") if dataset_name else "Unknown",
metric_name=metric_key.replace("_", " ").title(),
metric_type=metric_key,
metric_value=metric_value,
Expand Down

0 comments on commit daf25c1

Please sign in to comment.