Skip to content

Commit

Permalink
adjust chart schema
Browse files Browse the repository at this point in the history
  • Loading branch information
cyyeh committed Dec 12, 2024
1 parent ecffaab commit 9894047
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 5 deletions.
85 changes: 84 additions & 1 deletion wren-ai-service/src/pipelines/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ async def _task(result: Dict[str, str]):
"encoding": {
"x": {"field": "Region", "type": "nominal", "title": "<TITLE_IN_LANGUAGE_PROVIDED_BY_USER>"},
"y": {"field": "Sales", "type": "quantitative", "title": "<TITLE_IN_LANGUAGE_PROVIDED_BY_USER>"},
"xOffset": {"field": "Product", "type": "nominal"},
"xOffset": {"field": "Product", "type": "nominal", "title": "<TITLE_IN_LANGUAGE_PROVIDED_BY_USER>"},
"color": {"field": "Product", "type": "nominal", "title": "<TITLE_IN_LANGUAGE_PROVIDED_BY_USER>"}
}
}
Expand Down Expand Up @@ -754,10 +754,93 @@ class ChartType(BaseModel):
class ChartData(BaseModel):
values: list[dict]

class ChartEncoding(BaseModel):
field: str
type: Literal["temporal", "ordinal", "quantitative", "nominal"]
title: str
stack: Optional[Literal["zero"]]

schema: str = Field(
alias="$schema", default="https://vega.github.io/schema/vega-lite/v5.json"
)
title: str
data: ChartData
mark: ChartType
encoding: dict


class LineChartSchema(ChartSchema):
class LineChartMark(BaseModel):
type: Literal["line"]

class LineChartEncoding(BaseModel):
x: ChartSchema.ChartEncoding
y: ChartSchema.ChartEncoding
color: ChartSchema.ChartEncoding

mark: LineChartMark
encoding: LineChartEncoding


class BarChartSchema(ChartSchema):
class BarChartMark(BaseModel):
type: Literal["bar"]

class BarChartEncoding(BaseModel):
x: ChartSchema.ChartEncoding
y: ChartSchema.ChartEncoding
color: ChartSchema.ChartEncoding

mark: BarChartMark
encoding: BarChartEncoding


class GroupedBarChartSchema(ChartSchema):
class GroupedBarChartMark(BaseModel):
type: Literal["bar"]

class GroupedBarChartEncoding(BaseModel):
x: ChartSchema.ChartEncoding
y: ChartSchema.ChartEncoding
xOffset: ChartSchema.ChartEncoding
color: ChartSchema.ChartEncoding

mark: GroupedBarChartMark
encoding: GroupedBarChartEncoding


class StackedBarChartSchema(ChartSchema):
class StackedBarChartMark(BaseModel):
type: Literal["bar"]

class StackedBarChartEncoding(BaseModel):
x: ChartSchema.ChartEncoding
y: ChartSchema.ChartEncoding
color: ChartSchema.ChartEncoding

mark: StackedBarChartMark
encoding: StackedBarChartEncoding


class PieChartSchema(ChartSchema):
class PieChartMark(BaseModel):
type: Literal["arc"]

class PieChartEncoding(BaseModel):
theta: ChartSchema.ChartEncoding
color: ChartSchema.ChartEncoding

mark: PieChartMark
encoding: PieChartEncoding


class AreaChartSchema(ChartSchema):
class AreaChartMark(BaseModel):
type: Literal["area"]

class AreaChartEncoding(BaseModel):
x: ChartSchema.ChartEncoding
y: ChartSchema.ChartEncoding

mark: AreaChartMark
encoding: AreaChartEncoding
16 changes: 14 additions & 2 deletions wren-ai-service/src/pipelines/generation/chart_adjustment.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,13 @@
from src.core.pipeline import BasicPipeline, async_validate
from src.core.provider import LLMProvider
from src.pipelines.common import (
AreaChartSchema,
BarChartSchema,
ChartDataPreprocessor,
ChartSchema,
GroupedBarChartSchema,
LineChartSchema,
PieChartSchema,
StackedBarChartSchema,
chart_generation_instructions,
)
from src.utils import async_timer, timer
Expand Down Expand Up @@ -177,7 +182,14 @@ def post_process(
## End of Pipeline
class ChartAdjustmentResults(BaseModel):
reasoning: str
chart_schema: ChartSchema
chart_schema: (
LineChartSchema
| BarChartSchema
| PieChartSchema
| GroupedBarChartSchema
| StackedBarChartSchema
| AreaChartSchema
)


CHART_ADJUSTMENT_MODEL_KWARGS = {
Expand Down
16 changes: 14 additions & 2 deletions wren-ai-service/src/pipelines/generation/chart_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,13 @@
from src.core.pipeline import BasicPipeline, async_validate
from src.core.provider import LLMProvider
from src.pipelines.common import (
AreaChartSchema,
BarChartSchema,
ChartDataPreprocessor,
ChartSchema,
GroupedBarChartSchema,
LineChartSchema,
PieChartSchema,
StackedBarChartSchema,
chart_generation_instructions,
)
from src.utils import async_timer, timer
Expand Down Expand Up @@ -154,7 +159,14 @@ def post_process(
## End of Pipeline
class ChartGenerationResults(BaseModel):
reasoning: str
chart_schema: ChartSchema
chart_schema: (
LineChartSchema
| BarChartSchema
| PieChartSchema
| GroupedBarChartSchema
| StackedBarChartSchema
| AreaChartSchema
)


CHART_GENERATION_MODEL_KWARGS = {
Expand Down

0 comments on commit 9894047

Please sign in to comment.