From fd32f09b100cf5eae0b037e983d5190eb3b005ac Mon Sep 17 00:00:00 2001 From: Zain Hoda <7146154+zainhoda@users.noreply.github.com> Date: Wed, 19 Jul 2023 22:19:04 -0400 Subject: [PATCH] generate followup questions --- pyproject.toml | 2 +- src/vanna/__init__.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index c8d9a152..f1396518 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "vanna" -version = "0.0.10" +version = "0.0.11" authors = [ { name="Zain Hoda", email="zain@vanna.ai" }, ] diff --git a/src/vanna/__init__.py b/src/vanna/__init__.py index 30b6887f..dff8d61c 100644 --- a/src/vanna/__init__.py +++ b/src/vanna/__init__.py @@ -510,6 +510,41 @@ def generate_sql(question: str) -> str: return sql_answer.sql +def generate_followup_questions(question: str, df: pd.DataFrame) -> List[str]: + """ + ## Example + ```python + vn.generate_followup_questions(question="What is the average salary of employees?", df=df) + # ['What is the average salary of employees in the Sales department?', 'What is the average salary of employees in the Engineering department?', ...] + ``` + + Generate follow-up questions using the Vanna.AI API. + + Args: + question (str): The question to generate follow-up questions for. + df (pd.DataFrame): The DataFrame to generate follow-up questions for. + + Returns: + List[str] or None: The follow-up questions, or None if an error occurred. + """ + params = [DataResult( + question=question, + sql=None, + table_markdown=df.head().to_markdown(), + error=None, + correction_attempts=0, + )] + + d = __rpc_call(method="generate_followup_questions", params=params) + + if 'result' not in d: + return None + + # Load the result into a dataclass + question_string_list = QuestionStringList(**d['result']) + + return question_string_list.questions + def generate_questions() -> List[str]: """ ## Example