Skip to content

Commit

Permalink
[feat] Add query prompts to Information Retrieval Evaluator (#2951)
Browse files Browse the repository at this point in the history
* Added the possibility of masking the prompts if the tokenizer is left-padded.

* Simplify code

* Remove unrelated changes

* Move prompt_mask into the Transformer model

* Added query and corpus prompts to Information Retrieval Evaluator

* Fix for failing test

* Fix for pooling when mask is not passed

* Fix device placement for prompt_mask

* Revert left-padding changes

* Revert left-padding changes
  • Loading branch information
ArthurCamara authored and tomaarsen committed Sep 30, 2024
1 parent 71be4e3 commit 6a3750e
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 3 deletions.
4 changes: 2 additions & 2 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ def encode(

return all_embeddings

def forward(self, input: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]:
def forward(self, input: dict[str, Tensor], **kwargs) -> dict[str, Tensor]:
if self.module_kwargs is None:
return super().forward(input)

Expand Down Expand Up @@ -1023,7 +1023,7 @@ def tokenize(self, texts: list[str] | list[dict] | list[tuple[str, str]]) -> dic
"""
return self._first_module().tokenize(texts)

def get_sentence_features(self, *features) -> dict[Literal["sentence_embedding"], torch.Tensor]:
def get_sentence_features(self, *features) -> dict[Literal["sentence_embedding"], Tensor]:
return self._first_module().get_sentence_features(*features)

def get_sentence_embedding_dimension(self) -> int | None:
Expand Down
17 changes: 17 additions & 0 deletions sentence_transformers/evaluation/InformationRetrievalEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ def __init__(
SimilarityFunction.DOT_PRODUCT.value: dot_score,
}, # Score function, higher=more similar
main_score_function: str | SimilarityFunction | None = None,
query_prompt: str | None = None,
query_prompt_name: str | None = None,
corpus_prompt: str | None = None,
corpus_prompt_name: str | None = None,
) -> None:
"""
Initializes the InformationRetrievalEvaluator.
Expand All @@ -154,6 +158,10 @@ def __init__(
truncate_dim (int, optional): The dimension to truncate the embeddings to. Defaults to None.
score_functions (Dict[str, Callable[[Tensor, Tensor], Tensor]]): A dictionary mapping score function names to score functions. Defaults to {SimilarityFunction.COSINE.value: cos_sim, SimilarityFunction.DOT_PRODUCT.value: dot_score}.
main_score_function (Union[str, SimilarityFunction], optional): The main score function to use for evaluation. Defaults to None.
query_prompt (str, optional): A prompt to use for the queries. Defaults to None.
query_prompt_name (str, optional): A name for the query prompt. Defaults to None.
corpus_prompt (str, optional): A prompt to use for the corpus. Defaults to None.
corpus_prompt_name (str, optional): A name for the corpus prompt. Defaults to None.
"""
super().__init__()
self.queries_ids = []
Expand All @@ -166,6 +174,11 @@ def __init__(
self.corpus_ids = list(corpus.keys())
self.corpus = [corpus[cid] for cid in self.corpus_ids]

self.query_prompt = query_prompt
self.query_prompt_name = query_prompt_name
self.corpus_prompt = corpus_prompt
self.corpus_prompt_name = corpus_prompt_name

self.relevant_docs = relevant_docs
self.corpus_chunk_size = corpus_chunk_size
self.mrr_at_k = mrr_at_k
Expand Down Expand Up @@ -294,6 +307,8 @@ def compute_metrices(
with nullcontext() if self.truncate_dim is None else model.truncate_sentence_embeddings(self.truncate_dim):
query_embeddings = model.encode(
self.queries,
prompt_name=self.query_prompt_name,
prompt=self.query_prompt,
show_progress_bar=self.show_progress_bar,
batch_size=self.batch_size,
convert_to_tensor=True,
Expand All @@ -316,6 +331,8 @@ def compute_metrices(
):
sub_corpus_embeddings = corpus_model.encode(
self.corpus[corpus_start_idx:corpus_end_idx],
prompt_name=self.corpus_prompt_name,
prompt=self.corpus_prompt,
show_progress_bar=False,
batch_size=self.batch_size,
convert_to_tensor=True,
Expand Down
2 changes: 1 addition & 1 deletion sentence_transformers/models/Pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
pooling_mode_mean_sqrt_len_tokens: bool = False,
pooling_mode_weightedmean_tokens: bool = False,
pooling_mode_lasttoken: bool = False,
include_prompt=True,
include_prompt: bool = True,
) -> None:
super().__init__()

Expand Down

0 comments on commit 6a3750e

Please sign in to comment.