Skip to content

Commit

Permalink
Add Google Generative AI Embeddings as an option (docker#129)
Browse files Browse the repository at this point in the history
  • Loading branch information
Lapeyus authored Feb 21, 2024
1 parent c44aa64 commit 9191739
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 7 deletions.
7 changes: 7 additions & 0 deletions chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from typing import List, Any
from utils import BaseLogger, extract_title_and_question
from langchain_google_genai import GoogleGenerativeAIEmbeddings


def load_embedding_model(embedding_model_name: str, logger=BaseLogger(), config={}):
Expand All @@ -40,6 +41,12 @@ def load_embedding_model(embedding_model_name: str, logger=BaseLogger(), config=
embeddings = BedrockEmbeddings()
dimension = 1536
logger.info("Embedding: Using AWS")
elif embedding_model_name == "google-genai-embedding-001":
embeddings = GoogleGenerativeAIEmbeddings(
model="models/embedding-001"
)
dimension = 768
logger.info("Embedding: Using Google Generative AI Embeddings")
else:
embeddings = SentenceTransformerEmbeddings(
model_name="all-MiniLM-L6-v2", cache_folder="/embedding_model"
Expand Down
12 changes: 8 additions & 4 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ services:
- NEO4J_dbms_security_procedures_unrestricted=apoc.*
healthcheck:
test: ["CMD-SHELL", "wget --no-verbose --tries=1 --spider localhost:7474 || exit 1"]
interval: 5s
timeout: 3s
retries: 5
interval: 15s
timeout: 30s
retries: 10
networks:
- net

Expand All @@ -61,6 +61,7 @@ services:
- NEO4J_PASSWORD=${NEO4J_PASSWORD-password}
- NEO4J_USERNAME=${NEO4J_USERNAME-neo4j}
- OPENAI_API_KEY=${OPENAI_API_KEY-}
- GOOGLE_API_KEY=${GOOGLE_API_KEY-}
- OLLAMA_BASE_URL=${OLLAMA_BASE_URL-http://host.docker.internal:11434}
- EMBEDDING_MODEL=${EMBEDDING_MODEL-sentence_transformer}
- LANGCHAIN_ENDPOINT=${LANGCHAIN_ENDPOINT-"https://api.smith.langchain.com"}
Expand Down Expand Up @@ -101,7 +102,8 @@ services:
- NEO4J_URI=${NEO4J_URI-neo4j://database:7687}
- NEO4J_PASSWORD=${NEO4J_PASSWORD-password}
- NEO4J_USERNAME=${NEO4J_USERNAME-neo4j}
- OPENAI_API_KEY=${OPENAI_API_KEY-}
- OPENAI_API_KEY=${OPENAI_API_KEY-}
- GOOGLE_API_KEY=${GOOGLE_API_KEY-}
- OLLAMA_BASE_URL=${OLLAMA_BASE_URL-http://host.docker.internal:11434}
- LLM=${LLM-llama2}
- EMBEDDING_MODEL=${EMBEDDING_MODEL-sentence_transformer}
Expand Down Expand Up @@ -140,6 +142,7 @@ services:
- NEO4J_PASSWORD=${NEO4J_PASSWORD-password}
- NEO4J_USERNAME=${NEO4J_USERNAME-neo4j}
- OPENAI_API_KEY=${OPENAI_API_KEY-}
- GOOGLE_API_KEY=${GOOGLE_API_KEY-}
- OLLAMA_BASE_URL=${OLLAMA_BASE_URL-http://host.docker.internal:11434}
- LLM=${LLM-llama2}
- EMBEDDING_MODEL=${EMBEDDING_MODEL-sentence_transformer}
Expand Down Expand Up @@ -180,6 +183,7 @@ services:
- NEO4J_PASSWORD=${NEO4J_PASSWORD-password}
- NEO4J_USERNAME=${NEO4J_USERNAME-neo4j}
- OPENAI_API_KEY=${OPENAI_API_KEY}
- GOOGLE_API_KEY=${GOOGLE_API_KEY}
- OLLAMA_BASE_URL=${OLLAMA_BASE_URL-http://host.docker.internal:11434}
- LLM=${LLM-llama2}
- EMBEDDING_MODEL=${EMBEDDING_MODEL-sentence_transformer}
Expand Down
10 changes: 8 additions & 2 deletions env.example
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# LLM and Embedding Model
#*****************************************************************
LLM=llama2 #or any Ollama model tag, gpt-4, gpt-3.5, or claudev2
EMBEDDING_MODEL=sentence_transformer #or openai, ollama, or aws
EMBEDDING_MODEL=sentence_transformer #or google-genai-embedding-001 openai, ollama, or aws

#*****************************************************************
# Neo4j
Expand Down Expand Up @@ -40,4 +40,10 @@ EMBEDDING_MODEL=sentence_transformer #or openai, ollama, or aws

#AWS_ACCESS_KEY_ID=
#AWS_SECRET_ACCESS_KEY=
#AWS_DEFAULT_REGION=us-east-1
#AWS_DEFAULT_REGION=us-east-1

#*****************************************************************
# GOOGLE
#*****************************************************************
# Only required when using GoogleGenai LLM or embedding model
GOOGLE_API_KEY=
3 changes: 2 additions & 1 deletion readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@ Available variables:
| NEO4J_USERNAME | neo4j | REQUIRED - Username for Neo4j database |
| NEO4J_PASSWORD | password | REQUIRED - Password for Neo4j database |
| LLM | llama2 | REQUIRED - Can be any Ollama model tag, or gpt-4 or gpt-3.5 or claudev2 |
| EMBEDDING_MODEL | sentence_transformer | REQUIRED - Can be sentence_transformer, openai, aws or ollama |
| EMBEDDING_MODEL | sentence_transformer | REQUIRED - Can be sentence_transformer, openai, aws, ollama or google-genai-embedding-001|
| AWS_ACCESS_KEY_ID | | REQUIRED - Only if LLM=claudev2 or embedding_model=aws |
| AWS_SECRET_ACCESS_KEY | | REQUIRED - Only if LLM=claudev2 or embedding_model=aws |
| AWS_DEFAULT_REGION | | REQUIRED - Only if LLM=claudev2 or embedding_model=aws |
| OPENAI_API_KEY | | REQUIRED - Only if LLM=gpt-4 or LLM=gpt-3.5 or embedding_model=openai |
| GOOGLE_API_KEY | | REQUIRED - Only required when using GoogleGenai LLM or embedding model google-genai-embedding-001|
| LANGCHAIN_ENDPOINT | "https://api.smith.langchain.com" | OPTIONAL - URL to Langchain Smith API |
| LANGCHAIN_TRACING_V2 | false | OPTIONAL - Enable Langchain tracing v2 |
| LANGCHAIN_PROJECT | | OPTIONAL - Langchain project name |
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ boto3
# missing from the langchain base image?
langchain-openai
langchain-community
langchain-google-genai

0 comments on commit 9191739

Please sign in to comment.