diff --git a/src/vanna/chromadb/chromadb_vector.py b/src/vanna/chromadb/chromadb_vector.py index 22758aa5..ffb93de6 100644 --- a/src/vanna/chromadb/chromadb_vector.py +++ b/src/vanna/chromadb/chromadb_vector.py @@ -1,5 +1,4 @@ import json -import uuid from typing import List import chromadb @@ -16,17 +15,14 @@ class ChromaDB_VectorStore(VannaBase): def __init__(self, config=None): VannaBase.__init__(self, config=config) + if config is None: + config = {} - if config is not None: - path = config.get("path", ".") - self.embedding_function = config.get("embedding_function", default_ef) - curr_client = config.get("client", "persistent") - self.n_results = config.get("n_results", 10) - else: - path = "." - self.embedding_function = default_ef - curr_client = "persistent" # defaults to persistent storage - self.n_results = 10 # defaults to 10 documents + path = config.get("path", ".") + self.embedding_function = config.get("embedding_function", default_ef) + curr_client = config.get("client", "persistent") + collection_metadata = config.get("collection_metadata", None) + self.n_results = config.get("n_results", 10) if curr_client == "persistent": self.chroma_client = chromadb.PersistentClient( @@ -43,13 +39,19 @@ def __init__(self, config=None): raise ValueError(f"Unsupported client was set in config: {curr_client}") self.documentation_collection = self.chroma_client.get_or_create_collection( - name="documentation", embedding_function=self.embedding_function + name="documentation", + embedding_function=self.embedding_function, + metadata=collection_metadata, ) self.ddl_collection = self.chroma_client.get_or_create_collection( - name="ddl", embedding_function=self.embedding_function + name="ddl", + embedding_function=self.embedding_function, + metadata=collection_metadata, ) self.sql_collection = self.chroma_client.get_or_create_collection( - name="sql", embedding_function=self.embedding_function + name="sql", + embedding_function=self.embedding_function, + metadata=collection_metadata, ) def generate_embedding(self, data: str, **kwargs) -> List[float]: