diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index 119bd622..c6d40297 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -79,7 +79,7 @@ def generate_followup_questions(self, question: str, **kwargs) -> str: numbers_removed = re.sub(r"^\d+\.\s*", "", llm_response, flags=re.MULTILINE) return numbers_removed.split("\n") - def generate_questions(self, **kwargs) -> list[str]: + def generate_questions(self, **kwargs) -> List[str]: """ **Example:** ```python @@ -94,7 +94,7 @@ def generate_questions(self, **kwargs) -> list[str]: # ----------------- Use Any Embeddings API ----------------- # @abstractmethod - def generate_embedding(self, data: str, **kwargs) -> list[float]: + def generate_embedding(self, data: str, **kwargs) -> List[float]: pass # ----------------- Use Any Database to Store and Retrieve Context ----------------- # diff --git a/src/vanna/chromadb/chromadb_vector.py b/src/vanna/chromadb/chromadb_vector.py index 796c08a4..6dec25a5 100644 --- a/src/vanna/chromadb/chromadb_vector.py +++ b/src/vanna/chromadb/chromadb_vector.py @@ -1,4 +1,5 @@ import json +from typing import List import uuid from abc import abstractmethod @@ -34,7 +35,7 @@ def __init__(self, config=None): name="sql", embedding_function=default_ef ) - def generate_embedding(self, data: str, **kwargs) -> list[float]: + def generate_embedding(self, data: str, **kwargs) -> List[float]: embedding = default_ef([data]) if len(embedding) == 1: return embedding[0]