Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(wren-ai-service): Add LLM-based evaluation metrics for SQL generation #1303

Merged
merged 11 commits into from
Feb 21, 2025
27 changes: 23 additions & 4 deletions wren-ai-service/eval/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,26 @@
import eval.pipelines as pipelines
import src.providers as provider
from eval import EvalSettings
from eval.utils import engine_config, parse_toml, trace_metadata
from eval.utils import parse_toml, trace_metadata
from src import utils


def formatter(prediction: dict, meta: dict) -> dict:
"""
Formats the prediction result to be used as evaluation input.

This function takes a prediction dictionary and a meta dictionary,
processes them to extract relevant information, and returns a formatted
dictionary that serves as input for evaluation. It includes details such
as input, actual and expected outputs, context, and additional metadata.

Args:
prediction (dict): A dictionary containing prediction details.
meta (dict): A dictionary containing metadata information.

Returns:
dict: A formatted dictionary containing evaluation input data.
"""
retrieval_context = [str(context) for context in prediction["retrieval_context"]]
context = [str(context) for context in prediction["context"]]
enable_spider_metrics = "spider" in meta.get("evaluation_dataset", "").lower()
Expand All @@ -33,6 +48,7 @@ def formatter(prediction: dict, meta: dict) -> dict:
"expected_output": prediction["expected_output"],
"retrieval_context": retrieval_context,
"context": context,
"reasoning": prediction.get("reasoning", ""),
"additional_metadata": {
"trace_id": prediction["trace_id"],
"trace_url": prediction["trace_url"],
Expand Down Expand Up @@ -82,7 +98,9 @@ def eval(self, meta: dict, predictions: list) -> None:

try:
test_case = LLMTestCase(**formatter(prediction, meta))
result = evaluate([test_case], self._metrics, ignore_errors=True).test_results[0]
result = evaluate(
[test_case], self._metrics, ignore_errors=True
).test_results[0]
self._score_metrics(test_case, result)
[metric.collect(test_case, result) for metric in self._post_metrics]
except Exception:
Expand Down Expand Up @@ -152,8 +170,9 @@ def _average_score(self, meta: dict) -> None:
predictions = predicted_file["predictions"]

dataset = parse_toml(meta["evaluation_dataset"])
engine_info = engine_config(dataset["mdl"], pipe_components)
metrics = pipelines.metrics_initiator(meta["pipeline"], engine_info, args.semantics)
metrics = pipelines.metrics_initiator(
meta["pipeline"], dataset, pipe_components, args.semantics
)

evaluator = Evaluator(**metrics)
evaluator.eval(meta, predictions)
Expand Down
8 changes: 8 additions & 0 deletions wren-ai-service/eval/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
from .context_recall import ContextualRecallMetric
from .context_relevancy import ContextualRelevancyMetric
from .faithfulness import FaithfulnessMetric
from .llm import (
QuestionToReasoningJudge,
ReasoningToSqlJudge,
SqlSemanticsJudge,
)
from .spider.exact_match import ExactMatchAccuracy
from .spider.exec_match import ExecutionAccuracy

Expand All @@ -17,4 +22,7 @@
"FaithfulnessMetric",
"ExactMatchAccuracy",
"ExecutionAccuracy",
"QuestionToReasoningJudge",
"ReasoningToSqlJudge",
"SqlSemanticsJudge",
]
173 changes: 173 additions & 0 deletions wren-ai-service/eval/metrics/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import asyncio

from deepeval.metrics import BaseMetric
from deepeval.test_case import LLMTestCase
from haystack.components.builders.prompt_builder import PromptBuilder
from pydantic import BaseModel

from src.providers import LLMProvider


class EvalResult(BaseModel):
score: float
reason: str


_MODEL_KWARGS = {
"response_format": {
"type": "json_schema",
"json_schema": {
"name": "eval_result",
"schema": EvalResult.model_json_schema(),
},
}
}


def format(response: dict) -> EvalResult:
reply = response.get("replies", [])[0]
return EvalResult.model_validate_json(reply)


class QuestionToReasoningJudge(BaseMetric):
_system_prompt = """
You are an expert evaluator. Your task is to analyze the reasoning provided for a given question and determine if it makes sense.
Provide a score in the range 0.0~1.0 and a detailed explanation for your evaluation.
"""
_test_case_prompt = """
Question:
{{ question }}

Reasoning:
{{ reasoning }}
"""

def __init__(self, llm_provider: LLMProvider, **_):
self.threshold = 0
self.score = 0
self.llm_provider = llm_provider
self.llm = llm_provider.get_generator(
system_prompt=self._system_prompt,
generation_kwargs=_MODEL_KWARGS,
)
self.prompt_builder = PromptBuilder(template=self._test_case_prompt)

def measure(self, test_case: LLMTestCase):
return asyncio.run(self.a_measure(test_case))

async def a_measure(self, test_case: LLMTestCase, *args, **kwargs):
prompt = self.prompt_builder.run(
question=test_case.input,
reasoning=test_case.reasoning,
)
response = await self.llm(prompt.get("prompt"))
result = format(response)

self.score = result.score
self.reason = result.reason

self.success = self.score >= self.threshold
return self.score

def is_successful(self):
return self.success

@property
def __name__(self):
return "QuestionToReasoningJudge"


class ReasoningToSqlJudge(BaseMetric):
_system_prompt = """
You are an expert evaluator. Your task is to analyze the reasoning provided for a given SQL query and determine if it makes sense.
Provide a score in the range 0.0~1.0 and a detailed explanation for your evaluation.
"""
_test_case_prompt = """
Actual Output:
{{ actual_output }}

Reasoning:
{{ reasoning }}
"""

def __init__(self, llm_provider: LLMProvider, **_):
self.threshold = 0
self.score = 0
self.llm_provider = llm_provider
self.llm = llm_provider.get_generator(
system_prompt=self._system_prompt,
generation_kwargs=_MODEL_KWARGS,
)
self.prompt_builder = PromptBuilder(template=self._test_case_prompt)

def measure(self, test_case: LLMTestCase):
return asyncio.run(self.a_measure(test_case))

async def a_measure(self, test_case: LLMTestCase, *args, **kwargs):
prompt = self.prompt_builder.run(
actual_output=test_case.actual_output,
reasoning=test_case.reasoning,
)
response = await self.llm(prompt.get("prompt"))
result = format(response)

self.score = result.score
self.reason = result.reason

self.success = self.score >= self.threshold
return self.score

def is_successful(self):
return self.success

@property
def __name__(self):
return "ReasoningToSqlJudge"


class SqlSemanticsJudge(BaseMetric):
_system_prompt = """
You are an expert evaluator. Your task is to analyze the actual SQL query and the expected SQL query and determine if they are semantically equivalent.
Provide a score in the range 0.0~1.0 and a detailed explanation for your evaluation.
"""
_test_case_prompt = """
Actual SQL:
{{ actual_sql }}

Expected SQL:
{{ expected_sql }}
"""

def __init__(self, llm_provider: LLMProvider, **_):
self.threshold = 0
self.score = 0
self.llm_provider = llm_provider
self.llm = llm_provider.get_generator(
system_prompt=self._system_prompt,
generation_kwargs=_MODEL_KWARGS,
)
self.prompt_builder = PromptBuilder(template=self._test_case_prompt)

def measure(self, test_case: LLMTestCase):
return asyncio.run(self.a_measure(test_case))

async def a_measure(self, test_case: LLMTestCase, *args, **kwargs):
prompt = self.prompt_builder.run(
actual_sql=test_case.actual_output,
expected_sql=test_case.expected_output,
)
response = await self.llm(prompt.get("prompt"))
result = format(response)

self.score = result.score
self.reason = result.reason

self.success = self.score >= self.threshold
return self.score

def is_successful(self):
return self.success

@property
def __name__(self):
return "SqlSemanticsJudge"
36 changes: 31 additions & 5 deletions wren-ai-service/eval/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from langfuse.decorators import langfuse_context, observe
from tqdm.asyncio import tqdm_asyncio

from src.core.pipeline import PipelineComponent

sys.path.append(f"{Path().parent.resolve()}")

from eval import EvalSettings
Expand All @@ -23,6 +25,9 @@
ExactMatchAccuracy,
ExecutionAccuracy,
FaithfulnessMetric,
QuestionToReasoningJudge,
ReasoningToSqlJudge,
SqlSemanticsJudge,
)
from eval.utils import (
engine_config,
Expand Down Expand Up @@ -290,7 +295,11 @@ async def __call__(self, query: str, **_):
]

@staticmethod
def metrics(engine_info: dict, enable_semantics_comparison: bool) -> dict:
def metrics(
engine_info: dict,
enable_semantics_comparison: bool,
component: PipelineComponent,
) -> dict:
return {
"metrics": [
AccuracyMetric(
Expand All @@ -302,6 +311,9 @@ def metrics(engine_info: dict, enable_semantics_comparison: bool) -> dict:
# this is for spider dataset, rn we temporarily disable it
ExactMatchAccuracy(),
ExecutionAccuracy(),
QuestionToReasoningJudge(**component),
ReasoningToSqlJudge(**component),
SqlSemanticsJudge(**component),
],
"post_metrics": [],
}
Expand Down Expand Up @@ -402,7 +414,11 @@ async def __call__(self, query: str, **_):
]

@staticmethod
def metrics(engine_info: dict, enable_semantics_comparison: bool) -> dict:
def metrics(
engine_info: dict,
enable_semantics_comparison: bool,
component: PipelineComponent,
) -> dict:
return {
"metrics": [
AccuracyMetric(
Expand All @@ -417,6 +433,9 @@ def metrics(engine_info: dict, enable_semantics_comparison: bool) -> dict:
# this is for spider dataset, rn we temporarily disable it
ExactMatchAccuracy(),
ExecutionAccuracy(),
QuestionToReasoningJudge(**component),
ReasoningToSqlJudge(**component),
SqlSemanticsJudge(**component),
],
"post_metrics": [],
}
Expand Down Expand Up @@ -449,13 +468,20 @@ def init(

def metrics_initiator(
pipeline: str,
engine_info: dict,
dataset: dict,
pipe_components: dict[str, PipelineComponent],
enable_semantics_comparison: bool = True,
) -> dict:
engine_info = engine_config(dataset["mdl"], pipe_components)
component = pipe_components["evaluation"]
match pipeline:
case "retrieval":
return RetrievalPipeline.metrics(engine_info)
case "generation":
return GenerationPipeline.metrics(engine_info, enable_semantics_comparison)
return GenerationPipeline.metrics(
engine_info, enable_semantics_comparison, component
)
case "ask":
return AskPipeline.metrics(engine_info, enable_semantics_comparison)
return AskPipeline.metrics(
engine_info, enable_semantics_comparison, component
)
2 changes: 2 additions & 0 deletions wren-ai-service/tools/config/config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ pipes:
- name: sql_regeneration
llm: litellm_llm.gpt-4o-mini-2024-07-18
engine: wren_ui
- name: evaluation
llm: litellm_llm.gpt-4o-mini-2024-07-18

---
settings:
Expand Down
2 changes: 2 additions & 0 deletions wren-ai-service/tools/config/config.full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ pipes:
- name: sql_regeneration
llm: litellm_llm.gpt-4o-mini-2024-07-18
engine: wren_ui
- name: evaluation
llm: litellm_llm.gpt-4o-mini-2024-07-18

---
settings:
Expand Down
Loading