Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(wren-ai-service): minor-updates #1265

Merged
merged 5 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deployment/kustomizations/base/cm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ data:
llm: litellm_llm.gpt-4o-mini-2024-07-18
- name: sql_answer
llm: litellm_llm.gpt-4o-mini-2024-07-18
engine: wren_ui
- name: sql_breakdown
llm: litellm_llm.gpt-4o-mini-2024-07-18
engine: wren_ui
Expand Down Expand Up @@ -188,6 +187,7 @@ data:

---
settings:
engine_timeout: 30
column_indexing_batch_size: 50
table_retrieval_size: 10
table_column_retrieval_size: 100
Expand Down
2 changes: 1 addition & 1 deletion docker/config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ pipes:
llm: litellm_llm.gpt-4o-mini-2024-07-18
- name: sql_answer
llm: litellm_llm.gpt-4o-mini-2024-07-18
engine: wren_ui
- name: sql_breakdown
llm: litellm_llm.gpt-4o-mini-2024-07-18
engine: wren_ui
Expand Down Expand Up @@ -140,6 +139,7 @@ pipes:

---
settings:
engine_timeout: 30
column_indexing_batch_size: 50
table_retrieval_size: 10
table_column_retrieval_size: 100
Expand Down
3 changes: 3 additions & 0 deletions wren-ai-service/src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ class Settings(BaseSettings):
# generation config
allow_sql_generation_reasoning: bool = Field(default=True)

# engine config
engine_timeout: float = Field(default=30.0)

# service config
query_cache_ttl: int = Field(default=3600) # unit: seconds
query_cache_maxsize: int = Field(
Expand Down
10 changes: 6 additions & 4 deletions wren-ai-service/src/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,14 @@ def remove_limit_statement(sql: str) -> str:
return modified_sql


def add_quotes(sql: str) -> Tuple[str, bool]:
def add_quotes(sql: str) -> Tuple[str, str]:
try:
quoted_sql = sqlglot.transpile(sql, read="trino", identify=True)[0]
quoted_sql = sqlglot.transpile(
sql, read="trino", identify=True, error_level=sqlglot.ErrorLevel.RAISE
)[0]
except Exception as e:
logger.exception(f"Error in sqlglot.transpile to {sql}: {e}")

return "", False
return "", str(e)

return quoted_sql, True
return quoted_sql, ""
11 changes: 11 additions & 0 deletions wren-ai-service/src/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,18 @@ def create_service_container(
),
"sql_generation": generation.SQLGeneration(
**pipe_components["sql_generation"],
engine_timeout=settings.engine_timeout,
),
"sql_generation_reasoning": generation.SQLGenerationReasoning(
**pipe_components["sql_generation_reasoning"],
),
"sql_correction": generation.SQLCorrection(
**pipe_components["sql_correction"],
engine_timeout=settings.engine_timeout,
),
"followup_sql_generation": generation.FollowUpSQLGeneration(
**pipe_components["followup_sql_generation"],
engine_timeout=settings.engine_timeout,
),
"sql_summary": generation.SQLSummary(
**pipe_components["sql_summary"],
Expand All @@ -128,6 +131,7 @@ def create_service_container(
pipelines={
"sql_executor": retrieval.SQLExecutor(
**pipe_components["sql_executor"],
engine_timeout=settings.engine_timeout,
),
"chart_generation": generation.ChartGeneration(
**pipe_components["chart_generation"],
Expand All @@ -139,6 +143,7 @@ def create_service_container(
pipelines={
"sql_executor": retrieval.SQLExecutor(
**pipe_components["sql_executor"],
engine_timeout=settings.engine_timeout,
),
"chart_adjustment": generation.ChartAdjustment(
**pipe_components["chart_adjustment"],
Expand All @@ -153,6 +158,7 @@ def create_service_container(
),
"sql_answer": generation.SQLAnswer(
**pipe_components["sql_answer"],
engine_timeout=settings.engine_timeout,
),
},
**query_cache,
Expand All @@ -161,6 +167,7 @@ def create_service_container(
pipelines={
"sql_breakdown": generation.SQLBreakdown(
**pipe_components["sql_breakdown"],
engine_timeout=settings.engine_timeout,
),
"sql_summary": generation.SQLSummary(
**pipe_components["sql_summary"],
Expand All @@ -177,9 +184,11 @@ def create_service_container(
),
"sql_expansion": generation.SQLExpansion(
**pipe_components["sql_expansion"],
engine_timeout=settings.engine_timeout,
),
"sql_correction": generation.SQLCorrection(
**pipe_components["sql_correction"],
engine_timeout=settings.engine_timeout,
),
"sql_summary": generation.SQLSummary(
**pipe_components["sql_summary"],
Expand Down Expand Up @@ -207,6 +216,7 @@ def create_service_container(
pipelines={
"relationship_recommendation": generation.RelationshipRecommendation(
**pipe_components["relationship_recommendation"],
engine_timeout=settings.engine_timeout,
)
},
**query_cache,
Expand All @@ -224,6 +234,7 @@ def create_service_container(
),
"sql_generation": generation.SQLGeneration(
**pipe_components["question_recommendation_sql_generation"],
engine_timeout=settings.engine_timeout,
),
"sql_generation_reasoning": generation.SQLGenerationReasoning(
**pipe_components["sql_generation_reasoning"],
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import sys
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

from hamilton import base
from hamilton.async_driver import AsyncDriver
Expand Down Expand Up @@ -106,10 +106,13 @@ async def generate_sql_in_followup(prompt: dict, generator: Any) -> dict:
async def post_process(
generate_sql_in_followup: dict,
post_processor: SQLGenPostProcessor,
engine_timeout: float,
project_id: str | None = None,
) -> dict:
return await post_processor.run(
generate_sql_in_followup.get("replies"), project_id=project_id
generate_sql_in_followup.get("replies"),
timeout=engine_timeout,
project_id=project_id,
)


Expand All @@ -132,6 +135,7 @@ def __init__(
self,
llm_provider: LLMProvider,
engine: Engine,
engine_timeout: Optional[float] = 30.0,
**kwargs,
):
self._components = {
Expand All @@ -145,6 +149,10 @@ def __init__(
"post_processor": SQLGenPostProcessor(engine=engine),
}

self._configs = {
"engine_timeout": engine_timeout,
}

super().__init__(
AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult())
)
Expand Down Expand Up @@ -176,6 +184,7 @@ async def run(
"has_calculated_field": has_calculated_field,
"has_metric": has_metric,
**self._components,
**self._configs,
},
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import sys
from enum import Enum
from typing import Any
from typing import Any, Optional

import orjson
from hamilton import base
Expand Down Expand Up @@ -170,6 +170,7 @@ def __init__(
self,
llm_provider: LLMProvider,
engine: Engine,
engine_timeout: Optional[float] = 30.0,
**_,
):
self._components = {
Expand All @@ -181,6 +182,10 @@ def __init__(
"engine": engine,
}

self._configs = {
"engine_timeout": engine_timeout,
}

self._final = "validated"

super().__init__(
Expand All @@ -200,6 +205,7 @@ async def run(
"mdl": mdl,
"language": language,
**self._components,
**self._configs,
},
)

Expand Down
9 changes: 7 additions & 2 deletions wren-ai-service/src/pipelines/generation/sql_breakdown.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import sys
from typing import Any
from typing import Any, Optional

from hamilton import base
from hamilton.async_driver import AsyncDriver
Expand Down Expand Up @@ -135,10 +135,13 @@ async def generate_sql_details(prompt: dict, generator: Any) -> dict:
async def post_process(
generate_sql_details: dict,
post_processor: SQLBreakdownGenPostProcessor,
engine_timeout: float,
project_id: str | None = None,
) -> dict:
return await post_processor.run(
generate_sql_details.get("replies"), project_id=project_id
generate_sql_details.get("replies"),
timeout=engine_timeout,
project_id=project_id,
)


Expand Down Expand Up @@ -170,6 +173,7 @@ def __init__(
self,
llm_provider: LLMProvider,
engine: Engine,
engine_timeout: Optional[float] = 30.0,
**kwargs,
):
self._components = {
Expand All @@ -185,6 +189,7 @@ def __init__(

self._configs = {
"text_to_sql_rules": TEXT_TO_SQL_RULES,
"engine_timeout": engine_timeout,
}

super().__init__(
Expand Down
15 changes: 13 additions & 2 deletions wren-ai-service/src/pipelines/generation/sql_correction.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import logging
import sys
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

from hamilton import base
from hamilton.async_driver import AsyncDriver
Expand Down Expand Up @@ -82,9 +82,14 @@ async def generate_sql_corrections(prompts: list[dict], generator: Any) -> list[
async def post_process(
generate_sql_corrections: list[dict],
post_processor: SQLGenPostProcessor,
engine_timeout: float,
project_id: str | None = None,
) -> list[dict]:
return await post_processor.run(generate_sql_corrections, project_id=project_id)
return await post_processor.run(
generate_sql_corrections,
timeout=engine_timeout,
project_id=project_id,
)


## End of Pipeline
Expand All @@ -106,6 +111,7 @@ def __init__(
self,
llm_provider: LLMProvider,
engine: Engine,
engine_timeout: Optional[float] = 30.0,
**kwargs,
):
self._components = {
Expand All @@ -119,6 +125,10 @@ def __init__(
"post_processor": SQLGenPostProcessor(engine=engine),
}

self._configs = {
"engine_timeout": engine_timeout,
}

super().__init__(
AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult())
)
Expand All @@ -138,6 +148,7 @@ async def run(
"documents": contexts,
"project_id": project_id,
**self._components,
**self._configs,
},
)

Expand Down
13 changes: 11 additions & 2 deletions wren-ai-service/src/pipelines/generation/sql_expansion.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import sys
from typing import Any, List
from typing import Any, List, Optional

from hamilton import base
from hamilton.async_driver import AsyncDriver
Expand Down Expand Up @@ -75,10 +75,13 @@ async def generate_sql_expansion(prompt: dict, generator: Any) -> dict:
async def post_process(
generate_sql_expansion: dict,
post_processor: SQLGenPostProcessor,
engine_timeout: float,
project_id: str | None = None,
) -> dict:
return await post_processor.run(
generate_sql_expansion.get("replies"), project_id=project_id
generate_sql_expansion.get("replies"),
timeout=engine_timeout,
project_id=project_id,
)


Expand All @@ -105,6 +108,7 @@ def __init__(
self,
llm_provider: LLMProvider,
engine: Engine,
engine_timeout: Optional[float] = 30.0,
**kwargs,
):
self._components = {
Expand All @@ -118,6 +122,10 @@ def __init__(
"post_processor": SQLGenPostProcessor(engine=engine),
}

self._configs = {
"engine_timeout": engine_timeout,
}

super().__init__(
AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult())
)
Expand All @@ -141,6 +149,7 @@ async def run(
"project_id": project_id,
"configuration": configuration,
**self._components,
**self._configs,
},
)

Expand Down
Loading
Loading