From 91917399c413a127fe048b5894a343018a50f98f Mon Sep 17 00:00:00 2001 From: Joseph Villarreal Lopez Date: Wed, 21 Feb 2024 02:40:23 -0600 Subject: [PATCH] Add Google Generative AI Embeddings as an option (#129) --- chains.py | 7 +++++++ docker-compose.yml | 12 ++++++++---- env.example | 10 ++++++++-- readme.md | 3 ++- requirements.txt | 1 + 5 files changed, 26 insertions(+), 7 deletions(-) diff --git a/chains.py b/chains.py index 1166adf7..9ad10f40 100644 --- a/chains.py +++ b/chains.py @@ -23,6 +23,7 @@ from typing import List, Any from utils import BaseLogger, extract_title_and_question +from langchain_google_genai import GoogleGenerativeAIEmbeddings def load_embedding_model(embedding_model_name: str, logger=BaseLogger(), config={}): @@ -40,6 +41,12 @@ def load_embedding_model(embedding_model_name: str, logger=BaseLogger(), config= embeddings = BedrockEmbeddings() dimension = 1536 logger.info("Embedding: Using AWS") + elif embedding_model_name == "google-genai-embedding-001": + embeddings = GoogleGenerativeAIEmbeddings( + model="models/embedding-001" + ) + dimension = 768 + logger.info("Embedding: Using Google Generative AI Embeddings") else: embeddings = SentenceTransformerEmbeddings( model_name="all-MiniLM-L6-v2", cache_folder="/embedding_model" diff --git a/docker-compose.yml b/docker-compose.yml index d5be63a7..7faafe52 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -44,9 +44,9 @@ services: - NEO4J_dbms_security_procedures_unrestricted=apoc.* healthcheck: test: ["CMD-SHELL", "wget --no-verbose --tries=1 --spider localhost:7474 || exit 1"] - interval: 5s - timeout: 3s - retries: 5 + interval: 15s + timeout: 30s + retries: 10 networks: - net @@ -61,6 +61,7 @@ services: - NEO4J_PASSWORD=${NEO4J_PASSWORD-password} - NEO4J_USERNAME=${NEO4J_USERNAME-neo4j} - OPENAI_API_KEY=${OPENAI_API_KEY-} + - GOOGLE_API_KEY=${GOOGLE_API_KEY-} - OLLAMA_BASE_URL=${OLLAMA_BASE_URL-http://host.docker.internal:11434} - EMBEDDING_MODEL=${EMBEDDING_MODEL-sentence_transformer} - LANGCHAIN_ENDPOINT=${LANGCHAIN_ENDPOINT-"https://api.smith.langchain.com"} @@ -101,7 +102,8 @@ services: - NEO4J_URI=${NEO4J_URI-neo4j://database:7687} - NEO4J_PASSWORD=${NEO4J_PASSWORD-password} - NEO4J_USERNAME=${NEO4J_USERNAME-neo4j} - - OPENAI_API_KEY=${OPENAI_API_KEY-} + - OPENAI_API_KEY=${OPENAI_API_KEY-} + - GOOGLE_API_KEY=${GOOGLE_API_KEY-} - OLLAMA_BASE_URL=${OLLAMA_BASE_URL-http://host.docker.internal:11434} - LLM=${LLM-llama2} - EMBEDDING_MODEL=${EMBEDDING_MODEL-sentence_transformer} @@ -140,6 +142,7 @@ services: - NEO4J_PASSWORD=${NEO4J_PASSWORD-password} - NEO4J_USERNAME=${NEO4J_USERNAME-neo4j} - OPENAI_API_KEY=${OPENAI_API_KEY-} + - GOOGLE_API_KEY=${GOOGLE_API_KEY-} - OLLAMA_BASE_URL=${OLLAMA_BASE_URL-http://host.docker.internal:11434} - LLM=${LLM-llama2} - EMBEDDING_MODEL=${EMBEDDING_MODEL-sentence_transformer} @@ -180,6 +183,7 @@ services: - NEO4J_PASSWORD=${NEO4J_PASSWORD-password} - NEO4J_USERNAME=${NEO4J_USERNAME-neo4j} - OPENAI_API_KEY=${OPENAI_API_KEY} + - GOOGLE_API_KEY=${GOOGLE_API_KEY} - OLLAMA_BASE_URL=${OLLAMA_BASE_URL-http://host.docker.internal:11434} - LLM=${LLM-llama2} - EMBEDDING_MODEL=${EMBEDDING_MODEL-sentence_transformer} diff --git a/env.example b/env.example index 2a9457ac..88e33cc3 100644 --- a/env.example +++ b/env.example @@ -2,7 +2,7 @@ # LLM and Embedding Model #***************************************************************** LLM=llama2 #or any Ollama model tag, gpt-4, gpt-3.5, or claudev2 -EMBEDDING_MODEL=sentence_transformer #or openai, ollama, or aws +EMBEDDING_MODEL=sentence_transformer #or google-genai-embedding-001 openai, ollama, or aws #***************************************************************** # Neo4j @@ -40,4 +40,10 @@ EMBEDDING_MODEL=sentence_transformer #or openai, ollama, or aws #AWS_ACCESS_KEY_ID= #AWS_SECRET_ACCESS_KEY= -#AWS_DEFAULT_REGION=us-east-1 \ No newline at end of file +#AWS_DEFAULT_REGION=us-east-1 + +#***************************************************************** +# GOOGLE +#***************************************************************** +# Only required when using GoogleGenai LLM or embedding model +GOOGLE_API_KEY= diff --git a/readme.md b/readme.md index 978ef378..28fba3d4 100644 --- a/readme.md +++ b/readme.md @@ -15,11 +15,12 @@ Available variables: | NEO4J_USERNAME | neo4j | REQUIRED - Username for Neo4j database | | NEO4J_PASSWORD | password | REQUIRED - Password for Neo4j database | | LLM | llama2 | REQUIRED - Can be any Ollama model tag, or gpt-4 or gpt-3.5 or claudev2 | -| EMBEDDING_MODEL | sentence_transformer | REQUIRED - Can be sentence_transformer, openai, aws or ollama | +| EMBEDDING_MODEL | sentence_transformer | REQUIRED - Can be sentence_transformer, openai, aws, ollama or google-genai-embedding-001| | AWS_ACCESS_KEY_ID | | REQUIRED - Only if LLM=claudev2 or embedding_model=aws | | AWS_SECRET_ACCESS_KEY | | REQUIRED - Only if LLM=claudev2 or embedding_model=aws | | AWS_DEFAULT_REGION | | REQUIRED - Only if LLM=claudev2 or embedding_model=aws | | OPENAI_API_KEY | | REQUIRED - Only if LLM=gpt-4 or LLM=gpt-3.5 or embedding_model=openai | +| GOOGLE_API_KEY | | REQUIRED - Only required when using GoogleGenai LLM or embedding model google-genai-embedding-001| | LANGCHAIN_ENDPOINT | "https://api.smith.langchain.com" | OPTIONAL - URL to Langchain Smith API | | LANGCHAIN_TRACING_V2 | false | OPTIONAL - Enable Langchain tracing v2 | | LANGCHAIN_PROJECT | | OPTIONAL - Langchain project name | diff --git a/requirements.txt b/requirements.txt index 9aa7e4c8..ad0f6f90 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,3 +16,4 @@ boto3 # missing from the langchain base image? langchain-openai langchain-community +langchain-google-genai