Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ensure original embedding config works #1476

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/crewai/cli/reset_memories_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ def reset_memories_command(long, short, entity, kickoff_outputs, all) -> None:
click.echo("Long term memory has been reset.")

if short:
ShortTermMemory().reset()
ShortTermMemory(allow_reset=True).reset()
click.echo("Short term memory has been reset.")
if entity:
EntityMemory().reset()
EntityMemory(allow_reset=True).reset()
click.echo("Entity memory has been reset.")
if kickoff_outputs:
TaskOutputStorageHandler().reset()
Expand Down
91 changes: 84 additions & 7 deletions src/crewai/memory/storage/rag_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
from crewai.utilities.paths import db_storage_path
from chromadb.api import ClientAPI
from chromadb.api.types import validate_embedding_function


@contextlib.contextmanager
Expand Down Expand Up @@ -41,16 +42,87 @@ def __init__(self, type, allow_reset=True, embedder_config=None, crew=None):
self.agents = agents

self.type = type
self.embedder_config = embedder_config or self._create_embedding_function()

self.allow_reset = allow_reset
self._initialize_app()

def _set_embedder_config(self):
if self.embedder_config is None:
self.embedder_config = self._create_default_embedding_function()
if isinstance(self.embedder_config, dict):
provider = self.embedder_config.get("provider")
config = self.embedder_config.get("config", {})
model_name = config.get("model")
if provider == "openai":
import chromadb.utils.embedding_functions as embedding_functions

self.embedder_config = embedding_functions.OpenAIEmbeddingFunction(
api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"),
model_name=model_name,
)
elif provider == "azure":
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)

self.embedder_config = OpenAIEmbeddingFunction(
api_key=config.get("api_key"),
api_base=config.get("api_base"),
api_type=config.get("api_type"),
api_version=config.get("api_version"),
model_name=model_name,
)
elif provider == "ollama":
from chromadb.utils.embedding_functions.ollama_embedding_function import (
OllamaEmbeddingFunction,
)

self.embedder_config = OllamaEmbeddingFunction(
model_name=config.get("model"),
url=config.get("url") or "http://localhost:11434",
)
elif provider == "vertexai":
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleVertexEmbeddingFunction,
)

self.embedder_config = GoogleVertexEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
)
elif provider == "google":
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleGenerativeAiEmbeddingFunction,
)

self.embedder_config = GoogleGenerativeAiEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
)
elif provider == "cohere":
from chromadb.utils.embedding_functions.cohere_embedding_function import (
CohereEmbeddingFunction,
)

self.embedder_config = CohereEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
)
else:
self.embedder_config = self._create_default_embedding_function()
else:
validate_embedding_function(self.embedder_config) # type: ignore # used for validating embedder_config if defined a embedding function/class
self.embedder_config = self.embedder_config

def _initialize_app(self):
import chromadb

self._set_embedder_config()
chroma_client = chromadb.PersistentClient(
path=f"{db_storage_path()}/{self.type}/{self.agents}"
path=f"{db_storage_path()}/{self.type}/{self.agents}",
settings=chromadb.Settings(allow_reset=self.allow_reset),
)

self.app = chroma_client

try:
Expand Down Expand Up @@ -122,11 +194,16 @@ def reset(self) -> None:
if self.app:
self.app.reset()
except Exception as e:
raise Exception(
f"An error occurred while resetting the {self.type} memory: {e}"
)

def _create_embedding_function(self):
if "attempt to write a readonly database" in str(e):
print("ignoring error")
# Ignore this specific error
pass
else:
raise Exception(
f"An error occurred while resetting the {self.type} memory: {e}"
)

def _create_default_embedding_function(self):
import chromadb.utils.embedding_functions as embedding_functions

return embedding_functions.OpenAIEmbeddingFunction(
Expand Down
Loading