Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] MultipleNegativesBatchSampler #2960

Draft
wants to merge 17 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ def encode(

return all_embeddings

def forward(self, input: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]:
def forward(self, input: dict[str, Tensor], **kwargs) -> dict[str, Tensor]:
if self.module_kwargs is None:
return super().forward(input)

Expand Down Expand Up @@ -1023,7 +1023,7 @@ def tokenize(self, texts: list[str] | list[dict] | list[tuple[str, str]]) -> dic
"""
return self._first_module().tokenize(texts)

def get_sentence_features(self, *features) -> dict[Literal["sentence_embedding"], torch.Tensor]:
def get_sentence_features(self, *features) -> dict[Literal["sentence_embedding"], Tensor]:
return self._first_module().get_sentence_features(*features)

def get_sentence_embedding_dimension(self) -> int | None:
Expand Down
17 changes: 17 additions & 0 deletions sentence_transformers/evaluation/InformationRetrievalEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ def __init__(
SimilarityFunction.DOT_PRODUCT.value: dot_score,
}, # Score function, higher=more similar
main_score_function: str | SimilarityFunction | None = None,
query_prompt: str | None = None,
query_prompt_name: str | None = None,
corpus_prompt: str | None = None,
corpus_prompt_name: str | None = None,
) -> None:
"""
Initializes the InformationRetrievalEvaluator.
Expand All @@ -154,6 +158,10 @@ def __init__(
truncate_dim (int, optional): The dimension to truncate the embeddings to. Defaults to None.
score_functions (Dict[str, Callable[[Tensor, Tensor], Tensor]]): A dictionary mapping score function names to score functions. Defaults to {SimilarityFunction.COSINE.value: cos_sim, SimilarityFunction.DOT_PRODUCT.value: dot_score}.
main_score_function (Union[str, SimilarityFunction], optional): The main score function to use for evaluation. Defaults to None.
query_prompt (str, optional): A prompt to use for the queries. Defaults to None.
query_prompt_name (str, optional): A name for the query prompt. Defaults to None.
corpus_prompt (str, optional): A prompt to use for the corpus. Defaults to None.
corpus_prompt_name (str, optional): A name for the corpus prompt. Defaults to None.
"""
super().__init__()
self.queries_ids = []
Expand All @@ -166,6 +174,11 @@ def __init__(
self.corpus_ids = list(corpus.keys())
self.corpus = [corpus[cid] for cid in self.corpus_ids]

self.query_prompt = query_prompt
self.query_prompt_name = query_prompt_name
self.corpus_prompt = corpus_prompt
self.corpus_prompt_name = corpus_prompt_name

self.relevant_docs = relevant_docs
self.corpus_chunk_size = corpus_chunk_size
self.mrr_at_k = mrr_at_k
Expand Down Expand Up @@ -294,6 +307,8 @@ def compute_metrices(
with nullcontext() if self.truncate_dim is None else model.truncate_sentence_embeddings(self.truncate_dim):
query_embeddings = model.encode(
self.queries,
prompt_name=self.query_prompt_name,
prompt=self.query_prompt,
show_progress_bar=self.show_progress_bar,
batch_size=self.batch_size,
convert_to_tensor=True,
Expand All @@ -316,6 +331,8 @@ def compute_metrices(
):
sub_corpus_embeddings = corpus_model.encode(
self.corpus[corpus_start_idx:corpus_end_idx],
prompt_name=self.corpus_prompt_name,
prompt=self.corpus_prompt,
show_progress_bar=False,
batch_size=self.batch_size,
convert_to_tensor=True,
Expand Down
2 changes: 1 addition & 1 deletion sentence_transformers/models/Pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
pooling_mode_mean_sqrt_len_tokens: bool = False,
pooling_mode_weightedmean_tokens: bool = False,
pooling_mode_lasttoken: bool = False,
include_prompt=True,
include_prompt: bool = True,
) -> None:
super().__init__()

Expand Down
92 changes: 92 additions & 0 deletions sentence_transformers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,98 @@ def __len__(self) -> int:
return (len(self.dataset) + self.batch_size - 1) // self.batch_size


class MultipleNegativesBatchSampler(SetEpochMixin, BatchSampler):
def __init__(
self,
dataset: Dataset,
batch_size: int,
drop_last: bool,
valid_label_columns: list[str] = [],
generator: torch.Generator = None,
seed: int = 0,
) -> None:
"""
This sampler creates batches such that each batch contains samples where the negatives are not present
in any of the positives already sampled in the batch. This is useful when using a loss with in-batch
negatives as it will avoid that a positive also appears as a negative for the same anchor.
Using this sampler also avoids that the positives become duplicated
the batch, as its hard negatives are part of the same sample.

Recommended for:
- :class:`~sentence_transformers.losses.MultipleNegativesRankingLoss`
- :class:`~sentence_transformers.losses.CachedMultipleNegativesRankingLoss`
- :class:`~sentence_transformers.losses.MegaBatchMarginLoss`
- :class:`~sentence_transformers.losses.GISTEmbedLoss`
- :class:`~sentence_transformers.losses.CachedGISTEmbedLoss`

Args:
dataset (Dataset): The dataset to sample from.
batch_size (int): Number of samples per batch.
drop_last (bool): If True, drop the last incomplete batch if the dataset size
is not divisible by the batch size.
valid_label_columns (List[str]): List of column names to check for labels.
The first column name from ``valid_label_columns`` found in the dataset will
be used as the label column.
generator (torch.Generator, optional): Optional random number generator for shuffling
the indices.
seed (int, optional): Seed for the random number generator to ensure reproducibility.
"""
super().__init__(dataset, batch_size, drop_last)
if label_columns := set(dataset.column_names) & (set(valid_label_columns) | {"dataset_name"}):
dataset = dataset.remove_columns(label_columns)
self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last
self.generator = generator
self.seed = seed

def __iter__(self) -> Iterator[list[int]]:
"""
Iterate over the remaining non-yielded indices. For each index, check if the sample values are already in the
batch. If not, add the sample values to the batch keep going until the batch is full. If the batch is full, yield
the batch indices and continue with the next batch.
"""
if self.generator and self.seed:
self.generator.manual_seed(self.seed + self.epoch)
anchor_column = self.dataset.column_names[0]
positive_column = self.dataset.column_names[1]
negative_columns = [self.dataset.column_names[i] for i in range(2, len(self.dataset.column_names))]

remaining_indices = set(torch.randperm(len(self.dataset), generator=self.generator).tolist())

while remaining_indices:
batch_values = set()
batch_indices = []
for index in remaining_indices:
sample = self.dataset[index]
# Make sure that either the positive or the negatives ARE NOT in the seen positives or queries
if negative_columns:
negatives = set([sample[negative_column] for negative_column in negative_columns])
if negatives & batch_values:
continue
elif sample[positive_column] in batch_values:
continue
batch_indices.append(index)
if len(batch_indices) == self.batch_size:
yield batch_indices
break

batch_values.add(sample[anchor_column])
batch_values.add(sample[positive_column])
else:
# NOTE: some indices might still have been ignored here
if not self.drop_last:
yield batch_indices

remaining_indices -= set(batch_indices)

def __len__(self) -> int:
if self.drop_last:
return len(self.dataset) // self.batch_size
else:
return (len(self.dataset) + self.batch_size - 1) // self.batch_size


class RoundRobinBatchSampler(SetEpochMixin, BatchSampler):
"""
Batch sampler that yields batches in a round-robin fashion from multiple batch samplers, until one is exhausted.
Expand Down
9 changes: 9 additions & 0 deletions sentence_transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from sentence_transformers.sampler import (
DefaultBatchSampler,
GroupByLabelBatchSampler,
MultipleNegativesBatchSampler,
NoDuplicatesBatchSampler,
ProportionalBatchSampler,
RoundRobinBatchSampler,
Expand Down Expand Up @@ -526,6 +527,14 @@ def get_batch_sampler(
batch_size=batch_size,
drop_last=drop_last,
)
if self.args.batch_sampler == BatchSamplers.MULTIPLE_NEGATIVES:
return MultipleNegativesBatchSampler(
dataset=dataset,
batch_size=batch_size,
drop_last=drop_last,
valid_label_columns=valid_label_columns,
generator=generator,
)

def get_multi_dataset_batch_sampler(
self,
Expand Down
1 change: 1 addition & 0 deletions sentence_transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class BatchSamplers(ExplicitEnum):
BATCH_SAMPLER = "batch_sampler"
NO_DUPLICATES = "no_duplicates"
GROUP_BY_LABEL = "group_by_label"
MULTIPLE_NEGATIVES = "multiple_negatives"


class MultiDatasetBatchSamplers(ExplicitEnum):
Expand Down