diff --git a/src/vanna/openai/openai_embeddings.py b/src/vanna/openai/openai_embeddings.py index 9d3fcd65..d1063e4a 100644 --- a/src/vanna/openai/openai_embeddings.py +++ b/src/vanna/openai/openai_embeddings.py @@ -6,8 +6,15 @@ class OpenAI_Embeddings(VannaBase): - def __init__(self, config=None): + def __init__(self, client=None, config=None): VannaBase.__init__(self, config=config) + + if client is not None: + self.client = client + return + + if self.client is not None: + return self.client = OpenAI()