From 989404753f4192284ec8c7d12ce5edfcf894c6b9 Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Thu, 12 Dec 2024 22:08:13 +0800 Subject: [PATCH] adjust chart schema --- wren-ai-service/src/pipelines/common.py | 85 ++++++++++++++++++- .../pipelines/generation/chart_adjustment.py | 16 +++- .../pipelines/generation/chart_generation.py | 16 +++- 3 files changed, 112 insertions(+), 5 deletions(-) diff --git a/wren-ai-service/src/pipelines/common.py b/wren-ai-service/src/pipelines/common.py index a563bca6d..f834845ac 100644 --- a/wren-ai-service/src/pipelines/common.py +++ b/wren-ai-service/src/pipelines/common.py @@ -677,7 +677,7 @@ async def _task(result: Dict[str, str]): "encoding": { "x": {"field": "Region", "type": "nominal", "title": ""}, "y": {"field": "Sales", "type": "quantitative", "title": ""}, - "xOffset": {"field": "Product", "type": "nominal"}, + "xOffset": {"field": "Product", "type": "nominal", "title": ""}, "color": {"field": "Product", "type": "nominal", "title": ""} } } @@ -754,6 +754,12 @@ class ChartType(BaseModel): class ChartData(BaseModel): values: list[dict] + class ChartEncoding(BaseModel): + field: str + type: Literal["temporal", "ordinal", "quantitative", "nominal"] + title: str + stack: Optional[Literal["zero"]] + schema: str = Field( alias="$schema", default="https://vega.github.io/schema/vega-lite/v5.json" ) @@ -761,3 +767,80 @@ class ChartData(BaseModel): data: ChartData mark: ChartType encoding: dict + + +class LineChartSchema(ChartSchema): + class LineChartMark(BaseModel): + type: Literal["line"] + + class LineChartEncoding(BaseModel): + x: ChartSchema.ChartEncoding + y: ChartSchema.ChartEncoding + color: ChartSchema.ChartEncoding + + mark: LineChartMark + encoding: LineChartEncoding + + +class BarChartSchema(ChartSchema): + class BarChartMark(BaseModel): + type: Literal["bar"] + + class BarChartEncoding(BaseModel): + x: ChartSchema.ChartEncoding + y: ChartSchema.ChartEncoding + color: ChartSchema.ChartEncoding + + mark: BarChartMark + encoding: BarChartEncoding + + +class GroupedBarChartSchema(ChartSchema): + class GroupedBarChartMark(BaseModel): + type: Literal["bar"] + + class GroupedBarChartEncoding(BaseModel): + x: ChartSchema.ChartEncoding + y: ChartSchema.ChartEncoding + xOffset: ChartSchema.ChartEncoding + color: ChartSchema.ChartEncoding + + mark: GroupedBarChartMark + encoding: GroupedBarChartEncoding + + +class StackedBarChartSchema(ChartSchema): + class StackedBarChartMark(BaseModel): + type: Literal["bar"] + + class StackedBarChartEncoding(BaseModel): + x: ChartSchema.ChartEncoding + y: ChartSchema.ChartEncoding + color: ChartSchema.ChartEncoding + + mark: StackedBarChartMark + encoding: StackedBarChartEncoding + + +class PieChartSchema(ChartSchema): + class PieChartMark(BaseModel): + type: Literal["arc"] + + class PieChartEncoding(BaseModel): + theta: ChartSchema.ChartEncoding + color: ChartSchema.ChartEncoding + + mark: PieChartMark + encoding: PieChartEncoding + + +class AreaChartSchema(ChartSchema): + class AreaChartMark(BaseModel): + type: Literal["area"] + + class AreaChartEncoding(BaseModel): + x: ChartSchema.ChartEncoding + y: ChartSchema.ChartEncoding + + mark: AreaChartMark + encoding: AreaChartEncoding diff --git a/wren-ai-service/src/pipelines/generation/chart_adjustment.py b/wren-ai-service/src/pipelines/generation/chart_adjustment.py index be64fd3ff..1e9321d12 100644 --- a/wren-ai-service/src/pipelines/generation/chart_adjustment.py +++ b/wren-ai-service/src/pipelines/generation/chart_adjustment.py @@ -17,8 +17,13 @@ from src.core.pipeline import BasicPipeline, async_validate from src.core.provider import LLMProvider from src.pipelines.common import ( + AreaChartSchema, + BarChartSchema, ChartDataPreprocessor, - ChartSchema, + GroupedBarChartSchema, + LineChartSchema, + PieChartSchema, + StackedBarChartSchema, chart_generation_instructions, ) from src.utils import async_timer, timer @@ -177,7 +182,14 @@ def post_process( ## End of Pipeline class ChartAdjustmentResults(BaseModel): reasoning: str - chart_schema: ChartSchema + chart_schema: ( + LineChartSchema + | BarChartSchema + | PieChartSchema + | GroupedBarChartSchema + | StackedBarChartSchema + | AreaChartSchema + ) CHART_ADJUSTMENT_MODEL_KWARGS = { diff --git a/wren-ai-service/src/pipelines/generation/chart_generation.py b/wren-ai-service/src/pipelines/generation/chart_generation.py index ec91bcb85..c518dd094 100644 --- a/wren-ai-service/src/pipelines/generation/chart_generation.py +++ b/wren-ai-service/src/pipelines/generation/chart_generation.py @@ -17,8 +17,13 @@ from src.core.pipeline import BasicPipeline, async_validate from src.core.provider import LLMProvider from src.pipelines.common import ( + AreaChartSchema, + BarChartSchema, ChartDataPreprocessor, - ChartSchema, + GroupedBarChartSchema, + LineChartSchema, + PieChartSchema, + StackedBarChartSchema, chart_generation_instructions, ) from src.utils import async_timer, timer @@ -154,7 +159,14 @@ def post_process( ## End of Pipeline class ChartGenerationResults(BaseModel): reasoning: str - chart_schema: ChartSchema + chart_schema: ( + LineChartSchema + | BarChartSchema + | PieChartSchema + | GroupedBarChartSchema + | StackedBarChartSchema + | AreaChartSchema + ) CHART_GENERATION_MODEL_KWARGS = {