From 27c162909ce568550742c40a961dc953d70bf87b Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Mon, 3 Feb 2025 20:25:37 +0800 Subject: [PATCH 1/4] return sql error --- wren-ai-service/src/core/engine.py | 10 ++++--- .../src/pipelines/generation/utils/sql.py | 10 +++---- wren-ai-service/src/web/v1/services/ask.py | 21 +++++++------- .../src/web/v1/services/ask_details.py | 4 +-- .../src/web/v1/services/sql_expansion.py | 28 +++++++++---------- 5 files changed, 37 insertions(+), 36 deletions(-) diff --git a/wren-ai-service/src/core/engine.py b/wren-ai-service/src/core/engine.py index c145c32e9..e1ded50b9 100644 --- a/wren-ai-service/src/core/engine.py +++ b/wren-ai-service/src/core/engine.py @@ -50,12 +50,14 @@ def remove_limit_statement(sql: str) -> str: return modified_sql -def add_quotes(sql: str) -> Tuple[str, bool]: +def add_quotes(sql: str) -> Tuple[str, str]: try: - quoted_sql = sqlglot.transpile(sql, read="trino", identify=True)[0] + quoted_sql = sqlglot.transpile( + sql, read="trino", identify=True, error_level=sqlglot.ErrorLevel.RAISE + )[0] except Exception as e: logger.exception(f"Error in sqlglot.transpile to {sql}: {e}") - return "", False + return "", str(e) - return quoted_sql, True + return quoted_sql, "" diff --git a/wren-ai-service/src/pipelines/generation/utils/sql.py b/wren-ai-service/src/pipelines/generation/utils/sql.py index c0b7774b8..fb1bc60e8 100644 --- a/wren-ai-service/src/pipelines/generation/utils/sql.py +++ b/wren-ai-service/src/pipelines/generation/utils/sql.py @@ -45,8 +45,8 @@ async def run( steps[-1]["cte_name"] = "" for step in steps: - step["sql"], no_error = add_quotes(step["sql"]) - if not no_error: + step["sql"], error_message = add_quotes(step["sql"]) + if not error_message: return { "results": { "description": cleaned_generation_result["description"], @@ -160,9 +160,9 @@ async def _classify_invalid_generation_results( invalid_generation_results = [] async def _task(sql: str): - quoted_sql, no_error = add_quotes(sql) + quoted_sql, error_message = add_quotes(sql) - if no_error: + if not error_message: status, _, addition = await self._engine.execute_sql( quoted_sql, session, project_id=project_id ) @@ -188,7 +188,7 @@ async def _task(sql: str): { "sql": sql, "type": "ADD_QUOTES", - "error": "add_quotes failed", + "error": error_message, } ) diff --git a/wren-ai-service/src/web/v1/services/ask.py b/wren-ai-service/src/web/v1/services/ask.py index a690b968c..d7a51d05d 100644 --- a/wren-ai-service/src/web/v1/services/ask.py +++ b/wren-ai-service/src/web/v1/services/ask.py @@ -120,11 +120,6 @@ def _is_stopped(self, query_id: str): return False - def _get_failed_dry_run_results(self, invalid_generation_results: list[dict]): - return list( - filter(lambda x: x["type"] == "DRY_RUN", invalid_generation_results) - ) - @observe(name="Ask Question") @trace_metadata async def ask( @@ -146,6 +141,7 @@ async def ask( intent_reasoning = None sql_generation_reasoning = None api_results = [] + error_message = "" try: # ask status can be understanding, searching, generating, finished, failed, stopped @@ -362,11 +358,9 @@ async def ask( ) for result in sql_valid_results ][:1] - elif failed_dry_run_results := self._get_failed_dry_run_results( - text_to_sql_generation_results["post_process"][ - "invalid_generation_results" - ] - ): + elif failed_dry_run_results := text_to_sql_generation_results[ + "post_process" + ]["invalid_generation_results"]: self._ask_results[query_id] = AskResultResponse( status="correcting", ) @@ -390,6 +384,10 @@ async def ask( ) for valid_generation_result in valid_generation_results ][:1] + elif failed_dry_run_results := sql_correction_results[ + "post_process" + ]["invalid_generation_results"]: + error_message = failed_dry_run_results[0]["error"] if api_results: if not self._is_stopped(query_id): @@ -411,13 +409,14 @@ async def ask( type="TEXT_TO_SQL", error=AskError( code="NO_RELEVANT_SQL", - message="No relevant SQL", + message=error_message or "No relevant SQL", ), rephrased_question=rephrased_question, intent_reasoning=intent_reasoning, generation_reasoning=sql_generation_reasoning, ) results["metadata"]["error_type"] = "NO_RELEVANT_SQL" + results["metadata"]["error_message"] = error_message results["metadata"]["type"] = "TEXT_TO_SQL" return results diff --git a/wren-ai-service/src/web/v1/services/ask_details.py b/wren-ai-service/src/web/v1/services/ask_details.py index 4c90cf0ac..d6f167484 100644 --- a/wren-ai-service/src/web/v1/services/ask_details.py +++ b/wren-ai-service/src/web/v1/services/ask_details.py @@ -123,8 +123,8 @@ async def ask_details( ask_details_result = generation_result["post_process"]["results"] if not ask_details_result["steps"]: - quoted_sql, no_error = add_quotes(ask_details_request.sql) - sql = quoted_sql if no_error else ask_details_request.sql + quoted_sql, error_message = add_quotes(ask_details_request.sql) + sql = quoted_sql if not error_message else ask_details_request.sql sql_summary_results = await self._pipelines["sql_summary"].run( query=ask_details_request.query, diff --git a/wren-ai-service/src/web/v1/services/sql_expansion.py b/wren-ai-service/src/web/v1/services/sql_expansion.py index 06dfa5f86..43f6cd72e 100644 --- a/wren-ai-service/src/web/v1/services/sql_expansion.py +++ b/wren-ai-service/src/web/v1/services/sql_expansion.py @@ -93,11 +93,6 @@ def _is_stopped(self, query_id: str): return False - def _get_failed_dry_run_results(self, invalid_generation_results: list[dict]): - return list( - filter(lambda x: x["type"] == "DRY_RUN", invalid_generation_results) - ) - @observe(name="SQL Expansion") @trace_metadata async def sql_expansion( @@ -112,6 +107,7 @@ async def sql_expansion( "error_message": "", }, } + error_message = "" try: query_id = sql_expansion_request.query_id @@ -171,11 +167,9 @@ async def sql_expansion( ]["valid_generation_results"]: valid_generation_results += sql_valid_results - if failed_dry_run_results := self._get_failed_dry_run_results( - sql_expansion_generation_results["post_process"][ - "invalid_generation_results" - ] - ): + if failed_dry_run_results := sql_expansion_generation_results[ + "post_process" + ]["invalid_generation_results"]: sql_correction_results = await self._pipelines[ "sql_correction" ].run( @@ -183,9 +177,14 @@ async def sql_expansion( invalid_generation_results=failed_dry_run_results, project_id=sql_expansion_request.project_id, ) - valid_generation_results += sql_correction_results["post_process"][ - "valid_generation_results" - ] + if sql_correction_valid_results := sql_correction_results[ + "post_process" + ]["valid_generation_results"]: + valid_generation_results += sql_correction_valid_results + elif failed_dry_run_results := sql_correction_results[ + "post_process" + ]["invalid_generation_results"]: + error_message = failed_dry_run_results[0]["error"] valid_sql_summary_results = [] if valid_generation_results: @@ -206,10 +205,11 @@ async def sql_expansion( status="failed", error=AskError( code="NO_RELEVANT_SQL", - message="No relevant SQL", + message=error_message or "No relevant SQL", ), ) results["metadata"]["error_type"] = "NO_RELEVANT_SQL" + results["metadata"]["error_message"] = error_message return results api_results = SqlExpansionResultResponse.SqlExpansionResult( From efedd0920279062b03fc04a8854c083638a3449b Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Wed, 5 Feb 2025 13:23:33 +0800 Subject: [PATCH 2/4] add engine timeout and remove engine in sql_answer --- deployment/kustomizations/base/cm.yaml | 2 +- docker/config.example.yaml | 2 +- wren-ai-service/src/config.py | 3 +++ wren-ai-service/src/globals.py | 11 ++++++++++ .../generation/followup_sql_generation.py | 13 ++++++++++-- .../generation/relationship_recommendation.py | 8 ++++++- .../src/pipelines/generation/sql_breakdown.py | 9 ++++++-- .../pipelines/generation/sql_correction.py | 15 +++++++++++-- .../src/pipelines/generation/sql_expansion.py | 13 ++++++++++-- .../pipelines/generation/sql_generation.py | 15 +++++++++++-- .../src/pipelines/generation/utils/sql.py | 21 +++++++++++++++---- .../src/pipelines/retrieval/sql_executor.py | 16 +++++++++++++- .../tools/config/config.example.yaml | 2 +- wren-ai-service/tools/config/config.full.yaml | 2 +- 14 files changed, 112 insertions(+), 20 deletions(-) diff --git a/deployment/kustomizations/base/cm.yaml b/deployment/kustomizations/base/cm.yaml index c5a681383..e02b847b4 100644 --- a/deployment/kustomizations/base/cm.yaml +++ b/deployment/kustomizations/base/cm.yaml @@ -131,7 +131,6 @@ data: llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: sql_answer llm: litellm_llm.gpt-4o-mini-2024-07-18 - engine: wren_ui - name: sql_breakdown llm: litellm_llm.gpt-4o-mini-2024-07-18 engine: wren_ui @@ -188,6 +187,7 @@ data: --- settings: + engine_timeout: 30 column_indexing_batch_size: 50 table_retrieval_size: 10 table_column_retrieval_size: 100 diff --git a/docker/config.example.yaml b/docker/config.example.yaml index 5f0557c46..9d0e1bf41 100644 --- a/docker/config.example.yaml +++ b/docker/config.example.yaml @@ -83,7 +83,6 @@ pipes: llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: sql_answer llm: litellm_llm.gpt-4o-mini-2024-07-18 - engine: wren_ui - name: sql_breakdown llm: litellm_llm.gpt-4o-mini-2024-07-18 engine: wren_ui @@ -140,6 +139,7 @@ pipes: --- settings: + engine_timeout: 30 column_indexing_batch_size: 50 table_retrieval_size: 10 table_column_retrieval_size: 100 diff --git a/wren-ai-service/src/config.py b/wren-ai-service/src/config.py index 9e0a588db..dc19b02c8 100644 --- a/wren-ai-service/src/config.py +++ b/wren-ai-service/src/config.py @@ -34,6 +34,9 @@ class Settings(BaseSettings): # generation config allow_sql_generation_reasoning: bool = Field(default=True) + # engine config + engine_timeout: float = Field(default=30.0) + # service config query_cache_ttl: int = Field(default=3600) # unit: seconds query_cache_maxsize: int = Field( diff --git a/wren-ai-service/src/globals.py b/wren-ai-service/src/globals.py index 5555ffca5..9ed3e31c3 100644 --- a/wren-ai-service/src/globals.py +++ b/wren-ai-service/src/globals.py @@ -107,15 +107,18 @@ def create_service_container( ), "sql_generation": generation.SQLGeneration( **pipe_components["sql_generation"], + engine_timeout=settings.engine_timeout, ), "sql_generation_reasoning": generation.SQLGenerationReasoning( **pipe_components["sql_generation_reasoning"], ), "sql_correction": generation.SQLCorrection( **pipe_components["sql_correction"], + engine_timeout=settings.engine_timeout, ), "followup_sql_generation": generation.FollowUpSQLGeneration( **pipe_components["followup_sql_generation"], + engine_timeout=settings.engine_timeout, ), "sql_summary": generation.SQLSummary( **pipe_components["sql_summary"], @@ -128,6 +131,7 @@ def create_service_container( pipelines={ "sql_executor": retrieval.SQLExecutor( **pipe_components["sql_executor"], + engine_timeout=settings.engine_timeout, ), "chart_generation": generation.ChartGeneration( **pipe_components["chart_generation"], @@ -139,6 +143,7 @@ def create_service_container( pipelines={ "sql_executor": retrieval.SQLExecutor( **pipe_components["sql_executor"], + engine_timeout=settings.engine_timeout, ), "chart_adjustment": generation.ChartAdjustment( **pipe_components["chart_adjustment"], @@ -153,6 +158,7 @@ def create_service_container( ), "sql_answer": generation.SQLAnswer( **pipe_components["sql_answer"], + engine_timeout=settings.engine_timeout, ), }, **query_cache, @@ -161,6 +167,7 @@ def create_service_container( pipelines={ "sql_breakdown": generation.SQLBreakdown( **pipe_components["sql_breakdown"], + engine_timeout=settings.engine_timeout, ), "sql_summary": generation.SQLSummary( **pipe_components["sql_summary"], @@ -177,9 +184,11 @@ def create_service_container( ), "sql_expansion": generation.SQLExpansion( **pipe_components["sql_expansion"], + engine_timeout=settings.engine_timeout, ), "sql_correction": generation.SQLCorrection( **pipe_components["sql_correction"], + engine_timeout=settings.engine_timeout, ), "sql_summary": generation.SQLSummary( **pipe_components["sql_summary"], @@ -207,6 +216,7 @@ def create_service_container( pipelines={ "relationship_recommendation": generation.RelationshipRecommendation( **pipe_components["relationship_recommendation"], + engine_timeout=settings.engine_timeout, ) }, **query_cache, @@ -224,6 +234,7 @@ def create_service_container( ), "sql_generation": generation.SQLGeneration( **pipe_components["question_recommendation_sql_generation"], + engine_timeout=settings.engine_timeout, ), "sql_generation_reasoning": generation.SQLGenerationReasoning( **pipe_components["sql_generation_reasoning"], diff --git a/wren-ai-service/src/pipelines/generation/followup_sql_generation.py b/wren-ai-service/src/pipelines/generation/followup_sql_generation.py index e84f2e729..b94c9af97 100644 --- a/wren-ai-service/src/pipelines/generation/followup_sql_generation.py +++ b/wren-ai-service/src/pipelines/generation/followup_sql_generation.py @@ -1,6 +1,6 @@ import logging import sys -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from hamilton import base from hamilton.async_driver import AsyncDriver @@ -106,10 +106,13 @@ async def generate_sql_in_followup(prompt: dict, generator: Any) -> dict: async def post_process( generate_sql_in_followup: dict, post_processor: SQLGenPostProcessor, + engine_timeout: float, project_id: str | None = None, ) -> dict: return await post_processor.run( - generate_sql_in_followup.get("replies"), project_id=project_id + generate_sql_in_followup.get("replies"), + timeout=engine_timeout, + project_id=project_id, ) @@ -132,6 +135,7 @@ def __init__( self, llm_provider: LLMProvider, engine: Engine, + engine_timeout: Optional[float] = 30.0, **kwargs, ): self._components = { @@ -145,6 +149,10 @@ def __init__( "post_processor": SQLGenPostProcessor(engine=engine), } + self._configs = { + "engine_timeout": engine_timeout, + } + super().__init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) @@ -176,6 +184,7 @@ async def run( "has_calculated_field": has_calculated_field, "has_metric": has_metric, **self._components, + **self._configs, }, ) diff --git a/wren-ai-service/src/pipelines/generation/relationship_recommendation.py b/wren-ai-service/src/pipelines/generation/relationship_recommendation.py index efc5b8782..a053fed5b 100644 --- a/wren-ai-service/src/pipelines/generation/relationship_recommendation.py +++ b/wren-ai-service/src/pipelines/generation/relationship_recommendation.py @@ -1,7 +1,7 @@ import logging import sys from enum import Enum -from typing import Any +from typing import Any, Optional import orjson from hamilton import base @@ -170,6 +170,7 @@ def __init__( self, llm_provider: LLMProvider, engine: Engine, + engine_timeout: Optional[float] = 30.0, **_, ): self._components = { @@ -181,6 +182,10 @@ def __init__( "engine": engine, } + self._configs = { + "engine_timeout": engine_timeout, + } + self._final = "validated" super().__init__( @@ -200,6 +205,7 @@ async def run( "mdl": mdl, "language": language, **self._components, + **self._configs, }, ) diff --git a/wren-ai-service/src/pipelines/generation/sql_breakdown.py b/wren-ai-service/src/pipelines/generation/sql_breakdown.py index 5f50f4036..393bbdc21 100644 --- a/wren-ai-service/src/pipelines/generation/sql_breakdown.py +++ b/wren-ai-service/src/pipelines/generation/sql_breakdown.py @@ -1,6 +1,6 @@ import logging import sys -from typing import Any +from typing import Any, Optional from hamilton import base from hamilton.async_driver import AsyncDriver @@ -135,10 +135,13 @@ async def generate_sql_details(prompt: dict, generator: Any) -> dict: async def post_process( generate_sql_details: dict, post_processor: SQLBreakdownGenPostProcessor, + engine_timeout: float, project_id: str | None = None, ) -> dict: return await post_processor.run( - generate_sql_details.get("replies"), project_id=project_id + generate_sql_details.get("replies"), + timeout=engine_timeout, + project_id=project_id, ) @@ -170,6 +173,7 @@ def __init__( self, llm_provider: LLMProvider, engine: Engine, + engine_timeout: Optional[float] = 30.0, **kwargs, ): self._components = { @@ -185,6 +189,7 @@ def __init__( self._configs = { "text_to_sql_rules": TEXT_TO_SQL_RULES, + "engine_timeout": engine_timeout, } super().__init__( diff --git a/wren-ai-service/src/pipelines/generation/sql_correction.py b/wren-ai-service/src/pipelines/generation/sql_correction.py index 602e844e6..707c2ef9a 100644 --- a/wren-ai-service/src/pipelines/generation/sql_correction.py +++ b/wren-ai-service/src/pipelines/generation/sql_correction.py @@ -1,7 +1,7 @@ import asyncio import logging import sys -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from hamilton import base from hamilton.async_driver import AsyncDriver @@ -82,9 +82,14 @@ async def generate_sql_corrections(prompts: list[dict], generator: Any) -> list[ async def post_process( generate_sql_corrections: list[dict], post_processor: SQLGenPostProcessor, + engine_timeout: float, project_id: str | None = None, ) -> list[dict]: - return await post_processor.run(generate_sql_corrections, project_id=project_id) + return await post_processor.run( + generate_sql_corrections, + timeout=engine_timeout, + project_id=project_id, + ) ## End of Pipeline @@ -106,6 +111,7 @@ def __init__( self, llm_provider: LLMProvider, engine: Engine, + engine_timeout: Optional[float] = 30.0, **kwargs, ): self._components = { @@ -119,6 +125,10 @@ def __init__( "post_processor": SQLGenPostProcessor(engine=engine), } + self._configs = { + "engine_timeout": engine_timeout, + } + super().__init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) @@ -138,6 +148,7 @@ async def run( "documents": contexts, "project_id": project_id, **self._components, + **self._configs, }, ) diff --git a/wren-ai-service/src/pipelines/generation/sql_expansion.py b/wren-ai-service/src/pipelines/generation/sql_expansion.py index 796605623..417de42bd 100644 --- a/wren-ai-service/src/pipelines/generation/sql_expansion.py +++ b/wren-ai-service/src/pipelines/generation/sql_expansion.py @@ -1,6 +1,6 @@ import logging import sys -from typing import Any, List +from typing import Any, List, Optional from hamilton import base from hamilton.async_driver import AsyncDriver @@ -75,10 +75,13 @@ async def generate_sql_expansion(prompt: dict, generator: Any) -> dict: async def post_process( generate_sql_expansion: dict, post_processor: SQLGenPostProcessor, + engine_timeout: float, project_id: str | None = None, ) -> dict: return await post_processor.run( - generate_sql_expansion.get("replies"), project_id=project_id + generate_sql_expansion.get("replies"), + timeout=engine_timeout, + project_id=project_id, ) @@ -105,6 +108,7 @@ def __init__( self, llm_provider: LLMProvider, engine: Engine, + engine_timeout: Optional[float] = 30.0, **kwargs, ): self._components = { @@ -118,6 +122,10 @@ def __init__( "post_processor": SQLGenPostProcessor(engine=engine), } + self._configs = { + "engine_timeout": engine_timeout, + } + super().__init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) @@ -141,6 +149,7 @@ async def run( "project_id": project_id, "configuration": configuration, **self._components, + **self._configs, }, ) diff --git a/wren-ai-service/src/pipelines/generation/sql_generation.py b/wren-ai-service/src/pipelines/generation/sql_generation.py index 5fb55fd09..b3c6bb0a4 100644 --- a/wren-ai-service/src/pipelines/generation/sql_generation.py +++ b/wren-ai-service/src/pipelines/generation/sql_generation.py @@ -1,6 +1,6 @@ import logging import sys -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from hamilton import base from hamilton.async_driver import AsyncDriver @@ -94,9 +94,14 @@ async def generate_sql( async def post_process( generate_sql: dict, post_processor: SQLGenPostProcessor, + engine_timeout: float, project_id: str | None = None, ) -> dict: - return await post_processor.run(generate_sql.get("replies"), project_id=project_id) + return await post_processor.run( + generate_sql.get("replies"), + timeout=engine_timeout, + project_id=project_id, + ) ## End of Pipeline @@ -118,6 +123,7 @@ def __init__( self, llm_provider: LLMProvider, engine: Engine, + engine_timeout: Optional[float] = 30.0, **kwargs, ): self._components = { @@ -131,6 +137,10 @@ def __init__( "post_processor": SQLGenPostProcessor(engine=engine), } + self._configs = { + "engine_timeout": engine_timeout, + } + super().__init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) @@ -160,6 +170,7 @@ async def run( "has_calculated_field": has_calculated_field, "has_metric": has_metric, **self._components, + **self._configs, }, ) diff --git a/wren-ai-service/src/pipelines/generation/utils/sql.py b/wren-ai-service/src/pipelines/generation/utils/sql.py index fb1bc60e8..f052b3d80 100644 --- a/wren-ai-service/src/pipelines/generation/utils/sql.py +++ b/wren-ai-service/src/pipelines/generation/utils/sql.py @@ -29,6 +29,7 @@ async def run( self, replies: List[str], project_id: str | None = None, + timeout: Optional[float] = 30.0, ) -> Dict[str, Any]: cleaned_generation_result = orjson.loads(clean_generation_result(replies[0])) @@ -56,7 +57,11 @@ async def run( sql = self._build_cte_query(steps) - if not await self._check_if_sql_executable(sql, project_id=project_id): + if not await self._check_if_sql_executable( + sql, + project_id=project_id, + timeout=timeout, + ): return { "results": { "description": cleaned_generation_result["description"], @@ -84,12 +89,14 @@ async def _check_if_sql_executable( self, sql: str, project_id: str | None = None, + timeout: Optional[float] = 30.0, ): async with aiohttp.ClientSession() as session: status, _, addition = await self._engine.execute_sql( sql, session, project_id=project_id, + timeout=timeout, ) if not status: @@ -112,6 +119,7 @@ def __init__(self, engine: Engine): async def run( self, replies: List[str] | List[List[str]], + timeout: Optional[float] = 30.0, project_id: str | None = None, ) -> dict: try: @@ -138,7 +146,9 @@ async def run( valid_generation_results, invalid_generation_results, ) = await self._classify_invalid_generation_results( - cleaned_generation_result, project_id=project_id + cleaned_generation_result, + project_id=project_id, + timeout=timeout, ) return { @@ -154,7 +164,10 @@ async def run( } async def _classify_invalid_generation_results( - self, generation_results: list[str], project_id: str | None = None + self, + generation_results: list[str], + timeout: float, + project_id: str | None = None, ) -> List[Optional[Dict[str, str]]]: valid_generation_results = [] invalid_generation_results = [] @@ -164,7 +177,7 @@ async def _task(sql: str): if not error_message: status, _, addition = await self._engine.execute_sql( - quoted_sql, session, project_id=project_id + quoted_sql, session, project_id=project_id, timeout=timeout ) if status: diff --git a/wren-ai-service/src/pipelines/retrieval/sql_executor.py b/wren-ai-service/src/pipelines/retrieval/sql_executor.py index e0a721f84..d34bff8cd 100644 --- a/wren-ai-service/src/pipelines/retrieval/sql_executor.py +++ b/wren-ai-service/src/pipelines/retrieval/sql_executor.py @@ -27,6 +27,7 @@ async def run( sql: str, project_id: str | None = None, limit: int = 500, + timeout: float = 30.0, ): async with aiohttp.ClientSession() as session: _, data, _ = await self._engine.execute_sql( @@ -35,6 +36,7 @@ async def run( project_id=project_id, dry_run=False, limit=limit, + timeout=timeout, ) return {"results": data} @@ -45,10 +47,16 @@ async def run( async def execute_sql( sql: str, data_fetcher: DataFetcher, + engine_timeout: float, project_id: str | None = None, limit: int = 500, ) -> dict: - return await data_fetcher.run(sql=sql, project_id=project_id, limit=limit) + return await data_fetcher.run( + sql=sql, + project_id=project_id, + limit=limit, + timeout=engine_timeout, + ) ## End of Pipeline @@ -58,12 +66,17 @@ class SQLExecutor(BasicPipeline): def __init__( self, engine: Engine, + engine_timeout: Optional[float] = 30.0, **kwargs, ): self._components = { "data_fetcher": DataFetcher(engine=engine), } + self._configs = { + "engine_timeout": engine_timeout, + } + super().__init__( AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) ) @@ -80,6 +93,7 @@ async def run( "project_id": project_id, "limit": limit, **self._components, + **self._configs, }, ) diff --git a/wren-ai-service/tools/config/config.example.yaml b/wren-ai-service/tools/config/config.example.yaml index 093a80e6d..6f8b58fbb 100644 --- a/wren-ai-service/tools/config/config.example.yaml +++ b/wren-ai-service/tools/config/config.example.yaml @@ -97,7 +97,6 @@ pipes: llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: sql_answer llm: litellm_llm.gpt-4o-mini-2024-07-18 - engine: wren_ui - name: sql_breakdown llm: litellm_llm.gpt-4o-mini-2024-07-18 engine: wren_ui @@ -156,6 +155,7 @@ pipes: settings: host: 127.0.0.1 port: 5556 + engine_timeout: 30 column_indexing_batch_size: 50 table_retrieval_size: 10 table_column_retrieval_size: 100 diff --git a/wren-ai-service/tools/config/config.full.yaml b/wren-ai-service/tools/config/config.full.yaml index 962c1153b..e2180886e 100644 --- a/wren-ai-service/tools/config/config.full.yaml +++ b/wren-ai-service/tools/config/config.full.yaml @@ -97,7 +97,6 @@ pipes: llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: sql_answer llm: litellm_llm.gpt-4o-mini-2024-07-18 - engine: wren_ui - name: sql_breakdown llm: litellm_llm.gpt-4o-mini-2024-07-18 engine: wren_ui @@ -156,6 +155,7 @@ pipes: settings: host: 127.0.0.1 port: 5556 + engine_timeout: 30 column_indexing_batch_size: 50 table_retrieval_size: 10 table_column_retrieval_size: 100 From 66fdb9d73bb260dae597b7b4844e41e7ad6713d0 Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Wed, 5 Feb 2025 14:11:16 +0800 Subject: [PATCH 3/4] skip sql correction if it's timeout --- .../src/pipelines/generation/utils/sql.py | 7 ++- wren-ai-service/src/web/v1/services/ask.py | 53 ++++++++++--------- 2 files changed, 33 insertions(+), 27 deletions(-) diff --git a/wren-ai-service/src/pipelines/generation/utils/sql.py b/wren-ai-service/src/pipelines/generation/utils/sql.py index f052b3d80..535872498 100644 --- a/wren-ai-service/src/pipelines/generation/utils/sql.py +++ b/wren-ai-service/src/pipelines/generation/utils/sql.py @@ -188,11 +188,14 @@ async def _task(sql: str): } ) else: + error_message = addition.get("error_message", "") invalid_generation_results.append( { "sql": quoted_sql, - "type": "DRY_RUN", - "error": addition.get("error_message", ""), + "type": "TIME_OUT" + if error_message.startswith("Request timed out") + else "DRY_RUN", + "error": error_message, "correlation_id": addition.get("correlation_id", ""), } ) diff --git a/wren-ai-service/src/web/v1/services/ask.py b/wren-ai-service/src/web/v1/services/ask.py index d7a51d05d..155cb67cf 100644 --- a/wren-ai-service/src/web/v1/services/ask.py +++ b/wren-ai-service/src/web/v1/services/ask.py @@ -361,32 +361,35 @@ async def ask( elif failed_dry_run_results := text_to_sql_generation_results[ "post_process" ]["invalid_generation_results"]: - self._ask_results[query_id] = AskResultResponse( - status="correcting", - ) - sql_correction_results = await self._pipelines[ - "sql_correction" - ].run( - contexts=documents, - invalid_generation_results=failed_dry_run_results, - project_id=ask_request.project_id, - ) + if failed_dry_run_results[0]["type"] != "TIME_OUT": + self._ask_results[query_id] = AskResultResponse( + status="correcting", + ) + sql_correction_results = await self._pipelines[ + "sql_correction" + ].run( + contexts=documents, + invalid_generation_results=failed_dry_run_results, + project_id=ask_request.project_id, + ) - if valid_generation_results := sql_correction_results[ - "post_process" - ]["valid_generation_results"]: - api_results = [ - AskResult( - **{ - "sql": valid_generation_result.get("sql"), - "type": "llm", - } - ) - for valid_generation_result in valid_generation_results - ][:1] - elif failed_dry_run_results := sql_correction_results[ - "post_process" - ]["invalid_generation_results"]: + if valid_generation_results := sql_correction_results[ + "post_process" + ]["valid_generation_results"]: + api_results = [ + AskResult( + **{ + "sql": valid_generation_result.get("sql"), + "type": "llm", + } + ) + for valid_generation_result in valid_generation_results + ][:1] + elif failed_dry_run_results := sql_correction_results[ + "post_process" + ]["invalid_generation_results"]: + error_message = failed_dry_run_results[0]["error"] + else: error_message = failed_dry_run_results[0]["error"] if api_results: From 5d32919220a2ddb2179aef5999ce93a19ad7cf7f Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Wed, 5 Feb 2025 14:18:15 +0800 Subject: [PATCH 4/4] skip sql correction if it's timeout --- .../src/web/v1/services/sql_expansion.py | 31 ++++++++++--------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/wren-ai-service/src/web/v1/services/sql_expansion.py b/wren-ai-service/src/web/v1/services/sql_expansion.py index 43f6cd72e..f481559df 100644 --- a/wren-ai-service/src/web/v1/services/sql_expansion.py +++ b/wren-ai-service/src/web/v1/services/sql_expansion.py @@ -170,20 +170,23 @@ async def sql_expansion( if failed_dry_run_results := sql_expansion_generation_results[ "post_process" ]["invalid_generation_results"]: - sql_correction_results = await self._pipelines[ - "sql_correction" - ].run( - contexts=documents, - invalid_generation_results=failed_dry_run_results, - project_id=sql_expansion_request.project_id, - ) - if sql_correction_valid_results := sql_correction_results[ - "post_process" - ]["valid_generation_results"]: - valid_generation_results += sql_correction_valid_results - elif failed_dry_run_results := sql_correction_results[ - "post_process" - ]["invalid_generation_results"]: + if failed_dry_run_results[0]["type"] != "TIME_OUT": + sql_correction_results = await self._pipelines[ + "sql_correction" + ].run( + contexts=documents, + invalid_generation_results=failed_dry_run_results, + project_id=sql_expansion_request.project_id, + ) + if sql_correction_valid_results := sql_correction_results[ + "post_process" + ]["valid_generation_results"]: + valid_generation_results += sql_correction_valid_results + elif failed_dry_run_results := sql_correction_results[ + "post_process" + ]["invalid_generation_results"]: + error_message = failed_dry_run_results[0]["error"] + else: error_message = failed_dry_run_results[0]["error"] valid_sql_summary_results = []