-
Notifications
You must be signed in to change notification settings - Fork 118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Jina Embeddings #61
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, thanks a ton for doing this!
-
Please avoid adding pooling to other embedding implementations like
FlagEmbedding
, and add toJinaEmbedding
.FlagEmbedding
is a specific Embedding model, much like JINA -
The intent behind FastEmbed is to continue to be light, and an important part of that is to keep the dependency graph small. Do not add a Huggingface Hub dependency please?
The comments below are specifics of the above 2 ideas, and should automatically get resolved if you reason through the code base as shared above.
Please note that the code breaks right now since Huggingface hub dependency is not poetry.toml.
Thanks for your very fast review!
|
I refactored the code to come closer to your suggestions, unfortunately this results in some code duplication between the |
@@ -464,7 +527,8 @@ def embed( | |||
|
|||
if parallel is None or is_small: | |||
for batch in iter_batch(documents, batch_size): | |||
yield from self.model.onnx_embed(batch) | |||
embeddings, _ = self.model.onnx_embed(batch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I'm confused. I believe FlagEmbedding
should be left untouched since all the changes are in the parent class and JinaAI Embedding class, right?
Similarly, the list_supported_models rewrite isn't needed and should be removed from all implementations now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FlagEmbedding
cannot be left entirely untouched unfortunately, unless I am missing something.
Before this PR, the EmbeddingModel.onnx_embed()
method picks out the first token as form of pooling, and then applies normalization. Baked in with this is the assumption that all subclasses of Embedding
(that hold an EmbeddingModel
instance) intend for that behaviour. That assumption is broken by Jina embeddings, which requires mean pooling before the normalization.
And mean pooling cannot be applied after this, since the existing implementation of EmbeddingModel.onnx_embed()
"throws away" the tokens needed for that.
Therefore, the implementation of EmbeddingModel.onnx_embed()
needs two small modifications:
- It delegates pooling and normalization to the subclasses of
Embedding
- It returns the tokenizer's attentions mask. Otherwise, without access to the attention mask, pooling schemes such as mean pooling cannot be implemented on the
Embedding
level.
This requires FlagEmbedding
to adjust to those changes.
Just like JinaEmbedding
, it now implements its own pooling scheme (just picking out the first token). The attention mask is not required for this, so it can be ignored when returned by EmbeddingModel.onnx_embed()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As for the list_supported_models()
rewrite, yes, I can remove that. But then JinaEmbedding.list_supported_models()
would return a bunch of models that are actually not supported by the JinaEmbedding
class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see your point. Looks like we've to figure out a way to handle normalize, attention and pooling steps separately for each embedding implementation. At the moment, what you've proposed kinda works.
Let me think about this + test your PR and then we're good to go and merge this.
Two requests: Error on the pooling, this is on the mean pooling implementation. If there is a cannon implementation from Torch or Jina itself, let's re-use that here? This is from the pytest which I ran locally with Python 3.9.17 on M2 in a fresh poetry install: fastembed/embedding.py:648: in embed
yield from normalize(self.mean_pooling(embeddings, attn_mask)).astype(np.float32)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
model_output = array([[-0.4756149 , -0.44713712, -0.12763295, ..., 0.6697987 ,
0.30507904, 0.28676268],
[-0.46482596, -0.20412004, -0.27510062, ..., 0.48833442,
0.19993246, -0.01699639]], dtype=float32)
attention_mask = array([[1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 1]])
@staticmethod
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output
input_mask_expanded = (np.expand_dims(attention_mask, axis=-1)).astype(float)
> sum_embeddings = np.sum(token_embeddings * input_mask_expanded, axis=1)
E ValueError: operands could not be broadcast together with shapes (2,512) (2,6,1) I'll also figure out how you can get these tests on Github itself instead of needing to run them locally. That might help with faster dev for you Once this is done, let's resolve merge conflicts? |
Added via another PR! |
Hi Qdrant team, as promised, here comes the Jina x Fastembed integration!
The main differences to the existing FlagEmbedding class are:
TODO: