Skip to content

Commit

Permalink
🐛 Bug(settings): fix missing OPENAI_API_KEY for embedding model
Browse files Browse the repository at this point in the history
  • Loading branch information
MasterKenth committed Nov 29, 2024
1 parent fae1f78 commit abbdd18
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 14 deletions.
1 change: 1 addition & 0 deletions fai-rag-app/fai-backend/fai_backend/settings/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

class SettingKey(Enum):
FIXED_PIN = 'FIXED_PIN'
OPENAI_API_KEY = 'OPENAI_API_KEY'
BREVO_API_URL = 'BREVO_API_URL'
BREVO_API_KEY = 'BREVO_API_KEY'

Expand Down
14 changes: 7 additions & 7 deletions fai-rag-app/fai-backend/fai_backend/vector/base_chromadb.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ class BaseChromaDB(IVector):
def __init__(self, client: ClientAPI):
self.client = client

def _get_collection(self, collection_name: str, embedding_function: EmbeddingFunction | None = None):
async def _get_collection(self, collection_name: str, embedding_function: EmbeddingFunction | None = None):
return self.client.get_collection(
name=collection_name,
embedding_function=embedding_function or EmbeddingFnFactory.create('default')
embedding_function=embedding_function or await EmbeddingFnFactory.create('default')
)

async def add(
Expand All @@ -29,7 +29,7 @@ async def add(
uris: Optional[OneOrMany[str]] = None,
embedding_function: Callable[[str], np.ndarray] | None = None,
) -> None:
collection = self._get_collection(collection_name, embedding_function)
collection = await self._get_collection(collection_name, embedding_function)

collection.add(
documents=documents,
Expand All @@ -47,7 +47,7 @@ async def update(
documents: Optional[OneOrMany[Document]] = None,
embedding_function: Callable[[str], np.ndarray] | None = None,
) -> None:
collection = self._get_collection(collection_name, embedding_function)
collection = await self._get_collection(collection_name, embedding_function)
collection.update(
documents=documents,
embeddings=embeddings,
Expand All @@ -69,7 +69,7 @@ async def query(
where: Optional[Where] = None,
embedding_function: Callable[[str], np.ndarray] | None = None,
) -> dict:
collection = self._get_collection(collection_name, embedding_function)
collection = await self._get_collection(collection_name, embedding_function)

return collection.query(
query_embeddings=query_embeddings,
Expand All @@ -84,7 +84,7 @@ async def get(
ids: Optional[OneOrMany[str]] = None,
embedding_function: Callable[[str], np.ndarray] | None = None,
):
collection = self._get_collection(
collection = await self._get_collection(
collection_name,
embedding_function
)
Expand All @@ -104,7 +104,7 @@ async def create_collection(self, collection_name: str,
embedding_function: Callable[[str], np.ndarray] | None = None):
return self.client.create_collection(
name=collection_name,
embedding_function=embedding_function or EmbeddingFnFactory.create('default')
embedding_function=embedding_function or await EmbeddingFnFactory.create('default')
)

async def delete_collection(self, collection_name: str):
Expand Down
10 changes: 6 additions & 4 deletions fai-rag-app/fai-backend/fai_backend/vector/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,19 @@
import numpy as np
from chromadb.utils import embedding_functions

from fai_backend.config import settings
from fai_backend.settings.service import SettingsServiceFactory, SettingKey


class EmbeddingFnFactory:
@staticmethod
def create(embedding_model: Literal['default', 'text-embedding-3-small', 'text-embedding-3-large'] | None) -> \
Callable[[str], np.ndarray]:
async def create(
embedding_model: Literal['default', 'text-embedding-3-small', 'text-embedding-3-large'] | None
) -> Callable[[str], np.ndarray]:
settings_service = SettingsServiceFactory().get_service()
embedding_model_map = {
'default': embedding_functions.DefaultEmbeddingFunction(),
'text-embedding-3-small': embedding_functions.OpenAIEmbeddingFunction(
api_key=settings.OPENAI_API_KEY.get_secret_value(),
api_key=await settings_service.get_value(SettingKey.OPENAI_API_KEY),
model_name='text-embedding-3-small'
)
}
Expand Down
6 changes: 3 additions & 3 deletions fai-rag-app/fai-backend/fai_backend/vector/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, vector_db: IVector, collection_meta_service: CollectionServic

async def create_collection(self, collection_name: str, embedding_model: str | None = None):
await self.vector_db.create_collection(collection_name,
EmbeddingFnFactory.create(embedding_model))
await EmbeddingFnFactory.create(embedding_model))

async def delete_collection(self, collection_name: str):
await self.vector_db.delete_collection(collection_name)
Expand All @@ -32,7 +32,7 @@ async def add_to_collection(
collection_name=collection_name,
ids=ids,
documents=documents,
embedding_function=EmbeddingFnFactory.create(embedding_model)
embedding_function=await EmbeddingFnFactory.create(embedding_model)
)

async def add_documents_without_id_to_empty_collection(
Expand Down Expand Up @@ -67,7 +67,7 @@ async def query_from_collection(
collection_name=collection_name,
query_texts=query_texts,
n_results=n_results,
embedding_function=EmbeddingFnFactory.create(embedding_model),
embedding_function=await EmbeddingFnFactory.create(embedding_model),
)

async def list_collections(self):
Expand Down

0 comments on commit abbdd18

Please sign in to comment.