Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Jina Embeddings #61

Closed
wants to merge 13 commits into from
212 changes: 191 additions & 21 deletions fastembed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,15 @@ def load_tokenizer(cls, model_dir: Path, max_length: int = 512) -> Tokenizer:
return tokenizer

def __init__(
self,
path: Path,
model_name: str,
max_length: int = 512,
max_threads: int = None,
self,
path: Path,
model_name: str,
max_length: int = 512,
max_threads: int = None,
):
self.path = path
self.model_name = model_name

model_path = self.path / "model.onnx"
optimized_model_path = self.path / "model_optimized.onnx"

Expand Down Expand Up @@ -110,7 +111,7 @@ def __init__(
self.tokenizer = self.load_tokenizer(self.path, max_length=max_length)
self.model = ort.InferenceSession(str(model_path), providers=onnx_providers, sess_options=so)

def onnx_embed(self, documents: List[str]) -> np.ndarray:
def onnx_embed(self, documents: List[str]) -> Tuple[np.ndarray, np.ndarray]:
encoded = self.tokenizer.encode_batch(documents)
input_ids = np.array([e.ids for e in encoded])
attention_mask = np.array([e.attention_mask for e in encoded])
Expand All @@ -126,22 +127,29 @@ def onnx_embed(self, documents: List[str]) -> np.ndarray:
)

model_output = self.model.run(None, onnx_input)
last_hidden_state = model_output[0][:, 0]
embeddings = normalize(last_hidden_state).astype(np.float32)
return embeddings
embeddings = model_output[0]
return embeddings, attention_mask


class EmbeddingWorker(Worker):
def __init__(
self,
path: Path,
model_name: str,
max_length: int = 512,
self,
path: Path,
model_name: str,
max_length: int = 512,
):
self.model = EmbeddingModel(path=path, model_name=model_name, max_length=max_length, max_threads=1)
self.model = EmbeddingModel(
path=path, model_name=model_name, max_length=max_length, max_threads=1,
)

@classmethod
def start(cls, path: Path, model_name: str, max_length: int = 512, **kwargs: Any) -> "EmbeddingWorker":
def start(
cls,
path: Path,
model_name: str,
max_length: int = 512,
**kwargs: Any,
) -> "EmbeddingWorker":
return cls(
path=path,
model_name=model_name,
Expand All @@ -150,8 +158,8 @@ def start(cls, path: Path, model_name: str, max_length: int = 512, **kwargs: Any

def process(self, items: Iterable[Tuple[int, Any]]) -> Iterable[Tuple[int, Any]]:
for idx, batch in items:
embeddings = self.model.onnx_embed(batch)
yield idx, embeddings
embeddings, attn_mask = self.model.onnx_embed(batch)
yield idx, (embeddings, attn_mask)


class Embedding(ABC):
Expand Down Expand Up @@ -226,6 +234,18 @@ def list_supported_models(cls) -> List[Dict[str, Union[str, Union[int, float]]]]
"description": "Multilingual model, e5-large. Recommend using this model for non-English languages",
"size_in_GB": 2.24
},
{
"model": "jinaai/jina-embeddings-v2-base-en",
"dim": 768,
"description": " English embedding model supporting 8192 sequence length",
"size_in_GB": 0.55
},
{
"model": "jinaai/jina-embeddings-v2-small-en",
"dim": 512,
"description": " English embedding model supporting 8192 sequence length",
"size_in_GB": 0.13
}
]

@classmethod
Expand Down Expand Up @@ -282,6 +302,27 @@ def download_file_from_gcs(cls, url: str, output_path: str, show_progress: bool
progress_bar.close()
return output_path

@classmethod
def download_files_from_huggingface(cls, repod_id: str, cache_dir: Optional[str] = None) -> str:
"""
Downloads a model from HuggingFace Hub.

Args:
repod_id (str): The HF hub id (name) of the model to retrieve.
cache_dir (Optional[str]): The path to the cache directory.

Raises:
ValueError: If the model_name is not in the format <org>/<model> e.g. "jinaai/jina-embeddings-v2-small-en".

Returns:
Path: The path to the model directory.
"""
from huggingface_hub import snapshot_download

return snapshot_download(
repo_id=repod_id, ignore_patterns=["model.safetensors", "pytorch_model.bin"], cache_dir=cache_dir
)

@classmethod
def decompress_to_cache(cls, targz_path: str, cache_dir: str):
"""
Expand Down Expand Up @@ -317,7 +358,7 @@ def decompress_to_cache(cls, targz_path: str, cache_dir: str):

return cache_dir

def retrieve_model(self, model_name: str, cache_dir: str) -> Path:
def retrieve_model_gcs(self, model_name: str, cache_dir: str) -> Path:
"""
Retrieves a model from Google Cloud Storage.

Expand Down Expand Up @@ -361,6 +402,27 @@ def retrieve_model(self, model_name: str, cache_dir: str) -> Path:

return model_dir

def retrieve_model_hf(self, model_name: str, cache_dir: str) -> Path:
"""
Retrieves a model from HuggingFace Hub.

Args:
model_name (str): The name of the model to retrieve.
cache_dir (str): The path to the cache directory.

Raises:
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.

Returns:
Path: The path to the model directory.
"""

assert (
"/" in model_name
), "model_name must be in the format <org>/<model> e.g. jinaai/jina-embeddings-v2-small-en"

return Path(self.download_files_from_huggingface(repod_id=model_name, cache_dir=cache_dir))

def passage_embed(self, texts: Iterable[str], **kwargs) -> Iterable[np.ndarray]:
"""
Embeds a list of text passages into a list of embeddings.
Expand Down Expand Up @@ -401,6 +463,7 @@ class FlagEmbedding(Embedding):
Embedding (_type_): _description_
"""


def __init__(
self,
model_name: str = "BAAI/bge-small-en-v1.5",
Expand All @@ -425,7 +488,7 @@ def __init__(
cache_dir.mkdir(parents=True, exist_ok=True)

self._cache_dir = cache_dir
self._model_dir = self.retrieve_model(model_name, cache_dir)
self._model_dir = self.retrieve_model_gcs(model_name, cache_dir)
self._max_length = max_length

self.model = EmbeddingModel(self._model_dir, self.model_name, max_length=max_length,
Expand Down Expand Up @@ -464,7 +527,8 @@ def embed(

if parallel is None or is_small:
for batch in iter_batch(documents, batch_size):
yield from self.model.onnx_embed(batch)
embeddings, _ = self.model.onnx_embed(batch)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I'm confused. I believe FlagEmbedding should be left untouched since all the changes are in the parent class and JinaAI Embedding class, right?

Similarly, the list_supported_models rewrite isn't needed and should be removed from all implementations now?

Copy link
Author

@JohannesMessner JohannesMessner Nov 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FlagEmbedding cannot be left entirely untouched unfortunately, unless I am missing something.

Before this PR, the EmbeddingModel.onnx_embed() method picks out the first token as form of pooling, and then applies normalization. Baked in with this is the assumption that all subclasses of Embedding (that hold an EmbeddingModel instance) intend for that behaviour. That assumption is broken by Jina embeddings, which requires mean pooling before the normalization.
And mean pooling cannot be applied after this, since the existing implementation of EmbeddingModel.onnx_embed() "throws away" the tokens needed for that.

Therefore, the implementation of EmbeddingModel.onnx_embed() needs two small modifications:

  1. It delegates pooling and normalization to the subclasses of Embedding
  2. It returns the tokenizer's attentions mask. Otherwise, without access to the attention mask, pooling schemes such as mean pooling cannot be implemented on the Embedding level.

This requires FlagEmbedding to adjust to those changes.
Just like JinaEmbedding, it now implements its own pooling scheme (just picking out the first token). The attention mask is not required for this, so it can be ignored when returned by EmbeddingModel.onnx_embed().

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As for the list_supported_models() rewrite, yes, I can remove that. But then JinaEmbedding.list_supported_models() would return a bunch of models that are actually not supported by the JinaEmbedding class.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point. Looks like we've to figure out a way to handle normalize, attention and pooling steps separately for each embedding implementation. At the moment, what you've proposed kinda works.

Let me think about this + test your PR and then we're good to go and merge this.

yield from normalize(embeddings[:, 0]).astype(np.float32)
else:
start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn"
params = {
Expand All @@ -474,7 +538,16 @@ def embed(
}
pool = ParallelWorkerPool(parallel, EmbeddingWorker, start_method=start_method)
for batch in pool.ordered_map(iter_batch(documents, batch_size), **params):
yield from batch
embeddings, _ = batch
yield from normalize(embeddings[:, 0]).astype(np.float32)

@classmethod
def list_supported_models(cls) -> List[Dict[str, Union[str, Union[int, float]]]]:
"""
Lists the supported models.
"""
# jina models are not supported by this class
return [model for model in super().list_supported_models() if not model['model'].startswith('jinaai')]


class DefaultEmbedding(FlagEmbedding):
Expand Down Expand Up @@ -505,3 +578,100 @@ def embed(self, texts, batch_size: int = 256, parallel: int = None):
# Use your OpenAI model to embed the texts
# return self.model.embed(texts)
raise NotImplementedError


class JinaEmbedding(Embedding):
def __init__(
self,
model_name: str = "jinaai/jina-embeddings-v2-base-en",
max_length: int = 512,
cache_dir: str = None,
threads: int = None,
):
"""
Args:
model_name (str): The name of the model to use.
max_length (int, optional): The maximum number of tokens. Defaults to 512. Unknown behavior for values > 512.
cache_dir (str, optional): The path to the cache directory. Defaults to `local_cache` in the current directory.
threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.

Raises:
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
"""
self.model_name = model_name

if cache_dir is None:
cache_dir = Path(".").resolve() / "local_cache"
cache_dir.mkdir(parents=True, exist_ok=True)

self._cache_dir = cache_dir
self._model_dir = self.retrieve_model_hf(model_name, cache_dir)
self._max_length = max_length

self.model = EmbeddingModel(self._model_dir, self.model_name, max_length=max_length,
max_threads=threads)

def embed(
self, documents: Union[str, Iterable[str]], batch_size: int = 256, parallel: int = None
) -> Iterable[np.ndarray]:
"""
Encode a list of documents into list of embeddings.
We use mean pooling with attention so that the model can handle variable-length inputs.

Args:
documents: Iterator of documents or single document to embed
batch_size: Batch size for encoding -- higher values will use more memory, but be faster
parallel:
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
If 0, use all available cores.
If None, don't use data-parallel processing, use default onnxruntime threading instead.

Returns:
List of embeddings, one per document
"""
is_small = False

if isinstance(documents, str):
documents = [documents]
is_small = True

if isinstance(documents, list):
if len(documents) < batch_size:
is_small = True

if parallel == 0:
parallel = os.cpu_count()

if parallel is None or is_small:
for batch in iter_batch(documents, batch_size):
embeddings, attn_mask = self.model.onnx_embed(batch)
yield from normalize(self.mean_pooling(embeddings, attn_mask)).astype(np.float32)
else:
start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn"
params = {
"path": self._model_dir,
"model_name": self.model_name,
"max_length": self._max_length,
}
pool = ParallelWorkerPool(parallel, EmbeddingWorker, start_method=start_method)
for batch in pool.ordered_map(iter_batch(documents, batch_size), **params):
embeddings, attn_mask = batch
yield from normalize(self.mean_pooling(embeddings, attn_mask)).astype(np.float32)

@classmethod
def list_supported_models(cls) -> List[Dict[str, Union[str, Union[int, float]]]]:
"""
Lists the supported models.
"""
# only jina models are supported by this class
return [model for model in Embedding.list_supported_models() if model['model'].startswith('jinaai')]

@staticmethod
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output
input_mask_expanded = (np.expand_dims(attention_mask, axis=-1)).astype(float)

sum_embeddings = np.sum(token_embeddings * input_mask_expanded, axis=1)
mask_sum = np.clip(np.sum(input_mask_expanded, axis=1), a_min=1e-9, a_max=None)

return sum_embeddings / mask_sum
Loading
Loading