Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Jina Embeddings #61

Closed
wants to merge 13 commits into from
Closed

Conversation

JohannesMessner
Copy link

@JohannesMessner JohannesMessner commented Nov 7, 2023

Hi Qdrant team, as promised, here comes the Jina x Fastembed integration!

The main differences to the existing FlagEmbedding class are:

  • The model files are downloaded directly form HuggingFace and leverages the HF local cache
  • Mean pooling is applied

TODO:

  • add model info
  • documentation (actually I think there is nothing to do here)
  • tests

@JohannesMessner JohannesMessner marked this pull request as ready for review November 7, 2023 10:10
@azayarni azayarni requested a review from NirantK November 7, 2023 10:28
Copy link
Contributor

@NirantK NirantK left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, thanks a ton for doing this!

  1. Please avoid adding pooling to other embedding implementations like FlagEmbedding, and add to JinaEmbedding. FlagEmbedding is a specific Embedding model, much like JINA

  2. The intent behind FastEmbed is to continue to be light, and an important part of that is to keep the dependency graph small. Do not add a Huggingface Hub dependency please?

The comments below are specifics of the above 2 ideas, and should automatically get resolved if you reason through the code base as shared above.

Please note that the code breaks right now since Huggingface hub dependency is not poetry.toml.

fastembed/embedding.py Outdated Show resolved Hide resolved
fastembed/embedding.py Outdated Show resolved Hide resolved
fastembed/embedding.py Outdated Show resolved Hide resolved
fastembed/embedding.py Outdated Show resolved Hide resolved
tests/test_onnx_embeddings.py Outdated Show resolved Hide resolved
tests/test_onnx_embeddings.py Outdated Show resolved Hide resolved
fastembed/embedding.py Outdated Show resolved Hide resolved
@JohannesMessner
Copy link
Author

Hey, thanks a ton for doing this!

1. Please avoid adding pooling to other embedding implementations like `FlagEmbedding`,  and add to `JinaEmbedding`. `FlagEmbedding` is a specific Embedding model, much like JINA

2. The intent behind FastEmbed is to continue to be light, and an important part of that is to keep the dependency graph small. Do not add a Huggingface Hub dependency please?

The comments below are specifics of the above 2 ideas, and should automatically get resolved if you reason through the code base as shared above.

Please note that the code breaks right now since Huggingface hub dependency is not poetry.toml.

Thanks for your very fast review!

  • The way I am doing it right now, FlagEmbedding still doesn't use any pooling, I just chose this code structure for code sharing. But I can refactor it in a way where this part is completely removed from the FlagEmbedding class, if you prefer.

  • I saw the huggingface hub in the lock file so I thought it was a transitive dependency of one of the specified ones. I'll remove it

@JohannesMessner
Copy link
Author

I refactored the code to come closer to your suggestions, unfortunately this results in some code duplication between the FlagEmbedding and JinaEmbedding classes, let me know if it is ok like this!

@@ -464,7 +527,8 @@ def embed(

if parallel is None or is_small:
for batch in iter_batch(documents, batch_size):
yield from self.model.onnx_embed(batch)
embeddings, _ = self.model.onnx_embed(batch)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I'm confused. I believe FlagEmbedding should be left untouched since all the changes are in the parent class and JinaAI Embedding class, right?

Similarly, the list_supported_models rewrite isn't needed and should be removed from all implementations now?

Copy link
Author

@JohannesMessner JohannesMessner Nov 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FlagEmbedding cannot be left entirely untouched unfortunately, unless I am missing something.

Before this PR, the EmbeddingModel.onnx_embed() method picks out the first token as form of pooling, and then applies normalization. Baked in with this is the assumption that all subclasses of Embedding (that hold an EmbeddingModel instance) intend for that behaviour. That assumption is broken by Jina embeddings, which requires mean pooling before the normalization.
And mean pooling cannot be applied after this, since the existing implementation of EmbeddingModel.onnx_embed() "throws away" the tokens needed for that.

Therefore, the implementation of EmbeddingModel.onnx_embed() needs two small modifications:

  1. It delegates pooling and normalization to the subclasses of Embedding
  2. It returns the tokenizer's attentions mask. Otherwise, without access to the attention mask, pooling schemes such as mean pooling cannot be implemented on the Embedding level.

This requires FlagEmbedding to adjust to those changes.
Just like JinaEmbedding, it now implements its own pooling scheme (just picking out the first token). The attention mask is not required for this, so it can be ignored when returned by EmbeddingModel.onnx_embed().

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As for the list_supported_models() rewrite, yes, I can remove that. But then JinaEmbedding.list_supported_models() would return a bunch of models that are actually not supported by the JinaEmbedding class.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point. Looks like we've to figure out a way to handle normalize, attention and pooling steps separately for each embedding implementation. At the moment, what you've proposed kinda works.

Let me think about this + test your PR and then we're good to go and merge this.

@NirantK
Copy link
Contributor

NirantK commented Nov 20, 2023

Two requests:

Error on the pooling, this is on the mean pooling implementation. If there is a cannon implementation from Torch or Jina itself, let's re-use that here?

This is from the pytest which I ran locally with Python 3.9.17 on M2 in a fresh poetry install:

fastembed/embedding.py:648: in embed
    yield from normalize(self.mean_pooling(embeddings, attn_mask)).astype(np.float32)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

model_output = array([[-0.4756149 , -0.44713712, -0.12763295, ...,  0.6697987 ,
         0.30507904,  0.28676268],
       [-0.46482596, -0.20412004, -0.27510062, ...,  0.48833442,
         0.19993246, -0.01699639]], dtype=float32)
attention_mask = array([[1, 1, 1, 1, 0, 0],
       [1, 1, 1, 1, 1, 1]])

    @staticmethod
    def mean_pooling(model_output, attention_mask):
        token_embeddings = model_output
        input_mask_expanded = (np.expand_dims(attention_mask, axis=-1)).astype(float)
    
>       sum_embeddings = np.sum(token_embeddings * input_mask_expanded, axis=1)
E       ValueError: operands could not be broadcast together with shapes (2,512) (2,6,1)

I'll also figure out how you can get these tests on Github itself instead of needing to run them locally. That might help with faster dev for you

Once this is done, let's resolve merge conflicts?

@NirantK
Copy link
Contributor

NirantK commented Nov 21, 2023

Added via another PR!

@NirantK NirantK closed this Nov 21, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants