From 3fc83c624b7fd0167a665a04b9b1b26cd067ca6c Mon Sep 17 00:00:00 2001 From: Lorenze Jay Date: Sun, 20 Oct 2024 18:12:57 -0700 Subject: [PATCH] ensure original embedding config works --- src/crewai/cli/reset_memories_command.py | 4 +- src/crewai/memory/storage/rag_storage.py | 91 ++++++++++++++++++++++-- 2 files changed, 86 insertions(+), 9 deletions(-) diff --git a/src/crewai/cli/reset_memories_command.py b/src/crewai/cli/reset_memories_command.py index c4808594fa..0c83d54c0b 100644 --- a/src/crewai/cli/reset_memories_command.py +++ b/src/crewai/cli/reset_memories_command.py @@ -32,10 +32,10 @@ def reset_memories_command(long, short, entity, kickoff_outputs, all) -> None: click.echo("Long term memory has been reset.") if short: - ShortTermMemory().reset() + ShortTermMemory(allow_reset=True).reset() click.echo("Short term memory has been reset.") if entity: - EntityMemory().reset() + EntityMemory(allow_reset=True).reset() click.echo("Entity memory has been reset.") if kickoff_outputs: TaskOutputStorageHandler().reset() diff --git a/src/crewai/memory/storage/rag_storage.py b/src/crewai/memory/storage/rag_storage.py index 8d45d9f5a1..60d7a9213f 100644 --- a/src/crewai/memory/storage/rag_storage.py +++ b/src/crewai/memory/storage/rag_storage.py @@ -8,6 +8,7 @@ from crewai.memory.storage.base_rag_storage import BaseRAGStorage from crewai.utilities.paths import db_storage_path from chromadb.api import ClientAPI +from chromadb.api.types import validate_embedding_function @contextlib.contextmanager @@ -41,16 +42,87 @@ def __init__(self, type, allow_reset=True, embedder_config=None, crew=None): self.agents = agents self.type = type - self.embedder_config = embedder_config or self._create_embedding_function() + self.allow_reset = allow_reset self._initialize_app() + def set_embedder_config(self): + if self.embedder_config is None: + self.embedder_config = self._create_default_embedding_function() + if isinstance(self.embedder_config, dict): + provider = self.embedder_config.get("provider") + config = self.embedder_config.get("config", {}) + model_name = config.get("model") + if provider == "openai": + import chromadb.utils.embedding_functions as embedding_functions + + self.embedder_config = embedding_functions.OpenAIEmbeddingFunction( + api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"), + model_name=model_name, + ) + elif provider == "azure": + from chromadb.utils.embedding_functions.openai_embedding_function import ( + OpenAIEmbeddingFunction, + ) + + self.embedder_config = OpenAIEmbeddingFunction( + api_key=config.get("api_key"), + api_base=config.get("api_base"), + api_type=config.get("api_type"), + api_version=config.get("api_version"), + model_name=model_name, + ) + elif provider == "ollama": + from chromadb.utils.embedding_functions.ollama_embedding_function import ( + OllamaEmbeddingFunction, + ) + + self.embedder_config = OllamaEmbeddingFunction( + model_name=config.get("model"), + url=config.get("url") or "http://localhost:11434", + ) + elif provider == "vertexai": + from chromadb.utils.embedding_functions.google_embedding_function import ( + GoogleVertexEmbeddingFunction, + ) + + self.embedder_config = GoogleVertexEmbeddingFunction( + model_name=model_name, + api_key=config.get("api_key"), + ) + elif provider == "google": + from chromadb.utils.embedding_functions.google_embedding_function import ( + GoogleGenerativeAiEmbeddingFunction, + ) + + self.embedder_config = GoogleGenerativeAiEmbeddingFunction( + model_name=model_name, + api_key=config.get("api_key"), + ) + elif provider == "cohere": + from chromadb.utils.embedding_functions.cohere_embedding_function import ( + CohereEmbeddingFunction, + ) + + self.embedder_config = CohereEmbeddingFunction( + model_name=model_name, + api_key=config.get("api_key"), + ) + else: + self.embedder_config = self._create_default_embedding_function() + else: + validate_embedding_function(self.embedder_config) # type: ignore # used for validating embedder_config if defined a embedding function/class + self.embedder_config = self.embedder_config + def _initialize_app(self): import chromadb + self.set_embedder_config() chroma_client = chromadb.PersistentClient( - path=f"{db_storage_path()}/{self.type}/{self.agents}" + path=f"{db_storage_path()}/{self.type}/{self.agents}", + settings=chromadb.Settings(allow_reset=self.allow_reset), ) + self.app = chroma_client try: @@ -122,11 +194,16 @@ def reset(self) -> None: if self.app: self.app.reset() except Exception as e: - raise Exception( - f"An error occurred while resetting the {self.type} memory: {e}" - ) - - def _create_embedding_function(self): + if "attempt to write a readonly database" in str(e): + print("ignoring error") + # Ignore this specific error + pass + else: + raise Exception( + f"An error occurred while resetting the {self.type} memory: {e}" + ) + + def _create_default_embedding_function(self): import chromadb.utils.embedding_functions as embedding_functions return embedding_functions.OpenAIEmbeddingFunction(