diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index c2d476782..b38c11d0b 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -614,7 +614,23 @@ def encode( all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)] if precision and precision != "float32": - all_embeddings = quantize_embeddings(all_embeddings, precision=precision) + if output_value: + all_embeddings = quantize_embeddings(all_embeddings, precision=precision) + else: + # output_value=None, means we want to get both token and sentence embeddings. + # The value of all_embeddings is now a list of dictionaries. We temporarily + # build a list of token embeddings and sentence embeddings separately, quantize + # them, and then recombine them into a list of dictionaries. + combined_embeddings = [] + for emb in embeddings: + combined_embeddings.append(emb["token_embeddings"]) + combined_embeddings.append(emb["sentence_embedding"].reshape(1, -1)) + combined_embeddings = quantize_embeddings(combined_embeddings, precision=precision) + + # Reconstruct the list of dictionaries with quantized embeddings + for i, emb in enumerate(all_embeddings): + emb["token_embeddings"] = combined_embeddings[2 * i] + emb["sentence_embedding"] = combined_embeddings[2 * i + 1].reshape(-1) if convert_to_tensor: if len(all_embeddings): diff --git a/sentence_transformers/quantization.py b/sentence_transformers/quantization.py index 37402cae7..342400d42 100644 --- a/sentence_transformers/quantization.py +++ b/sentence_transformers/quantization.py @@ -394,17 +394,23 @@ def quantize_embeddings( Returns: Quantized embeddings with the specified precision """ + outputs, lengths = None, None if isinstance(embeddings, Tensor): embeddings = embeddings.cpu().numpy() + embeddings = np.concatenate(embeddings) elif isinstance(embeddings, list): if isinstance(embeddings[0], Tensor): embeddings = [embedding.cpu().numpy() for embedding in embeddings] + if not isinstance(embeddings[0], list) and len(embeddings[0].shape) == 2: + # It will happen when we request token_embeddings + lengths = [embedding.shape[0] for embedding in embeddings] + embeddings = np.concatenate(embeddings) embeddings = np.array(embeddings) if embeddings.dtype in (np.uint8, np.int8): raise Exception("Embeddings to quantize must be float rather than int8 or uint8.") if precision == "float32": - return embeddings.astype(np.float32) + outputs = embeddings.astype(np.float32) if precision.endswith("int8"): # Either use the 1. provided ranges, 2. the calibration dataset or 3. the provided embeddings @@ -423,14 +429,20 @@ def quantize_embeddings( steps = (ranges[1, :] - ranges[0, :]) / 255 if precision == "uint8": - return ((embeddings - starts) / steps).astype(np.uint8) + outputs = ((embeddings - starts) / steps).astype(np.uint8) elif precision == "int8": - return ((embeddings - starts) / steps - 128).astype(np.int8) + outputs = ((embeddings - starts) / steps - 128).astype(np.int8) if precision == "binary": - return (np.packbits(embeddings > 0).reshape(embeddings.shape[0], -1) - 128).astype(np.int8) + outputs = (np.packbits(embeddings > 0).reshape(embeddings.shape[0], -1) - 128).astype(np.int8) if precision == "ubinary": - return np.packbits(embeddings > 0).reshape(embeddings.shape[0], -1) + outputs = np.packbits(embeddings > 0).reshape(embeddings.shape[0], -1) - raise ValueError(f"Precision {precision} is not supported") + if outputs is None: + raise ValueError(f"Precision {precision} is not supported") + + if lengths is not None: + outputs = np.split(outputs, np.cumsum(lengths)[:-1]) + + return outputs diff --git a/tests/test_compute_embeddings.py b/tests/test_compute_embeddings.py index 5b0bf6aaa..4f2eccac1 100644 --- a/tests/test_compute_embeddings.py +++ b/tests/test_compute_embeddings.py @@ -4,7 +4,10 @@ from __future__ import annotations +from typing import Literal + import numpy as np +import pytest from sentence_transformers import SentenceTransformer @@ -84,3 +87,123 @@ def test_encode_tuple_sentences(paraphrase_distilroberta_base_v1_model: Sentence ) assert emb.shape == (3, 768) assert abs(np.sum(emb) - 32.14627) < 0.002 + + +@pytest.mark.parametrize("precision", ("int8", "uint8")) +def test_encode_sentence_embedding_int_precision( + paraphrase_distilroberta_base_v1_model: SentenceTransformer, + precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] +) -> None: + model = paraphrase_distilroberta_base_v1_model + # Single sentence + emb = model.encode("Hello Word, a test sentence", output_value="sentence_embedding", precision=precision) + assert emb.shape == (768, ) + assert emb.dtype == np.dtype(precision) + + # Single sentence as list + emb = model.encode(["Hello Word, a test sentence"], output_value="sentence_embedding", precision=precision) + assert isinstance(emb, np.ndarray) + assert emb.shape == (1, 768) + assert emb.dtype == np.dtype(precision) + + # Sentence list + emb = model.encode( + [ + "Hello Word, a test sentence", + "Here comes another sentence", + "My final sentence", + ], + output_value="sentence_embedding", + precision=precision, + ) + assert isinstance(emb, np.ndarray) + assert emb.shape == (3, 768) + assert emb.dtype == np.dtype(precision) + + +@pytest.mark.parametrize("precision", ("int8", "uint8")) +def test_encode_token_embeddings_int_precision( + paraphrase_distilroberta_base_v1_model: SentenceTransformer, + precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] +) -> None: + model = paraphrase_distilroberta_base_v1_model + # Single sentence + emb = model.encode("Hello Word, a test sentence", output_value="token_embeddings", precision=precision) + assert emb.shape == (8, 768) + assert emb.dtype == np.dtype(precision) + + # Single sentence as list + emb = model.encode(["Hello Word, a test sentence"], output_value="token_embeddings", precision=precision) + assert isinstance(emb, list) + assert emb[0].shape == (8, 768) + assert emb[0].dtype == np.dtype(precision) + + # Sentence list + emb = model.encode( + [ + "Hello Word, a test sentence", + "Here comes another sentence", + "My final sentence", + ], + output_value="token_embeddings", + precision=precision, + ) + assert isinstance(emb, list) + assert emb[0].shape == (8, 768) + assert emb[0].dtype == np.dtype(precision) + assert emb[1].shape == (6, 768) + assert emb[1].dtype == np.dtype(precision) + assert emb[2].shape == (5, 768) + assert emb[2].dtype == np.dtype(precision) + + +@pytest.mark.parametrize("precision", ("int8", "uint8")) +def test_encode_output_value_none_int_precision( + paraphrase_distilroberta_base_v1_model: SentenceTransformer, + precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] +) -> None: + model = paraphrase_distilroberta_base_v1_model + # Single sentence + emb = model.encode("Hello Word, a test sentence", output_value=None, precision=precision) + assert isinstance(emb, dict) + assert emb["sentence_embedding"].shape == (768,) + assert emb["sentence_embedding"].dtype == np.dtype(precision) + assert emb["token_embeddings"].shape == (8, 768) + assert emb["token_embeddings"].dtype == np.dtype(precision) + + # Single sentence as list + emb = model.encode(["Hello Word, a test sentence"], output_value=None, precision=precision) + assert isinstance(emb, list) + assert isinstance(emb[0], dict) + assert emb[0]["sentence_embedding"].shape == (768,) + assert emb[0]["sentence_embedding"].dtype == np.dtype(precision) + assert emb[0]["token_embeddings"].shape == (8, 768) + assert emb[0]["token_embeddings"].dtype == np.dtype(precision) + + # Sentence list + emb = model.encode( + [ + "Hello Word, a test sentence", + "Here comes another sentence", + "My final sentence", + ], + output_value=None, + precision=precision, + ) + assert isinstance(emb, list) + assert all(isinstance(e, dict) for e in emb) + + assert emb[0]["sentence_embedding"].shape == (768,) + assert emb[0]["sentence_embedding"].dtype == np.dtype(precision) + assert emb[0]["token_embeddings"].shape == (8, 768) + assert emb[0]["token_embeddings"].dtype == np.dtype(precision) + + assert emb[1]["sentence_embedding"].shape == (768,) + assert emb[1]["sentence_embedding"].dtype == np.dtype(precision) + assert emb[1]["token_embeddings"].shape == (8, 768) + assert emb[1]["token_embeddings"].dtype == np.dtype(precision) + + assert emb[2]["sentence_embedding"].shape == (768,) + assert emb[2]["sentence_embedding"].dtype == np.dtype(precision) + assert emb[2]["token_embeddings"].shape == (8, 768) + assert emb[2]["token_embeddings"].dtype == np.dtype(precision) \ No newline at end of file