Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Text Embeddings Inference (TEI). #82

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions doc/use_tei_for_embeddings.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Use Text Embeddings Inference (TEI) support

TEI is an optimized tooklit for deploying and serving text embeddings and sequence classification models.

See more information in the [TEI documentation](https://huggingface.co/docs/text-embeddings-inference/index)

**Current limitation:** The `chunk_size` option must be 510 or lower to work. I was not able to configure a higher size.

## How to use

It is simple, just run a docker image suitable to your [compatible hardware](https://huggingface.co/docs/text-embeddings-inference/supported_models) like the following:

```shell
docker run --gpus all -e HUGGING_FACE_HUB_TOKEN=<your-hf-token> -p 8080:80 -v <your-tei-local-data>:/data ghcr.io/huggingface/text-embeddings-inference:turing-1.1 --model-id dariolopez/roberta-base-bne-finetuned-msmarco-qa-es-mnrl-mn --max-client-batch-size 64
```

The previous command will start a new service with the model `dariolopez/roberta-base-bne-finetuned-msmarco-qa-es-mnrl-mn` ready to generate embeddings.

In justicio's configuration, limit `chunk_size` to 510 and change the `embeddings_model_name` to the URL where TEI service is running and listening, like *http://localhost:8080*.

You will need to have an environment variable where justicio is running to provide the HF token.

```shell
HUGGINGFACEHUB_API_TOKEN=<your-hf-token> python -m src.etls.boja.load dates 2024/01/01 2024/01/31
```

Embeddings will be generated using TEI and embedded into the configured vector database (only tested with Qdrant).
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ retry==0.9.2
typer==0.9.0
schedule==1.2.1

langchain==0.0.305
langchain==0.0.354
# langchainplus-sdk==0.0.20
langsmith==0.0.41
langsmith==0.0.92
qdrant-client==1.5.4
supabase==1.0.2
pinecone-client==2.2.2
Expand Down
28 changes: 14 additions & 14 deletions src/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import yaml
from langchain.chains import RetrievalQA
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceHubEmbeddings
from langchain.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
Expand Down Expand Up @@ -82,10 +82,7 @@ def _init_vector_store_pinecone(config_loader):
)
index_name = config_loader["vector_store_index_name"]
index = pinecone.Index(index_name)
embeddings = HuggingFaceEmbeddings(
model_name=config_loader["embeddings_model_name"],
model_kwargs={"device": "cpu"},
)
embeddings = _init_embeddings(config_loader=config_loader)
vector_store = Pinecone(index, embeddings.embed_query, "text")
logger.info(pinecone.describe_index(index_name))
logger.info(index.describe_index_stats())
Expand All @@ -103,10 +100,7 @@ def _init_vector_store_supabase(config_loader):
supabase_key=os.environ.get("SUPABASE_API_KEY"),
options=ClientOptions(postgrest_client_timeout=60),
)
embeddings = HuggingFaceEmbeddings(
model_name=config_loader["embeddings_model_name"],
model_kwargs={"device": "cpu"},
)
embeddings = _init_embeddings(config_loader)
vector_store = StandardSupabaseVectorStore(
client=supabase_client,
embedding=embeddings,
Expand All @@ -116,7 +110,6 @@ def _init_vector_store_supabase(config_loader):
logger.info("Initialized vector store")
return vector_store


def _init_vector_stores_qdrant(config_loader):
logger = lg.getLogger(_init_vector_stores_qdrant.__name__)
logger.info("Initializing vector stores")
Expand All @@ -125,10 +118,7 @@ def _init_vector_stores_qdrant(config_loader):
api_key=os.environ["QDRANT_API_KEY"],
prefer_grpc=True,
)
embeddings = HuggingFaceEmbeddings(
model_name=config_loader["embeddings_model_name"],
model_kwargs={"device": "cpu"},
)
embeddings = _init_embeddings(config_loader)
vector_stores = {}
for collection_name in config_loader["collections"]:
if not _exists_collection(qdrant_client, collection_name):
Expand All @@ -145,6 +135,16 @@ def _init_vector_stores_qdrant(config_loader):
logger.info("Initialized vector store for collection [%s]", collection_name)
return vector_stores

def _init_embeddings(config_loader):
model: str = config_loader["embeddings_model_name"]
if model.startswith('http'):
return HuggingFaceHubEmbeddings(model=model)
else:
return HuggingFaceEmbeddings(
model_name=model,
model_kwargs={"device": "cpu"},
)


def _init_openai_client():
logger = lg.getLogger(_init_openai_client.__name__)
Expand Down