Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update vdb cookbook #189

Merged
merged 12 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions appbuilder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ def check_version(self):
from .core.components.doc_splitter.doc_splitter import DocSplitter
from .core.components.retriever.bes.bes_retriever import BESRetriever
from .core.components.retriever.bes.bes_retriever import BESVectorStoreIndex
from .core.components.retriever.baidu_vdb.baiduvdb_retriever import BaiduVDBVectorStoreIndex
from .core.components.retriever.baidu_vdb.baiduvdb_retriever import BaiduVDBRetriever
from .core.components.retriever.baidu_vdb.baiduvdb_retriever import TableParams

from .core.components.dish_recognize.component import DishRecognition
from .core.components.translate.component import Translation
from .core.components.animal_recognize.component import AnimalRecognition
Expand Down Expand Up @@ -134,6 +138,10 @@ def check_version(self):
"DocSplitter",
"BESRetriever",
"BESVectorStoreIndex",
"BaiduVDBVectorStoreIndex",
"BaiduVDBRetriever",
"TableParams",

'DishRecognition',
'Translation',
'Message',
Expand Down
20 changes: 20 additions & 0 deletions appbuilder/core/components/retriever/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) 2023 Baidu, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .bes import BESVectorStoreIndex
from .bes import BESRetriever

from .baidu_vdb import BaiduVDBVectorStoreIndex
from .baidu_vdb import BaiduVDBRetriever
from .baidu_vdb import TableParams
33 changes: 28 additions & 5 deletions appbuilder/core/components/retriever/baidu_vdb/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,17 @@

以下是有关如何开始使用BaiduVDBRetriever的代码示例:

补充说明:
- `you_vdb_instance_id` 为VectorDB 实例ID,请替换为您的实例ID,在VectorDB控制台界面上可以查看
- `your_api_key` 为您在VectorDB上申请的账户密钥,请替换为您自己的root账户密钥,在VectorDB控制台界面上可以查看

```python
import os
import appbuilder

# 请前往千帆AppBuilder官网创建密钥,流程详见:https://cloud.baidu.com/doc/AppBuilder/s/Olq6grrt6#1%E3%80%81%E5%88%9B%E5%BB%BA%E5%AF%86%E9%92%A5
os.environ["APPBUILDER_TOKEN"] = '...'

embedding = appbuilder.Embedding()
segments = appbuilder.Message(["文心一言大模型", "百度在线科技有限公司"])
# 初始化构建索引
vector_index = appbuilder.BaiduVDBVectorStoreIndex.from_params(
Expand All @@ -52,21 +55,41 @@ os.environ["APPBUILDER_TOKEN"] = "bce-YOURTOKEN"
```

### 初始化参数说明:
`BaiduVDBVectorStoreIndex()` 实例化参数说明:
- instance_id(str,必填):百度向量数据库的实例id,创建实例时获取
- api_key (str,必填):连接向量数据库所需的密码,创建实例时获取
- account (str,非必填):连接向量数据库所需的用户名,默认root
- database_name (str,非必填) :向量数据库的名称,默认为AppBuilderDatabase
- table_params (TableParams,非必填) :VectorDB table参数,参考链接[VectorDB table params](https://cloud.baidu.com/doc/VDB/s/mlrsob0p6)
- embedding (Embedding,非必填) :appbuilder.Embedding类型,若有构造好的Embedding,可以增量插入,否则默认新建embedding

-------

- segments (Message[List[str]],必填):需要入库的文本段落
`BaiduVDBVectorStoreIndex().from_params()` 构造函数参数说明:
- instance_id(str,必填):百度向量数据库的实例id,创建实例时获取
- api_key (str,必填):连接向量数据库所需的密码,创建实例时获取
- account (str,非必填):连接向量数据库所需的用户名,默认root
- embedding (obj,非必填):用于将文本转为向量的模型,默认为Embedding
- database_name (str,非必填) :向量数据库的名称,默认为AppBuilderDatabase
- table_name (str,非必填) :向量数据库的表名,默认为AppBuilderTable
- drop_exists (bool, 非必填) :是否清空数据库历史记录,默认为False

-------


### 调用参数:

`BaiduVDBRetriever().run()` 函数参数说明:

| 参数名称 | 参数类型 |是否必须 | 描述 | 示例值 |
|---------|--------|--------|------------------|---------------|
| message | String |是 | 需要检索的内容 | "中国2023人均GDP" |
| top_k | int |否 | 返回相似度最高的top_k个内容 | 1 |
| message | String |是 | 需要检索的内容, 类型为Message,content类型为str, 长度要求(0,1000) | "中国2023人均GDP" |
| top_k | int |否 | 返回相似度最高的top_k个内容,top_k的数值范围(1,embedding索引数量] | 1 |


### 响应参数

`BaiduVDBRetriever().run()` 函数返回值说明:

| 参数名称 | 参数类型 | 描述 | 示例值 |
|------|--------|-----|--------------------|
| text | string | 检索结果 | "中国2023年人均GDP8.94万元" |
Expand Down
17 changes: 17 additions & 0 deletions appbuilder/core/components/retriever/baidu_vdb/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) 2023 Baidu, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .baiduvdb_retriever import BaiduVDBVectorStoreIndex
from .baiduvdb_retriever import BaiduVDBRetriever
from .baiduvdb_retriever import TableParams
127 changes: 111 additions & 16 deletions appbuilder/core/components/retriever/baidu_vdb/baiduvdb_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
"Unsupported metric type: `{}`, supported metric types are {}"
)


def _try_import() -> None:
try:
import pymochow
Expand All @@ -66,6 +67,7 @@ def _try_import() -> None:
"`pymochow` package not found, please run `pip install pymochow`"
)


class TableParams:
"""Baidu VectorDB table params.
See the following documentation for details:
Expand Down Expand Up @@ -102,6 +104,7 @@ def __init__(
self.drop_exists = drop_exists
self.vector_params = vector_params


class BaiduVDBVectorStoreIndex:
"""
Baidu VDB向量存储检索工具
Expand All @@ -110,19 +113,43 @@ class BaiduVDBVectorStoreIndex:

def __init__(
self,
instance_id,
instance_id: str,
api_key: str,
account: str = DEFAULT_ACCOUNT,
database_name: str = DEFAULT_DATABASE_NAME,
table_params: TableParams = TableParams(dimension=384),
embedding=None,
):
if not isinstance(instance_id, str):
raise TypeError(
"Parameter `instance_id` must be a string, but got {}".format(
type(instance_id)))
if not isinstance(api_key, str):
raise TypeError(
"Parameter `api_key` must be a string, but got {}".format(
type(api_key)))
if not isinstance(account, str):
raise TypeError(
"Parameter `account` must be a string, but got {}".format(
type(account)))
if not isinstance(database_name, str):
raise TypeError(
"Parameter `database_name` must be a string, but got {}".format(
type(database_name)))
if not isinstance(table_params, TableParams):
raise TypeError(
"Parameter `table_params` must be a TableParams, but got {}".format(
type(table_params)))
if embedding is not None and not isinstance(embedding, Embedding):
raise TypeError(
"Parameter `embedding` must be a Embedding, but got {}".format(
type(embedding)))

if embedding is None:
embedding = Embedding()

self.embedding = embedding

self._init_client(instance_id, account, api_key)
self._create_database_if_not_exists(database_name)
self._create_table(table_params)
Expand All @@ -135,13 +162,16 @@ def _init_client(self, instance_id, account, api_key):
from pymochow.configuration import Configuration
from pymochow.auth.bce_credentials import AppBuilderCredentials

gateway = os.getenv("GATEWAY_URL") if os.getenv("GATEWAY_URL") else GATEWAY_URL
gateway = os.getenv("GATEWAY_URL") if os.getenv(
"GATEWAY_URL") else GATEWAY_URL
appbuilder_token = os.getenv("APPBUILDER_TOKEN")
uri_prefix = self.vdb_uri_prefix + instance_id.encode('utf-8')

config = Configuration(
credentials=AppBuilderCredentials(account, api_key, appbuilder_token),
credentials=AppBuilderCredentials(
account, api_key, appbuilder_token),
endpoint=gateway,
uri_perfix=self.vdb_uri_prefix,
uri_prefix=uri_prefix,
connection_timeout_in_mills=DEFAULT_TIMEOUT_IN_MILLS,
)
self.vdb_client = pymochow.MochowClient(config)
Expand Down Expand Up @@ -196,7 +226,10 @@ def _create_table_in_db(
fields.append(Field(FIELD_TEXT, FieldType.STRING))
fields.append(
Field(
FIELD_VECTOR, FieldType.FLOAT_VECTOR, dimension=table_params.dimension
FIELD_VECTOR,
FieldType.FLOAT_VECTOR,
dimension=table_params.dimension,
not_null=True,
)
)

Expand All @@ -221,7 +254,7 @@ def _create_table_in_db(
)
# need wait 10s to wait proxy sync meta
time.sleep(10)

@staticmethod
def _get_index_params(index_type: Any, table_params: TableParams) -> None:
from pymochow.model.enum import IndexType
Expand All @@ -248,7 +281,8 @@ def _get_index_type(index_type_value: str) -> Any:
try:
return IndexType(index_type_value)
except ValueError:
support_index_types = [d.value for d in IndexType.__members__.values()]
support_index_types = [
d.value for d in IndexType.__members__.values()]
raise ValueError(
NOT_SUPPORT_INDEX_TYPE_ERROR.format(
index_type_value, support_index_types
Expand All @@ -263,7 +297,8 @@ def _get_metric_type(metric_type_value: str) -> Any:
try:
return MetricType(metric_type_value.upper())
except ValueError:
support_metric_types = [d.value for d in MetricType.__members__.values()]
support_metric_types = [
d.value for d in MetricType.__members__.values()]
raise ValueError(
NOT_SUPPORT_METRIC_TYPE_ERROR.format(
metric_type_value, support_metric_types
Expand Down Expand Up @@ -297,14 +332,16 @@ def add_segments(self, segments: Message, metadata=""):
segment_vectors = segment_vectors.content
vector_dims = len(segment_vectors[0])
segments = segments.content

if len(segments) == 0:
raise ValueError("add_segments函数 参数segment 内容为空")

rows = []
for segment, vector in zip(segments, segment_vectors):
row = Row(text=segment, vector=vector, metadata=metadata)
rows.append(row)
if len(rows) >= DEFAULT_BATCH_SIZE:
self.collection.upsert(rows=rows)
rows = []
self.collection.upsert(rows=rows)
rows = []

if len(rows) > 0:
self.table.upsert(rows=rows)
Expand All @@ -320,10 +357,47 @@ def from_params(
drop_exists: bool = False,
**kwargs,
):
"""
从参数中实例化类。

Args:
cls: 类对象,即当前函数所属的类。
instance_id: str,实例ID。
api_key: str,API密钥。
account: str,账户名,默认为root。
database_name: str,数据库名,默认为AppBuilderDatabase。
table_name: str,表名,默认为AppBuilderTable。
drop_exists: bool,是否删除已存在的表,默认为False。
**kwargs: 其他参数,可选的维度参数dimension默认为384。

Returns:
类实例,包含实例ID、账户名、API密钥、数据库名、表参数等属性。

"""
_try_import()
dimension = kwargs.get("dimension", 384)

if not isinstance(instance_id, str):
raise TypeError("instance_id must be a string. but got {}".format(
type(instance_id)))
if not isinstance(api_key, str):
raise TypeError("api_key must be a string. but got {}".format(
type(api_key)))
if not isinstance(account, str):
raise TypeError("account must be a string. but got {}".format(
type(account)))
if not isinstance(database_name, str):
raise TypeError("database_name must be a string. but got {}".format(
type(database_name)))
if not isinstance(table_name, str):
raise TypeError("table_name must be a string. but got {}".format(
type(table_name)))
if not isinstance(drop_exists, bool):
raise TypeError("drop_exists must be a boolean. but got {}".format(
type(drop_exists)))

table_params = TableParams(
dimension=dimension,
dimension=dimension,
table_name=table_name,
drop_exists=drop_exists,
)
Expand Down Expand Up @@ -353,15 +427,16 @@ class BaiduVDBRetriever(Component):
self.api_key,
)
vector_index.add_segments(segments)

query = appbuilder.Message("文心一言")
time.sleep(5)
retriever = vector_index.as_retriever()
res = retriever(query)

"""
name: str = "BaiduVectorDBRetriever"
tool_desc: Dict[str, Any] = {"description": "a retriever based on Baidu VectorDB"}
tool_desc: Dict[str, Any] = {
"description": "a retriever based on Baidu VectorDB"}

def __init__(self, embedding, table):
super().__init__()
Expand All @@ -381,13 +456,33 @@ def run(self, query: Message, top_k: int = 1):
from pymochow.model.table import AnnSearch, HNSWSearchParams
from pymochow.model.enum import ReadConsistency

if not isinstance(query, Message):
raise TypeError("Parameter `query` must be a Message, but got {}"
.format(type(query)))
if not isinstance(top_k, int):
raise TypeError("Parameter `top_k` must be a int, but got {}"
.format(type(top_k)))
if top_k <= 0:
raise ValueError("Parameter `top_k` must be a positive integer, but got {}"
.format(top_k))

content = query.content
if not isinstance(content, str):
raise ValueError("Parameter `query` content is not a string, got: {}"
.format(type(content)))
if len(content) == 0:
raise ValueError("Parameter `query` content is empty")
if len(content) > 1000:
raise ValueError("Parameter `query` content is too long, max length per batch size is 1000")

query_embedding = self.embedding(query)
anns = AnnSearch(
vector_field=FIELD_VECTOR,
vector_floats=query_embedding.content,
params=HNSWSearchParams(ef=10, limit=top_k),
)
res = self.table.search(anns=anns, read_consistency=ReadConsistency.STRONG)
res = self.table.search(
anns=anns, read_consistency=ReadConsistency.STRONG)
rows = res.rows
docs = []
if rows is None or len(rows) == 0:
Expand Down
16 changes: 16 additions & 0 deletions appbuilder/core/components/retriever/bes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) 2023 Baidu, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .bes_retriever import BESVectorStoreIndex
from .bes_retriever import BESRetriever
Loading
Loading