From 3e80570713292527a68dc118d85581cb9ab97b3e Mon Sep 17 00:00:00 2001 From: cjopengler Date: Tue, 26 Dec 2023 10:25:39 +0800 Subject: [PATCH 1/2] add gbi code and doc --- appbuilder/__init__.py | 3 + appbuilder/core/components/gbi/__init__.py | 13 + appbuilder/core/components/gbi/basic.py | 100 ++++ .../core/components/gbi/nl2sql/README.md | 279 +++++++++ .../core/components/gbi/nl2sql/__init__.py | 13 + .../core/components/gbi/nl2sql/component.py | 153 +++++ .../components/gbi/select_table/README.md | 120 ++++ .../components/gbi/select_table/__init__.py | 13 + .../components/gbi/select_table/component.py | 140 +++++ appbuilder/tests/test_gbi_nl2sql.py | 167 ++++++ appbuilder/tests/test_gbi_select_table.py | 111 ++++ cookbooks/gbi.ipynb | 550 ++++++++++++++++++ 12 files changed, 1662 insertions(+) create mode 100644 appbuilder/core/components/gbi/__init__.py create mode 100644 appbuilder/core/components/gbi/basic.py create mode 100644 appbuilder/core/components/gbi/nl2sql/README.md create mode 100644 appbuilder/core/components/gbi/nl2sql/__init__.py create mode 100644 appbuilder/core/components/gbi/nl2sql/component.py create mode 100644 appbuilder/core/components/gbi/select_table/README.md create mode 100644 appbuilder/core/components/gbi/select_table/__init__.py create mode 100644 appbuilder/core/components/gbi/select_table/component.py create mode 100644 appbuilder/tests/test_gbi_nl2sql.py create mode 100644 appbuilder/tests/test_gbi_select_table.py create mode 100644 cookbooks/gbi.ipynb diff --git a/appbuilder/__init__.py b/appbuilder/__init__.py index 5db1cc3eb..ab40f1c92 100644 --- a/appbuilder/__init__.py +++ b/appbuilder/__init__.py @@ -67,6 +67,9 @@ def check_version(self): from .core.components.embeddings import Embedding from .core.components.matching import Matching +from .core.components.gbi.nl2sql.component import GBINL2Sql +from .core.components.gbi.select_table.component import GBISelectTable + from appbuilder.core.message import Message from appbuilder.core.agent import AgentBase from appbuilder.core.context import UserSession diff --git a/appbuilder/core/components/gbi/__init__.py b/appbuilder/core/components/gbi/__init__.py new file mode 100644 index 000000000..c33303636 --- /dev/null +++ b/appbuilder/core/components/gbi/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023 Baidu, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/appbuilder/core/components/gbi/basic.py b/appbuilder/core/components/gbi/basic.py new file mode 100644 index 000000000..c51e716a0 --- /dev/null +++ b/appbuilder/core/components/gbi/basic.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python 3 +# -*- coding: utf-8 -*- +# Copyright (c) 2023 Baidu, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""GBI nl2sql component. +""" + +from typing import Dict, List + + +class NL2SqlResult(object): + """ + gbi_nl2sql 返回的结果 + """ + + def __init__(self, llm_result: str, sql: str): + """ + 初始化 + Args: + llm_result: 大模型返回的结果 + sql: 从 llm_result 中抽取的 sql 语句 + """ + self.llm_result = llm_result + self.sql = sql + + def to_json(self) -> Dict: + """ + 转换成 字典 + Returns: + + """ + return self.__dict__ + + +class GBISessionRecord(object): + """ + gbi session record + """ + + def __init__(self, query: str, answer: NL2SqlResult): + """ + GBI Session 的记录 + Args: + query: 用户的问题 + answer: gbi_nl2sql 返回的结果 + """ + self.query = query + self.answer = answer + + def to_json(self) -> Dict: + return {"query": self.query, + "answer": self.answer.to_json()} + + +class ColumnItem(object): + """ + column item + """ + + def __init__(self, ori_value: str, column_name: str, column_value: str, table_name: str, + is_like: bool = False): + """ + 用于标识 query 中的词 应该对应到数据库中的某个列值以及列名,用于提升 sql 生成效果 + Args: + ori_value: query 中的 词语, 比如: "北京去年收入", 分词后: "北京, 去年, 收入", ori_value 是分词中某一个,比如: ori_value = "北京" + column_name: 对应数据库中的列名称, city + column_value: 对应数据库中的列值, 北京市 + table_name: 该列所属的表名称 + is_like: 与 ori_value 的匹配是包含 还是 等于,包含: True; 等于: False + """ + self.column_name = column_name + self.column_value = column_value + self.ori_value = ori_value + self.table_name = table_name + self.is_like = is_like + + def to_json(self) -> Dict: + """ + 转换成 json + Returns: + + """ + return self.__dict__ + + +SUPPORTED_MODEL_NAME = { + "ERNIE-Bot 4.0", "ERNIE-Bot-8K", "ERNIE-Bot", "ERNIE-Bot-turbo", "EB-turbo-AppBuilder" +} diff --git a/appbuilder/core/components/gbi/nl2sql/README.md b/appbuilder/core/components/gbi/nl2sql/README.md new file mode 100644 index 000000000..d605100d2 --- /dev/null +++ b/appbuilder/core/components/gbi/nl2sql/README.md @@ -0,0 +1,279 @@ +# GBI 问表 + +## 简介 +GBI 问表,根据提供的表的 schema 信息,生成对应问题的 sql 语句。 + +## 基本用法 + +### 快速开启 + + +```` +import logging +import os +import appbuilder +from appbuilder.core.message import Message +from appbuilder.core.components.gbi.basic import GBISessionRecord + +# 设置环境变量 +os.environ["APPBUILDER_TOKEN"] = "***" + +SUPER_MARKET_SCHEMA = """ +``` +CREATE TABLE `超市营收明细表` ( + `订单编号` varchar(32) DEFAULT NULL, + `订单日期` date DEFAULT NULL, + `邮寄方式` varchar(32) DEFAULT NULL, + `地区` varchar(32) DEFAULT NULL, + `省份` varchar(32) DEFAULT NULL, + `客户类型` varchar(32) DEFAULT NULL, + `客户名称` varchar(32) DEFAULT NULL, + `商品类别` varchar(32) DEFAULT NULL, + `制造商` varchar(32) DEFAULT NULL, + `商品名称` varchar(32) DEFAULT NULL, + `数量` int(11) DEFAULT NULL, + `销售额` int(11) DEFAULT NULL, + `利润` int(11) DEFAULT NULL +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 +``` +""" + +table_schemas = [SUPER_MARKET_SCHEMA] +query = "列出超市中的所有数据" +msg = Message(query) +session = list() +gbi_nl2sql = appbuilder.GBINL2Sql(model_name="ERNIE-Bot 4.0", table_schemas=table_schemas) +nl2sql_result_message = gbi_nl2sql(message=msg, session=session) +print(f"sql: {nl2sql_result_message.content.sql}") +print("-----------------") +print(f"llm result: {nl2sql_result_message.content.llm_result}") +```` + + sql: + SELECT * FROM `超市营收明细表`; + ----------------- + llm result: ```sql + SELECT * FROM `超市营收明细表`; + ``` + + +## 参数说明 + +### 初始化参数 +- model_name: 支持的模型名字 ERNIE-Bot 4.0, ERNIE-Bot-8K, ERNIE-Bot, ERNIE-Bot-turbo, EB-turbo-AppBuilder" +- table_schemas: 表的 schema,例如: + +``` +CREATE TABLE `超市营收明细表` ( + `订单编号` varchar(32) DEFAULT NULL, + `订单日期` date DEFAULT NULL, + `邮寄方式` varchar(32) DEFAULT NULL, + `地区` varchar(32) DEFAULT NULL, + `省份` varchar(32) DEFAULT NULL, + `客户类型` varchar(32) DEFAULT NULL, + `客户名称` varchar(32) DEFAULT NULL, + `商品类别` varchar(32) DEFAULT NULL, + `制造商` varchar(32) DEFAULT NULL, + `商品名称` varchar(32) DEFAULT NULL, + `数量` int(11) DEFAULT NULL, + `销售额` int(11) DEFAULT NULL, + `利润` int(11) DEFAULT NULL +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 +``` + +- knowledge: 用于提供一些知识, 比如 {"毛利率": "毛收入-毛成本/毛成本"} +- prompt_template: prompt 模版, 必须包含的格式如下: + + ***你的描述 + {schema} + ***你的描述 + {column_instrument} + ***你的描述 + {kg} + ***你的描述 + 当前时间:{date} + ***你的描述 + {history_instrument} + ***你的描述 + 当前问题:{query} + 回答: + +### 调用参数 +- message: message.content 是 query +- session: gbi session 的历史 列表, 参考 GBISessionRecord +- column_constraint: 列选约束 参考 ColumnItem 具体定义 + +#### GBISessionRecord 初始化参数 +- query: 用户的问题 +- answer: gbi_nl2sql 返回的结果 NL2SqlResult + +#### ColumnItem 初始化参数如下 +- ori_value: query 中的 词语, 比如: "北京去年收入", 分词后: "北京, 去年, 收入", ori_value 是分词中某一个,比如: ori_value = "北京" +- column_name: 对应数据库中的列名称, city +- column_value: 对应数据库中的列值, 北京市 +- table_name: 该列所属的表名称 +- is_like: 与 ori_value 的匹配是包含 还是 等于,包含: True; 等于: False + +### 返回值 +- NL2SqlResult 的 message + +#### NL2SqlResult 初始化参数如下 +- llm_result: 大模型返回的结果 +- sql: 从 llm_result 中抽取的 sql 语句 + +## 高级用法 +### 设置 session + + +```python +session.append(GBISessionRecord(query=query, answer=nl2sql_result_message.content)) +``` + +再次问表 + + +```python +query2 = "查看商品类别是水果的所有数据" +msg2 = Message(query2) +nl2sql_result_message2 = gbi_nl2sql(message=msg2, session=session) +print(f"sql: {nl2sql_result_message2.content.sql}") +print("-----------------") +print(f"llm result: {nl2sql_result_message2.content.llm_result}") +``` + + sql: + SELECT * FROM `超市营收明细表` WHERE `商品类别` = '水果'; + ----------------- + llm result: ```sql + SELECT * FROM `超市营收明细表` WHERE `商品类别` = '水果'; + ``` + + +### 增加列选优化 +实际上数据中 "商品类别" 存储的是 "新鲜水果", 那么就可以通过列选的限制来优化 sql. + + +```python +from appbuilder.core.components.gbi.basic import ColumnItem + +query2 = "查看商品类别是水果的所有数据" +msg2 = Message(query2) + +column_constraint = [ColumnItem(ori_value="水果", + column_name="商品类别", + column_value="新鲜水果", + table_name="超市营收明细表", + is_like=False)] +nl2sql_result_message2 = gbi_nl2sql(message=msg2, session=session, column_constraint=column_constraint) +print(f"sql: {nl2sql_result_message2.content.sql}") +print("-----------------") +print(f"llm result: {nl2sql_result_message2.content.llm_result}") +``` + + sql: + SELECT * FROM `超市营收明细表` WHERE `商品类别` = '新鲜水果' + ----------------- + llm result: ```sql + SELECT * FROM `超市营收明细表` WHERE `商品类别` = '新鲜水果' + ``` + + 这个查询会返回`超市营收明细表`中所有商品类别为"新鲜水果"的数据。因为问题中没有涉及到其他特定的条件或聚合操作,所以这是一个简单的筛选查询。 + + +从上面我们看到,商品类别不在是 "水果" 而是 修订为 "新鲜水果" + +### 增加知识优化 +当计算某些特殊知识的时候,大模型是不知道的,所以需要告诉大模型具体的知识,比如: +利润率的计算方式: 利润/销售额 +可以将该知识注入。具体示例如下: + + +```python +# 注入知识 +gbi_nl2sql.knowledge["利润率"] = "计算方式: 利润/销售额" +``` + + +```python +query3 = "列出商品类别是日用品的利润率" +msg3 = Message(query3) + +nl2sql_result_message3 = gbi_nl2sql(message=msg3, session=session, column_constraint=list()) +print(f"sql: {nl2sql_result_message3.content.sql}") +print("-----------------") +print(f"llm result: {nl2sql_result_message3.content.llm_result}") +``` + + sql: + SELECT 商品类别, SUM(利润)/SUM(销售额) AS 利润率 + FROM `超市营收明细表` + WHERE 商品类别 = '日用品' + GROUP BY 商品类别 + ----------------- + llm result: ```sql + SELECT 商品类别, SUM(利润)/SUM(销售额) AS 利润率 + FROM `超市营收明细表` + WHERE 商品类别 = '日用品' + GROUP BY 商品类别 + ``` + + 思考步骤: + + 1. 首先,我们需要从`超市营收明细表`中选择数据。 + 2. 根据当前问题,我们关心的是商品类别为“日用品”的数据。 + 3. 利润率是利润除以销售额,所以我们需要对利润和销售额进行聚合。 + 4. 使用`SUM`函数来计算总的利润和销售额。 + 5. 使用`GROUP BY`语句按商品类别进行分组,以确保我们计算的是日用品的总利润和总销售额。 + 6. 最后,选择商品类别并计算利润率,即利润总和除以销售额总和。 + + +## 调整 prompt 模版 +有时候,我们希望定义自己的prompt, 但是必须遵循对应的 prompt 模版的格式。 + + +问表的 prompt template 必须包含: +1. {schema} - 表的 schema 信息 +2. {instrument} - 列选限制的信息 +3. {kg} - 知识 +4. {date} - 时间 +5. {history_prompt} - 历史 +6. {query} - 当前问题 + +参考下面的示例 + + +```python +NL2SQL_PROMPT_TEMPLATE = """ + MySql 表 Schema 如下: + {schema} + 请根据用户当前问题,联系历史信息,仅编写1个sql,其中 sql 语句需要使用```sql ```这种 markdown 形式给出。 + 请参考列选信息: + {instrument} + 请参考知识: + {kg} + 当前时间:{date} + 历史信息如下: + {history_prompt} + 当前问题:"{query}" + 回答: +""" +``` + + +```python + +msg5 = Message("查看商品类别是水果的所有数据") +gbi_nl2sql5 = appbuilder.GBINL2Sql(model_name="ERNIE-Bot 4.0", table_schemas=table_schemas, prompt_template=NL2SQL_PROMPT_TEMPLATE) +nl2sql_result_message5 = gbi_nl2sql5(message=msg5, session=session) +print(f"sql: {nl2sql_result_message5.content.sql}") +print("-----------------") +print(f"llm result: {nl2sql_result_message5.content.llm_result}") +``` + + sql: + SELECT * FROM `超市营收明细表` WHERE `商品类别` = '水果'; + ----------------- + llm result: ```sql + SELECT * FROM `超市营收明细表` WHERE `商品类别` = '水果'; + ``` + diff --git a/appbuilder/core/components/gbi/nl2sql/__init__.py b/appbuilder/core/components/gbi/nl2sql/__init__.py new file mode 100644 index 000000000..c33303636 --- /dev/null +++ b/appbuilder/core/components/gbi/nl2sql/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023 Baidu, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/appbuilder/core/components/gbi/nl2sql/component.py b/appbuilder/core/components/gbi/nl2sql/component.py new file mode 100644 index 000000000..9ff7327c3 --- /dev/null +++ b/appbuilder/core/components/gbi/nl2sql/component.py @@ -0,0 +1,153 @@ +# Copyright (c) 2023 Baidu, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""GBI nl2sql component. +""" +import uuid +import json +from typing import Dict, List, Optional +from pydantic import BaseModel, Field + +from appbuilder.core.component import Component +from appbuilder.core.message import Message +from appbuilder.core.components.gbi.basic import GBISessionRecord +from appbuilder.core.components.gbi.basic import ColumnItem +from appbuilder.core.components.gbi.basic import NL2SqlResult +from appbuilder.core.components.gbi.basic import SUPPORTED_MODEL_NAME + + +class GBINL2Sql(Component): + """ + gib nl2sql + """ + + def __init__(self, model_name: str, table_schemas: List[str], knowledge: Dict = None, + prompt_template: str = "", + secret_key: Optional[str] = None, + gateway: str = ""): + """ + 创建 gbi nl2sql 对象 + Args: + model_name: 支持的模型名字 ERNIE-Bot 4.0, ERNIE-Bot-8K, ERNIE-Bot, ERNIE-Bot-turbo, EB-turbo-AppBuilder + table_schemas: 表的 schema 列表,例如: ``` + CREATE TABLE `mytable` ( + `d_year` COMMENT '年度,2019,2020..2022..', + `industry` COMMENT '行业', + `project_name` COMMENT '项目名称', + `customer_name` COMMENT '客户名称') + ```" + knowledge: 用于提供一些知识, 比如 {"毛利率": "毛收入-毛成本/毛成本"} + prompt_template: prompt 模版, 必须包含的格式如下: + ***你的描述 + {schema} + ***你的描述 + {column_instrument} + ***你的描述 + {kg} + ***你的描述 + 当前时间:{date} + ***你的描述 + {history_instrument} + ***你的描述 + 当前问题:{query} + 回答: + secret_key: 用户创建的 key + gateway: gateway 地址 + """ + super().__init__(secret_key=secret_key, gateway=gateway) + + if model_name not in SUPPORTED_MODEL_NAME: + raise ValueError(f"model_name 错误, 请使用 {SUPPORTED_MODEL_NAME} 中的大模型") + self.model_name = model_name + self.server_sub_path = "/v1/ai_engine/gbi/v1/gbi_nl2sql" + self.table_schemas = table_schemas + self.knowledge = knowledge or dict() + self.prompt_template = prompt_template + + def run(self, + message: Message, + session: List[GBISessionRecord], + column_constraint: List[ColumnItem] = None) -> Message[NL2SqlResult]: + """ + 执行 nl2sql + Args: + message: message.content 是 query + session: gbi session 的历史 列表 + column_constraint: 列选约束 参考 ColumnItem 具体定义 + Returns: + NL2SqlResult 的 message + """ + + query = message.content + session = session + column_constraint = column_constraint or list() + + response = self._run_nl2sql(query=query, session=session, table_schemas=self.table_schemas, + column_constraint=column_constraint, knowledge=self.knowledge, + prompt_template=self.prompt_template, + model_name=self.model_name, + timeout=60, + retry=2) + + rsp_data = response.json() + nl2sql_result = NL2SqlResult(llm_result=rsp_data["llm_result"], + sql=rsp_data["sql"]) + return Message(content=nl2sql_result) + + def _run_nl2sql(self, query: str, session: List[GBISessionRecord], table_schemas: List[str], knowledge: Dict[str, str], + prompt_template: str, + column_constraint: List[ColumnItem], + model_name: str, + timeout: float = None, retry: int = 0): + """ + 运行 + Args: + query: query + session: gbi session 的历史 列表 + table_schemas: 表的 schema 列表 + knowledge: 知识 + prompt_template: prompt 模版 + column_constraint: 列的限制 + model_name: 模型名字 + timeout: 超时时间 + retry: + + Returns: + + """ + + headers = self.auth_header() + headers["Content-Type"] = "application/json" + + if retry != self.retry.total: + self.retry.total = retry + + payload = {"query": query, + "table_schemas": table_schemas, + "session": [session_record.to_json() for session_record in session], + "column_constraint": [column_item.to_json() for column_item in column_constraint], + "model_name": model_name, + "knowledge": knowledge, + "prompt_template": prompt_template} + + server_url = self.service_url(prefix="", sub_path=self.server_sub_path) + response = self.s.post(url=server_url, headers=headers, + json=payload, timeout=timeout) + super().check_response_header(response) + data = response.json() + super().check_response_json(data) + + request_id = self.response_request_id(response) + response.request_id = request_id + return response diff --git a/appbuilder/core/components/gbi/select_table/README.md b/appbuilder/core/components/gbi/select_table/README.md new file mode 100644 index 000000000..13b03c4e0 --- /dev/null +++ b/appbuilder/core/components/gbi/select_table/README.md @@ -0,0 +1,120 @@ +# GBI 选表 + +## 简介 +GBI 选表,根据提供的表的描述信息以及 query 选择对应的表. + +## 基本用法 + +### 快速开启 + + + +```python +import logging +import os +import appbuilder +from appbuilder.core.message import Message +from appbuilder.core.components.gbi.basic import GBISessionRecord + +# 设置环境变量 +os.environ["APPBUILDER_TOKEN"] = "***" + +# 表的描述信息 +table_descriptions = { + "超市营收明细表": "超市营收明细表,包含超市各种信息等", + "product_sales_info": "产品销售表" +} + + +# 生成问表对象 +select_table = appbuilder.GBISelectTable(model_name="ERNIE-Bot 4.0", table_descriptions=table_descriptions) +query = "列出超市中的所有数据" +msg = Message(query) +session = list() +select_table_result_message = select_table(message=msg, session=session) +print(f"选的表是: {select_table_result_message.content}") +``` + + 选的表是: ['超市营收明细表'] + + +## 参数说明 +### 初始化参数 +- model_name: 支持的模型名字 ERNIE-Bot 4.0, ERNIE-Bot-8K, ERNIE-Bot, ERNIE-Bot-turbo, EB-turbo-AppBuilder +- table_descriptions: 表的描述是个字典,key: 是表的名字, value: 是表的描述,例如: + +``` +{ + "超市营收明细表": "超市营收明细表,包含超市各种信息等", + "product_sales_info": "产品销售表" +} +``` +- prompt_template: prompt 模版, 必须包含如下: + 1. {num} - 表的数量, 注意 {num} 有两个地方出现 + 2. {table_desc} - 表的描述 + 3. {query} - query + 参考下面的示例: + +``` +你是一个专业的业务人员,下面有{num}张表,具体表名如下: +{table_desc} +请根据问题帮我选择上述1-{num}种的其中相关表并返回,可以为多表,也可以为单表, +返回多张表请用“,”隔开 +返回格式请参考如下示例: +问题:有多少个审核通过的投运单? +回答: ```DWD_MAT_OPERATION``` +请严格参考示例只不要返回无关内容,直接给出最终答案后面的内容,分析步骤不要输出 +问题:{query} +回答: +``` + +### 调用参数 +- message: message.content 是用户的问题,也就是 query +- session: GBISessionRecord 列表 + +#### GBISessionRecord 初始化参数 +- query: 用户的问题 +- answer: gbi_nl2sql 返回的结果 NL2SqlResult + +### 返回值 +识别的表名的列表例如: +`["table_name"]` + +## 调整 prompt 模版 +有时候,我们希望定义自己的prompt, 选表支持 prompt 模版的定制化,但是必须遵循对应的 prompt 模版的格式。 + +### 选表 prompt 调整 +选表的 prompt template, 必须包含 +1. {num} - 表的数量, 注意 {num} 有两个地方出现 +2. {table_desc} - 表的描述 +3. {query} - query, 参考下面的示例: + + +```python +SELECT_TABLE_PROMPT_TEMPLATE = """ +你是一个专业的业务人员,下面有{num}张表,具体表名如下: +{table_desc} +请根据问题帮我选择上述1-{num}种的其中相关表并返回,可以为多表,也可以为单表, +返回多张表请用“,”隔开 +返回格式请参考如下示例: +问题:有多少个审核通过的投运单? +回答: ```DWD_MAT_OPERATION``` +请严格参考示例只不要返回无关内容,直接给出最终答案后面的内容,分析步骤不要输出 +问题:{query} +回答: +""" +``` + + +```python +select_table4 = appbuilder.GBISelectTable(model_name="ERNIE-Bot 4.0", + table_descriptions=table_descriptions, + prompt_template=SELECT_TABLE_PROMPT_TEMPLATE) +query4 = "列出超市中的所有数据" +msg4 = Message(query4) +select_table_result_message4 = select_table4(message=msg4, session=list()) +print(f"选的表是: {select_table_result_message4.content}") +``` + + 选的表是: ['超市营收明细表'] + diff --git a/appbuilder/core/components/gbi/select_table/__init__.py b/appbuilder/core/components/gbi/select_table/__init__.py new file mode 100644 index 000000000..c33303636 --- /dev/null +++ b/appbuilder/core/components/gbi/select_table/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023 Baidu, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/appbuilder/core/components/gbi/select_table/component.py b/appbuilder/core/components/gbi/select_table/component.py new file mode 100644 index 000000000..21bda916e --- /dev/null +++ b/appbuilder/core/components/gbi/select_table/component.py @@ -0,0 +1,140 @@ +# Copyright (c) 2023 Baidu, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""GBI nl2sql component. +""" +import uuid +import json +from typing import Dict, List, Optional +from pydantic import BaseModel, Field + +from appbuilder.core.component import Component +from appbuilder.core.message import Message +from appbuilder.core.components.gbi.basic import GBISessionRecord +from appbuilder.core.components.gbi.basic import SUPPORTED_MODEL_NAME + + +class GBISelectTable(Component): + """ + gbi 选表 + """ + + def __init__(self, model_name: str, table_descriptions: Dict[str, str], + prompt_template: str = "", + secret_key: Optional[str] = None, + gateway: str = ""): + """ + 创建 GBI 选表对象 + Args: + model_name: 支持的模型名字 ERNIE-Bot 4.0, ERNIE-Bot-8K, ERNIE-Bot, ERNIE-Bot-turbo, EB-turbo-AppBuilder + table_descriptions: 表的描述是个字典,key: 是表的名字, value: 是表的描述,例如: + { + "超市营收明细表": "超市营收明细表,包含超市各种信息等", + "product_sales_info": "产品销售表" + } + prompt_template: rompt 模版, 必须包含如下: + 1. {num} - 表的数量, 注意 {num} 有两个地方出现 + 2. {table_desc} - 表的描述 + 3. {query} - query + 参考下面的示例: + + ``` + 你是一个专业的业务人员,下面有{num}张表,具体表名如下: + {table_desc} + 请根据问题帮我选择上述1-{num}种的其中相关表并返回,可以为多表,也可以为单表, + 返回多张表请用“,”隔开 + 返回格式请参考如下示例: + 问题:有多少个审核通过的投运单? + 回答: ```DWD_MAT_OPERATION``` + 请严格参考示例只不要返回无关内容,直接给出最终答案后面的内容,分析步骤不要输出 + 问题:{query} + 回答: + ``` + secret_key: + gateway: + """ + super().__init__(secret_key=secret_key, gateway=gateway) + if model_name not in SUPPORTED_MODEL_NAME: + raise ValueError(f"model_name 错误, 请使用 {SUPPORTED_MODEL_NAME} 中的大模型") + self.model_name = model_name + self.server_sub_path = "/v1/ai_engine/gbi/v1/gbi_select_table" + self.table_descriptions = table_descriptions + self.prompt_template = prompt_template + + def run(self, + message: Message, + session: List[GBISessionRecord]) -> Message[List[str]]: + """ + Args: + message: message.content 是用户的问题,也就是 query + session: GBISessionRecord 列表 + + Returns: 识别的表名的列表 ["table_name"] + """ + + + query = message.content + session = session + + response = self._run_select_table(query=query, session=session, + prompt_template=self.prompt_template, + table_descriptions=self.table_descriptions, + model_name=self.model_name, + timeout=60, + retry=2) + + rsp_data = response.json() + + return Message(content=rsp_data) + + def _run_select_table(self, query: str, session: List[GBISessionRecord], + prompt_template, + table_descriptions: Dict[str, str], + model_name: str, + timeout: float = None, retry: int = 0): + """ + 使用给定的输入并返回语音识别的结果。 + + 参数: + request (obj:`ShortSpeechRecognitionRequest`): 输入请求,这是一个必需的参数。 + timeout (float, 可选): 请求的超时时间。 + retry (int, 可选): 请求的重试次数。 + + 返回: + obj:`ShortSpeechRecognitionResponse`: 接口返回的输出消息。 + """ + + headers = self.auth_header() + headers["Content_Type"] = "application/json" + + if retry != self.retry.total: + self.retry.total = retry + + payload = {"query": query, + "table_descriptions": table_descriptions, + "session": [session_record.to_json() for session_record in session], + "model_name": model_name, + "prompt_template": prompt_template} + + server_url = self.service_url(sub_path=self.server_sub_path) + response = self.s.post(url=server_url, headers=headers, + json=payload, timeout=timeout) + super().check_response_header(response) + data = response.json() + super().check_response_json(data) + + request_id = self.response_request_id(response) + response.request_id = request_id + return response + diff --git a/appbuilder/tests/test_gbi_nl2sql.py b/appbuilder/tests/test_gbi_nl2sql.py new file mode 100644 index 000000000..89aa17d12 --- /dev/null +++ b/appbuilder/tests/test_gbi_nl2sql.py @@ -0,0 +1,167 @@ +""" +Copyright (c) 2023 Baidu, Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import unittest +import os +import appbuilder +from appbuilder.core.message import Message +from appbuilder.core.components.gbi.basic import NL2SqlResult, GBISessionRecord +from appbuilder.core.components.gbi.basic import ColumnItem + +SUPER_MARKET_SCHEMA = """ +``` +CREATE TABLE `超市营收明细` ( + `订单编号` varchar(32) DEFAULT NULL, + `订单日期` date DEFAULT NULL, + `邮寄方式` varchar(32) DEFAULT NULL, + `地区` varchar(32) DEFAULT NULL, + `省份` varchar(32) DEFAULT NULL, + `客户类型` varchar(32) DEFAULT NULL, + `客户名称` varchar(32) DEFAULT NULL, + `商品类别` varchar(32) DEFAULT NULL, + `制造商` varchar(32) DEFAULT NULL, + `商品名称` varchar(32) DEFAULT NULL, + `数量` int(11) DEFAULT NULL, + `销售额` int(11) DEFAULT NULL, + `利润` int(11) DEFAULT NULL +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 +``` +""" + +PRODUCT_SALES_INFO = """ +现有 mysql 表 product_sales_info, +该表的用途是: 产品收入表 +``` +CREATE TABLE `product_sales_info` ( + `年` int, + `月` int, + `产品名称` varchar, + `收入` decimal, + `非交付成本` decimal, + `含交付毛利` decimal +) +``` +""" + +PROMPT_TEMPLATE = """ + MySql 表 Schema 如下: + {schema} + 请根据用户当前问题,联系历史信息,仅编写1个sql,其中 sql 语句需要使用```sql ```这种 markdown 形式给出。 + 请参考列选信息: + {instrument} + 请参考知识: + {kg} + 当前时间:{date} + 历史信息如下:{history_prompt} + 当前问题:"{query}" + 回答: +""" + + +class TestGBINL2Sql(unittest.TestCase): + + def setUp(self): + """ + 设置环境变量及必要数据。 + """ + model_name = "ERNIE-Bot 4.0" + table_schemas = [SUPER_MARKET_SCHEMA] + self.nl2sql_node = appbuilder.GBINL2Sql(model_name=model_name, + table_schemas=table_schemas) + + def test_run_with_default_param(self): + """测试 run 方法使用有效参数""" + query = "列出商品类别是水果的所有信息" + msg = Message(query) + session = list() + result_message = self.nl2sql_node(message=msg, session=session) + print(result_message.content.sql) + self.assertIsNotNone(result_message) + self.assertTrue(result_message.content.sql != "") + self.assertTrue(result_message.content.llm_result != "") + + def test_run_with_knowledge(self): + """测试 增加 knowledge 参数""" + + self.nl2sql_node.knowledge["利润率"] = "计算方式: 利润/销售额" + query = "列出商品类别是水果的的利润率" + + msg = Message(query) + session = list() + result_message = self.nl2sql_node(message=msg, session=session) + self.assertIsNotNone(result_message) + self.assertTrue(result_message.content.sql != "") + self.assertTrue(result_message.content.llm_result != "") + self.nl2sql_node.knowledge = dict() + + def test_run_with_column_constraint(self): + """测试 增加 column constraint 参数""" + + query = "列出商品类别是水果的的利润率" + + msg = Message(query) + session = list() + column_constraint = [ColumnItem(ori_value="水果", + column_value="新鲜水果", + column_name="商品类别", + table_name="超市营收明细", + is_like=False)] + result_message = self.nl2sql_node(message=msg, session=session, column_constraint=column_constraint) + + self.assertIsNotNone(result_message) + self.assertTrue(result_message.content.sql != "") + self.assertTrue(result_message.content.llm_result != "") + self.assertIn("新鲜水果", result_message.content.sql) + + def test_run_with_prompt_template(self): + """测试 增加 prompt template 参数""" + self.nl2sql_node.prompt_template = PROMPT_TEMPLATE + query = "列出商品类别是水果的的利润率" + + msg = Message(query) + session = list() + column_constraint = [ColumnItem(ori_value="水果", + column_value="新鲜水果", + column_name="商品类别", + table_name="超市营收明细", + is_like=False)] + result_message = self.nl2sql_node(message=msg, session=session, column_constraint=column_constraint) + self.assertIsNotNone(result_message) + self.assertTrue(result_message.content.sql != "") + self.assertTrue(result_message.content.llm_result != "") + # 恢复 prompt template + self.nl2sql_node.prompt_template = "" + + def test_run_with_session(self): + """测试 增加 session 参数""" + session = list() + session_record = GBISessionRecord(query="列出商品类别是水果的的利润率", + answer=NL2SqlResult( + llm_result="根据问题分析得到 sql 如下: \n " + "```sql\nSELECT * FROM `超市营收明细` " + "WHERE `商品类别` = '水果'\n```", + sql="SELECT * FROM `超市营收明细` WHERE `商品类别` = '水果'")) + session.append(session_record) + + query = "列出所有的商品类别" + msg = Message(query) + result_message = self.nl2sql_node(message=msg, session=session, column_constraint=list()) + self.assertIsNotNone(result_message) + self.assertTrue(result_message.content.sql != "") + self.assertTrue(result_message.content.llm_result != "") + + +if __name__ == '__main__': + unittest.main() diff --git a/appbuilder/tests/test_gbi_select_table.py b/appbuilder/tests/test_gbi_select_table.py new file mode 100644 index 000000000..ac166af71 --- /dev/null +++ b/appbuilder/tests/test_gbi_select_table.py @@ -0,0 +1,111 @@ +""" +Copyright (c) 2023 Baidu, Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import unittest +import os +import appbuilder +from appbuilder.core.message import Message +from appbuilder.core.components.gbi.basic import GBISessionRecord + + +SUPER_MARKET_SCHEMA = """ +``` +CREATE TABLE `超市营收明细表` ( + `订单编号` varchar(32) DEFAULT NULL, + `订单日期` date DEFAULT NULL, + `邮寄方式` varchar(32) DEFAULT NULL, + `地区` varchar(32) DEFAULT NULL, + `省份` varchar(32) DEFAULT NULL, + `客户类型` varchar(32) DEFAULT NULL, + `客户名称` varchar(32) DEFAULT NULL, + `商品类别` varchar(32) DEFAULT NULL, + `制造商` varchar(32) DEFAULT NULL, + `商品名称` varchar(32) DEFAULT NULL, + `数量` int(11) DEFAULT NULL, + `销售额` int(11) DEFAULT NULL, + `利润` int(11) DEFAULT NULL +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 +``` +""" + +PRODUCT_SALES_INFO = """ +现有 mysql 表 product_sales_info, +该表的用途是: 产品收入表 +``` +CREATE TABLE `product_sales_info` ( + `年` int, + `月` int, + `产品名称` varchar, + `收入` decimal, + `非交付成本` decimal, + `含交付毛利` decimal +) +``` +""" + +PROMPT_TEMPLATE = """ +你是一个专业的业务人员,下面有{num}张表,具体表名如下: +{table_desc} +请根据问题帮我选择上述1-{num}种的其中相关表并返回,可以为多表,也可以为单表, +返回多张表请用“,”隔开 +返回格式请参考如下示例: +问题:有多少个审核通过的投运单? +回答: ```DWD_MAT_OPERATION``` +请严格参考示例只不要返回无关内容,直接给出最终答案后面的内容,分析步骤不要输出 +问题:{query} +回答: +""" + + +class TestGBISelectTable(unittest.TestCase): + + def setUp(self): + """ + 设置环境变量及必要数据。 + """ + model_name = "ERNIE-Bot 4.0" + + self.select_table_node = \ + appbuilder.GBISelectTable(model_name=model_name, + table_descriptions={"超市营收明细表": "超市营收明细表,包含超市各种信息等", + "product_sales_info": "产品销售表"}) + + def test_run_with_default_param(self): + """测试 run 方法使用有效参数""" + query = "列出超市中的所有数据" + msg = Message(query) + session = list() + result_message = self.select_table_node(message=msg, session=session) + print(result_message.content) + self.assertIsNotNone(result_message) + self.assertEqual(len(result_message.content), 1) + self.assertEqual(result_message.content[0], "超市营收明细表") + + def test_run_with_prompt_template(self): + """测试 run 方法中 prompt template 模版""" + query = "列出超市中的所有数据" + msg = Message(query) + session = list() + self.select_table_node.prompt_template = PROMPT_TEMPLATE + result_message = self.select_table_node(message=msg, session=session) + + self.assertIsNotNone(result_message) + self.assertEqual(len(result_message.content), 1) + self.assertEqual(result_message.content[0], "超市营收明细表") + self.select_table_node.prompt_template = "" + + +if __name__ == '__main__': + unittest.main() diff --git a/cookbooks/gbi.ipynb b/cookbooks/gbi.ipynb new file mode 100644 index 000000000..00557cc51 --- /dev/null +++ b/cookbooks/gbi.ipynb @@ -0,0 +1,550 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "f802e64d-4eaa-445d-a48a-1042a91bc394", + "metadata": { + "tags": [] + }, + "source": [ + "# GBI\n", + "\n", + "## 目标\n", + "通过 GBI sdk 接口完成选表和问表的能力。\n", + "\n", + "## 准备工作\n", + "### 平台注册\n", + "1.先在appbuilder平台注册,获取token\n", + "2.创建BES集群,详见(https://cloud.baidu.com/doc/BES/s/0jwvyk4tv)\n", + "3.安装appbuilder-sdk" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "2939356f-61c2-42e9-9e0c-fc6729c193f6", + "metadata": {}, + "outputs": [], + "source": [ + "# !pip install appbuilder-sdk" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "4ccff03b-1567-4e8b-8e1f-9a5032690406", + "metadata": {}, + "outputs": [], + "source": [ + "import logging\n", + "import os\n", + "\n", + "# 设置环境变量\n", + "os.environ[\"APPBUILDER_TOKEN\"] = \"***\"\n" + ] + }, + { + "cell_type": "markdown", + "id": "aeb2fa55-075f-48df-a9fb-8b40d9900684", + "metadata": {}, + "source": [ + "## 开发过程" + ] + }, + { + "cell_type": "markdown", + "id": "1c3c5cee", + "metadata": {}, + "source": [ + "### 设置表的 schema" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "d7d6440c", + "metadata": {}, + "outputs": [], + "source": [ + "SUPER_MARKET_SCHEMA = \"\"\"\n", + "```\n", + "CREATE TABLE `超市营收明细表` (\n", + " `订单编号` varchar(32) DEFAULT NULL,\n", + " `订单日期` date DEFAULT NULL,\n", + " `邮寄方式` varchar(32) DEFAULT NULL,\n", + " `地区` varchar(32) DEFAULT NULL,\n", + " `省份` varchar(32) DEFAULT NULL,\n", + " `客户类型` varchar(32) DEFAULT NULL,\n", + " `客户名称` varchar(32) DEFAULT NULL,\n", + " `商品类别` varchar(32) DEFAULT NULL,\n", + " `制造商` varchar(32) DEFAULT NULL,\n", + " `商品名称` varchar(32) DEFAULT NULL,\n", + " `数量` int(11) DEFAULT NULL,\n", + " `销售额` int(11) DEFAULT NULL,\n", + " `利润` int(11) DEFAULT NULL\n", + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4\n", + "```\n", + "\"\"\"\n", + "\n", + "PRODUCT_SALES_INFO = \"\"\"\n", + "现有 mysql 表 product_sales_info, \n", + "该表的用途是: 产品收入表\n", + "```\n", + "CREATE TABLE `product_sales_info` (\n", + " `年` int,\n", + " `月` int,\n", + " `产品名称` varchar,\n", + " `收入` decimal,\n", + " `非交付成本` decimal,\n", + " `含交付毛利` decimal\n", + ")\n", + "```\n", + "\"\"\"\n", + "\n", + "# schema 和表名的映射\n", + "SCHEMA_MAPPING = {\n", + " \"超市营收明细表\": SUPER_MARKET_SCHEMA,\n", + " \"PRODUCT_SALES_INFO\": PRODUCT_SALES_INFO\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "463254a1", + "metadata": {}, + "source": [ + "设置表的描述用于选表" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "7fefcae1", + "metadata": {}, + "outputs": [], + "source": [ + "table_descriptions = {\n", + " \"超市营收明细表\": \"超市营收明细表,包含超市各种信息等\",\n", + " \"product_sales_info\": \"产品销售表\"\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "a0aff843", + "metadata": {}, + "source": [ + "### 选表" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "41559341-fd7a-478c-a08b-1477d79e9d41", + "metadata": { + "ExecuteTime": { + "end_time": "2023-12-18T06:24:26.982459Z", + "start_time": "2023-12-18T06:23:53.771345Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Intel MKL WARNING: Support of Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2) enabled only processors has been deprecated. Intel oneAPI Math Kernel Library 2025.0 will require Intel(R) Advanced Vector Extensions (Intel(R) AVX) instructions.\n", + "Intel MKL WARNING: Support of Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2) enabled only processors has been deprecated. Intel oneAPI Math Kernel Library 2025.0 will require Intel(R) Advanced Vector Extensions (Intel(R) AVX) instructions.\n", + "选的表是: ['超市营收明细表']\n" + ] + } + ], + "source": [ + "import appbuilder\n", + "from appbuilder.core.message import Message\n", + "from appbuilder.core.components.gbi.basic import GBISessionRecord\n", + "\n", + "# 生成问表对象\n", + "select_table = appbuilder.GBISelectTable(model_name=\"ERNIE-Bot 4.0\", table_descriptions=table_descriptions)\n", + "query = \"列出超市中的所有数据\"\n", + "msg = Message(query)\n", + "session = list()\n", + "select_table_result_message = select_table(message=msg, session=session)\n", + "print(f\"选的表是: {select_table_result_message.content}\")" + ] + }, + { + "cell_type": "markdown", + "id": "16a8aa38-7a33-4e27-bca4-00900cfe1641", + "metadata": {}, + "source": [ + "### 问表\n", + "基于上面选出的表,通过获取 shema 进行问表" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "9f45ef5f-6206-4b31-83c4-3c8eb2c86925", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "sql: \n", + "SELECT * FROM `超市营收明细表`;\n", + "-----------------\n", + "llm result: ```sql\n", + "SELECT * FROM `超市营收明细表`;\n", + "```\n" + ] + } + ], + "source": [ + "table_schemas = [SCHEMA_MAPPING[table_name] for table_name in select_table_result_message.content]\n", + "gbi_nl2sql = appbuilder.GBINL2Sql(model_name=\"ERNIE-Bot 4.0\", table_schemas=table_schemas)\n", + "nl2sql_result_message = gbi_nl2sql(message=msg, session=session)\n", + "print(f\"sql: {nl2sql_result_message.content.sql}\")\n", + "print(\"-----------------\")\n", + "print(f\"llm result: {nl2sql_result_message.content.llm_result}\")" + ] + }, + { + "cell_type": "markdown", + "id": "b0409c46-e8c7-403a-a827-fcdc8e717be6", + "metadata": {}, + "source": [ + "设置 session" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "a23b8cad-f426-4074-9311-c2c33aaea07b", + "metadata": {}, + "outputs": [], + "source": [ + "session.append(GBISessionRecord(query=query, answer=nl2sql_result_message.content))" + ] + }, + { + "cell_type": "markdown", + "id": "22b3d877-f61f-4958-a084-7507a3017e17", + "metadata": {}, + "source": [ + "再次问表" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "2adcb091-fb53-4364-b4d8-20564439ff51", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "sql: \n", + "SELECT * FROM `超市营收明细表` WHERE `商品类别` = '水果'\n", + "-----------------\n", + "llm result: ```sql\n", + "SELECT * FROM `超市营收明细表` WHERE `商品类别` = '水果'\n", + "```\n" + ] + } + ], + "source": [ + "query2 = \"查看商品类别是水果的所有数据\"\n", + "msg2 = Message(query2)\n", + "nl2sql_result_message2 = gbi_nl2sql(message=msg2, session=session)\n", + "print(f\"sql: {nl2sql_result_message2.content.sql}\")\n", + "print(\"-----------------\")\n", + "print(f\"llm result: {nl2sql_result_message2.content.llm_result}\")" + ] + }, + { + "cell_type": "markdown", + "id": "9e0609ae-f2bc-43d3-9023-14e9f8618158", + "metadata": {}, + "source": [ + "### 增加列选优化\n", + "实际上数据中 \"商品类别\" 存储的是 \"新鲜水果\", 那么就可以通过列选的限制来优化 sql." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "2a7c7923-019e-4660-9e36-4431e9d2f3a6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "sql: \n", + "SELECT * FROM `超市营收明细表` WHERE `商品类别` = '新鲜水果'\n", + "-----------------\n", + "llm result: ```sql\n", + "SELECT * FROM `超市营收明细表` WHERE `商品类别` = '新鲜水果'\n", + "```\n" + ] + } + ], + "source": [ + "from appbuilder.core.components.gbi.basic import ColumnItem\n", + "\n", + "query2 = \"查看商品类别是水果的所有数据\"\n", + "msg2 = Message(query2)\n", + "\n", + "column_constraint = [ColumnItem(ori_value=\"水果\", \n", + " column_name=\"商品类别\", \n", + " column_value=\"新鲜水果\", \n", + " table_name=\"超市营收明细表\", \n", + " is_like=False)]\n", + "nl2sql_result_message2 = gbi_nl2sql(message=msg2, session=session, column_constraint=column_constraint)\n", + "print(f\"sql: {nl2sql_result_message2.content.sql}\")\n", + "print(\"-----------------\")\n", + "print(f\"llm result: {nl2sql_result_message2.content.llm_result}\")" + ] + }, + { + "cell_type": "markdown", + "id": "8385312c-aea1-42cd-b61b-a8d36f4f0665", + "metadata": {}, + "source": [ + "从上面我们看到,商品类别不在是 \"水果\" 而是 修订为 \"新鲜水果\"" + ] + }, + { + "cell_type": "markdown", + "id": "6e98c414-8b2b-4187-a270-3117a4f431ff", + "metadata": {}, + "source": [ + "### 增加知识优化\n", + "当计算某些特殊知识的时候,大模型是不知道的,所以需要告诉大模型具体的知识,比如:\n", + "利润率的计算方式: 利润/销售额\n", + "可以将该知识注入。具体示例如下:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "cade4693-29dc-431c-bf84-c6dc09104294", + "metadata": {}, + "outputs": [], + "source": [ + "# 注入知识\n", + "gbi_nl2sql.knowledge[\"利润率\"] = \"计算方式: 利润/销售额\"" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "1dc181e8-47a1-4b82-8bb5-ce3339be53f6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "sql: \n", + "SELECT 商品类别, SUM(利润)/SUM(销售额) AS 利润率\n", + "FROM `超市营收明细表`\n", + "WHERE 商品类别 = '日用品'\n", + "GROUP BY 商品类别\n", + "-----------------\n", + "llm result: ```sql\n", + "SELECT 商品类别, SUM(利润)/SUM(销售额) AS 利润率\n", + "FROM `超市营收明细表`\n", + "WHERE 商品类别 = '日用品'\n", + "GROUP BY 商品类别\n", + "```\n", + "\n", + "思考步骤:\n", + "\n", + "1. 首先,我们需要从`超市营收明细表`中选择数据。\n", + "2. 根据当前问题,我们关心的是商品类别为“日用品”的数据。\n", + "3. 利润率是利润除以销售额,所以我们需要对利润和销售额进行聚合。\n", + "4. 使用`SUM`函数来计算总的利润和销售额。\n", + "5. 使用`GROUP BY`语句按商品类别进行分组,以便计算每个商品类别的利润率。\n", + "6. 最后,使用`AS`关键字给计算出的利润率命名,使其更易读。\n" + ] + } + ], + "source": [ + "query3 = \"列出商品类别是日用品的利润率\"\n", + "msg3 = Message(query3)\n", + "\n", + "nl2sql_result_message3 = gbi_nl2sql(message=msg3, session=session, column_constraint=list())\n", + "print(f\"sql: {nl2sql_result_message3.content.sql}\")\n", + "print(\"-----------------\")\n", + "print(f\"llm result: {nl2sql_result_message3.content.llm_result}\")" + ] + }, + { + "cell_type": "markdown", + "id": "c5570cd9-dbaf-45cd-ab03-1a7f92e7d0d4", + "metadata": {}, + "source": [ + "## 调整 prompt 模版\n", + "有时候,我们希望定义自己的prompt, 选表和问表两个环节都支持 prompt 模版的定制化,但是必须遵循对应的 prompt 模版的格式。" + ] + }, + { + "cell_type": "markdown", + "id": "6e3d4967-2b4c-437d-9d72-fb1b94bdcf59", + "metadata": {}, + "source": [ + "### 选表 prompt 调整\n", + "选表的 prompt template, 必须包含 \n", + "1. {num} - 表的数量, 注意 {num} 有两个地方出现\n", + "2. {table_desc} - 表的描述\n", + "3. {query} - query, 参考下面的示例:" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "2ae6ffbc-4237-4fb2-8168-480b81bfd873", + "metadata": {}, + "outputs": [], + "source": [ + "SELECT_TABLE_PROMPT_TEMPLATE = \"\"\"\n", + "你是一个专业的业务人员,下面有{num}张表,具体表名如下:\n", + "{table_desc}\n", + "请根据问题帮我选择上述1-{num}种的其中相关表并返回,可以为多表,也可以为单表,\n", + "返回多张表请用“,”隔开\n", + "返回格式请参考如下示例:\n", + "问题:有多少个审核通过的投运单?\n", + "回答: ```DWD_MAT_OPERATION```\n", + "请严格参考示例只不要返回无关内容,直接给出最终答案后面的内容,分析步骤不要输出\n", + "问题:{query}\n", + "回答:\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "2bbbb375-6659-4ef0-82ff-a4ace9fdd4f0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "选的表是: ['超市营收明细表']\n" + ] + } + ], + "source": [ + "select_table4 = appbuilder.GBISelectTable(model_name=\"ERNIE-Bot 4.0\", \n", + " table_descriptions=table_descriptions,\n", + " prompt_template=SELECT_TABLE_PROMPT_TEMPLATE)\n", + "query4 = \"列出超市中的所有数据\"\n", + "msg4 = Message(query4)\n", + "select_table_result_message4 = select_table4(message=msg4, session=list())\n", + "print(f\"选的表是: {select_table_result_message4.content}\")" + ] + }, + { + "cell_type": "markdown", + "id": "4f3fd089-613b-4bdd-95ac-c87f89c0fc61", + "metadata": {}, + "source": [ + "## 问表 prompt 调整\n", + "问表的 prompt template 必须包含:\n", + "1. {schema} - 表的 schema 信息\n", + "2. {instrument} - 列选限制的信息\n", + "3. {kg} - 知识\n", + "4. {date} - 时间\n", + "5. {history_prompt} - 历史\n", + "6. {query} - 当前问题\n", + "\n", + "参考下面的示例" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "323fbe75-62ca-44ab-9ca2-9f747939a2b5", + "metadata": {}, + "outputs": [], + "source": [ + "NL2SQL_PROMPT_TEMPLATE = \"\"\"\n", + " MySql 表 Schema 如下:\n", + " {schema}\n", + " 请根据用户当前问题,联系历史信息,仅编写1个sql,其中 sql 语句需要使用```sql ```这种 markdown 形式给出。\n", + " 请参考列选信息:\n", + " {instrument}\n", + " 请参考知识:\n", + " {kg}\n", + " 当前时间:{date}\n", + " 历史信息如下:\n", + " {history_prompt}\n", + " 当前问题:\"{query}\"\n", + " 回答:\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "52436f03-e01c-456a-aaa0-5a7f1afcd9d2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "sql: \n", + "SELECT * FROM `超市营收明细表` WHERE `商品类别` = '水果'\n", + "-----------------\n", + "llm result: ```sql\n", + "SELECT * FROM `超市营收明细表` WHERE `商品类别` = '水果'\n", + "```\n" + ] + } + ], + "source": [ + "\n", + "msg5 = Message(\"查看商品类别是水果的所有数据\")\n", + "gbi_nl2sql5 = appbuilder.GBINL2Sql(model_name=\"ERNIE-Bot 4.0\", table_schemas=table_schemas, prompt_template=NL2SQL_PROMPT_TEMPLATE)\n", + "nl2sql_result_message5 = gbi_nl2sql5(message=msg5, session=session)\n", + "print(f\"sql: {nl2sql_result_message5.content.sql}\")\n", + "print(\"-----------------\")\n", + "print(f\"llm result: {nl2sql_result_message5.content.llm_result}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bbdd2f66-e4a8-4001-bc4c-6cd245141deb", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 3fab43c826d0e8057b7ec119846bd549af16f85a Mon Sep 17 00:00:00 2001 From: cjopengler Date: Tue, 26 Dec 2023 18:36:29 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=20gbi=20=E4=BB=A3?= =?UTF-8?q?=E7=A0=81=E6=A1=86=E6=9E=B6=E4=BB=A5=E5=8F=8A=E6=96=87=E6=97=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- appbuilder/__init__.py | 4 +- appbuilder/core/components/gbi/basic.py | 76 +++-------- .../core/components/gbi/nl2sql/README.md | 81 +++++------ .../core/components/gbi/nl2sql/component.py | 58 ++++---- .../components/gbi/select_table/README.md | 40 +++--- .../components/gbi/select_table/component.py | 42 +++--- appbuilder/tests/test_gbi_nl2sql.py | 40 +++--- appbuilder/tests/test_gbi_select_table.py | 24 ++-- cookbooks/gbi.ipynb | 127 +++++++----------- 9 files changed, 206 insertions(+), 286 deletions(-) diff --git a/appbuilder/__init__.py b/appbuilder/__init__.py index ab40f1c92..20a2db9c2 100644 --- a/appbuilder/__init__.py +++ b/appbuilder/__init__.py @@ -67,8 +67,8 @@ def check_version(self): from .core.components.embeddings import Embedding from .core.components.matching import Matching -from .core.components.gbi.nl2sql.component import GBINL2Sql -from .core.components.gbi.select_table.component import GBISelectTable +from .core.components.gbi.nl2sql.component import NL2Sql +from .core.components.gbi.select_table.component import SelectTable from appbuilder.core.message import Message from appbuilder.core.agent import AgentBase diff --git a/appbuilder/core/components/gbi/basic.py b/appbuilder/core/components/gbi/basic.py index c51e716a0..28fa34b81 100644 --- a/appbuilder/core/components/gbi/basic.py +++ b/appbuilder/core/components/gbi/basic.py @@ -17,82 +17,36 @@ r"""GBI nl2sql component. """ +from pydantic import BaseModel, Field from typing import Dict, List -class NL2SqlResult(object): +class NL2SqlResult(BaseModel): """ gbi_nl2sql 返回的结果 """ - def __init__(self, llm_result: str, sql: str): - """ - 初始化 - Args: - llm_result: 大模型返回的结果 - sql: 从 llm_result 中抽取的 sql 语句 - """ - self.llm_result = llm_result - self.sql = sql + llm_result: str = Field(..., description="大模型返回的结果") + sql: str = Field(..., description="从大模型中抽取的 sql 语句") - def to_json(self) -> Dict: - """ - 转换成 字典 - Returns: - - """ - return self.__dict__ - - -class GBISessionRecord(object): +class SessionRecord(BaseModel): """ gbi session record """ + query: str = Field(..., description="用户的问题") + answer: NL2SqlResult = Field(..., description="nl2sql 返回的结果") - def __init__(self, query: str, answer: NL2SqlResult): - """ - GBI Session 的记录 - Args: - query: 用户的问题 - answer: gbi_nl2sql 返回的结果 - """ - self.query = query - self.answer = answer - - def to_json(self) -> Dict: - return {"query": self.query, - "answer": self.answer.to_json()} - - -class ColumnItem(object): +class ColumnItem(BaseModel): """ - column item + 列信息 """ + ori_value: str = Field(..., description="query 中的 词语, 比如: 北京去年收入, " + "分词后: 北京, 去年, 收入, ori_value 是分词中某一个,比如: ori_value = 北京") + column_name: str = Field(..., description="对应数据库中的列名称, 比如: city") + column_value: str = Field(..., description="对应数据库中的列值, 比如: 北京市") - def __init__(self, ori_value: str, column_name: str, column_value: str, table_name: str, - is_like: bool = False): - """ - 用于标识 query 中的词 应该对应到数据库中的某个列值以及列名,用于提升 sql 生成效果 - Args: - ori_value: query 中的 词语, 比如: "北京去年收入", 分词后: "北京, 去年, 收入", ori_value 是分词中某一个,比如: ori_value = "北京" - column_name: 对应数据库中的列名称, city - column_value: 对应数据库中的列值, 北京市 - table_name: 该列所属的表名称 - is_like: 与 ori_value 的匹配是包含 还是 等于,包含: True; 等于: False - """ - self.column_name = column_name - self.column_value = column_value - self.ori_value = ori_value - self.table_name = table_name - self.is_like = is_like - - def to_json(self) -> Dict: - """ - 转换成 json - Returns: - - """ - return self.__dict__ + table_name: str = Field(..., description="该列所在表的名字") + is_like: bool = Field(default=False, description="与 ori_value 的匹配是包含 还是 等于,包含: True; 等于: False") SUPPORTED_MODEL_NAME = { diff --git a/appbuilder/core/components/gbi/nl2sql/README.md b/appbuilder/core/components/gbi/nl2sql/README.md index d605100d2..3de0f994c 100644 --- a/appbuilder/core/components/gbi/nl2sql/README.md +++ b/appbuilder/core/components/gbi/nl2sql/README.md @@ -1,26 +1,25 @@ # GBI 问表 ## 简介 -GBI 问表,根据提供的表的 schema 信息,生成对应问题的 sql 语句。 +GBI 问表,根据提供的 mysql 表的 schema 信息,生成对应问题的 sql 语句。 ## 基本用法 +这里是一个示例,展示如何基于 mysql 表的 schema, 根据问题生成 sql 语句。 -### 快速开启 - -```` +````python import logging import os import appbuilder from appbuilder.core.message import Message -from appbuilder.core.components.gbi.basic import GBISessionRecord +from appbuilder.core.components.gbi.basic import SessionRecord # 设置环境变量 os.environ["APPBUILDER_TOKEN"] = "***" SUPER_MARKET_SCHEMA = """ ``` -CREATE TABLE `超市营收明细表` ( +CREATE TABLE `supper_market_info` ( `订单编号` varchar(32) DEFAULT NULL, `订单日期` date DEFAULT NULL, `邮寄方式` varchar(32) DEFAULT NULL, @@ -39,21 +38,20 @@ CREATE TABLE `超市营收明细表` ( """ table_schemas = [SUPER_MARKET_SCHEMA] +gbi_nl2sql = appbuilder.NL2Sql(model_name="ERNIE-Bot 4.0", table_schemas=table_schemas) query = "列出超市中的所有数据" -msg = Message(query) -session = list() -gbi_nl2sql = appbuilder.GBINL2Sql(model_name="ERNIE-Bot 4.0", table_schemas=table_schemas) -nl2sql_result_message = gbi_nl2sql(message=msg, session=session) +nl2sql_result_message = gbi_nl2sql(Message({"query": query})) + print(f"sql: {nl2sql_result_message.content.sql}") print("-----------------") print(f"llm result: {nl2sql_result_message.content.llm_result}") ```` sql: - SELECT * FROM `超市营收明细表`; + SELECT * FROM supper_market_info; ----------------- llm result: ```sql - SELECT * FROM `超市营收明细表`; + SELECT * FROM supper_market_info; ``` @@ -64,7 +62,7 @@ print(f"llm result: {nl2sql_result_message.content.llm_result}") - table_schemas: 表的 schema,例如: ``` -CREATE TABLE `超市营收明细表` ( +CREATE TABLE `supper_market_info` ( `订单编号` varchar(32) DEFAULT NULL, `订单日期` date DEFAULT NULL, `邮寄方式` varchar(32) DEFAULT NULL, @@ -83,7 +81,6 @@ CREATE TABLE `超市营收明细表` ( - knowledge: 用于提供一些知识, 比如 {"毛利率": "毛收入-毛成本/毛成本"} - prompt_template: prompt 模版, 必须包含的格式如下: - ***你的描述 {schema} ***你的描述 @@ -99,9 +96,10 @@ CREATE TABLE `超市营收明细表` ( 回答: ### 调用参数 -- message: message.content 是 query -- session: gbi session 的历史 列表, 参考 GBISessionRecord -- column_constraint: 列选约束 参考 ColumnItem 具体定义 +- message: message.content 是 字典,包含: query, session, column_constraint 三个key + * query: 用户的问题 + * session: gbi session 的历史 列表, 参考 GBISessionRecord + * column_constraint: 列选约束 参考 ColumnItem 具体定义 #### GBISessionRecord 初始化参数 - query: 用户的问题 @@ -126,26 +124,26 @@ CREATE TABLE `超市营收明细表` ( ```python -session.append(GBISessionRecord(query=query, answer=nl2sql_result_message.content)) +session = list() +session.append(SessionRecord(query=query, answer=nl2sql_result_message.content)) ``` 再次问表 ```python -query2 = "查看商品类别是水果的所有数据" -msg2 = Message(query2) -nl2sql_result_message2 = gbi_nl2sql(message=msg2, session=session) +nl2sql_result_message2 = gbi_nl2sql(Message({"query": "查看商品类别是水果的所有数据", + "session": session})) print(f"sql: {nl2sql_result_message2.content.sql}") print("-----------------") print(f"llm result: {nl2sql_result_message2.content.llm_result}") ``` sql: - SELECT * FROM `超市营收明细表` WHERE `商品类别` = '水果'; + SELECT * FROM supper_market_info WHERE 商品类别 = '水果'; ----------------- llm result: ```sql - SELECT * FROM `超市营收明细表` WHERE `商品类别` = '水果'; + SELECT * FROM supper_market_info WHERE 商品类别 = '水果'; ``` @@ -156,28 +154,25 @@ print(f"llm result: {nl2sql_result_message2.content.llm_result}") ```python from appbuilder.core.components.gbi.basic import ColumnItem -query2 = "查看商品类别是水果的所有数据" -msg2 = Message(query2) - column_constraint = [ColumnItem(ori_value="水果", column_name="商品类别", column_value="新鲜水果", table_name="超市营收明细表", is_like=False)] -nl2sql_result_message2 = gbi_nl2sql(message=msg2, session=session, column_constraint=column_constraint) +nl2sql_result_message2 = gbi_nl2sql(Message({"query": "查看商品类别是水果的所有数据", + "column_constraint": column_constraint})) + print(f"sql: {nl2sql_result_message2.content.sql}") print("-----------------") print(f"llm result: {nl2sql_result_message2.content.llm_result}") ``` sql: - SELECT * FROM `超市营收明细表` WHERE `商品类别` = '新鲜水果' + SELECT * FROM supper_market_info WHERE 商品类别='新鲜水果' ----------------- llm result: ```sql - SELECT * FROM `超市营收明细表` WHERE `商品类别` = '新鲜水果' + SELECT * FROM supper_market_info WHERE 商品类别='新鲜水果' ``` - - 这个查询会返回`超市营收明细表`中所有商品类别为"新鲜水果"的数据。因为问题中没有涉及到其他特定的条件或聚合操作,所以这是一个简单的筛选查询。 从上面我们看到,商品类别不在是 "水果" 而是 修订为 "新鲜水果" @@ -198,7 +193,7 @@ gbi_nl2sql.knowledge["利润率"] = "计算方式: 利润/销售额" query3 = "列出商品类别是日用品的利润率" msg3 = Message(query3) -nl2sql_result_message3 = gbi_nl2sql(message=msg3, session=session, column_constraint=list()) +nl2sql_result_message3 = gbi_nl2sql(Message({"query": "列出商品类别是日用品的利润率"})) print(f"sql: {nl2sql_result_message3.content.sql}") print("-----------------") print(f"llm result: {nl2sql_result_message3.content.llm_result}") @@ -206,25 +201,16 @@ print(f"llm result: {nl2sql_result_message3.content.llm_result}") sql: SELECT 商品类别, SUM(利润)/SUM(销售额) AS 利润率 - FROM `超市营收明细表` + FROM supper_market_info WHERE 商品类别 = '日用品' GROUP BY 商品类别 ----------------- llm result: ```sql SELECT 商品类别, SUM(利润)/SUM(销售额) AS 利润率 - FROM `超市营收明细表` + FROM supper_market_info WHERE 商品类别 = '日用品' GROUP BY 商品类别 ``` - - 思考步骤: - - 1. 首先,我们需要从`超市营收明细表`中选择数据。 - 2. 根据当前问题,我们关心的是商品类别为“日用品”的数据。 - 3. 利润率是利润除以销售额,所以我们需要对利润和销售额进行聚合。 - 4. 使用`SUM`函数来计算总的利润和销售额。 - 5. 使用`GROUP BY`语句按商品类别进行分组,以确保我们计算的是日用品的总利润和总销售额。 - 6. 最后,选择商品类别并计算利润率,即利润总和除以销售额总和。 ## 调整 prompt 模版 @@ -261,19 +247,18 @@ NL2SQL_PROMPT_TEMPLATE = """ ```python +gbi_nl2sql5 = appbuilder.NL2Sql(model_name="ERNIE-Bot 4.0", table_schemas=table_schemas, prompt_template=NL2SQL_PROMPT_TEMPLATE) +nl2sql_result_message5 = gbi_nl2sql5(Message({"query": "查看商品类别是水果的所有数据"})) -msg5 = Message("查看商品类别是水果的所有数据") -gbi_nl2sql5 = appbuilder.GBINL2Sql(model_name="ERNIE-Bot 4.0", table_schemas=table_schemas, prompt_template=NL2SQL_PROMPT_TEMPLATE) -nl2sql_result_message5 = gbi_nl2sql5(message=msg5, session=session) print(f"sql: {nl2sql_result_message5.content.sql}") print("-----------------") print(f"llm result: {nl2sql_result_message5.content.llm_result}") ``` sql: - SELECT * FROM `超市营收明细表` WHERE `商品类别` = '水果'; + SELECT * FROM supper_market_info WHERE 商品类别 = '水果' ----------------- llm result: ```sql - SELECT * FROM `超市营收明细表` WHERE `商品类别` = '水果'; + SELECT * FROM supper_market_info WHERE 商品类别 = '水果' ``` diff --git a/appbuilder/core/components/gbi/nl2sql/component.py b/appbuilder/core/components/gbi/nl2sql/component.py index 9ff7327c3..940119f75 100644 --- a/appbuilder/core/components/gbi/nl2sql/component.py +++ b/appbuilder/core/components/gbi/nl2sql/component.py @@ -19,23 +19,32 @@ from typing import Dict, List, Optional from pydantic import BaseModel, Field -from appbuilder.core.component import Component +from appbuilder.core.component import Component, ComponentArguments from appbuilder.core.message import Message -from appbuilder.core.components.gbi.basic import GBISessionRecord +from appbuilder.core.components.gbi.basic import SessionRecord from appbuilder.core.components.gbi.basic import ColumnItem from appbuilder.core.components.gbi.basic import NL2SqlResult from appbuilder.core.components.gbi.basic import SUPPORTED_MODEL_NAME -class GBINL2Sql(Component): + +class NL2SqlArgs(ComponentArguments): + """ + nl2sql 的参数 + """ + query: str = Field(..., description="用户的 query 输入") + session: List[SessionRecord] = Field(default=list(), description="gbi session 的历史 列表") + column_constraint: List[ColumnItem] = Field(default=list(), description="列选的限制条件") + + +class NL2Sql(Component): """ gib nl2sql """ + meta = NL2SqlArgs def __init__(self, model_name: str, table_schemas: List[str], knowledge: Dict = None, - prompt_template: str = "", - secret_key: Optional[str] = None, - gateway: str = ""): + prompt_template: str = ""): """ 创建 gbi nl2sql 对象 Args: @@ -62,10 +71,8 @@ def __init__(self, model_name: str, table_schemas: List[str], knowledge: Dict = ***你的描述 当前问题:{query} 回答: - secret_key: 用户创建的 key - gateway: gateway 地址 """ - super().__init__(secret_key=secret_key, gateway=gateway) + super().__init__(meta=NL2SqlArgs) if model_name not in SUPPORTED_MODEL_NAME: raise ValueError(f"model_name 错误, 请使用 {SUPPORTED_MODEL_NAME} 中的大模型") @@ -76,36 +83,37 @@ def __init__(self, model_name: str, table_schemas: List[str], knowledge: Dict = self.prompt_template = prompt_template def run(self, - message: Message, - session: List[GBISessionRecord], - column_constraint: List[ColumnItem] = None) -> Message[NL2SqlResult]: + message: Message, timeout: float = 60, retry: int = 0) -> Message[NL2SqlResult]: """ 执行 nl2sql Args: - message: message.content 是 query - session: gbi session 的历史 列表 - column_constraint: 列选约束 参考 ColumnItem 具体定义 + message: message.content 是字典包含, key 如下: + 1. query: 用户问题 + 2. session: gbi session 的历史 列表, 参考 SessionRecord + 3. column_constraint: 列选约束 参考 ColumnItem 具体定义 Returns: NL2SqlResult 的 message """ - query = message.content - session = session - column_constraint = column_constraint or list() - response = self._run_nl2sql(query=query, session=session, table_schemas=self.table_schemas, - column_constraint=column_constraint, knowledge=self.knowledge, + try: + inputs = self.meta(**message.content) + except ValidationError as e: + raise ValueError(e) + + response = self._run_nl2sql(query=inputs.query, session=inputs.session, table_schemas=self.table_schemas, + column_constraint=inputs.column_constraint, knowledge=self.knowledge, prompt_template=self.prompt_template, model_name=self.model_name, - timeout=60, - retry=2) + timeout=timeout, + retry=retry) rsp_data = response.json() nl2sql_result = NL2SqlResult(llm_result=rsp_data["llm_result"], sql=rsp_data["sql"]) return Message(content=nl2sql_result) - def _run_nl2sql(self, query: str, session: List[GBISessionRecord], table_schemas: List[str], knowledge: Dict[str, str], + def _run_nl2sql(self, query: str, session: List[SessionRecord], table_schemas: List[str], knowledge: Dict[str, str], prompt_template: str, column_constraint: List[ColumnItem], model_name: str, @@ -135,8 +143,8 @@ def _run_nl2sql(self, query: str, session: List[GBISessionRecord], table_schemas payload = {"query": query, "table_schemas": table_schemas, - "session": [session_record.to_json() for session_record in session], - "column_constraint": [column_item.to_json() for column_item in column_constraint], + "session": [session_record.dict() for session_record in session], + "column_constraint": [column_item.dict() for column_item in column_constraint], "model_name": model_name, "knowledge": knowledge, "prompt_template": prompt_template} diff --git a/appbuilder/core/components/gbi/select_table/README.md b/appbuilder/core/components/gbi/select_table/README.md index 13b03c4e0..391fe8eeb 100644 --- a/appbuilder/core/components/gbi/select_table/README.md +++ b/appbuilder/core/components/gbi/select_table/README.md @@ -1,12 +1,11 @@ # GBI 选表 ## 简介 -GBI 选表,根据提供的表的描述信息以及 query 选择对应的表. +GBI 选表,根据提供的多个 MySql 表名 以及 表名对应的描述信息,通过 query 选择一个或多个最合适的表来回答该 query. +一般的试用场景是,当有数据库有多个表的时候,但是实际只有1个表能回答该 query,那么,通过该能力将该表选择出来,用于后面的 问表 环节。 ## 基本用法 - -### 快速开启 - +下面是根据提供的表的描述信息以及 query 选择对应的表的示例。 ```python @@ -14,28 +13,26 @@ import logging import os import appbuilder from appbuilder.core.message import Message -from appbuilder.core.components.gbi.basic import GBISessionRecord +from appbuilder.core.components.gbi.basic import SessionRecord # 设置环境变量 os.environ["APPBUILDER_TOKEN"] = "***" -# 表的描述信息 +# 表的描述信息, key: 表名; value: 是表的描述 table_descriptions = { - "超市营收明细表": "超市营收明细表,包含超市各种信息等", + "supper_market_info": "超市营收明细表,包含超市各种信息等", "product_sales_info": "产品销售表" } # 生成问表对象 -select_table = appbuilder.GBISelectTable(model_name="ERNIE-Bot 4.0", table_descriptions=table_descriptions) -query = "列出超市中的所有数据" -msg = Message(query) -session = list() -select_table_result_message = select_table(message=msg, session=session) +select_table = appbuilder.SelectTable(model_name="ERNIE-Bot 4.0", table_descriptions=table_descriptions) +select_table_result_message = select_table(Message({"query": "列出超市中的所有数据"})) + print(f"选的表是: {select_table_result_message.content}") ``` - 选的表是: ['超市营收明细表'] + 选的表是: ['supper_market_info'] ## 参数说明 @@ -45,7 +42,7 @@ print(f"选的表是: {select_table_result_message.content}") ``` { - "超市营收明细表": "超市营收明细表,包含超市各种信息等", + "supper_market_info": "超市营收明细表,包含超市各种信息等", "product_sales_info": "产品销售表" } ``` @@ -69,8 +66,9 @@ print(f"选的表是: {select_table_result_message.content}") ``` ### 调用参数 -- message: message.content 是用户的问题,也就是 query -- session: GBISessionRecord 列表 +- message: message.content 是用户的问题,包含的key: query, session + * query: 用户提出的问题 + * session: GBISessionRecord 列表 #### GBISessionRecord 初始化参数 - query: 用户的问题 @@ -107,14 +105,14 @@ SELECT_TABLE_PROMPT_TEMPLATE = """ ```python -select_table4 = appbuilder.GBISelectTable(model_name="ERNIE-Bot 4.0", +select_table4 = appbuilder.SelectTable(model_name="ERNIE-Bot 4.0", table_descriptions=table_descriptions, prompt_template=SELECT_TABLE_PROMPT_TEMPLATE) -query4 = "列出超市中的所有数据" -msg4 = Message(query4) -select_table_result_message4 = select_table4(message=msg4, session=list()) + +select_table_result_message4 = select_table4(Message({"query": "列出超市中的所有数据"})) + print(f"选的表是: {select_table_result_message4.content}") ``` - 选的表是: ['超市营收明细表'] + 选的表是: ['supper_market_info'] diff --git a/appbuilder/core/components/gbi/select_table/component.py b/appbuilder/core/components/gbi/select_table/component.py index 21bda916e..d66546da5 100644 --- a/appbuilder/core/components/gbi/select_table/component.py +++ b/appbuilder/core/components/gbi/select_table/component.py @@ -19,21 +19,26 @@ from typing import Dict, List, Optional from pydantic import BaseModel, Field -from appbuilder.core.component import Component +from appbuilder.core.component import Component, ComponentArguments from appbuilder.core.message import Message -from appbuilder.core.components.gbi.basic import GBISessionRecord +from appbuilder.core.components.gbi.basic import SessionRecord from appbuilder.core.components.gbi.basic import SUPPORTED_MODEL_NAME +class SelectTableArgs(ComponentArguments): + """ + 选表的参数 + """ + query: str = Field(..., description="用户的 query 输入") + session: List[SessionRecord] = Field(default=list(), description="gbi session 的历史 列表") -class GBISelectTable(Component): + +class SelectTable(Component): """ gbi 选表 """ def __init__(self, model_name: str, table_descriptions: Dict[str, str], - prompt_template: str = "", - secret_key: Optional[str] = None, - gateway: str = ""): + prompt_template: str = ""): """ 创建 GBI 选表对象 Args: @@ -64,7 +69,7 @@ def __init__(self, model_name: str, table_descriptions: Dict[str, str], secret_key: gateway: """ - super().__init__(secret_key=secret_key, gateway=gateway) + super().__init__(meta=SelectTableArgs) if model_name not in SUPPORTED_MODEL_NAME: raise ValueError(f"model_name 错误, 请使用 {SUPPORTED_MODEL_NAME} 中的大模型") self.model_name = model_name @@ -73,32 +78,33 @@ def __init__(self, model_name: str, table_descriptions: Dict[str, str], self.prompt_template = prompt_template def run(self, - message: Message, - session: List[GBISessionRecord]) -> Message[List[str]]: + message: Message, timeout: int = 60,retry: int = 0) -> Message[List[str]]: """ Args: - message: message.content 是用户的问题,也就是 query - session: GBISessionRecord 列表 + message: message.content 字典包含 key: + 1. query - 用户的问题输入 + 2. session - 对话历史, 可选 Returns: 识别的表名的列表 ["table_name"] """ + try: + inputs = self.meta(**message.content) + except ValidationError as e: + raise ValueError(e) - query = message.content - session = session - - response = self._run_select_table(query=query, session=session, + response = self._run_select_table(query=inputs.query, session=inputs.session, prompt_template=self.prompt_template, table_descriptions=self.table_descriptions, model_name=self.model_name, - timeout=60, - retry=2) + timeout=timeout, + retry=retry) rsp_data = response.json() return Message(content=rsp_data) - def _run_select_table(self, query: str, session: List[GBISessionRecord], + def _run_select_table(self, query: str, session: List[SessionRecord], prompt_template, table_descriptions: Dict[str, str], model_name: str, diff --git a/appbuilder/tests/test_gbi_nl2sql.py b/appbuilder/tests/test_gbi_nl2sql.py index 89aa17d12..5617f1578 100644 --- a/appbuilder/tests/test_gbi_nl2sql.py +++ b/appbuilder/tests/test_gbi_nl2sql.py @@ -17,12 +17,12 @@ import os import appbuilder from appbuilder.core.message import Message -from appbuilder.core.components.gbi.basic import NL2SqlResult, GBISessionRecord +from appbuilder.core.components.gbi.basic import NL2SqlResult, SessionRecord from appbuilder.core.components.gbi.basic import ColumnItem SUPER_MARKET_SCHEMA = """ ``` -CREATE TABLE `超市营收明细` ( +CREATE TABLE `supper_market_info` ( `订单编号` varchar(32) DEFAULT NULL, `订单日期` date DEFAULT NULL, `邮寄方式` varchar(32) DEFAULT NULL, @@ -78,15 +78,14 @@ def setUp(self): """ model_name = "ERNIE-Bot 4.0" table_schemas = [SUPER_MARKET_SCHEMA] - self.nl2sql_node = appbuilder.GBINL2Sql(model_name=model_name, - table_schemas=table_schemas) + self.nl2sql_node = appbuilder.NL2Sql(model_name=model_name, + table_schemas=table_schemas) def test_run_with_default_param(self): """测试 run 方法使用有效参数""" query = "列出商品类别是水果的所有信息" - msg = Message(query) - session = list() - result_message = self.nl2sql_node(message=msg, session=session) + msg = Message({"query": query}) + result_message = self.nl2sql_node(msg) print(result_message.content.sql) self.assertIsNotNone(result_message) self.assertTrue(result_message.content.sql != "") @@ -98,9 +97,8 @@ def test_run_with_knowledge(self): self.nl2sql_node.knowledge["利润率"] = "计算方式: 利润/销售额" query = "列出商品类别是水果的的利润率" - msg = Message(query) - session = list() - result_message = self.nl2sql_node(message=msg, session=session) + msg = Message({"query": query}) + result_message = self.nl2sql_node(msg) self.assertIsNotNone(result_message) self.assertTrue(result_message.content.sql != "") self.assertTrue(result_message.content.llm_result != "") @@ -110,15 +108,15 @@ def test_run_with_column_constraint(self): """测试 增加 column constraint 参数""" query = "列出商品类别是水果的的利润率" - - msg = Message(query) - session = list() column_constraint = [ColumnItem(ori_value="水果", column_value="新鲜水果", column_name="商品类别", table_name="超市营收明细", is_like=False)] - result_message = self.nl2sql_node(message=msg, session=session, column_constraint=column_constraint) + + msg = Message({"query": query, "column_constraint": column_constraint}) + + result_message = self.nl2sql_node(msg) self.assertIsNotNone(result_message) self.assertTrue(result_message.content.sql != "") @@ -129,15 +127,13 @@ def test_run_with_prompt_template(self): """测试 增加 prompt template 参数""" self.nl2sql_node.prompt_template = PROMPT_TEMPLATE query = "列出商品类别是水果的的利润率" - - msg = Message(query) - session = list() column_constraint = [ColumnItem(ori_value="水果", column_value="新鲜水果", column_name="商品类别", table_name="超市营收明细", is_like=False)] - result_message = self.nl2sql_node(message=msg, session=session, column_constraint=column_constraint) + msg = Message({"query": query, "column_constraint": column_constraint}) + result_message = self.nl2sql_node(msg) self.assertIsNotNone(result_message) self.assertTrue(result_message.content.sql != "") self.assertTrue(result_message.content.llm_result != "") @@ -147,8 +143,8 @@ def test_run_with_prompt_template(self): def test_run_with_session(self): """测试 增加 session 参数""" session = list() - session_record = GBISessionRecord(query="列出商品类别是水果的的利润率", - answer=NL2SqlResult( + session_record = SessionRecord(query="列出商品类别是水果的的利润率", + answer=NL2SqlResult( llm_result="根据问题分析得到 sql 如下: \n " "```sql\nSELECT * FROM `超市营收明细` " "WHERE `商品类别` = '水果'\n```", @@ -156,8 +152,8 @@ def test_run_with_session(self): session.append(session_record) query = "列出所有的商品类别" - msg = Message(query) - result_message = self.nl2sql_node(message=msg, session=session, column_constraint=list()) + msg = Message({"query": query, "session": session}) + result_message = self.nl2sql_node(msg) self.assertIsNotNone(result_message) self.assertTrue(result_message.content.sql != "") self.assertTrue(result_message.content.llm_result != "") diff --git a/appbuilder/tests/test_gbi_select_table.py b/appbuilder/tests/test_gbi_select_table.py index ac166af71..324850583 100644 --- a/appbuilder/tests/test_gbi_select_table.py +++ b/appbuilder/tests/test_gbi_select_table.py @@ -17,12 +17,12 @@ import os import appbuilder from appbuilder.core.message import Message -from appbuilder.core.components.gbi.basic import GBISessionRecord +from appbuilder.core.components.gbi.basic import SessionRecord SUPER_MARKET_SCHEMA = """ ``` -CREATE TABLE `超市营收明细表` ( +CREATE TABLE `supper_market_info` ( `订单编号` varchar(32) DEFAULT NULL, `订单日期` date DEFAULT NULL, `邮寄方式` varchar(32) DEFAULT NULL, @@ -68,7 +68,6 @@ 回答: """ - class TestGBISelectTable(unittest.TestCase): def setUp(self): @@ -78,32 +77,31 @@ def setUp(self): model_name = "ERNIE-Bot 4.0" self.select_table_node = \ - appbuilder.GBISelectTable(model_name=model_name, - table_descriptions={"超市营收明细表": "超市营收明细表,包含超市各种信息等", + appbuilder.SelectTable(model_name=model_name, + table_descriptions={"supper_market_info": "超市营收明细表,包含超市各种信息等", "product_sales_info": "产品销售表"}) def test_run_with_default_param(self): """测试 run 方法使用有效参数""" query = "列出超市中的所有数据" - msg = Message(query) - session = list() - result_message = self.select_table_node(message=msg, session=session) + msg = Message({"query": query}) + result_message = self.select_table_node(message=msg) print(result_message.content) self.assertIsNotNone(result_message) self.assertEqual(len(result_message.content), 1) - self.assertEqual(result_message.content[0], "超市营收明细表") + self.assertEqual(result_message.content[0], "supper_market_info") def test_run_with_prompt_template(self): """测试 run 方法中 prompt template 模版""" query = "列出超市中的所有数据" - msg = Message(query) - session = list() + msg = Message({"query": query}) + result_message = self.select_table_node(message=msg) self.select_table_node.prompt_template = PROMPT_TEMPLATE - result_message = self.select_table_node(message=msg, session=session) + result_message = self.select_table_node(msg) self.assertIsNotNone(result_message) self.assertEqual(len(result_message.content), 1) - self.assertEqual(result_message.content[0], "超市营收明细表") + self.assertEqual(result_message.content[0], "supper_market_info") self.select_table_node.prompt_template = "" diff --git a/cookbooks/gbi.ipynb b/cookbooks/gbi.ipynb index 00557cc51..f7b4809fb 100644 --- a/cookbooks/gbi.ipynb +++ b/cookbooks/gbi.ipynb @@ -15,8 +15,7 @@ "## 准备工作\n", "### 平台注册\n", "1.先在appbuilder平台注册,获取token\n", - "2.创建BES集群,详见(https://cloud.baidu.com/doc/BES/s/0jwvyk4tv)\n", - "3.安装appbuilder-sdk" + "2.安装appbuilder-sdk" ] }, { @@ -31,7 +30,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 16, "id": "4ccff03b-1567-4e8b-8e1f-9a5032690406", "metadata": {}, "outputs": [], @@ -61,14 +60,14 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 17, "id": "d7d6440c", "metadata": {}, "outputs": [], "source": [ "SUPER_MARKET_SCHEMA = \"\"\"\n", "```\n", - "CREATE TABLE `超市营收明细表` (\n", + "CREATE TABLE `supper_market_info` (\n", " `订单编号` varchar(32) DEFAULT NULL,\n", " `订单日期` date DEFAULT NULL,\n", " `邮寄方式` varchar(32) DEFAULT NULL,\n", @@ -103,7 +102,7 @@ "\n", "# schema 和表名的映射\n", "SCHEMA_MAPPING = {\n", - " \"超市营收明细表\": SUPER_MARKET_SCHEMA,\n", + " \"supper_market_info\": SUPER_MARKET_SCHEMA,\n", " \"PRODUCT_SALES_INFO\": PRODUCT_SALES_INFO\n", "}" ] @@ -118,13 +117,13 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 18, "id": "7fefcae1", "metadata": {}, "outputs": [], "source": [ "table_descriptions = {\n", - " \"超市营收明细表\": \"超市营收明细表,包含超市各种信息等\",\n", + " \"supper_market_info\": \"超市营收明细表,包含超市各种信息等\",\n", " \"product_sales_info\": \"产品销售表\"\n", "}" ] @@ -139,7 +138,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 19, "id": "41559341-fd7a-478c-a08b-1477d79e9d41", "metadata": { "ExecuteTime": { @@ -152,23 +151,20 @@ "name": "stdout", "output_type": "stream", "text": [ - "Intel MKL WARNING: Support of Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2) enabled only processors has been deprecated. Intel oneAPI Math Kernel Library 2025.0 will require Intel(R) Advanced Vector Extensions (Intel(R) AVX) instructions.\n", - "Intel MKL WARNING: Support of Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2) enabled only processors has been deprecated. Intel oneAPI Math Kernel Library 2025.0 will require Intel(R) Advanced Vector Extensions (Intel(R) AVX) instructions.\n", - "选的表是: ['超市营收明细表']\n" + "选的表是: ['supper_market_info']\n" ] } ], "source": [ "import appbuilder\n", "from appbuilder.core.message import Message\n", - "from appbuilder.core.components.gbi.basic import GBISessionRecord\n", + "from appbuilder.core.components.gbi.basic import SessionRecord\n", "\n", "# 生成问表对象\n", - "select_table = appbuilder.GBISelectTable(model_name=\"ERNIE-Bot 4.0\", table_descriptions=table_descriptions)\n", + "select_table = appbuilder.SelectTable(model_name=\"ERNIE-Bot 4.0\", table_descriptions=table_descriptions)\n", "query = \"列出超市中的所有数据\"\n", - "msg = Message(query)\n", - "session = list()\n", - "select_table_result_message = select_table(message=msg, session=session)\n", + "msg = Message({\"query\": query})\n", + "select_table_result_message = select_table(msg)\n", "print(f\"选的表是: {select_table_result_message.content}\")" ] }, @@ -183,7 +179,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 20, "id": "9f45ef5f-6206-4b31-83c4-3c8eb2c86925", "metadata": {}, "outputs": [ @@ -192,18 +188,18 @@ "output_type": "stream", "text": [ "sql: \n", - "SELECT * FROM `超市营收明细表`;\n", + "SELECT * FROM supper_market_info;\n", "-----------------\n", "llm result: ```sql\n", - "SELECT * FROM `超市营收明细表`;\n", + "SELECT * FROM supper_market_info;\n", "```\n" ] } ], "source": [ "table_schemas = [SCHEMA_MAPPING[table_name] for table_name in select_table_result_message.content]\n", - "gbi_nl2sql = appbuilder.GBINL2Sql(model_name=\"ERNIE-Bot 4.0\", table_schemas=table_schemas)\n", - "nl2sql_result_message = gbi_nl2sql(message=msg, session=session)\n", + "gbi_nl2sql = appbuilder.NL2Sql(model_name=\"ERNIE-Bot 4.0\", table_schemas=table_schemas)\n", + "nl2sql_result_message = gbi_nl2sql(Message({\"query\": \"列出超市中的所有数据\"}))\n", "print(f\"sql: {nl2sql_result_message.content.sql}\")\n", "print(\"-----------------\")\n", "print(f\"llm result: {nl2sql_result_message.content.llm_result}\")" @@ -219,12 +215,13 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 21, "id": "a23b8cad-f426-4074-9311-c2c33aaea07b", "metadata": {}, "outputs": [], "source": [ - "session.append(GBISessionRecord(query=query, answer=nl2sql_result_message.content))" + "session = list()\n", + "session.append(SessionRecord(query=query, answer=nl2sql_result_message.content))" ] }, { @@ -237,7 +234,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 22, "id": "2adcb091-fb53-4364-b4d8-20564439ff51", "metadata": {}, "outputs": [ @@ -246,18 +243,17 @@ "output_type": "stream", "text": [ "sql: \n", - "SELECT * FROM `超市营收明细表` WHERE `商品类别` = '水果'\n", + "SELECT * FROM supper_market_info WHERE 商品类别 = '水果';\n", "-----------------\n", "llm result: ```sql\n", - "SELECT * FROM `超市营收明细表` WHERE `商品类别` = '水果'\n", + "SELECT * FROM supper_market_info WHERE 商品类别 = '水果';\n", "```\n" ] } ], "source": [ - "query2 = \"查看商品类别是水果的所有数据\"\n", - "msg2 = Message(query2)\n", - "nl2sql_result_message2 = gbi_nl2sql(message=msg2, session=session)\n", + "nl2sql_result_message2 = gbi_nl2sql(Message({\"query\": \"查看商品类别是水果的所有数据\", \n", + " \"session\": session}))\n", "print(f\"sql: {nl2sql_result_message2.content.sql}\")\n", "print(\"-----------------\")\n", "print(f\"llm result: {nl2sql_result_message2.content.llm_result}\")" @@ -274,7 +270,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 23, "id": "2a7c7923-019e-4660-9e36-4431e9d2f3a6", "metadata": {}, "outputs": [ @@ -283,10 +279,10 @@ "output_type": "stream", "text": [ "sql: \n", - "SELECT * FROM `超市营收明细表` WHERE `商品类别` = '新鲜水果'\n", + "SELECT * FROM supper_market_info WHERE 商品类别='新鲜水果'\n", "-----------------\n", "llm result: ```sql\n", - "SELECT * FROM `超市营收明细表` WHERE `商品类别` = '新鲜水果'\n", + "SELECT * FROM supper_market_info WHERE 商品类别='新鲜水果'\n", "```\n" ] } @@ -294,15 +290,16 @@ "source": [ "from appbuilder.core.components.gbi.basic import ColumnItem\n", "\n", - "query2 = \"查看商品类别是水果的所有数据\"\n", - "msg2 = Message(query2)\n", "\n", "column_constraint = [ColumnItem(ori_value=\"水果\", \n", " column_name=\"商品类别\", \n", " column_value=\"新鲜水果\", \n", " table_name=\"超市营收明细表\", \n", " is_like=False)]\n", - "nl2sql_result_message2 = gbi_nl2sql(message=msg2, session=session, column_constraint=column_constraint)\n", + "\n", + "nl2sql_result_message2 = gbi_nl2sql(Message({\"query\": \"查看商品类别是水果的所有数据\",\n", + " \"column_constraint\": column_constraint}))\n", + "\n", "print(f\"sql: {nl2sql_result_message2.content.sql}\")\n", "print(\"-----------------\")\n", "print(f\"llm result: {nl2sql_result_message2.content.llm_result}\")" @@ -329,7 +326,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 24, "id": "cade4693-29dc-431c-bf84-c6dc09104294", "metadata": {}, "outputs": [], @@ -340,7 +337,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 25, "id": "1dc181e8-47a1-4b82-8bb5-ce3339be53f6", "metadata": {}, "outputs": [ @@ -350,33 +347,21 @@ "text": [ "sql: \n", "SELECT 商品类别, SUM(利润)/SUM(销售额) AS 利润率\n", - "FROM `超市营收明细表`\n", + "FROM supper_market_info\n", "WHERE 商品类别 = '日用品'\n", "GROUP BY 商品类别\n", "-----------------\n", "llm result: ```sql\n", "SELECT 商品类别, SUM(利润)/SUM(销售额) AS 利润率\n", - "FROM `超市营收明细表`\n", + "FROM supper_market_info\n", "WHERE 商品类别 = '日用品'\n", "GROUP BY 商品类别\n", - "```\n", - "\n", - "思考步骤:\n", - "\n", - "1. 首先,我们需要从`超市营收明细表`中选择数据。\n", - "2. 根据当前问题,我们关心的是商品类别为“日用品”的数据。\n", - "3. 利润率是利润除以销售额,所以我们需要对利润和销售额进行聚合。\n", - "4. 使用`SUM`函数来计算总的利润和销售额。\n", - "5. 使用`GROUP BY`语句按商品类别进行分组,以便计算每个商品类别的利润率。\n", - "6. 最后,使用`AS`关键字给计算出的利润率命名,使其更易读。\n" + "```\n" ] } ], "source": [ - "query3 = \"列出商品类别是日用品的利润率\"\n", - "msg3 = Message(query3)\n", - "\n", - "nl2sql_result_message3 = gbi_nl2sql(message=msg3, session=session, column_constraint=list())\n", + "nl2sql_result_message3 = gbi_nl2sql(Message({\"query\": \"列出商品类别是日用品的利润率\"}))\n", "print(f\"sql: {nl2sql_result_message3.content.sql}\")\n", "print(\"-----------------\")\n", "print(f\"llm result: {nl2sql_result_message3.content.llm_result}\")" @@ -405,7 +390,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 26, "id": "2ae6ffbc-4237-4fb2-8168-480b81bfd873", "metadata": {}, "outputs": [], @@ -426,7 +411,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 27, "id": "2bbbb375-6659-4ef0-82ff-a4ace9fdd4f0", "metadata": {}, "outputs": [ @@ -434,17 +419,16 @@ "name": "stdout", "output_type": "stream", "text": [ - "选的表是: ['超市营收明细表']\n" + "选的表是: ['supper_market_info']\n" ] } ], "source": [ - "select_table4 = appbuilder.GBISelectTable(model_name=\"ERNIE-Bot 4.0\", \n", + "select_table4 = appbuilder.SelectTable(model_name=\"ERNIE-Bot 4.0\", \n", " table_descriptions=table_descriptions,\n", " prompt_template=SELECT_TABLE_PROMPT_TEMPLATE)\n", - "query4 = \"列出超市中的所有数据\"\n", - "msg4 = Message(query4)\n", - "select_table_result_message4 = select_table4(message=msg4, session=list())\n", + "\n", + "select_table_result_message4 = select_table4(Message({\"query\":\"列出超市中的所有数据\"}))\n", "print(f\"选的表是: {select_table_result_message4.content}\")" ] }, @@ -467,7 +451,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 28, "id": "323fbe75-62ca-44ab-9ca2-9f747939a2b5", "metadata": {}, "outputs": [], @@ -490,7 +474,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 29, "id": "52436f03-e01c-456a-aaa0-5a7f1afcd9d2", "metadata": {}, "outputs": [ @@ -499,31 +483,22 @@ "output_type": "stream", "text": [ "sql: \n", - "SELECT * FROM `超市营收明细表` WHERE `商品类别` = '水果'\n", + "SELECT * FROM supper_market_info WHERE 商品类别 = '水果'\n", "-----------------\n", "llm result: ```sql\n", - "SELECT * FROM `超市营收明细表` WHERE `商品类别` = '水果'\n", + "SELECT * FROM supper_market_info WHERE 商品类别 = '水果'\n", "```\n" ] } ], "source": [ "\n", - "msg5 = Message(\"查看商品类别是水果的所有数据\")\n", - "gbi_nl2sql5 = appbuilder.GBINL2Sql(model_name=\"ERNIE-Bot 4.0\", table_schemas=table_schemas, prompt_template=NL2SQL_PROMPT_TEMPLATE)\n", - "nl2sql_result_message5 = gbi_nl2sql5(message=msg5, session=session)\n", + "gbi_nl2sql5 = appbuilder.NL2Sql(model_name=\"ERNIE-Bot 4.0\", table_schemas=table_schemas, prompt_template=NL2SQL_PROMPT_TEMPLATE)\n", + "nl2sql_result_message5 = gbi_nl2sql5(Message({\"query\": \"查看商品类别是水果的所有数据\"}))\n", "print(f\"sql: {nl2sql_result_message5.content.sql}\")\n", "print(\"-----------------\")\n", "print(f\"llm result: {nl2sql_result_message5.content.llm_result}\")" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bbdd2f66-e4a8-4001-bc4c-6cd245141deb", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": {