Skip to content

Commit

Permalink
ensure original embedding config works
Browse files Browse the repository at this point in the history
  • Loading branch information
lorenzejay committed Oct 21, 2024
1 parent 40f81ae commit 3fc83c6
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 9 deletions.
4 changes: 2 additions & 2 deletions src/crewai/cli/reset_memories_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ def reset_memories_command(long, short, entity, kickoff_outputs, all) -> None:
click.echo("Long term memory has been reset.")

if short:
ShortTermMemory().reset()
ShortTermMemory(allow_reset=True).reset()
click.echo("Short term memory has been reset.")
if entity:
EntityMemory().reset()
EntityMemory(allow_reset=True).reset()
click.echo("Entity memory has been reset.")
if kickoff_outputs:
TaskOutputStorageHandler().reset()
Expand Down
91 changes: 84 additions & 7 deletions src/crewai/memory/storage/rag_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
from crewai.utilities.paths import db_storage_path
from chromadb.api import ClientAPI
from chromadb.api.types import validate_embedding_function


@contextlib.contextmanager
Expand Down Expand Up @@ -41,16 +42,87 @@ def __init__(self, type, allow_reset=True, embedder_config=None, crew=None):
self.agents = agents

self.type = type
self.embedder_config = embedder_config or self._create_embedding_function()

self.allow_reset = allow_reset
self._initialize_app()

def set_embedder_config(self):
if self.embedder_config is None:
self.embedder_config = self._create_default_embedding_function()
if isinstance(self.embedder_config, dict):
provider = self.embedder_config.get("provider")
config = self.embedder_config.get("config", {})
model_name = config.get("model")
if provider == "openai":
import chromadb.utils.embedding_functions as embedding_functions

self.embedder_config = embedding_functions.OpenAIEmbeddingFunction(
api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"),
model_name=model_name,
)
elif provider == "azure":
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)

self.embedder_config = OpenAIEmbeddingFunction(
api_key=config.get("api_key"),
api_base=config.get("api_base"),
api_type=config.get("api_type"),
api_version=config.get("api_version"),
model_name=model_name,
)
elif provider == "ollama":
from chromadb.utils.embedding_functions.ollama_embedding_function import (
OllamaEmbeddingFunction,
)

self.embedder_config = OllamaEmbeddingFunction(
model_name=config.get("model"),
url=config.get("url") or "http://localhost:11434",
)
elif provider == "vertexai":
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleVertexEmbeddingFunction,
)

self.embedder_config = GoogleVertexEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
)
elif provider == "google":
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleGenerativeAiEmbeddingFunction,
)

self.embedder_config = GoogleGenerativeAiEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
)
elif provider == "cohere":
from chromadb.utils.embedding_functions.cohere_embedding_function import (
CohereEmbeddingFunction,
)

self.embedder_config = CohereEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
)
else:
self.embedder_config = self._create_default_embedding_function()
else:
validate_embedding_function(self.embedder_config) # type: ignore # used for validating embedder_config if defined a embedding function/class
self.embedder_config = self.embedder_config

def _initialize_app(self):
import chromadb

self.set_embedder_config()
chroma_client = chromadb.PersistentClient(
path=f"{db_storage_path()}/{self.type}/{self.agents}"
path=f"{db_storage_path()}/{self.type}/{self.agents}",
settings=chromadb.Settings(allow_reset=self.allow_reset),
)

self.app = chroma_client

try:
Expand Down Expand Up @@ -122,11 +194,16 @@ def reset(self) -> None:
if self.app:
self.app.reset()
except Exception as e:
raise Exception(
f"An error occurred while resetting the {self.type} memory: {e}"
)

def _create_embedding_function(self):
if "attempt to write a readonly database" in str(e):
print("ignoring error")
# Ignore this specific error
pass
else:
raise Exception(
f"An error occurred while resetting the {self.type} memory: {e}"
)

def _create_default_embedding_function(self):
import chromadb.utils.embedding_functions as embedding_functions

return embedding_functions.OpenAIEmbeddingFunction(
Expand Down

0 comments on commit 3fc83c6

Please sign in to comment.