diff --git a/vizro-ai/changelog.d/20240613_020634_lingyi_zhang_dashboard_generator.md b/vizro-ai/changelog.d/20240613_020634_lingyi_zhang_dashboard_generator.md new file mode 100644 index 000000000..f1f65e73c --- /dev/null +++ b/vizro-ai/changelog.d/20240613_020634_lingyi_zhang_dashboard_generator.md @@ -0,0 +1,48 @@ + + + + + + + + + diff --git a/vizro-ai/examples/example_dashboard.ipynb b/vizro-ai/examples/example_dashboard.ipynb new file mode 100644 index 000000000..c918bfca5 --- /dev/null +++ b/vizro-ai/examples/example_dashboard.ipynb @@ -0,0 +1,283 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "53e857ce-22bc-49de-9adc-9a2e7c9829cf", + "metadata": {}, + "outputs": [], + "source": [ + "from dotenv import load_dotenv\n", + "load_dotenv()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2a25acdd-20c3-4762-b97f-254de1586aeb", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import vizro.plotly.express as px\n", + "\n", + "from vizro import Vizro\n", + "from vizro_ai import VizroAI\n", + "\n", + "# vizro_ai = VizroAI(model=\"gpt-4-turbo\")\n", + "vizro_ai = VizroAI(model=\"gpt-4o\")\n", + "# vizro_ai = VizroAI()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b5e24f1b-e698-40e5-be00-c3a59c53ec65", + "metadata": {}, + "outputs": [], + "source": [ + "df1 = px.data.gapminder()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "449da2ee-c754-420a-ba2e-c9b0ef62d934", + "metadata": {}, + "outputs": [], + "source": [ + "df2 = px.data.stocks()" + ] + }, + { + "cell_type": "markdown", + "id": "ec46d4d1-d20b-4351-831d-d3d8ddc5cb70", + "metadata": {}, + "source": [ + "# Example: Simple dashboard request" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "820a5d0f-a31e-4bbd-a924-9629631cc291", + "metadata": {}, + "outputs": [], + "source": [ + "user_question_2_data = \"\"\"\n", + "I need a page with 1 table.\n", + "The table shows the tech companies stock data.\n", + "\n", + "I need a second page showing 2 cards and one chart.\n", + "The first card says 'The Gapminder dataset provides historical data on countries' development indicators.'\n", + "The chart is a scatter plot showing life expectancy vs. GDP per capita by country. Life expectancy on the y axis, GDP per capita on the x axis, and colored by continent.\n", + "The second card says 'Data spans from 1952 to 2007 across various countries'\n", + "The layout uses a grid of 3 columns and 2 rows.\n", + "\n", + "Row 1: The first row has three columns:\n", + "The first column is occupied by the first card.\n", + "The second and third columns are spanned by the chart.\n", + "\n", + "Row 2: The second row mirrors the layout of the first row with respect to chart, but the first column is occupied by the second card.\n", + "\n", + "Add a filter to filter the scatter plot by continent.\n", + "Add a second filter to filter the chart by year.\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0d71e089-8c94-4d12-87bd-d803552acb32", + "metadata": {}, + "outputs": [], + "source": [ + "dashboard = vizro_ai.dashboard([df1, df2], user_question_2_data)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14477c56-54e9-43a5-9136-25bc950fdf3a", + "metadata": {}, + "outputs": [], + "source": [ + "Vizro().build(dashboard).run()" + ] + }, + { + "cell_type": "markdown", + "id": "747964b9-fd05-4c5a-a73a-79dae82320b3", + "metadata": {}, + "source": [ + "# Example: 5-page dashboard request" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "967ff6a4-f138-4643-b993-a72e5cc26de2", + "metadata": {}, + "outputs": [], + "source": [ + "df3 = px.data.tips()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cb9347f8", + "metadata": {}, + "outputs": [], + "source": [ + "user_question_3_data = \"\"\"\n", + "\n", + "I need a page with 1 table and 1 line chart. \n", + "The chart shows the stock price trends of GOOG and AAPL.\n", + "The table shows the stock prices data details.\n", + "\n", + "\n", + "I need a second page showing 1 card and 1 chart.\n", + "The card says 'The Gapminder dataset provides historical data on countries' development indicators.'\n", + "The chart is a scatter plot showing GDP per capita vs. life expectancy. GDP per capita on the x axis, life expectancy on the y axis, and colored by continent.\n", + "Layout the card on the left and the chart on the right. The card takes 1/3 of the whole space on the left.\n", + "The chart takes 2/3 of the whole space and is on the right.\n", + "Add a filter to filter the scatter plot by continent.\n", + "Add a second filter to filter the chart by year.\n", + "\n", + "\n", + "This page displays the tips dataset. use two different charts to show data\n", + "distributions. one chart should be a bar chart and the other should be a scatter plot.\n", + "first chart is on the left and the second chart is on the right.\n", + "Add a filter to filter data in the scatter plot by smoker.\n", + "\n", + "\n", + "Create 3 cards on this page:\n", + "1. The first card on top says \"This page combines data from various sources including tips, stock prices, and global indicators.\"\n", + "2. The second card says \"Insights from Gapminder dataset.\"\n", + "3. The third card says \"Stock price trends over time.\"\n", + "\n", + "Layout these 3 cards in this way:\n", + "create a grid with 3 columns and 2 rows.\n", + "Row 1: The first row has three columns:\n", + "- The first column is empty.\n", + "- The second and third columns span the area for card 1.\n", + "\n", + "Row 2: The second row also has three columns:\n", + "- The first column is empty.\n", + "- The second column is occupied by the area for card 2.\n", + "- The third column is occupied by the area for card 3.\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f0a0cdfa", + "metadata": {}, + "outputs": [], + "source": [ + "Vizro._reset()\n", + "dashboard = vizro_ai.dashboard([df1, df2, df3], user_question_3_data)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3167e996", + "metadata": {}, + "outputs": [], + "source": [ + "Vizro().build(dashboard).run()" + ] + }, + { + "cell_type": "markdown", + "id": "bbf5c920-0432-4415-996f-1acb9d7b6b8a", + "metadata": {}, + "source": [ + "# Example: Request with unsupported features" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12d5976e", + "metadata": {}, + "outputs": [], + "source": [ + "user_question_2_data = \"\"\"\n", + "\n", + "I need a page showing 2 cards, one chart, and 1 button.\n", + "The first card says 'The Tips dataset provides insights into customer tipping behavior.'\n", + "The chart is a bar chart showing the total bill amount by day. Day on the x axis, total bill amount on the y axis, and colored by time of day.\n", + "The second card says 'Data collected from various days and times.'\n", + "Layout the two cards on the left and the chart on the right. Two cards take 1/3 of the whole space on the left in total.\n", + "The first card is on top of the second card vertically.\n", + "The chart takes 2/3 of the whole space and is on the right.\n", + "The button would trigger a download action to download the Tips dataset.\n", + "Add a filter to filter the bar chart by `size`.\n", + "Make another tab on this page,\n", + "In this tab, create a card saying \"Tipping patterns and trends.\"\n", + "Group all the above content into the first NavLink.\n", + "\n", + "\n", + "Create two pages:\n", + "1. The first page has a card saying \"Analyzing global development trends.\"\n", + "2. The second page has a scatter plot showing GDP per capita vs. life expectancy. GDP per capita on the x axis, life expectancy on the y axis, and colored by continent.\n", + "Add a parameter to control the title of the scatter plot, with title options \"Economic Growth vs. Health\" and \"Development Indicators.\"\n", + "Also create a button and a spinning circle on the right-hand side of the page.\n", + "\n", + "\n", + "Create one page:\n", + "1. The first page has a card saying \"Stock price trends over time.\"\n", + "Create a button and a spinning circle on the right-hand side of the page.\n", + "\n", + "For hosting the dashboard on AWS, which service should I use?\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6b4838d1", + "metadata": {}, + "outputs": [], + "source": [ + "Vizro._reset()\n", + "dashboard = vizro_ai.dashboard([df3, df2, df1], user_question_2_data)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f055bec1", + "metadata": {}, + "outputs": [], + "source": [ + "Vizro().build(dashboard).run()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/vizro-ai/examples/example_dashboard.py b/vizro-ai/examples/example_dashboard.py new file mode 100644 index 000000000..b6dc55839 --- /dev/null +++ b/vizro-ai/examples/example_dashboard.py @@ -0,0 +1,33 @@ +"""Example of creating a dashboard using VizroAI.""" + +import vizro.plotly.express as px +from dotenv import load_dotenv +from vizro import Vizro +from vizro_ai import VizroAI + +load_dotenv() + +vizro_ai = VizroAI(model="gpt-4o") +# vizro_ai = VizroAI() + +gapminder_data = px.data.gapminder() +tips_data = px.data.tips() + +dfs = [gapminder_data, tips_data] +input_text = ( + "Create a dashboard that displays the Gapminder dataset and the tips dataset. " + "page1 displays the Gapminder dataset. create a bar chart for average GDP per capita of each continent. " + "add a filter to filter by continent. " + "Use a card to explain what Gapminder dataset is about. " + "The card should only take 1/6 of the whole page. " + "The rest of the page should be the graph or table. Don't create empty space." + "page2 displays the tips dataset. use two different charts to help me understand the data " + "distributions. one chart should be a bar chart and the other should be a scatter plot. " + "first chart is on the left and the second chart is on the right. " + "add a filter to filter data in the scatter plot by smoker." +) + +dashboard = vizro_ai.dashboard(dfs=dfs, user_input=input_text) + +if __name__ == "__main__": + Vizro().build(dashboard).run() diff --git a/vizro-ai/hatch.toml b/vizro-ai/hatch.toml index 66b69e519..258e6c479 100644 --- a/vizro-ai/hatch.toml +++ b/vizro-ai/hatch.toml @@ -25,6 +25,7 @@ VIZRO_AI_LOG_LEVEL = "DEBUG" [envs.default.scripts] example = "cd examples; python example.py" +example-create-dashboard = "cd examples; python example_dashboard.py" lint = "hatch run lint:lint {args:--all-files}" prep-release = [ "hatch version release", diff --git a/vizro-ai/pyproject.toml b/vizro-ai/pyproject.toml index 2136e2b88..d6c027b84 100644 --- a/vizro-ai/pyproject.toml +++ b/vizro-ai/pyproject.toml @@ -17,8 +17,9 @@ dependencies = [ "pandas", "tabulate", "openai>=1.0.0", - "langchain>=0.1.0, <0.3.0", # TODO update all LLMChain class and remove upper bound + "langchain>=0.1.0, <0.3.0", # TODO update all LLMChain class, update to pydantic v2 and remove upper bound "langchain-openai", + "langgraph>=0.1.2", "python-dotenv>=1.0.0", # TODO decide env var management to see if we need this "vizro>=0.1.4", # TODO set upper bound later "ipython>=8.10.0", # not directly required, pinned by Snyk to avoid a vulnerability: https://app.snyk.io/vuln/SNYK-PYTHON-IPYTHON-3318382 diff --git a/vizro-ai/snyk/requirements.txt b/vizro-ai/snyk/requirements.txt index b2588a330..19da50b49 100644 --- a/vizro-ai/snyk/requirements.txt +++ b/vizro-ai/snyk/requirements.txt @@ -3,6 +3,7 @@ tabulate openai>=1.0.0 langchain>=0.1.0, <0.3.0 langchain-openai +langgraph>=0.1.2 python-dotenv>=1.0.0 vizro>=0.1.4 ipython>=8.10.0 diff --git a/vizro-ai/src/vizro_ai/_llm_models.py b/vizro-ai/src/vizro_ai/_llm_models.py index 9014ad17c..b9a955a8b 100644 --- a/vizro-ai/src/vizro_ai/_llm_models.py +++ b/vizro-ai/src/vizro_ai/_llm_models.py @@ -1,3 +1,4 @@ +from contextlib import suppress from typing import Dict, Optional, Union from langchain_core.language_models.chat_models import BaseChatModel @@ -17,7 +18,7 @@ "gpt-3.5-turbo", "gpt-4o-2024-05-13", "gpt-4o", - ] + ], } DEFAULT_WRAPPER_MAP: Dict[str, BaseChatModel] = {"OpenAI": ChatOpenAI} @@ -49,6 +50,8 @@ def _get_llm_model(model: Optional[Union[ChatOpenAI, str]] = None) -> BaseChatMo if isinstance(model, str): if any(model in model_list for model_list in SUPPORTED_MODELS.values()): vendor = model_to_vendor[model] + if DEFAULT_WRAPPER_MAP.get(vendor) is None: + raise ValueError(f"Additional library to support {vendor} models is not installed.") return DEFAULT_WRAPPER_MAP.get(vendor)(model_name=model, temperature=DEFAULT_TEMPERATURE) raise ValueError( @@ -56,6 +59,19 @@ def _get_llm_model(model: Optional[Union[ChatOpenAI, str]] = None) -> BaseChatMo ) +def _get_model_name(model: BaseChatModel) -> str: + methods = [ + lambda: model.model_name, # OpenAI models + lambda: model.model, # Anthropic models + ] + + for method in methods: + with suppress(AttributeError): + return method() + + raise ValueError("Model name could not be retrieved") + + if __name__ == "__main__": llm_chat_openai = _get_llm_model(model="gpt-3.5-turbo") print(repr(llm_chat_openai)) # noqa: T201 diff --git a/vizro-ai/src/vizro_ai/_vizro_ai.py b/vizro-ai/src/vizro_ai/_vizro_ai.py index c0866aa41..61e2dc48b 100644 --- a/vizro-ai/src/vizro_ai/_vizro_ai.py +++ b/vizro-ai/src/vizro_ai/_vizro_ai.py @@ -1,11 +1,15 @@ import logging -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union import pandas as pd import plotly.graph_objects as go +import vizro.models as vm +from langchain_core.messages import HumanMessage from langchain_openai import ChatOpenAI -from vizro_ai._llm_models import _get_llm_model +from vizro_ai._llm_models import _get_llm_model, _get_model_name +from vizro_ai.dashboard._graph.dashboard_creation import _create_and_compile_graph +from vizro_ai.dashboard.utils import DashboardOutputs, _register_data from vizro_ai.plot.components import GetCodeExplanation, GetDebugger from vizro_ai.plot.task_pipeline._pipeline_manager import PipelineManager from vizro_ai.utils.helper import ( @@ -36,8 +40,9 @@ def __init__(self, model: Optional[Union[ChatOpenAI, str]] = None): self.components_instances = {} # TODO add pending URL link to docs + model_name = _get_model_name(self.model) logger.info( - f"You have selected {self.model.model_name}," + f"You have selected {model_name}," f"Engaging with LLMs (Large Language Models) carries certain risks. " f"Users are advised to become familiar with these risks to make informed decisions, " f"and visit this page for detailed information: " @@ -154,3 +159,44 @@ def plot( # pylint: disable=too-many-arguments # noqa: PLR0913 ) return vizro_plot if return_elements else vizro_plot.figure + + def dashboard( + self, + dfs: List[pd.DataFrame], + user_input: str, + return_elements: bool = False, + ) -> Union[DashboardOutputs, vm.Dashboard]: + """Creates a Vizro dashboard using english descriptions. + + Args: + dfs: The dataframes to be analyzed. + user_input: User questions or descriptions of the desired visual. + return_elements: Flag to return DashboardOutputs dataclass that includes all possible elements generated. + + Returns: + vm.Dashboard or DashboardOutputs dataclass. + + """ + runnable = _create_and_compile_graph() + + config = {"configurable": {"model": self.model}} + message_res = runnable.invoke( + { + "dfs": dfs, + "all_df_metadata": {}, + "dashboard_plan": None, + "pages": [], + "dashboard": None, + "messages": [HumanMessage(content=user_input)], + }, + config=config, + ) + dashboard = message_res["dashboard"] + _register_data(all_df_metadata=message_res["all_df_metadata"]) + + if return_elements: + # code = _dashboard_code(dashboard) # TODO: `_dashboard_code` to be implemented + dashboard_output = DashboardOutputs(dashboard=dashboard) + return dashboard_output + else: + return dashboard diff --git a/vizro-ai/src/vizro_ai/dashboard/_graph/dashboard_creation.py b/vizro-ai/src/vizro_ai/dashboard/_graph/dashboard_creation.py new file mode 100644 index 000000000..e11caf5da --- /dev/null +++ b/vizro-ai/src/vizro_ai/dashboard/_graph/dashboard_creation.py @@ -0,0 +1,190 @@ +"""Code generation graph for dashboard generation.""" + +import logging +import operator +from typing import Annotated, Dict, List, Optional + +import pandas as pd +import vizro.models as vm +from langchain_core.messages import BaseMessage +from langchain_core.runnables import RunnableConfig +from langgraph.constants import END, Send +from langgraph.graph import StateGraph +from tqdm.auto import tqdm +from vizro_ai.dashboard._pydantic_output import _get_pydantic_model +from vizro_ai.dashboard._response_models.dashboard import DashboardPlan +from vizro_ai.dashboard._response_models.df_info import DfInfo, _create_df_info_content, _get_df_info +from vizro_ai.dashboard._response_models.page import PagePlan +from vizro_ai.dashboard.utils import AllDfMetadata, DfMetadata, _execute_step +from vizro_ai.utils.helper import DebugFailure + +try: + from pydantic.v1 import BaseModel, ValidationError +except ImportError: # pragma: no cov + from pydantic import BaseModel, ValidationError + + +logger = logging.getLogger(__name__) + + +Messages = List[BaseMessage] +"""List of messages.""" + + +class GraphState(BaseModel): + """Represents the state of the dashboard graph. + + Attributes + messages: With user question, error messages, reasoning + dfs: Dataframes + all_df_metadata: Cleaned dataframe names and their metadata + dashboard_plan: Plan for the dashboard + pages: Vizro pages + dashboard: Vizro dashboard + + """ + + messages: List[BaseMessage] + dfs: List[pd.DataFrame] + all_df_metadata: AllDfMetadata + dashboard_plan: Optional[DashboardPlan] = None + pages: Annotated[List, operator.add] + dashboard: Optional[vm.Dashboard] = None + + class Config: + """Pydantic configuration.""" + + arbitrary_types_allowed = True + + +def _store_df_info(state: GraphState, config: RunnableConfig) -> Dict[str, AllDfMetadata]: + """Store information about the dataframes.""" + dfs = state.dfs + all_df_metadata = state.all_df_metadata + query = state.messages[0].content + current_df_names = [] + with tqdm(total=len(dfs), desc="Store df info") as pbar: + for df in dfs: + df_schema, df_sample = _get_df_info(df) + df_info = _create_df_info_content( + df_schema=df_schema, df_sample=df_sample, current_df_names=current_df_names + ) + + llm = config["configurable"].get("model", None) + try: + df_name = _get_pydantic_model( + query=query, + llm_model=llm, + response_model=DfInfo, + df_info=df_info, + ).dataset + except DebugFailure as e: + logger.warning(f"Failed in name generation {e}") + df_name = f"df_{len(current_df_names)}" + + current_df_names.append(df_name) + + pbar.write(f"df_name: {df_name}") + pbar.update(1) + all_df_metadata.all_df_metadata[df_name] = DfMetadata(df_schema=df_schema, df=df, df_sample=df_sample) + + return {"all_df_metadata": all_df_metadata} + + +def _dashboard_plan(state: GraphState, config: RunnableConfig) -> Dict[str, DashboardPlan]: + """Generate a dashboard plan.""" + node_desc = "Generate dashboard plan" + pbar = tqdm(total=2, desc=node_desc) + query = state.messages[0].content + all_df_metadata = state.all_df_metadata + + llm = config["configurable"].get("model", None) + + _execute_step( + pbar, + node_desc + " --> in progress \n(this step could take longer when more complex requirements are given)", + None, + ) + try: + dashboard_plan = _get_pydantic_model( + query=query, + llm_model=llm, + response_model=DashboardPlan, + df_info=all_df_metadata.get_schemas_and_samples(), + ) + except (DebugFailure, ValidationError) as e: + raise ValueError( + f""" + Failed to create a valid dashboard plan. Try rephrase the prompt or select a different + model. Error details: + {e} + """ + ) + + _execute_step(pbar, node_desc + " --> done", None) + pbar.close() + + return {"dashboard_plan": dashboard_plan} + + +class BuildPageState(BaseModel): + """Represents the state of building the page. + + Attributes + all_df_metadata: Cleaned dataframe names and their metadata + page_plan: Plan for the dashboard page + + """ + + all_df_metadata: AllDfMetadata + page_plan: Optional[PagePlan] = None + + +def _build_page(state: BuildPageState, config: RunnableConfig) -> Dict[str, List[vm.Page]]: + """Build a page.""" + all_df_metadata = state["all_df_metadata"] + page_plan = state["page_plan"] + + llm = config["configurable"].get("model", None) + page = page_plan.create(model=llm, all_df_metadata=all_df_metadata) + + return {"pages": [page]} + + +def _continue_to_pages(state: GraphState) -> List[Send]: + """Map-reduce logic to build pages in parallel.""" + all_df_metadata = state.all_df_metadata + return [ + Send(node="_build_page", arg={"page_plan": v, "all_df_metadata": all_df_metadata}) + for v in state.dashboard_plan.pages + ] + + +def _build_dashboard(state: GraphState) -> Dict[str, vm.Dashboard]: + """Build a dashboard.""" + dashboard_plan = state.dashboard_plan + pages = state.pages + + dashboard = vm.Dashboard(title=dashboard_plan.title, pages=pages) + + return {"dashboard": dashboard} + + +def _create_and_compile_graph(): + graph = StateGraph(GraphState) + + graph.add_node("_store_df_info", _store_df_info) + graph.add_node("_dashboard_plan", _dashboard_plan) + graph.add_node("_build_page", _build_page) + graph.add_node("_build_dashboard", _build_dashboard) + + graph.add_edge("_store_df_info", "_dashboard_plan") + graph.add_conditional_edges("_dashboard_plan", _continue_to_pages) + graph.add_edge("_build_page", "_build_dashboard") + graph.add_edge("_build_dashboard", END) + + graph.set_entry_point("_store_df_info") + + runnable = graph.compile() + + return runnable diff --git a/vizro-ai/src/vizro_ai/dashboard/_pydantic_output.py b/vizro-ai/src/vizro_ai/dashboard/_pydantic_output.py new file mode 100644 index 000000000..c1511b8e3 --- /dev/null +++ b/vizro-ai/src/vizro_ai/dashboard/_pydantic_output.py @@ -0,0 +1,98 @@ +"""Contains the _get_pydantic_model for the Vizro AI dashboard.""" + +# ruff: noqa: F821 + +import logging + +try: + from pydantic.v1 import BaseModel, ValidationError +except ImportError: # pragma: no cov + from pydantic import BaseModel, ValidationError + +from typing import Any, Optional + +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.messages import HumanMessage +from langchain_core.prompts import ChatPromptTemplate + +logger = logging.getLogger(__name__) + +BASE_PROMPT = """ +You are a front-end developer with expertise in Plotly, Dash, and the visualization library named Vizro. +Your goal is to summarize the given specifications into the given Pydantic schema. +IMPORTANT: Please always output your response by using a tool. + +This is the task context: +{df_info} + +Additional information: +{additional_info} + +Here is the user request: +""" + + +def _create_prompt_template(additional_info: str) -> ChatPromptTemplate: + """Create the ChatPromptTemplate from the base prompt and additional info.""" + return ChatPromptTemplate.from_messages( + [ + ("system", BASE_PROMPT.format(df_info="{df_info}", additional_info=additional_info)), + ("placeholder", "{message}"), + ] + ) + + +SINGLE_MODEL_PROMPT = _create_prompt_template("") +MODEL_REPROMPT = _create_prompt_template("Pay special attention to the following error: {validation_error}") + + +def _create_prompt(retry: bool = False) -> ChatPromptTemplate: + """Create the prompt message for the LLM model.""" + return MODEL_REPROMPT if retry else SINGLE_MODEL_PROMPT + + +def _create_message_content( + query: str, df_info: Any, validation_error: Optional[str] = None, retry: bool = False +) -> dict: + """Create the message content for the LLM model.""" + message_content = {"message": [HumanMessage(content=query)], "df_info": df_info} + + if retry: + message_content["validation_error"] = validation_error + + return message_content + + +def _get_pydantic_model( + query: str, + llm_model: BaseChatModel, + response_model: BaseModel, + df_info: Optional[Any] = None, + max_retry: int = 2, +) -> BaseModel: + """Get the pydantic output from the LLM model with retry logic.""" + for attempt in range(max_retry): + attempt_is_retry = attempt > 0 + prompt = _create_prompt(retry=attempt_is_retry) + message_content = _create_message_content( + query, df_info, str(last_validation_error) if attempt_is_retry else None, retry=attempt_is_retry + ) + pydantic_llm = prompt | llm_model.with_structured_output(response_model) + try: + res = pydantic_llm.invoke(message_content) + except ValidationError as validation_error: + last_validation_error = validation_error + else: + return res + + raise last_validation_error + + +if __name__ == "__main__": + import vizro.models as vm + from vizro_ai._llm_models import _get_llm_model + + model = _get_llm_model() + component_description = "Create a card with the following content: 'Hello, world!'" + res = _get_pydantic_model(query=component_description, llm_model=model, response_model=vm.Card) + print(res) # noqa: T201 diff --git a/vizro-ai/src/vizro_ai/dashboard/_response_models/components.py b/vizro-ai/src/vizro_ai/dashboard/_response_models/components.py new file mode 100644 index 000000000..feb0dfde8 --- /dev/null +++ b/vizro-ai/src/vizro_ai/dashboard/_response_models/components.py @@ -0,0 +1,98 @@ +"""Component plan model.""" + +import logging +from typing import Union + +import vizro.models as vm + +try: + from pydantic.v1 import BaseModel, Field +except ImportError: # pragma: no cov + from pydantic import BaseModel, Field +from vizro.tables import dash_ag_grid +from vizro_ai.dashboard._pydantic_output import _get_pydantic_model +from vizro_ai.dashboard._response_models.types import ComponentType +from vizro_ai.utils.helper import DebugFailure + +logger = logging.getLogger(__name__) + + +class ComponentPlan(BaseModel): + """Component plan model.""" + + component_type: ComponentType + component_description: str = Field( + ..., + description=""" + Description of the component. Include everything that relates to this component. + Be as specific and detailed as possible. + Keep the original relevant description AS IS. Keep any links exactly as provided. + Remember: Accuracy and completeness are key. Do not omit any relevant information provided about the component. + """, + ) + component_id: str = Field( + pattern=r"^[a-z]+(_[a-z]+)?$", description="Small snake case description of this component." + ) + df_name: str = Field( + ..., + description=""" + The name of the dataframe that this component will use. If no dataframe is + used, please specify that as N/A. + """, + ) + + def create(self, model, all_df_metadata) -> Union[vm.Card, vm.AgGrid, vm.Figure]: + """Create the component.""" + from vizro_ai import VizroAI + + vizro_ai = VizroAI(model=model) + + try: + if self.component_type == "Graph": + return vm.Graph( + id=self.component_id, + figure=vizro_ai.plot( + df=all_df_metadata.get_df(self.df_name), user_input=self.component_description + ), + ) + elif self.component_type == "AgGrid": + return vm.AgGrid(id=self.component_id, figure=dash_ag_grid(data_frame=self.df_name)) + elif self.component_type == "Card": + card_prompt = f""" + The Card uses the dcc.Markdown component from Dash as its underlying text component. + Create a card based on the card description: {self.component_description}. + """ + result_proxy = _get_pydantic_model(query=card_prompt, llm_model=model, response_model=vm.Card) + proxy_dict = result_proxy.dict() + proxy_dict["id"] = self.component_id + return vm.Card.parse_obj(proxy_dict) + + except DebugFailure as e: + logger.warning( + f""" +[FALLBACK] Failed to build `Component`: {self.component_id}. +Reason: {e} +Relevant prompt: {self.component_description} +""" + ) + return vm.Card(id=self.component_id, text=f"Failed to build component: {self.component_id}") + + +if __name__ == "__main__": + from dotenv import load_dotenv + from vizro_ai._llm_models import _get_llm_model + from vizro_ai.dashboard.utils import AllDfMetadata + + load_dotenv() + + model = _get_llm_model() + + all_df_metadata = AllDfMetadata({}) + component_plan = ComponentPlan( + component_type="Card", + component_description="Create a card says 'this is worldwide GDP'.", + component_id="gdp_card", + df_name="N/A", + ) + component = component_plan.create(model, all_df_metadata) + print(component.__repr__()) # noqa: T201 diff --git a/vizro-ai/src/vizro_ai/dashboard/_response_models/controls.py b/vizro-ai/src/vizro_ai/dashboard/_response_models/controls.py new file mode 100644 index 000000000..43bdcf62f --- /dev/null +++ b/vizro-ai/src/vizro_ai/dashboard/_response_models/controls.py @@ -0,0 +1,170 @@ +"""Controls plan model.""" + +import logging +from typing import List, Optional + +import pandas as pd +import vizro.models as vm + +try: + from pydantic.v1 import BaseModel, Field, ValidationError, create_model, root_validator, validator +except ImportError: # pragma: no cov + from pydantic import BaseModel, Field, ValidationError, create_model, root_validator, validator +from vizro_ai.dashboard._pydantic_output import _get_pydantic_model +from vizro_ai.dashboard._response_models.types import ControlType + +logger = logging.getLogger(__name__) + + +def _create_filter_proxy(df_cols, df_schema, controllable_components) -> BaseModel: + """Create a filter proxy model.""" + + def validate_targets(v): + """Validate the targets.""" + if v not in controllable_components: + raise ValueError(f"targets must be one of {controllable_components}") + return v + + def validate_targets_not_empty(v): + """Validate the targets not empty.""" + if not controllable_components: + raise ValueError( + """ + This might be due to the filter target is not found in the controllable components. + returning default values. + """ + ) + return v + + def validate_column(v): + """Validate the column.""" + if v not in df_cols: + raise ValueError(f"column must be one of {df_cols}") + return v + + @root_validator(allow_reuse=True) + def validate_date_picker_column(cls, values): + """Validate the column for date picker.""" + column = values.get("column") + selector = values.get("selector") + if selector and selector.type == "date_picker": + if not pd.api.types.is_datetime64_any_dtype(df_schema[column]): + raise ValueError( + f""" + The column '{column}' is not of datetime type. Selector type 'date_picker' is + not allowed. Use 'dropdown' instead. + """ + ) + return values + + return create_model( + "FilterProxy", + targets=( + List[str], + Field( + ..., + description=f""" + Target component to be affected by filter. + Must be one of {controllable_components}. ALWAYS REQUIRED. + """, + ), + ), + column=(str, Field(..., description="Column name of DataFrame to filter. ALWAYS REQUIRED.")), + __validators__={ + "validator1": validator("targets", pre=True, each_item=True, allow_reuse=True)(validate_targets), + "validator2": validator("column", allow_reuse=True)(validate_column), + "validator3": validator("targets", pre=True, allow_reuse=True)(validate_targets_not_empty), + "validator4": validate_date_picker_column, + }, + __base__=vm.Filter, + ) + + +def _create_filter(filter_prompt, model, df_cols, df_schema, controllable_components) -> vm.Filter: + result_proxy = _create_filter_proxy( + df_cols=df_cols, df_schema=df_schema, controllable_components=controllable_components + ) + proxy = _get_pydantic_model(query=filter_prompt, llm_model=model, response_model=result_proxy, df_info=df_schema) + return vm.Filter.parse_obj(proxy.dict(exclude_unset=True)) + + +class ControlPlan(BaseModel): + """Control plan model.""" + + control_type: ControlType + control_description: str = Field( + ..., + description=""" + Description of the control. Include everything that seems to relate to this control. + Be as detailed as possible. Keep the original relevant description AS IS. If this control is used + to control a specific component, include the relevant component details. + """, + ) + df_name: str = Field( + ..., + description=""" + The name of the dataframe that the target component will use. + If the dataframe is not used, please specify that. + """, + ) + + def create(self, model, controllable_components, all_df_metadata) -> Optional[vm.Filter]: + """Create the control.""" + filter_prompt = f""" + Create a filter from the following instructions: <{self.control_description}>. Do not make up + things that are optional and DO NOT configure actions, action triggers or action chains. + If no options are specified, leave them out. + """ + try: + _df_schema = all_df_metadata.get_df_schema(self.df_name) + _df_cols = list(_df_schema.keys()) + except KeyError: + logger.warning(f"Dataframe {self.df_name} not found in metadata, returning default values.") + return None + + try: + if self.control_type == "Filter": + res = _create_filter( + filter_prompt=filter_prompt, + model=model, + df_cols=_df_cols, + df_schema=_df_schema, + controllable_components=controllable_components, + ) + return res + + except ValidationError as e: + logger.warning( + f""" +[FALLBACK] Build failed for `Control`, returning default values. Try rephrase the prompt or select a different model. +Error details: {e} +Relevant prompt: {self.control_description} +""" + ) + return None + + +if __name__ == "__main__": + import pandas as pd + from dotenv import load_dotenv + from vizro_ai._llm_models import _get_llm_model + from vizro_ai.dashboard.utils import AllDfMetadata, DfMetadata + + load_dotenv() + + model = _get_llm_model() + + all_df_metadata = AllDfMetadata({}) + all_df_metadata.all_df_metadata["gdp_chart"] = DfMetadata( + df_schema={"a": "int64", "b": "int64"}, + df=pd.DataFrame({"a": [1, 2, 3, 4, 5], "b": [4, 5, 6, 7, 8]}), + df_sample=pd.DataFrame({"a": [1, 2, 3, 4, 5], "b": [4, 5, 6, 7, 8]}), + ) + control_plan = ControlPlan( + control_type="Filter", + control_description="Create a filter that filters the data by column 'a'.", + df_name="gdp_chart", + ) + control = control_plan.create( + model, ["gdp_chart"], all_df_metadata + ) # error: Target gdp_chart not found in model_manager. diff --git a/vizro-ai/src/vizro_ai/dashboard/_response_models/dashboard.py b/vizro-ai/src/vizro_ai/dashboard/_response_models/dashboard.py new file mode 100644 index 000000000..a96550b61 --- /dev/null +++ b/vizro-ai/src/vizro_ai/dashboard/_response_models/dashboard.py @@ -0,0 +1,25 @@ +"""Dashboard plan model.""" + +import logging +from typing import List + +try: + from pydantic.v1 import BaseModel, Field +except ImportError: # pragma: no cov + from pydantic import BaseModel, Field +from vizro_ai.dashboard._response_models.page import PagePlan + +logger = logging.getLogger(__name__) + + +class DashboardPlan(BaseModel): + """Dashboard plan model.""" + + title: str = Field( + ..., + description=""" + Title of the dashboard. If no description is provided, + make a short and concise title from the content of the pages. + """, + ) + pages: List[PagePlan] diff --git a/vizro-ai/src/vizro_ai/dashboard/_response_models/df_info.py b/vizro-ai/src/vizro_ai/dashboard/_response_models/df_info.py new file mode 100644 index 000000000..0ea59395b --- /dev/null +++ b/vizro-ai/src/vizro_ai/dashboard/_response_models/df_info.py @@ -0,0 +1,46 @@ +"""Data Summary Node.""" + +from typing import Dict, List, Tuple + +import pandas as pd + +try: + from pydantic.v1 import BaseModel, Field +except ImportError: # pragma: no cov + from pydantic import BaseModel, Field + + +DF_SUMMARY_PROMPT = """ +Inspect the provided data and give a short unique name to the dataset. \n +dataframe sample: \n ------- \n {df_sample} \n ------- \n +Here is the data schema: \n ------- \n {df_schema} \n ------- \n +AVOID the following names: \n ------- \n {current_df_names} \n ------- \n +Provide descriptive name mainly based on the data context above. +User request content is just for context. +""" + + +class DfInfo(BaseModel): + """Data Info output.""" + + dataset: str = Field(pattern=r"^[a-z]+(_[a-z]+)?$", description="Small snake case name of the dataset.") + + +def _get_df_info(df: pd.DataFrame) -> Tuple[Dict[str, str], pd.DataFrame]: + """Get the dataframe schema and sample.""" + formatted_pairs = dict(df.dtypes.astype(str)) + df_sample = df.sample(5, replace=True, random_state=19) + return formatted_pairs, df_sample + + +def _create_df_info_content(df_schema: Dict[str, str], df_sample: pd.DataFrame, current_df_names: List[str]) -> dict: + """Create the message content for the dataframe summarization.""" + return DF_SUMMARY_PROMPT.format(df_sample=df_sample, df_schema=df_schema, current_df_names=current_df_names) + + +if __name__ == "__main__": + df = pd.DataFrame({"a": [1, 2, 3, 4, 5], "b": [4, 5, 6, 7, 8]}) + df_schema, df_sample = _get_df_info(df) + current_df_names = ["df1", "df2"] + print(_create_df_info_content(df_schema, df_sample, current_df_names)) # noqa: T201 + print(DfInfo(dataset="test").dict()) # noqa: T201 diff --git a/vizro-ai/src/vizro_ai/dashboard/_response_models/layout.py b/vizro-ai/src/vizro_ai/dashboard/_response_models/layout.py new file mode 100644 index 000000000..dbec8b3d6 --- /dev/null +++ b/vizro-ai/src/vizro_ai/dashboard/_response_models/layout.py @@ -0,0 +1,92 @@ +"""Layout plan model.""" + +import logging +from typing import List, Optional + +import vizro.models as vm + +try: + from pydantic.v1 import BaseModel, Field, ValidationError +except ImportError: # pragma: no cov + from pydantic import BaseModel, Field, ValidationError + +logger = logging.getLogger(__name__) + + +def _convert_to_grid(layout_grid_template_areas: List[str], component_ids: List[str]) -> List[List[int]]: + component_map = {component: index for index, component in enumerate(component_ids)} + grid = [] + + for row in layout_grid_template_areas: + grid_row = [] + for cell in row.split(): + if cell == ".": + grid_row.append(-1) + else: + try: + grid_row.append(component_map[cell]) + except KeyError: + logger.warning( + f""" +[FALLBACK] Component {cell} not found in component_ids: {component_ids}. +Returning default values. +""" + ) + return [] + grid.append(grid_row) + + return grid + + +class LayoutPlan(BaseModel): + """Layout plan model, which only applies to Vizro Components(Graph, AgGrid, Card).""" + + layout_grid_template_areas: List[str] = Field( + [], + description=""" + Generate grid template areas for the layout adhering to the grid-template-areas CSS property syntax. + If no layout requested, return an empty list. + If requested, represent each component by 'component_id'. + IMPORTANT: Ensure that the `component_id` matches the `component_id` in the ComponentPlan. + If a grid area is empty, use a dot ('.') to represent it. + Ensure that each row of the grid layout is represented by a string, with each grid area separated by a space. + Return the grid template areas as a list of strings, where each string corresponds to a row in the grid. + No more than 600 characters in total. + """, + ) + + def create(self, component_ids: List[str]) -> Optional[vm.Layout]: + """Create the layout.""" + if not self.layout_grid_template_areas: + return None + + try: + grid = _convert_to_grid( + layout_grid_template_areas=self.layout_grid_template_areas, component_ids=component_ids + ) + actual = vm.Layout(grid=grid) + except ValidationError as e: + logger.warning( + f""" +[FALLBACK] Build failed for `Layout`, returning default values. Try rephrase the prompt or select a different model. +Error details: {e} +Relevant layout_grid_template_areas: +{self.layout_grid_template_areas} +""" + ) + if grid: + logger.warning(f"Calculated grid which caused the error: {grid}") + actual = None + + return actual + + +if __name__ == "__main__": + from vizro_ai._llm_models import _get_llm_model + + model = _get_llm_model() + layout_plan = LayoutPlan( + layout_grid_template_areas=["graph1 card2 card2", "graph1 . card1"], + ) + layout = layout_plan.create(component_ids=["graph1", "card1", "card2"]) + print(layout) # noqa: T201 diff --git a/vizro-ai/src/vizro_ai/dashboard/_response_models/page.py b/vizro-ai/src/vizro_ai/dashboard/_response_models/page.py new file mode 100644 index 000000000..de37b7db1 --- /dev/null +++ b/vizro-ai/src/vizro_ai/dashboard/_response_models/page.py @@ -0,0 +1,223 @@ +"""Page plan model.""" + +import logging +from collections import Counter +from typing import List, Union + +try: + from pydantic.v1 import BaseModel, Field, PrivateAttr, ValidationError, root_validator, validator +except ImportError: # pragma: no cov + from pydantic import BaseModel, Field, PrivateAttr, ValidationError, root_validator, validator +import vizro.models as vm +from tqdm.auto import tqdm +from vizro_ai.dashboard._response_models.components import ComponentPlan +from vizro_ai.dashboard._response_models.controls import ControlPlan +from vizro_ai.dashboard._response_models.layout import LayoutPlan +from vizro_ai.dashboard.utils import _execute_step + +logger = logging.getLogger(__name__) + + +class PagePlan(BaseModel): + """Page plan model.""" + + title: str = Field( + ..., + description=""" + Title of the page. If no description is provided, + make a concise and descriptive title from the components. + """, + ) + components_plan: List[ComponentPlan] = Field( + ..., description="List of components. Must contain at least one component." + ) + controls_plan: List[ControlPlan] = Field([], description="Controls of the page.") + layout_plan: LayoutPlan = Field(None, description="Layout of components on the page.") + unsupported_specs: List[str] = Field( + [], + description=""" + List of unsupported specs. If there are any unsupported specs, + list them here. If not, leave this as an empty list. + """, + ) + + _components: List[Union[vm.Card, vm.AgGrid, vm.Figure]] = PrivateAttr() + _controls: List[vm.Filter] = PrivateAttr() + _layout: vm.Layout = PrivateAttr() + + @validator("components_plan") + def _check_components_plan(cls, v): + if not v: + raise ValueError("A page must contain at least one component.") + return v + + @validator("unsupported_specs") + def _check_unsupported_specs(cls, v, values): + title = values.get("title", "Unknown Title") + if v: + logger.warning(f"\n ------- \n Unsupported specs on page <{title}>: \n {v}") + return [] + + @root_validator(allow_reuse=True) + def validate_component_id_unique(cls, values): + """Validate the component id is unique.""" + components = values.get("components_plan", []) + component_ids = [comp.component_id for comp in components] + duplicates = [id for id, count in Counter(component_ids).items() if count > 1] + if duplicates: + raise ValidationError(f"Component ids must be unique. Duplicated component ids: {duplicates}") + return values + + def __init__(self, **data): + """Initialize the page plan.""" + super().__init__(**data) + self._components = None + self._controls = None + self._layout = None + + def _get_components(self, model, all_df_metadata): + if self._components is None: + self._components = self._build_components(model=model, all_df_metadata=all_df_metadata) + return self._components + + def _build_components(self, model, all_df_metadata): + components = [] + component_log = tqdm(total=0, bar_format="{desc}", leave=False) + with tqdm( + total=len(self.components_plan), + desc=f"Currently Building ... [Page] <{self.title}> components", + leave=False, + ) as pbar: + for component_plan in self.components_plan: + component_log.set_description_str(f"[Page] <{self.title}>: [Component] {component_plan.component_id}") + pbar.update(1) + components.append(component_plan.create(model=model, all_df_metadata=all_df_metadata)) + component_log.close() + return components + + def _get_layout(self, model, all_df_metadata): + if self._layout is None: + self._layout = self._build_layout(model, all_df_metadata) + return self._layout + + def _build_layout(self, model, all_df_metadata): + if self.layout_plan is None: + return None + return self.layout_plan.create( + component_ids=self._get_component_ids(model=model, all_df_metadata=all_df_metadata), + ) + + def _get_controls(self, model, all_df_metadata): + if self._controls is None: + self._controls = self._build_controls(model=model, all_df_metadata=all_df_metadata) + return self._controls + + def _controllable_components(self, model, all_df_metadata): + return [ + comp.id + for comp in self._get_components(model=model, all_df_metadata=all_df_metadata) + if isinstance(comp, (vm.Graph, vm.AgGrid)) + ] + + def _get_component_ids(self, model, all_df_metadata): + return [comp.id for comp in self._get_components(model=model, all_df_metadata=all_df_metadata)] + + def _build_controls(self, model, all_df_metadata): + controls = [] + with tqdm( + total=len(self.controls_plan), + desc=f"Currently Building ... [Page] <{self.title}> controls", + leave=False, + ) as pbar: + for control_plan in self.controls_plan: + pbar.update(1) + control = control_plan.create( + model=model, + controllable_components=self._controllable_components(model=model, all_df_metadata=all_df_metadata), + all_df_metadata=all_df_metadata, + ) + if control: + controls.append(control) + + return controls + + def create(self, model, all_df_metadata) -> Union[vm.Page, None]: + """Create the page.""" + page_desc = f"Building page: {self.title}" + logger.info(page_desc) + pbar = tqdm(total=5, desc=page_desc) + + title = _execute_step(pbar, page_desc + " --> add title", self.title) + components = _execute_step( + pbar, page_desc + " --> add components", self._get_components(model=model, all_df_metadata=all_df_metadata) + ) + controls = _execute_step( + pbar, page_desc + " --> add controls", self._get_controls(model=model, all_df_metadata=all_df_metadata) + ) + layout = _execute_step( + pbar, page_desc + " --> add layout", self._get_layout(model=model, all_df_metadata=all_df_metadata) + ) + + try: + page = vm.Page(title=title, components=components, controls=controls, layout=layout) + except Exception as e: + # TODO: This Exception might be redundant. Check if it can be removed. + if any("Number of page and grid components need to be the same" in error["msg"] for error in e.errors()): + logger.warning( + """ +[FALLBACK] Number of page and grid components provided are not the same. +Build page without layout. +""" + ) + page = vm.Page(title=title, components=components, controls=controls, layout=None) + else: + logger.warning(f"[FALLBACK] Failed to build page: {self.title}. Reason: {e}") + page = None + _execute_step(pbar, page_desc + " --> done", None) + pbar.close() + return page + + +if __name__ == "__main__": + import pandas as pd + from dotenv import load_dotenv + from vizro_ai._llm_models import _get_llm_model + from vizro_ai.dashboard.utils import AllDfMetadata, DfMetadata + + load_dotenv() + + model = _get_llm_model() + + all_df_metadata = AllDfMetadata( + all_df_metadata={ + "gdp_chart": DfMetadata( + df_schema={"a": "int64", "b": "int64"}, + df=pd.DataFrame({"a": [1, 2, 3, 4, 5], "b": [4, 5, 6, 7, 8]}), + df_sample=pd.DataFrame({"a": [1, 2, 3, 4, 5], "b": [4, 5, 6, 7, 8]}), + ) + } + ) + page_plan = PagePlan( + title="Worldwide GDP", + components_plan=[ + ComponentPlan( + component_type="Card", + component_description="Create a card says 'this is worldwide GDP'.", + component_id="gdp_card", + df_name="N/A", + ) + ], + controls_plan=[ + ControlPlan( + control_type="Filter", + control_description="Create a filter that filters the data by column 'a'.", + df_name="gdp_chart", + ) + ], + layout_plan=LayoutPlan( + layout_grid_template_areas=[], + ), + unsupported_specs=[], + ) + page = page_plan.create(model, all_df_metadata) + print(page) # noqa: T201 diff --git a/vizro-ai/src/vizro_ai/dashboard/_response_models/types.py b/vizro-ai/src/vizro_ai/dashboard/_response_models/types.py new file mode 100644 index 000000000..56cc2023f --- /dev/null +++ b/vizro-ai/src/vizro_ai/dashboard/_response_models/types.py @@ -0,0 +1,13 @@ +"""Types for response models.""" + +from typing import Literal + +# TODO make available in documentation + +# Complete list: ["AgGrid", "Button", "Card", "Container", "Graph", "Table", "Tabs"] +ComponentType = Literal["AgGrid", "Card", "Graph"] +"""Component types currently supported by Vizro-AI.""" + +# Complete list: ["Filter", "Parameter"] +ControlType = Literal["Filter"] +"""Control types currently supported by Vizro-AI.""" diff --git a/vizro-ai/src/vizro_ai/dashboard/utils.py b/vizro-ai/src/vizro_ai/dashboard/utils.py new file mode 100644 index 000000000..7276a57bb --- /dev/null +++ b/vizro-ai/src/vizro_ai/dashboard/utils.py @@ -0,0 +1,63 @@ +"""Helper Functions For Vizro AI dashboard.""" + +from dataclasses import dataclass, field +from typing import Any, Dict + +import pandas as pd +import tqdm.std as tsd +import vizro.models as vm + + +@dataclass +class DfMetadata: + """Dataclass containing metadata content for a dataframe.""" + + df_schema: Dict[str, str] + df: pd.DataFrame + df_sample: pd.DataFrame + + +@dataclass +class AllDfMetadata: + """Dataclass containing metadata for all dataframes.""" + + all_df_metadata: Dict[str, DfMetadata] = field(default_factory=dict) + + def get_schemas_and_samples(self) -> Dict[str, Dict[str, str]]: + """Retrieve only the df_schema and df_sample for all datasets.""" + return { + name: {"df_schema": metadata.df_schema, "df_sample": metadata.df_sample} + for name, metadata in self.all_df_metadata.items() + } + + def get_df(self, name: str) -> pd.DataFrame: + """Retrieve the dataframe by name.""" + try: + return self.all_df_metadata[name].df + except KeyError: + raise KeyError("Dataframe not found in metadata. Please ensure that the correct dataframe is provided.") + + def get_df_schema(self, name: str) -> Dict[str, str]: + """Retrieve the schema of the dataframe by name.""" + return self.all_df_metadata[name].df_schema + + +@dataclass +class DashboardOutputs: + """Dataclass containing all possible `VizroAI.dashboard()` output.""" + + dashboard: vm.Dashboard + + +def _execute_step(pbar: tsd.tqdm, description: str, value: Any) -> Any: + pbar.update(1) + pbar.set_description_str(description) + return value + + +def _register_data(all_df_metadata: AllDfMetadata) -> vm.Dashboard: + """Register the dashboard data in data manager.""" + from vizro.managers import data_manager + + for name, metadata in all_df_metadata.all_df_metadata.items(): + data_manager[name] = metadata.df diff --git a/vizro-ai/src/vizro_ai/py.typed b/vizro-ai/src/vizro_ai/py.typed new file mode 100644 index 000000000..512ec7cb8 --- /dev/null +++ b/vizro-ai/src/vizro_ai/py.typed @@ -0,0 +1 @@ + # Marker file for PEP 561