-
Notifications
You must be signed in to change notification settings - Fork 118
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
397 additions
and
0 deletions.
There are no files selected for viewing
397 changes: 397 additions & 0 deletions
397
appbuilder/core/components/retriever/baiduvdb_retriever.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,397 @@ | ||
# Copyright (c) 2023 Baidu, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
# -*- coding: utf-8 -*- | ||
""" | ||
基于Baidu VDB的retriever | ||
""" | ||
import importlib | ||
import os | ||
import random | ||
import string | ||
import time | ||
from typing import Dict, Any | ||
from appbuilder.core.component import Component, Message | ||
from appbuilder.core.components.embeddings.component import Embedding | ||
from appbuilder.core.constants import GATEWAY_URL | ||
from appbuilder.utils.logger_util import logger | ||
|
||
DEFAULT_ACCOUNT = "root" | ||
DEFAULT_DATABASE_NAME = "AppBuilderDatabase" | ||
DEFAULT_TABLE_NAME = "AppBuilderTable" | ||
DEFAULT_TIMEOUT_IN_MILLS: int = 30 * 1000 | ||
|
||
DEFAULT_PARTITION = 1 | ||
DEFAULT_REPLICA = 3 | ||
DEFAULT_INDEX_TYPE = "HNSW" | ||
DEFAULT_METRIC_TYPE = "L2" | ||
|
||
DEFAULT_HNSW_M = 16 | ||
DEFAULT_HNSW_EF_CONSTRUCTION = 200 | ||
DEFAULT_HNSW_EF = 10 | ||
|
||
DEFAULT_BATCH_SIZE = 1000 | ||
|
||
FIELD_ID: str = "id" | ||
FIELD_TEXT: str = "text" | ||
FIELD_VECTOR: str = "vector" | ||
FIELD_METADATA: str = "metadata" | ||
INDEX_VECTOR: str = "vector_idx" | ||
|
||
def _try_import() -> None: | ||
try: | ||
import pymochow | ||
except ImportError: | ||
raise ImportError( | ||
"`pymochow` package not found, please run `pip install pymochow`" | ||
) | ||
|
||
class TableParams: | ||
"""Baidu VectorDB table params. | ||
See the following documentation for details: | ||
https://cloud.baidu.com/doc/VDB/s/mlrsob0p6 | ||
Args: | ||
dimension int: The dimension of vector. | ||
replication int: The number of replicas in the table. | ||
partition int: The number of partitions in the table. | ||
index_type (Optional[str]): HNSW, FLAT... Default value is "HNSW" | ||
metric_type (Optional[str]): L2, COSINE, IP. Default value is "L2" | ||
drop_exists (Optional[bool]): Delete the existing Table. Default value is False. | ||
vector_params (Optional[Dict]): | ||
if HNSW set parameters: `M` and `efConstruction`, for example `{'M': 16, efConstruction: 200}` | ||
default is HNSW | ||
""" | ||
|
||
def __init__( | ||
self, | ||
dimension: int, | ||
table_name: str = DEFAULT_TABLE_NAME, | ||
replication: int = DEFAULT_REPLICA, | ||
partition: int = DEFAULT_PARTITION, | ||
index_type: str = DEFAULT_INDEX_TYPE, | ||
metric_type: str = DEFAULT_METRIC_TYPE, | ||
drop_exists: bool = False, | ||
vector_params: Dict = None, | ||
): | ||
self.dimension = dimension | ||
self.table_name = table_name | ||
self.replication = replication | ||
self.partition = partition | ||
self.index_type = index_type | ||
self.metric_type = metric_type | ||
self.drop_exists = drop_exists | ||
self.vector_params = vector_params | ||
|
||
class BaiduVDBVectorStoreIndex: | ||
""" | ||
Baidu VDB向量存储检索工具 | ||
""" | ||
base_vdb_url: str = "/v1/bce/vdb/cluster/" | ||
|
||
def __init__( | ||
self, | ||
cluster_id, | ||
api_key: str, | ||
account: str = DEFAULT_ACCOUNT, | ||
database_name: str = DEFAULT_DATABASE_NAME, | ||
table_params: TableParams = TableParams(dimension=384), | ||
embedding=None, | ||
prefix="/rpc/2.0/cloud_hub" | ||
): | ||
|
||
if embedding is None: | ||
embedding = Embedding() | ||
|
||
self.embedding = embedding | ||
self.prefix = prefix | ||
|
||
self._init_client(cluster_id, account, api_key) | ||
self._create_database_if_not_exists(database_name) | ||
self._create_table(table_params) | ||
|
||
def _init_client(self, cluster_id, account, api_key): | ||
""" | ||
创建一个vdb的client | ||
""" | ||
import pymochow | ||
from pymochow.configuration import Configuration | ||
from pymochow.auth.bce_credentials import BceCredentials | ||
|
||
gateway = os.getenv("GATEWAY_URL") if os.getenv("GATEWAY_URL") else GATEWAY_URL | ||
endpoint = gateway + self.prefix + self.base_vdb_url + cluster_id | ||
|
||
config = Configuration( | ||
credentials=BceCredentials(account, api_key), | ||
endpoint=endpoint, | ||
connection_timeout_in_mills=DEFAULT_TIMEOUT_IN_MILLS, | ||
) | ||
self.vdb_client = pymochow.MochowClient(config) | ||
|
||
def _create_database_if_not_exists(self, database_name: str) -> None: | ||
db_list = self.vdb_client.list_databases() | ||
|
||
if database_name in [db.database_name for db in db_list]: | ||
self.database = self.vdb_client.database(database_name) | ||
else: | ||
self.database = self.vdb_client.create_database(database_name) | ||
|
||
def _create_table(self, table_params: TableParams) -> None: | ||
import pymochow | ||
|
||
if table_params is None: | ||
raise ValueError(VALUE_NONE_ERROR.format("table_params")) | ||
|
||
try: | ||
self.table = self.database.describe_table(table_params.table_name) | ||
if table_params.drop_exists: | ||
self.database.drop_table(table_params.table_name) | ||
# wait db release resource | ||
time.sleep(5) | ||
self._create_table_in_db(table_params) | ||
except pymochow.exception.ServerError: | ||
self._create_table_in_db(table_params) | ||
|
||
def _create_table_in_db( | ||
self, | ||
table_params: TableParams, | ||
) -> None: | ||
from pymochow.model.enum import FieldType | ||
from pymochow.model.schema import Field, Schema, SecondaryIndex, VectorIndex | ||
from pymochow.model.table import Partition | ||
|
||
index_type = self._get_index_type(table_params.index_type) | ||
metric_type = self._get_metric_type(table_params.metric_type) | ||
vector_params = self._get_index_params(index_type, table_params) | ||
fields = [] | ||
fields.append( | ||
Field( | ||
FIELD_ID, | ||
FieldType.UINT64, | ||
primary_key=True, | ||
partition_key=True, | ||
auto_increment=True, | ||
not_null=True, | ||
) | ||
) | ||
fields.append(Field(FIELD_METADATA, FieldType.STRING)) | ||
fields.append(Field(FIELD_TEXT, FieldType.STRING)) | ||
fields.append( | ||
Field( | ||
FIELD_VECTOR, FieldType.FLOAT_VECTOR, dimension=table_params.dimension | ||
) | ||
) | ||
|
||
indexes = [] | ||
indexes.append( | ||
VectorIndex( | ||
index_name=INDEX_VECTOR, | ||
index_type=index_type, | ||
field=FIELD_VECTOR, | ||
metric_type=metric_type, | ||
params=vector_params, | ||
) | ||
) | ||
|
||
schema = Schema(fields=fields, indexes=indexes) | ||
self.table = self.database.create_table( | ||
table_name=table_params.table_name, | ||
replication=table_params.replication, | ||
partition=Partition(partition_num=table_params.partition), | ||
schema=Schema(fields=fields, indexes=indexes), | ||
enable_dynamic_field=True, | ||
) | ||
# need wait 10s to wait proxy sync meta | ||
time.sleep(10) | ||
|
||
@staticmethod | ||
def _get_index_params(index_type: Any, table_params: TableParams) -> None: | ||
from pymochow.model.enum import IndexType | ||
from pymochow.model.schema import HNSWParams | ||
|
||
vector_params = ( | ||
{} if table_params.vector_params is None else table_params.vector_params | ||
) | ||
|
||
if index_type == IndexType.HNSW: | ||
return HNSWParams( | ||
m=vector_params.get("M", DEFAULT_HNSW_M), | ||
efconstruction=vector_params.get( | ||
"efConstruction", DEFAULT_HNSW_EF_CONSTRUCTION | ||
), | ||
) | ||
return None | ||
|
||
@staticmethod | ||
def _get_index_type(index_type_value: str) -> Any: | ||
from pymochow.model.enum import IndexType | ||
|
||
index_type_value = index_type_value or IndexType.HNSW | ||
try: | ||
return IndexType(index_type_value) | ||
except ValueError: | ||
support_index_types = [d.value for d in IndexType.__members__.values()] | ||
raise ValueError( | ||
NOT_SUPPORT_INDEX_TYPE_ERROR.format( | ||
index_type_value, support_index_types | ||
) | ||
) | ||
|
||
@staticmethod | ||
def _get_metric_type(metric_type_value: str) -> Any: | ||
from pymochow.model.enum import MetricType | ||
|
||
metric_type_value = metric_type_value or MetricType.L2 | ||
try: | ||
return MetricType(metric_type_value.upper()) | ||
except ValueError: | ||
support_metric_types = [d.value for d in MetricType.__members__.values()] | ||
raise ValueError( | ||
NOT_SUPPORT_METRIC_TYPE_ERROR.format( | ||
metric_type_value, support_metric_types | ||
) | ||
) | ||
|
||
@property | ||
def client(self) -> Any: | ||
"""Get client.""" | ||
return self.vdb_client | ||
|
||
def as_retriever(self): | ||
""" | ||
转化为retriever | ||
""" | ||
return BaiduVDBRetriever( | ||
embedding=self.embedding, | ||
table=self.table, | ||
) | ||
|
||
def add_segments(self, segments: Message, metadata=""): | ||
""" | ||
向bes中插入数据 | ||
参数: | ||
query (Message[str]): 需要插入的内容 | ||
返回: | ||
""" | ||
from pymochow.model.table import Row | ||
|
||
segment_vectors = self.embedding.batch(segments) | ||
segment_vectors = segment_vectors.content | ||
vector_dims = len(segment_vectors[0]) | ||
segments = segments.content | ||
|
||
rows = [] | ||
for segment, vector in zip(segments, segment_vectors): | ||
row = Row(text=segment, vector=vector, metadata=metadata) | ||
rows.append(row) | ||
if len(rows) >= DEFAULT_BATCH_SIZE: | ||
self.collection.upsert(rows=rows) | ||
rows = [] | ||
|
||
if len(rows) > 0: | ||
self.table.upsert(rows=rows) | ||
|
||
@classmethod | ||
def from_params( | ||
cls, | ||
cluster_id: str, | ||
api_key: str, | ||
account: str = DEFAULT_ACCOUNT, | ||
database_name: str = DEFAULT_DATABASE_NAME, | ||
table_name: str = DEFAULT_TABLE_NAME, | ||
drop_exists: bool = False, | ||
**kwargs, | ||
): | ||
_try_import() | ||
dimension = kwargs.get("dimension", 384) | ||
table_params = TableParams( | ||
dimension=dimension, | ||
table_name=table_name, | ||
drop_exists=drop_exists, | ||
) | ||
return cls( | ||
cluster_id=cluster_id, | ||
account=account, | ||
api_key=api_key, | ||
database_name=database_name, | ||
table_params=table_params, | ||
) | ||
|
||
|
||
class BaiduVDBRetriever(Component): | ||
""" | ||
向量检索组件,用于检索和query相匹配的内容 | ||
Examples: | ||
.. code-block:: python | ||
import appbuilder | ||
os.environ["APPBUILDER_TOKEN"] = '...' | ||
segments = appbuilder.Message(["文心一言大模型", "百度在线科技有限公司"]) | ||
vector_index = appbuilder.BaiduVDBVectorStoreIndex.from_params( | ||
self.cluster_id, | ||
self.api_key, | ||
) | ||
vector_index.add_segments(segments) | ||
query = appbuilder.Message("文心一言") | ||
time.sleep(5) | ||
retriever = vector_index.as_retriever() | ||
res = retriever(query) | ||
""" | ||
name: str = "BaiduVectorDBRetriever" | ||
tool_desc: Dict[str, Any] = {"description": "a retriever based on Baidu VectorDB"} | ||
|
||
def __init__(self, embedding, table): | ||
super().__init__() | ||
|
||
self.embedding = embedding | ||
self.table = table | ||
|
||
def run(self, query: Message, top_k: int = 1): | ||
""" | ||
根据query进行查询 | ||
参数: | ||
query (Message[str]): 需要查询的内容, | ||
top_k (bool): 查询结果中匹配度最高的top_k个结果 | ||
返回: | ||
obj (Message[Dict]): 查询到的结果,包含文本和匹配得分。 | ||
""" | ||
from pymochow.model.table import AnnSearch, HNSWSearchParams | ||
from pymochow.model.enum import ReadConsistency | ||
|
||
query_embedding = self.embedding(query) | ||
anns = AnnSearch( | ||
vector_field=FIELD_VECTOR, | ||
vector_floats=query_embedding.content, | ||
params=HNSWSearchParams(ef=10, limit=top_k), | ||
) | ||
res = self.table.search(anns=anns, read_consistency=ReadConsistency.STRONG) | ||
rows = res.rows | ||
docs = [] | ||
if rows is None or len(rows) == 0: | ||
return Message(docs) | ||
|
||
for row in rows: | ||
row_data = row.get("row", {}) | ||
docs.append({ | ||
"text": row_data.get(FIELD_TEXT), | ||
"meta": row_data.get(FIELD_METADATA), | ||
"score": row.get("score") | ||
}) | ||
|
||
return Message(docs) |