diff --git a/appbuilder/__init__.py b/appbuilder/__init__.py index 146f4b289..b451c6d20 100644 --- a/appbuilder/__init__.py +++ b/appbuilder/__init__.py @@ -67,6 +67,9 @@ def check_version(self): from .core.components.embeddings import Embedding from .core.components.matching import Matching +from .core.components.gbi.nl2sql.component import NL2Sql +from .core.components.gbi.select_table.component import SelectTable + from appbuilder.core.message import Message from appbuilder.core.agent import AgentBase from appbuilder.core.context import UserSession diff --git a/appbuilder/core/components/gbi/__init__.py b/appbuilder/core/components/gbi/__init__.py new file mode 100644 index 000000000..c33303636 --- /dev/null +++ b/appbuilder/core/components/gbi/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023 Baidu, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/appbuilder/core/components/gbi/basic.py b/appbuilder/core/components/gbi/basic.py new file mode 100644 index 000000000..28fa34b81 --- /dev/null +++ b/appbuilder/core/components/gbi/basic.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python 3 +# -*- coding: utf-8 -*- +# Copyright (c) 2023 Baidu, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""GBI nl2sql component. +""" + +from pydantic import BaseModel, Field +from typing import Dict, List + + +class NL2SqlResult(BaseModel): + """ + gbi_nl2sql 返回的结果 + """ + + llm_result: str = Field(..., description="大模型返回的结果") + sql: str = Field(..., description="从大模型中抽取的 sql 语句") + +class SessionRecord(BaseModel): + """ + gbi session record + """ + query: str = Field(..., description="用户的问题") + answer: NL2SqlResult = Field(..., description="nl2sql 返回的结果") + +class ColumnItem(BaseModel): + """ + 列信息 + """ + ori_value: str = Field(..., description="query 中的 词语, 比如: 北京去年收入, " + "分词后: 北京, 去年, 收入, ori_value 是分词中某一个,比如: ori_value = 北京") + column_name: str = Field(..., description="对应数据库中的列名称, 比如: city") + column_value: str = Field(..., description="对应数据库中的列值, 比如: 北京市") + + table_name: str = Field(..., description="该列所在表的名字") + is_like: bool = Field(default=False, description="与 ori_value 的匹配是包含 还是 等于,包含: True; 等于: False") + + +SUPPORTED_MODEL_NAME = { + "ERNIE-Bot 4.0", "ERNIE-Bot-8K", "ERNIE-Bot", "ERNIE-Bot-turbo", "EB-turbo-AppBuilder" +} diff --git a/appbuilder/core/components/gbi/nl2sql/README.md b/appbuilder/core/components/gbi/nl2sql/README.md new file mode 100644 index 000000000..3de0f994c --- /dev/null +++ b/appbuilder/core/components/gbi/nl2sql/README.md @@ -0,0 +1,264 @@ +# GBI 问表 + +## 简介 +GBI 问表,根据提供的 mysql 表的 schema 信息,生成对应问题的 sql 语句。 + +## 基本用法 +这里是一个示例,展示如何基于 mysql 表的 schema, 根据问题生成 sql 语句。 + + +````python +import logging +import os +import appbuilder +from appbuilder.core.message import Message +from appbuilder.core.components.gbi.basic import SessionRecord + +# 设置环境变量 +os.environ["APPBUILDER_TOKEN"] = "***" + +SUPER_MARKET_SCHEMA = """ +``` +CREATE TABLE `supper_market_info` ( + `订单编号` varchar(32) DEFAULT NULL, + `订单日期` date DEFAULT NULL, + `邮寄方式` varchar(32) DEFAULT NULL, + `地区` varchar(32) DEFAULT NULL, + `省份` varchar(32) DEFAULT NULL, + `客户类型` varchar(32) DEFAULT NULL, + `客户名称` varchar(32) DEFAULT NULL, + `商品类别` varchar(32) DEFAULT NULL, + `制造商` varchar(32) DEFAULT NULL, + `商品名称` varchar(32) DEFAULT NULL, + `数量` int(11) DEFAULT NULL, + `销售额` int(11) DEFAULT NULL, + `利润` int(11) DEFAULT NULL +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 +``` +""" + +table_schemas = [SUPER_MARKET_SCHEMA] +gbi_nl2sql = appbuilder.NL2Sql(model_name="ERNIE-Bot 4.0", table_schemas=table_schemas) +query = "列出超市中的所有数据" +nl2sql_result_message = gbi_nl2sql(Message({"query": query})) + +print(f"sql: {nl2sql_result_message.content.sql}") +print("-----------------") +print(f"llm result: {nl2sql_result_message.content.llm_result}") +```` + + sql: + SELECT * FROM supper_market_info; + ----------------- + llm result: ```sql + SELECT * FROM supper_market_info; + ``` + + +## 参数说明 + +### 初始化参数 +- model_name: 支持的模型名字 ERNIE-Bot 4.0, ERNIE-Bot-8K, ERNIE-Bot, ERNIE-Bot-turbo, EB-turbo-AppBuilder" +- table_schemas: 表的 schema,例如: + +``` +CREATE TABLE `supper_market_info` ( + `订单编号` varchar(32) DEFAULT NULL, + `订单日期` date DEFAULT NULL, + `邮寄方式` varchar(32) DEFAULT NULL, + `地区` varchar(32) DEFAULT NULL, + `省份` varchar(32) DEFAULT NULL, + `客户类型` varchar(32) DEFAULT NULL, + `客户名称` varchar(32) DEFAULT NULL, + `商品类别` varchar(32) DEFAULT NULL, + `制造商` varchar(32) DEFAULT NULL, + `商品名称` varchar(32) DEFAULT NULL, + `数量` int(11) DEFAULT NULL, + `销售额` int(11) DEFAULT NULL, + `利润` int(11) DEFAULT NULL +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 +``` + +- knowledge: 用于提供一些知识, 比如 {"毛利率": "毛收入-毛成本/毛成本"} +- prompt_template: prompt 模版, 必须包含的格式如下: + ***你的描述 + {schema} + ***你的描述 + {column_instrument} + ***你的描述 + {kg} + ***你的描述 + 当前时间:{date} + ***你的描述 + {history_instrument} + ***你的描述 + 当前问题:{query} + 回答: + +### 调用参数 +- message: message.content 是 字典,包含: query, session, column_constraint 三个key + * query: 用户的问题 + * session: gbi session 的历史 列表, 参考 GBISessionRecord + * column_constraint: 列选约束 参考 ColumnItem 具体定义 + +#### GBISessionRecord 初始化参数 +- query: 用户的问题 +- answer: gbi_nl2sql 返回的结果 NL2SqlResult + +#### ColumnItem 初始化参数如下 +- ori_value: query 中的 词语, 比如: "北京去年收入", 分词后: "北京, 去年, 收入", ori_value 是分词中某一个,比如: ori_value = "北京" +- column_name: 对应数据库中的列名称, city +- column_value: 对应数据库中的列值, 北京市 +- table_name: 该列所属的表名称 +- is_like: 与 ori_value 的匹配是包含 还是 等于,包含: True; 等于: False + +### 返回值 +- NL2SqlResult 的 message + +#### NL2SqlResult 初始化参数如下 +- llm_result: 大模型返回的结果 +- sql: 从 llm_result 中抽取的 sql 语句 + +## 高级用法 +### 设置 session + + +```python +session = list() +session.append(SessionRecord(query=query, answer=nl2sql_result_message.content)) +``` + +再次问表 + + +```python +nl2sql_result_message2 = gbi_nl2sql(Message({"query": "查看商品类别是水果的所有数据", + "session": session})) +print(f"sql: {nl2sql_result_message2.content.sql}") +print("-----------------") +print(f"llm result: {nl2sql_result_message2.content.llm_result}") +``` + + sql: + SELECT * FROM supper_market_info WHERE 商品类别 = '水果'; + ----------------- + llm result: ```sql + SELECT * FROM supper_market_info WHERE 商品类别 = '水果'; + ``` + + +### 增加列选优化 +实际上数据中 "商品类别" 存储的是 "新鲜水果", 那么就可以通过列选的限制来优化 sql. + + +```python +from appbuilder.core.components.gbi.basic import ColumnItem + +column_constraint = [ColumnItem(ori_value="水果", + column_name="商品类别", + column_value="新鲜水果", + table_name="超市营收明细表", + is_like=False)] +nl2sql_result_message2 = gbi_nl2sql(Message({"query": "查看商品类别是水果的所有数据", + "column_constraint": column_constraint})) + +print(f"sql: {nl2sql_result_message2.content.sql}") +print("-----------------") +print(f"llm result: {nl2sql_result_message2.content.llm_result}") +``` + + sql: + SELECT * FROM supper_market_info WHERE 商品类别='新鲜水果' + ----------------- + llm result: ```sql + SELECT * FROM supper_market_info WHERE 商品类别='新鲜水果' + ``` + + +从上面我们看到,商品类别不在是 "水果" 而是 修订为 "新鲜水果" + +### 增加知识优化 +当计算某些特殊知识的时候,大模型是不知道的,所以需要告诉大模型具体的知识,比如: +利润率的计算方式: 利润/销售额 +可以将该知识注入。具体示例如下: + + +```python +# 注入知识 +gbi_nl2sql.knowledge["利润率"] = "计算方式: 利润/销售额" +``` + + +```python +query3 = "列出商品类别是日用品的利润率" +msg3 = Message(query3) + +nl2sql_result_message3 = gbi_nl2sql(Message({"query": "列出商品类别是日用品的利润率"})) +print(f"sql: {nl2sql_result_message3.content.sql}") +print("-----------------") +print(f"llm result: {nl2sql_result_message3.content.llm_result}") +``` + + sql: + SELECT 商品类别, SUM(利润)/SUM(销售额) AS 利润率 + FROM supper_market_info + WHERE 商品类别 = '日用品' + GROUP BY 商品类别 + ----------------- + llm result: ```sql + SELECT 商品类别, SUM(利润)/SUM(销售额) AS 利润率 + FROM supper_market_info + WHERE 商品类别 = '日用品' + GROUP BY 商品类别 + ``` + + +## 调整 prompt 模版 +有时候,我们希望定义自己的prompt, 但是必须遵循对应的 prompt 模版的格式。 + + +问表的 prompt template 必须包含: +1. {schema} - 表的 schema 信息 +2. {instrument} - 列选限制的信息 +3. {kg} - 知识 +4. {date} - 时间 +5. {history_prompt} - 历史 +6. {query} - 当前问题 + +参考下面的示例 + + +```python +NL2SQL_PROMPT_TEMPLATE = """ + MySql 表 Schema 如下: + {schema} + 请根据用户当前问题,联系历史信息,仅编写1个sql,其中 sql 语句需要使用```sql ```这种 markdown 形式给出。 + 请参考列选信息: + {instrument} + 请参考知识: + {kg} + 当前时间:{date} + 历史信息如下: + {history_prompt} + 当前问题:"{query}" + 回答: +""" +``` + + +```python +gbi_nl2sql5 = appbuilder.NL2Sql(model_name="ERNIE-Bot 4.0", table_schemas=table_schemas, prompt_template=NL2SQL_PROMPT_TEMPLATE) +nl2sql_result_message5 = gbi_nl2sql5(Message({"query": "查看商品类别是水果的所有数据"})) + +print(f"sql: {nl2sql_result_message5.content.sql}") +print("-----------------") +print(f"llm result: {nl2sql_result_message5.content.llm_result}") +``` + + sql: + SELECT * FROM supper_market_info WHERE 商品类别 = '水果' + ----------------- + llm result: ```sql + SELECT * FROM supper_market_info WHERE 商品类别 = '水果' + ``` + diff --git a/appbuilder/core/components/gbi/nl2sql/__init__.py b/appbuilder/core/components/gbi/nl2sql/__init__.py new file mode 100644 index 000000000..c33303636 --- /dev/null +++ b/appbuilder/core/components/gbi/nl2sql/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023 Baidu, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/appbuilder/core/components/gbi/nl2sql/component.py b/appbuilder/core/components/gbi/nl2sql/component.py new file mode 100644 index 000000000..940119f75 --- /dev/null +++ b/appbuilder/core/components/gbi/nl2sql/component.py @@ -0,0 +1,161 @@ +# Copyright (c) 2023 Baidu, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""GBI nl2sql component. +""" +import uuid +import json +from typing import Dict, List, Optional +from pydantic import BaseModel, Field + +from appbuilder.core.component import Component, ComponentArguments +from appbuilder.core.message import Message +from appbuilder.core.components.gbi.basic import SessionRecord +from appbuilder.core.components.gbi.basic import ColumnItem +from appbuilder.core.components.gbi.basic import NL2SqlResult +from appbuilder.core.components.gbi.basic import SUPPORTED_MODEL_NAME + + + +class NL2SqlArgs(ComponentArguments): + """ + nl2sql 的参数 + """ + query: str = Field(..., description="用户的 query 输入") + session: List[SessionRecord] = Field(default=list(), description="gbi session 的历史 列表") + column_constraint: List[ColumnItem] = Field(default=list(), description="列选的限制条件") + + +class NL2Sql(Component): + """ + gib nl2sql + """ + meta = NL2SqlArgs + + def __init__(self, model_name: str, table_schemas: List[str], knowledge: Dict = None, + prompt_template: str = ""): + """ + 创建 gbi nl2sql 对象 + Args: + model_name: 支持的模型名字 ERNIE-Bot 4.0, ERNIE-Bot-8K, ERNIE-Bot, ERNIE-Bot-turbo, EB-turbo-AppBuilder + table_schemas: 表的 schema 列表,例如: ``` + CREATE TABLE `mytable` ( + `d_year` COMMENT '年度,2019,2020..2022..', + `industry` COMMENT '行业', + `project_name` COMMENT '项目名称', + `customer_name` COMMENT '客户名称') + ```" + knowledge: 用于提供一些知识, 比如 {"毛利率": "毛收入-毛成本/毛成本"} + prompt_template: prompt 模版, 必须包含的格式如下: + ***你的描述 + {schema} + ***你的描述 + {column_instrument} + ***你的描述 + {kg} + ***你的描述 + 当前时间:{date} + ***你的描述 + {history_instrument} + ***你的描述 + 当前问题:{query} + 回答: + """ + super().__init__(meta=NL2SqlArgs) + + if model_name not in SUPPORTED_MODEL_NAME: + raise ValueError(f"model_name 错误, 请使用 {SUPPORTED_MODEL_NAME} 中的大模型") + self.model_name = model_name + self.server_sub_path = "/v1/ai_engine/gbi/v1/gbi_nl2sql" + self.table_schemas = table_schemas + self.knowledge = knowledge or dict() + self.prompt_template = prompt_template + + def run(self, + message: Message, timeout: float = 60, retry: int = 0) -> Message[NL2SqlResult]: + """ + 执行 nl2sql + Args: + message: message.content 是字典包含, key 如下: + 1. query: 用户问题 + 2. session: gbi session 的历史 列表, 参考 SessionRecord + 3. column_constraint: 列选约束 参考 ColumnItem 具体定义 + Returns: + NL2SqlResult 的 message + """ + + + try: + inputs = self.meta(**message.content) + except ValidationError as e: + raise ValueError(e) + + response = self._run_nl2sql(query=inputs.query, session=inputs.session, table_schemas=self.table_schemas, + column_constraint=inputs.column_constraint, knowledge=self.knowledge, + prompt_template=self.prompt_template, + model_name=self.model_name, + timeout=timeout, + retry=retry) + + rsp_data = response.json() + nl2sql_result = NL2SqlResult(llm_result=rsp_data["llm_result"], + sql=rsp_data["sql"]) + return Message(content=nl2sql_result) + + def _run_nl2sql(self, query: str, session: List[SessionRecord], table_schemas: List[str], knowledge: Dict[str, str], + prompt_template: str, + column_constraint: List[ColumnItem], + model_name: str, + timeout: float = None, retry: int = 0): + """ + 运行 + Args: + query: query + session: gbi session 的历史 列表 + table_schemas: 表的 schema 列表 + knowledge: 知识 + prompt_template: prompt 模版 + column_constraint: 列的限制 + model_name: 模型名字 + timeout: 超时时间 + retry: + + Returns: + + """ + + headers = self.auth_header() + headers["Content-Type"] = "application/json" + + if retry != self.retry.total: + self.retry.total = retry + + payload = {"query": query, + "table_schemas": table_schemas, + "session": [session_record.dict() for session_record in session], + "column_constraint": [column_item.dict() for column_item in column_constraint], + "model_name": model_name, + "knowledge": knowledge, + "prompt_template": prompt_template} + + server_url = self.service_url(prefix="", sub_path=self.server_sub_path) + response = self.s.post(url=server_url, headers=headers, + json=payload, timeout=timeout) + super().check_response_header(response) + data = response.json() + super().check_response_json(data) + + request_id = self.response_request_id(response) + response.request_id = request_id + return response diff --git a/appbuilder/core/components/gbi/select_table/README.md b/appbuilder/core/components/gbi/select_table/README.md new file mode 100644 index 000000000..391fe8eeb --- /dev/null +++ b/appbuilder/core/components/gbi/select_table/README.md @@ -0,0 +1,118 @@ +# GBI 选表 + +## 简介 +GBI 选表,根据提供的多个 MySql 表名 以及 表名对应的描述信息,通过 query 选择一个或多个最合适的表来回答该 query. +一般的试用场景是,当有数据库有多个表的时候,但是实际只有1个表能回答该 query,那么,通过该能力将该表选择出来,用于后面的 问表 环节。 + +## 基本用法 +下面是根据提供的表的描述信息以及 query 选择对应的表的示例。 + + +```python +import logging +import os +import appbuilder +from appbuilder.core.message import Message +from appbuilder.core.components.gbi.basic import SessionRecord + +# 设置环境变量 +os.environ["APPBUILDER_TOKEN"] = "***" + +# 表的描述信息, key: 表名; value: 是表的描述 +table_descriptions = { + "supper_market_info": "超市营收明细表,包含超市各种信息等", + "product_sales_info": "产品销售表" +} + + +# 生成问表对象 +select_table = appbuilder.SelectTable(model_name="ERNIE-Bot 4.0", table_descriptions=table_descriptions) +select_table_result_message = select_table(Message({"query": "列出超市中的所有数据"})) + +print(f"选的表是: {select_table_result_message.content}") +``` + + 选的表是: ['supper_market_info'] + + +## 参数说明 +### 初始化参数 +- model_name: 支持的模型名字 ERNIE-Bot 4.0, ERNIE-Bot-8K, ERNIE-Bot, ERNIE-Bot-turbo, EB-turbo-AppBuilder +- table_descriptions: 表的描述是个字典,key: 是表的名字, value: 是表的描述,例如: + +``` +{ + "supper_market_info": "超市营收明细表,包含超市各种信息等", + "product_sales_info": "产品销售表" +} +``` +- prompt_template: prompt 模版, 必须包含如下: + 1. {num} - 表的数量, 注意 {num} 有两个地方出现 + 2. {table_desc} - 表的描述 + 3. {query} - query + 参考下面的示例: + +``` +你是一个专业的业务人员,下面有{num}张表,具体表名如下: +{table_desc} +请根据问题帮我选择上述1-{num}种的其中相关表并返回,可以为多表,也可以为单表, +返回多张表请用“,”隔开 +返回格式请参考如下示例: +问题:有多少个审核通过的投运单? +回答: ```DWD_MAT_OPERATION``` +请严格参考示例只不要返回无关内容,直接给出最终答案后面的内容,分析步骤不要输出 +问题:{query} +回答: +``` + +### 调用参数 +- message: message.content 是用户的问题,包含的key: query, session + * query: 用户提出的问题 + * session: GBISessionRecord 列表 + +#### GBISessionRecord 初始化参数 +- query: 用户的问题 +- answer: gbi_nl2sql 返回的结果 NL2SqlResult + +### 返回值 +识别的表名的列表例如: +`["table_name"]` + +## 调整 prompt 模版 +有时候,我们希望定义自己的prompt, 选表支持 prompt 模版的定制化,但是必须遵循对应的 prompt 模版的格式。 + +### 选表 prompt 调整 +选表的 prompt template, 必须包含 +1. {num} - 表的数量, 注意 {num} 有两个地方出现 +2. {table_desc} - 表的描述 +3. {query} - query, 参考下面的示例: + + +```python +SELECT_TABLE_PROMPT_TEMPLATE = """ +你是一个专业的业务人员,下面有{num}张表,具体表名如下: +{table_desc} +请根据问题帮我选择上述1-{num}种的其中相关表并返回,可以为多表,也可以为单表, +返回多张表请用“,”隔开 +返回格式请参考如下示例: +问题:有多少个审核通过的投运单? +回答: ```DWD_MAT_OPERATION``` +请严格参考示例只不要返回无关内容,直接给出最终答案后面的内容,分析步骤不要输出 +问题:{query} +回答: +""" +``` + + +```python +select_table4 = appbuilder.SelectTable(model_name="ERNIE-Bot 4.0", + table_descriptions=table_descriptions, + prompt_template=SELECT_TABLE_PROMPT_TEMPLATE) + +select_table_result_message4 = select_table4(Message({"query": "列出超市中的所有数据"})) + +print(f"选的表是: {select_table_result_message4.content}") +``` + + 选的表是: ['supper_market_info'] + diff --git a/appbuilder/core/components/gbi/select_table/__init__.py b/appbuilder/core/components/gbi/select_table/__init__.py new file mode 100644 index 000000000..c33303636 --- /dev/null +++ b/appbuilder/core/components/gbi/select_table/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023 Baidu, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/appbuilder/core/components/gbi/select_table/component.py b/appbuilder/core/components/gbi/select_table/component.py new file mode 100644 index 000000000..d66546da5 --- /dev/null +++ b/appbuilder/core/components/gbi/select_table/component.py @@ -0,0 +1,146 @@ +# Copyright (c) 2023 Baidu, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""GBI nl2sql component. +""" +import uuid +import json +from typing import Dict, List, Optional +from pydantic import BaseModel, Field + +from appbuilder.core.component import Component, ComponentArguments +from appbuilder.core.message import Message +from appbuilder.core.components.gbi.basic import SessionRecord +from appbuilder.core.components.gbi.basic import SUPPORTED_MODEL_NAME + +class SelectTableArgs(ComponentArguments): + """ + 选表的参数 + """ + query: str = Field(..., description="用户的 query 输入") + session: List[SessionRecord] = Field(default=list(), description="gbi session 的历史 列表") + + +class SelectTable(Component): + """ + gbi 选表 + """ + + def __init__(self, model_name: str, table_descriptions: Dict[str, str], + prompt_template: str = ""): + """ + 创建 GBI 选表对象 + Args: + model_name: 支持的模型名字 ERNIE-Bot 4.0, ERNIE-Bot-8K, ERNIE-Bot, ERNIE-Bot-turbo, EB-turbo-AppBuilder + table_descriptions: 表的描述是个字典,key: 是表的名字, value: 是表的描述,例如: + { + "超市营收明细表": "超市营收明细表,包含超市各种信息等", + "product_sales_info": "产品销售表" + } + prompt_template: rompt 模版, 必须包含如下: + 1. {num} - 表的数量, 注意 {num} 有两个地方出现 + 2. {table_desc} - 表的描述 + 3. {query} - query + 参考下面的示例: + + ``` + 你是一个专业的业务人员,下面有{num}张表,具体表名如下: + {table_desc} + 请根据问题帮我选择上述1-{num}种的其中相关表并返回,可以为多表,也可以为单表, + 返回多张表请用“,”隔开 + 返回格式请参考如下示例: + 问题:有多少个审核通过的投运单? + 回答: ```DWD_MAT_OPERATION``` + 请严格参考示例只不要返回无关内容,直接给出最终答案后面的内容,分析步骤不要输出 + 问题:{query} + 回答: + ``` + secret_key: + gateway: + """ + super().__init__(meta=SelectTableArgs) + if model_name not in SUPPORTED_MODEL_NAME: + raise ValueError(f"model_name 错误, 请使用 {SUPPORTED_MODEL_NAME} 中的大模型") + self.model_name = model_name + self.server_sub_path = "/v1/ai_engine/gbi/v1/gbi_select_table" + self.table_descriptions = table_descriptions + self.prompt_template = prompt_template + + def run(self, + message: Message, timeout: int = 60,retry: int = 0) -> Message[List[str]]: + """ + Args: + message: message.content 字典包含 key: + 1. query - 用户的问题输入 + 2. session - 对话历史, 可选 + + Returns: 识别的表名的列表 ["table_name"] + """ + + try: + inputs = self.meta(**message.content) + except ValidationError as e: + raise ValueError(e) + + response = self._run_select_table(query=inputs.query, session=inputs.session, + prompt_template=self.prompt_template, + table_descriptions=self.table_descriptions, + model_name=self.model_name, + timeout=timeout, + retry=retry) + + rsp_data = response.json() + + return Message(content=rsp_data) + + def _run_select_table(self, query: str, session: List[SessionRecord], + prompt_template, + table_descriptions: Dict[str, str], + model_name: str, + timeout: float = None, retry: int = 0): + """ + 使用给定的输入并返回语音识别的结果。 + + 参数: + request (obj:`ShortSpeechRecognitionRequest`): 输入请求,这是一个必需的参数。 + timeout (float, 可选): 请求的超时时间。 + retry (int, 可选): 请求的重试次数。 + + 返回: + obj:`ShortSpeechRecognitionResponse`: 接口返回的输出消息。 + """ + + headers = self.auth_header() + headers["Content_Type"] = "application/json" + + if retry != self.retry.total: + self.retry.total = retry + + payload = {"query": query, + "table_descriptions": table_descriptions, + "session": [session_record.to_json() for session_record in session], + "model_name": model_name, + "prompt_template": prompt_template} + + server_url = self.service_url(sub_path=self.server_sub_path) + response = self.s.post(url=server_url, headers=headers, + json=payload, timeout=timeout) + super().check_response_header(response) + data = response.json() + super().check_response_json(data) + + request_id = self.response_request_id(response) + response.request_id = request_id + return response + diff --git a/appbuilder/tests/test_gbi_nl2sql.py b/appbuilder/tests/test_gbi_nl2sql.py new file mode 100644 index 000000000..5617f1578 --- /dev/null +++ b/appbuilder/tests/test_gbi_nl2sql.py @@ -0,0 +1,163 @@ +""" +Copyright (c) 2023 Baidu, Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import unittest +import os +import appbuilder +from appbuilder.core.message import Message +from appbuilder.core.components.gbi.basic import NL2SqlResult, SessionRecord +from appbuilder.core.components.gbi.basic import ColumnItem + +SUPER_MARKET_SCHEMA = """ +``` +CREATE TABLE `supper_market_info` ( + `订单编号` varchar(32) DEFAULT NULL, + `订单日期` date DEFAULT NULL, + `邮寄方式` varchar(32) DEFAULT NULL, + `地区` varchar(32) DEFAULT NULL, + `省份` varchar(32) DEFAULT NULL, + `客户类型` varchar(32) DEFAULT NULL, + `客户名称` varchar(32) DEFAULT NULL, + `商品类别` varchar(32) DEFAULT NULL, + `制造商` varchar(32) DEFAULT NULL, + `商品名称` varchar(32) DEFAULT NULL, + `数量` int(11) DEFAULT NULL, + `销售额` int(11) DEFAULT NULL, + `利润` int(11) DEFAULT NULL +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 +``` +""" + +PRODUCT_SALES_INFO = """ +现有 mysql 表 product_sales_info, +该表的用途是: 产品收入表 +``` +CREATE TABLE `product_sales_info` ( + `年` int, + `月` int, + `产品名称` varchar, + `收入` decimal, + `非交付成本` decimal, + `含交付毛利` decimal +) +``` +""" + +PROMPT_TEMPLATE = """ + MySql 表 Schema 如下: + {schema} + 请根据用户当前问题,联系历史信息,仅编写1个sql,其中 sql 语句需要使用```sql ```这种 markdown 形式给出。 + 请参考列选信息: + {instrument} + 请参考知识: + {kg} + 当前时间:{date} + 历史信息如下:{history_prompt} + 当前问题:"{query}" + 回答: +""" + + +class TestGBINL2Sql(unittest.TestCase): + + def setUp(self): + """ + 设置环境变量及必要数据。 + """ + model_name = "ERNIE-Bot 4.0" + table_schemas = [SUPER_MARKET_SCHEMA] + self.nl2sql_node = appbuilder.NL2Sql(model_name=model_name, + table_schemas=table_schemas) + + def test_run_with_default_param(self): + """测试 run 方法使用有效参数""" + query = "列出商品类别是水果的所有信息" + msg = Message({"query": query}) + result_message = self.nl2sql_node(msg) + print(result_message.content.sql) + self.assertIsNotNone(result_message) + self.assertTrue(result_message.content.sql != "") + self.assertTrue(result_message.content.llm_result != "") + + def test_run_with_knowledge(self): + """测试 增加 knowledge 参数""" + + self.nl2sql_node.knowledge["利润率"] = "计算方式: 利润/销售额" + query = "列出商品类别是水果的的利润率" + + msg = Message({"query": query}) + result_message = self.nl2sql_node(msg) + self.assertIsNotNone(result_message) + self.assertTrue(result_message.content.sql != "") + self.assertTrue(result_message.content.llm_result != "") + self.nl2sql_node.knowledge = dict() + + def test_run_with_column_constraint(self): + """测试 增加 column constraint 参数""" + + query = "列出商品类别是水果的的利润率" + column_constraint = [ColumnItem(ori_value="水果", + column_value="新鲜水果", + column_name="商品类别", + table_name="超市营收明细", + is_like=False)] + + msg = Message({"query": query, "column_constraint": column_constraint}) + + result_message = self.nl2sql_node(msg) + + self.assertIsNotNone(result_message) + self.assertTrue(result_message.content.sql != "") + self.assertTrue(result_message.content.llm_result != "") + self.assertIn("新鲜水果", result_message.content.sql) + + def test_run_with_prompt_template(self): + """测试 增加 prompt template 参数""" + self.nl2sql_node.prompt_template = PROMPT_TEMPLATE + query = "列出商品类别是水果的的利润率" + column_constraint = [ColumnItem(ori_value="水果", + column_value="新鲜水果", + column_name="商品类别", + table_name="超市营收明细", + is_like=False)] + msg = Message({"query": query, "column_constraint": column_constraint}) + result_message = self.nl2sql_node(msg) + self.assertIsNotNone(result_message) + self.assertTrue(result_message.content.sql != "") + self.assertTrue(result_message.content.llm_result != "") + # 恢复 prompt template + self.nl2sql_node.prompt_template = "" + + def test_run_with_session(self): + """测试 增加 session 参数""" + session = list() + session_record = SessionRecord(query="列出商品类别是水果的的利润率", + answer=NL2SqlResult( + llm_result="根据问题分析得到 sql 如下: \n " + "```sql\nSELECT * FROM `超市营收明细` " + "WHERE `商品类别` = '水果'\n```", + sql="SELECT * FROM `超市营收明细` WHERE `商品类别` = '水果'")) + session.append(session_record) + + query = "列出所有的商品类别" + msg = Message({"query": query, "session": session}) + result_message = self.nl2sql_node(msg) + self.assertIsNotNone(result_message) + self.assertTrue(result_message.content.sql != "") + self.assertTrue(result_message.content.llm_result != "") + + +if __name__ == '__main__': + unittest.main() diff --git a/appbuilder/tests/test_gbi_select_table.py b/appbuilder/tests/test_gbi_select_table.py new file mode 100644 index 000000000..324850583 --- /dev/null +++ b/appbuilder/tests/test_gbi_select_table.py @@ -0,0 +1,109 @@ +""" +Copyright (c) 2023 Baidu, Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import unittest +import os +import appbuilder +from appbuilder.core.message import Message +from appbuilder.core.components.gbi.basic import SessionRecord + + +SUPER_MARKET_SCHEMA = """ +``` +CREATE TABLE `supper_market_info` ( + `订单编号` varchar(32) DEFAULT NULL, + `订单日期` date DEFAULT NULL, + `邮寄方式` varchar(32) DEFAULT NULL, + `地区` varchar(32) DEFAULT NULL, + `省份` varchar(32) DEFAULT NULL, + `客户类型` varchar(32) DEFAULT NULL, + `客户名称` varchar(32) DEFAULT NULL, + `商品类别` varchar(32) DEFAULT NULL, + `制造商` varchar(32) DEFAULT NULL, + `商品名称` varchar(32) DEFAULT NULL, + `数量` int(11) DEFAULT NULL, + `销售额` int(11) DEFAULT NULL, + `利润` int(11) DEFAULT NULL +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 +``` +""" + +PRODUCT_SALES_INFO = """ +现有 mysql 表 product_sales_info, +该表的用途是: 产品收入表 +``` +CREATE TABLE `product_sales_info` ( + `年` int, + `月` int, + `产品名称` varchar, + `收入` decimal, + `非交付成本` decimal, + `含交付毛利` decimal +) +``` +""" + +PROMPT_TEMPLATE = """ +你是一个专业的业务人员,下面有{num}张表,具体表名如下: +{table_desc} +请根据问题帮我选择上述1-{num}种的其中相关表并返回,可以为多表,也可以为单表, +返回多张表请用“,”隔开 +返回格式请参考如下示例: +问题:有多少个审核通过的投运单? +回答: ```DWD_MAT_OPERATION``` +请严格参考示例只不要返回无关内容,直接给出最终答案后面的内容,分析步骤不要输出 +问题:{query} +回答: +""" + +class TestGBISelectTable(unittest.TestCase): + + def setUp(self): + """ + 设置环境变量及必要数据。 + """ + model_name = "ERNIE-Bot 4.0" + + self.select_table_node = \ + appbuilder.SelectTable(model_name=model_name, + table_descriptions={"supper_market_info": "超市营收明细表,包含超市各种信息等", + "product_sales_info": "产品销售表"}) + + def test_run_with_default_param(self): + """测试 run 方法使用有效参数""" + query = "列出超市中的所有数据" + msg = Message({"query": query}) + result_message = self.select_table_node(message=msg) + print(result_message.content) + self.assertIsNotNone(result_message) + self.assertEqual(len(result_message.content), 1) + self.assertEqual(result_message.content[0], "supper_market_info") + + def test_run_with_prompt_template(self): + """测试 run 方法中 prompt template 模版""" + query = "列出超市中的所有数据" + msg = Message({"query": query}) + result_message = self.select_table_node(message=msg) + self.select_table_node.prompt_template = PROMPT_TEMPLATE + result_message = self.select_table_node(msg) + + self.assertIsNotNone(result_message) + self.assertEqual(len(result_message.content), 1) + self.assertEqual(result_message.content[0], "supper_market_info") + self.select_table_node.prompt_template = "" + + +if __name__ == '__main__': + unittest.main() diff --git a/cookbooks/gbi.ipynb b/cookbooks/gbi.ipynb new file mode 100644 index 000000000..f7b4809fb --- /dev/null +++ b/cookbooks/gbi.ipynb @@ -0,0 +1,525 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "f802e64d-4eaa-445d-a48a-1042a91bc394", + "metadata": { + "tags": [] + }, + "source": [ + "# GBI\n", + "\n", + "## 目标\n", + "通过 GBI sdk 接口完成选表和问表的能力。\n", + "\n", + "## 准备工作\n", + "### 平台注册\n", + "1.先在appbuilder平台注册,获取token\n", + "2.安装appbuilder-sdk" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "2939356f-61c2-42e9-9e0c-fc6729c193f6", + "metadata": {}, + "outputs": [], + "source": [ + "# !pip install appbuilder-sdk" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "4ccff03b-1567-4e8b-8e1f-9a5032690406", + "metadata": {}, + "outputs": [], + "source": [ + "import logging\n", + "import os\n", + "\n", + "# 设置环境变量\n", + "os.environ[\"APPBUILDER_TOKEN\"] = \"***\"\n" + ] + }, + { + "cell_type": "markdown", + "id": "aeb2fa55-075f-48df-a9fb-8b40d9900684", + "metadata": {}, + "source": [ + "## 开发过程" + ] + }, + { + "cell_type": "markdown", + "id": "1c3c5cee", + "metadata": {}, + "source": [ + "### 设置表的 schema" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "d7d6440c", + "metadata": {}, + "outputs": [], + "source": [ + "SUPER_MARKET_SCHEMA = \"\"\"\n", + "```\n", + "CREATE TABLE `supper_market_info` (\n", + " `订单编号` varchar(32) DEFAULT NULL,\n", + " `订单日期` date DEFAULT NULL,\n", + " `邮寄方式` varchar(32) DEFAULT NULL,\n", + " `地区` varchar(32) DEFAULT NULL,\n", + " `省份` varchar(32) DEFAULT NULL,\n", + " `客户类型` varchar(32) DEFAULT NULL,\n", + " `客户名称` varchar(32) DEFAULT NULL,\n", + " `商品类别` varchar(32) DEFAULT NULL,\n", + " `制造商` varchar(32) DEFAULT NULL,\n", + " `商品名称` varchar(32) DEFAULT NULL,\n", + " `数量` int(11) DEFAULT NULL,\n", + " `销售额` int(11) DEFAULT NULL,\n", + " `利润` int(11) DEFAULT NULL\n", + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4\n", + "```\n", + "\"\"\"\n", + "\n", + "PRODUCT_SALES_INFO = \"\"\"\n", + "现有 mysql 表 product_sales_info, \n", + "该表的用途是: 产品收入表\n", + "```\n", + "CREATE TABLE `product_sales_info` (\n", + " `年` int,\n", + " `月` int,\n", + " `产品名称` varchar,\n", + " `收入` decimal,\n", + " `非交付成本` decimal,\n", + " `含交付毛利` decimal\n", + ")\n", + "```\n", + "\"\"\"\n", + "\n", + "# schema 和表名的映射\n", + "SCHEMA_MAPPING = {\n", + " \"supper_market_info\": SUPER_MARKET_SCHEMA,\n", + " \"PRODUCT_SALES_INFO\": PRODUCT_SALES_INFO\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "463254a1", + "metadata": {}, + "source": [ + "设置表的描述用于选表" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "7fefcae1", + "metadata": {}, + "outputs": [], + "source": [ + "table_descriptions = {\n", + " \"supper_market_info\": \"超市营收明细表,包含超市各种信息等\",\n", + " \"product_sales_info\": \"产品销售表\"\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "a0aff843", + "metadata": {}, + "source": [ + "### 选表" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "41559341-fd7a-478c-a08b-1477d79e9d41", + "metadata": { + "ExecuteTime": { + "end_time": "2023-12-18T06:24:26.982459Z", + "start_time": "2023-12-18T06:23:53.771345Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "选的表是: ['supper_market_info']\n" + ] + } + ], + "source": [ + "import appbuilder\n", + "from appbuilder.core.message import Message\n", + "from appbuilder.core.components.gbi.basic import SessionRecord\n", + "\n", + "# 生成问表对象\n", + "select_table = appbuilder.SelectTable(model_name=\"ERNIE-Bot 4.0\", table_descriptions=table_descriptions)\n", + "query = \"列出超市中的所有数据\"\n", + "msg = Message({\"query\": query})\n", + "select_table_result_message = select_table(msg)\n", + "print(f\"选的表是: {select_table_result_message.content}\")" + ] + }, + { + "cell_type": "markdown", + "id": "16a8aa38-7a33-4e27-bca4-00900cfe1641", + "metadata": {}, + "source": [ + "### 问表\n", + "基于上面选出的表,通过获取 shema 进行问表" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "9f45ef5f-6206-4b31-83c4-3c8eb2c86925", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "sql: \n", + "SELECT * FROM supper_market_info;\n", + "-----------------\n", + "llm result: ```sql\n", + "SELECT * FROM supper_market_info;\n", + "```\n" + ] + } + ], + "source": [ + "table_schemas = [SCHEMA_MAPPING[table_name] for table_name in select_table_result_message.content]\n", + "gbi_nl2sql = appbuilder.NL2Sql(model_name=\"ERNIE-Bot 4.0\", table_schemas=table_schemas)\n", + "nl2sql_result_message = gbi_nl2sql(Message({\"query\": \"列出超市中的所有数据\"}))\n", + "print(f\"sql: {nl2sql_result_message.content.sql}\")\n", + "print(\"-----------------\")\n", + "print(f\"llm result: {nl2sql_result_message.content.llm_result}\")" + ] + }, + { + "cell_type": "markdown", + "id": "b0409c46-e8c7-403a-a827-fcdc8e717be6", + "metadata": {}, + "source": [ + "设置 session" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "a23b8cad-f426-4074-9311-c2c33aaea07b", + "metadata": {}, + "outputs": [], + "source": [ + "session = list()\n", + "session.append(SessionRecord(query=query, answer=nl2sql_result_message.content))" + ] + }, + { + "cell_type": "markdown", + "id": "22b3d877-f61f-4958-a084-7507a3017e17", + "metadata": {}, + "source": [ + "再次问表" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "2adcb091-fb53-4364-b4d8-20564439ff51", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "sql: \n", + "SELECT * FROM supper_market_info WHERE 商品类别 = '水果';\n", + "-----------------\n", + "llm result: ```sql\n", + "SELECT * FROM supper_market_info WHERE 商品类别 = '水果';\n", + "```\n" + ] + } + ], + "source": [ + "nl2sql_result_message2 = gbi_nl2sql(Message({\"query\": \"查看商品类别是水果的所有数据\", \n", + " \"session\": session}))\n", + "print(f\"sql: {nl2sql_result_message2.content.sql}\")\n", + "print(\"-----------------\")\n", + "print(f\"llm result: {nl2sql_result_message2.content.llm_result}\")" + ] + }, + { + "cell_type": "markdown", + "id": "9e0609ae-f2bc-43d3-9023-14e9f8618158", + "metadata": {}, + "source": [ + "### 增加列选优化\n", + "实际上数据中 \"商品类别\" 存储的是 \"新鲜水果\", 那么就可以通过列选的限制来优化 sql." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "2a7c7923-019e-4660-9e36-4431e9d2f3a6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "sql: \n", + "SELECT * FROM supper_market_info WHERE 商品类别='新鲜水果'\n", + "-----------------\n", + "llm result: ```sql\n", + "SELECT * FROM supper_market_info WHERE 商品类别='新鲜水果'\n", + "```\n" + ] + } + ], + "source": [ + "from appbuilder.core.components.gbi.basic import ColumnItem\n", + "\n", + "\n", + "column_constraint = [ColumnItem(ori_value=\"水果\", \n", + " column_name=\"商品类别\", \n", + " column_value=\"新鲜水果\", \n", + " table_name=\"超市营收明细表\", \n", + " is_like=False)]\n", + "\n", + "nl2sql_result_message2 = gbi_nl2sql(Message({\"query\": \"查看商品类别是水果的所有数据\",\n", + " \"column_constraint\": column_constraint}))\n", + "\n", + "print(f\"sql: {nl2sql_result_message2.content.sql}\")\n", + "print(\"-----------------\")\n", + "print(f\"llm result: {nl2sql_result_message2.content.llm_result}\")" + ] + }, + { + "cell_type": "markdown", + "id": "8385312c-aea1-42cd-b61b-a8d36f4f0665", + "metadata": {}, + "source": [ + "从上面我们看到,商品类别不在是 \"水果\" 而是 修订为 \"新鲜水果\"" + ] + }, + { + "cell_type": "markdown", + "id": "6e98c414-8b2b-4187-a270-3117a4f431ff", + "metadata": {}, + "source": [ + "### 增加知识优化\n", + "当计算某些特殊知识的时候,大模型是不知道的,所以需要告诉大模型具体的知识,比如:\n", + "利润率的计算方式: 利润/销售额\n", + "可以将该知识注入。具体示例如下:" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "cade4693-29dc-431c-bf84-c6dc09104294", + "metadata": {}, + "outputs": [], + "source": [ + "# 注入知识\n", + "gbi_nl2sql.knowledge[\"利润率\"] = \"计算方式: 利润/销售额\"" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "1dc181e8-47a1-4b82-8bb5-ce3339be53f6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "sql: \n", + "SELECT 商品类别, SUM(利润)/SUM(销售额) AS 利润率\n", + "FROM supper_market_info\n", + "WHERE 商品类别 = '日用品'\n", + "GROUP BY 商品类别\n", + "-----------------\n", + "llm result: ```sql\n", + "SELECT 商品类别, SUM(利润)/SUM(销售额) AS 利润率\n", + "FROM supper_market_info\n", + "WHERE 商品类别 = '日用品'\n", + "GROUP BY 商品类别\n", + "```\n" + ] + } + ], + "source": [ + "nl2sql_result_message3 = gbi_nl2sql(Message({\"query\": \"列出商品类别是日用品的利润率\"}))\n", + "print(f\"sql: {nl2sql_result_message3.content.sql}\")\n", + "print(\"-----------------\")\n", + "print(f\"llm result: {nl2sql_result_message3.content.llm_result}\")" + ] + }, + { + "cell_type": "markdown", + "id": "c5570cd9-dbaf-45cd-ab03-1a7f92e7d0d4", + "metadata": {}, + "source": [ + "## 调整 prompt 模版\n", + "有时候,我们希望定义自己的prompt, 选表和问表两个环节都支持 prompt 模版的定制化,但是必须遵循对应的 prompt 模版的格式。" + ] + }, + { + "cell_type": "markdown", + "id": "6e3d4967-2b4c-437d-9d72-fb1b94bdcf59", + "metadata": {}, + "source": [ + "### 选表 prompt 调整\n", + "选表的 prompt template, 必须包含 \n", + "1. {num} - 表的数量, 注意 {num} 有两个地方出现\n", + "2. {table_desc} - 表的描述\n", + "3. {query} - query, 参考下面的示例:" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "2ae6ffbc-4237-4fb2-8168-480b81bfd873", + "metadata": {}, + "outputs": [], + "source": [ + "SELECT_TABLE_PROMPT_TEMPLATE = \"\"\"\n", + "你是一个专业的业务人员,下面有{num}张表,具体表名如下:\n", + "{table_desc}\n", + "请根据问题帮我选择上述1-{num}种的其中相关表并返回,可以为多表,也可以为单表,\n", + "返回多张表请用“,”隔开\n", + "返回格式请参考如下示例:\n", + "问题:有多少个审核通过的投运单?\n", + "回答: ```DWD_MAT_OPERATION```\n", + "请严格参考示例只不要返回无关内容,直接给出最终答案后面的内容,分析步骤不要输出\n", + "问题:{query}\n", + "回答:\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "2bbbb375-6659-4ef0-82ff-a4ace9fdd4f0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "选的表是: ['supper_market_info']\n" + ] + } + ], + "source": [ + "select_table4 = appbuilder.SelectTable(model_name=\"ERNIE-Bot 4.0\", \n", + " table_descriptions=table_descriptions,\n", + " prompt_template=SELECT_TABLE_PROMPT_TEMPLATE)\n", + "\n", + "select_table_result_message4 = select_table4(Message({\"query\":\"列出超市中的所有数据\"}))\n", + "print(f\"选的表是: {select_table_result_message4.content}\")" + ] + }, + { + "cell_type": "markdown", + "id": "4f3fd089-613b-4bdd-95ac-c87f89c0fc61", + "metadata": {}, + "source": [ + "## 问表 prompt 调整\n", + "问表的 prompt template 必须包含:\n", + "1. {schema} - 表的 schema 信息\n", + "2. {instrument} - 列选限制的信息\n", + "3. {kg} - 知识\n", + "4. {date} - 时间\n", + "5. {history_prompt} - 历史\n", + "6. {query} - 当前问题\n", + "\n", + "参考下面的示例" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "323fbe75-62ca-44ab-9ca2-9f747939a2b5", + "metadata": {}, + "outputs": [], + "source": [ + "NL2SQL_PROMPT_TEMPLATE = \"\"\"\n", + " MySql 表 Schema 如下:\n", + " {schema}\n", + " 请根据用户当前问题,联系历史信息,仅编写1个sql,其中 sql 语句需要使用```sql ```这种 markdown 形式给出。\n", + " 请参考列选信息:\n", + " {instrument}\n", + " 请参考知识:\n", + " {kg}\n", + " 当前时间:{date}\n", + " 历史信息如下:\n", + " {history_prompt}\n", + " 当前问题:\"{query}\"\n", + " 回答:\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "52436f03-e01c-456a-aaa0-5a7f1afcd9d2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "sql: \n", + "SELECT * FROM supper_market_info WHERE 商品类别 = '水果'\n", + "-----------------\n", + "llm result: ```sql\n", + "SELECT * FROM supper_market_info WHERE 商品类别 = '水果'\n", + "```\n" + ] + } + ], + "source": [ + "\n", + "gbi_nl2sql5 = appbuilder.NL2Sql(model_name=\"ERNIE-Bot 4.0\", table_schemas=table_schemas, prompt_template=NL2SQL_PROMPT_TEMPLATE)\n", + "nl2sql_result_message5 = gbi_nl2sql5(Message({\"query\": \"查看商品类别是水果的所有数据\"}))\n", + "print(f\"sql: {nl2sql_result_message5.content.sql}\")\n", + "print(\"-----------------\")\n", + "print(f\"llm result: {nl2sql_result_message5.content.llm_result}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}