From a0a41ff13b4232a75316e6875e08efc0dec480c5 Mon Sep 17 00:00:00 2001 From: fengjial Date: Wed, 13 Mar 2024 14:47:57 +0800 Subject: [PATCH] add baidu vdb as retriever --- .../retriever/baiduvdb_retriever.py | 397 ++++++++++++++++++ 1 file changed, 397 insertions(+) create mode 100644 appbuilder/core/components/retriever/baiduvdb_retriever.py diff --git a/appbuilder/core/components/retriever/baiduvdb_retriever.py b/appbuilder/core/components/retriever/baiduvdb_retriever.py new file mode 100644 index 000000000..1b55acfc4 --- /dev/null +++ b/appbuilder/core/components/retriever/baiduvdb_retriever.py @@ -0,0 +1,397 @@ +# Copyright (c) 2023 Baidu, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# -*- coding: utf-8 -*- +""" +基于Baidu VDB的retriever +""" +import importlib +import os +import random +import string +import time +from typing import Dict, Any +from appbuilder.core.component import Component, Message +from appbuilder.core.components.embeddings.component import Embedding +from appbuilder.core.constants import GATEWAY_URL +from appbuilder.utils.logger_util import logger + +DEFAULT_ACCOUNT = "root" +DEFAULT_DATABASE_NAME = "AppBuilderDatabase" +DEFAULT_TABLE_NAME = "AppBuilderTable" +DEFAULT_TIMEOUT_IN_MILLS: int = 30 * 1000 + +DEFAULT_PARTITION = 1 +DEFAULT_REPLICA = 3 +DEFAULT_INDEX_TYPE = "HNSW" +DEFAULT_METRIC_TYPE = "L2" + +DEFAULT_HNSW_M = 16 +DEFAULT_HNSW_EF_CONSTRUCTION = 200 +DEFAULT_HNSW_EF = 10 + +DEFAULT_BATCH_SIZE = 1000 + +FIELD_ID: str = "id" +FIELD_TEXT: str = "text" +FIELD_VECTOR: str = "vector" +FIELD_METADATA: str = "metadata" +INDEX_VECTOR: str = "vector_idx" + +def _try_import() -> None: + try: + import pymochow + except ImportError: + raise ImportError( + "`pymochow` package not found, please run `pip install pymochow`" + ) + +class TableParams: + """Baidu VectorDB table params. + See the following documentation for details: + https://cloud.baidu.com/doc/VDB/s/mlrsob0p6 + Args: + dimension int: The dimension of vector. + replication int: The number of replicas in the table. + partition int: The number of partitions in the table. + index_type (Optional[str]): HNSW, FLAT... Default value is "HNSW" + metric_type (Optional[str]): L2, COSINE, IP. Default value is "L2" + drop_exists (Optional[bool]): Delete the existing Table. Default value is False. + vector_params (Optional[Dict]): + if HNSW set parameters: `M` and `efConstruction`, for example `{'M': 16, efConstruction: 200}` + default is HNSW + """ + + def __init__( + self, + dimension: int, + table_name: str = DEFAULT_TABLE_NAME, + replication: int = DEFAULT_REPLICA, + partition: int = DEFAULT_PARTITION, + index_type: str = DEFAULT_INDEX_TYPE, + metric_type: str = DEFAULT_METRIC_TYPE, + drop_exists: bool = False, + vector_params: Dict = None, + ): + self.dimension = dimension + self.table_name = table_name + self.replication = replication + self.partition = partition + self.index_type = index_type + self.metric_type = metric_type + self.drop_exists = drop_exists + self.vector_params = vector_params + +class BaiduVDBVectorStoreIndex: + """ + Baidu VDB向量存储检索工具 + """ + base_vdb_url: str = "/v1/bce/vdb/cluster/" + + def __init__( + self, + cluster_id, + api_key: str, + account: str = DEFAULT_ACCOUNT, + database_name: str = DEFAULT_DATABASE_NAME, + table_params: TableParams = TableParams(dimension=384), + embedding=None, + prefix="/rpc/2.0/cloud_hub" + ): + + if embedding is None: + embedding = Embedding() + + self.embedding = embedding + self.prefix = prefix + + self._init_client(cluster_id, account, api_key) + self._create_database_if_not_exists(database_name) + self._create_table(table_params) + + def _init_client(self, cluster_id, account, api_key): + """ + 创建一个vdb的client + """ + import pymochow + from pymochow.configuration import Configuration + from pymochow.auth.bce_credentials import BceCredentials + + gateway = os.getenv("GATEWAY_URL") if os.getenv("GATEWAY_URL") else GATEWAY_URL + endpoint = gateway + self.prefix + self.base_vdb_url + cluster_id + + config = Configuration( + credentials=BceCredentials(account, api_key), + endpoint=endpoint, + connection_timeout_in_mills=DEFAULT_TIMEOUT_IN_MILLS, + ) + self.vdb_client = pymochow.MochowClient(config) + + def _create_database_if_not_exists(self, database_name: str) -> None: + db_list = self.vdb_client.list_databases() + + if database_name in [db.database_name for db in db_list]: + self.database = self.vdb_client.database(database_name) + else: + self.database = self.vdb_client.create_database(database_name) + + def _create_table(self, table_params: TableParams) -> None: + import pymochow + + if table_params is None: + raise ValueError(VALUE_NONE_ERROR.format("table_params")) + + try: + self.table = self.database.describe_table(table_params.table_name) + if table_params.drop_exists: + self.database.drop_table(table_params.table_name) + # wait db release resource + time.sleep(5) + self._create_table_in_db(table_params) + except pymochow.exception.ServerError: + self._create_table_in_db(table_params) + + def _create_table_in_db( + self, + table_params: TableParams, + ) -> None: + from pymochow.model.enum import FieldType + from pymochow.model.schema import Field, Schema, SecondaryIndex, VectorIndex + from pymochow.model.table import Partition + + index_type = self._get_index_type(table_params.index_type) + metric_type = self._get_metric_type(table_params.metric_type) + vector_params = self._get_index_params(index_type, table_params) + fields = [] + fields.append( + Field( + FIELD_ID, + FieldType.UINT64, + primary_key=True, + partition_key=True, + auto_increment=True, + not_null=True, + ) + ) + fields.append(Field(FIELD_METADATA, FieldType.STRING)) + fields.append(Field(FIELD_TEXT, FieldType.STRING)) + fields.append( + Field( + FIELD_VECTOR, FieldType.FLOAT_VECTOR, dimension=table_params.dimension + ) + ) + + indexes = [] + indexes.append( + VectorIndex( + index_name=INDEX_VECTOR, + index_type=index_type, + field=FIELD_VECTOR, + metric_type=metric_type, + params=vector_params, + ) + ) + + schema = Schema(fields=fields, indexes=indexes) + self.table = self.database.create_table( + table_name=table_params.table_name, + replication=table_params.replication, + partition=Partition(partition_num=table_params.partition), + schema=Schema(fields=fields, indexes=indexes), + enable_dynamic_field=True, + ) + # need wait 10s to wait proxy sync meta + time.sleep(10) + + @staticmethod + def _get_index_params(index_type: Any, table_params: TableParams) -> None: + from pymochow.model.enum import IndexType + from pymochow.model.schema import HNSWParams + + vector_params = ( + {} if table_params.vector_params is None else table_params.vector_params + ) + + if index_type == IndexType.HNSW: + return HNSWParams( + m=vector_params.get("M", DEFAULT_HNSW_M), + efconstruction=vector_params.get( + "efConstruction", DEFAULT_HNSW_EF_CONSTRUCTION + ), + ) + return None + + @staticmethod + def _get_index_type(index_type_value: str) -> Any: + from pymochow.model.enum import IndexType + + index_type_value = index_type_value or IndexType.HNSW + try: + return IndexType(index_type_value) + except ValueError: + support_index_types = [d.value for d in IndexType.__members__.values()] + raise ValueError( + NOT_SUPPORT_INDEX_TYPE_ERROR.format( + index_type_value, support_index_types + ) + ) + + @staticmethod + def _get_metric_type(metric_type_value: str) -> Any: + from pymochow.model.enum import MetricType + + metric_type_value = metric_type_value or MetricType.L2 + try: + return MetricType(metric_type_value.upper()) + except ValueError: + support_metric_types = [d.value for d in MetricType.__members__.values()] + raise ValueError( + NOT_SUPPORT_METRIC_TYPE_ERROR.format( + metric_type_value, support_metric_types + ) + ) + + @property + def client(self) -> Any: + """Get client.""" + return self.vdb_client + + def as_retriever(self): + """ + 转化为retriever + """ + return BaiduVDBRetriever( + embedding=self.embedding, + table=self.table, + ) + + def add_segments(self, segments: Message, metadata=""): + """ + 向bes中插入数据 + 参数: + query (Message[str]): 需要插入的内容 + 返回: + """ + from pymochow.model.table import Row + + segment_vectors = self.embedding.batch(segments) + segment_vectors = segment_vectors.content + vector_dims = len(segment_vectors[0]) + segments = segments.content + + rows = [] + for segment, vector in zip(segments, segment_vectors): + row = Row(text=segment, vector=vector, metadata=metadata) + rows.append(row) + if len(rows) >= DEFAULT_BATCH_SIZE: + self.collection.upsert(rows=rows) + rows = [] + + if len(rows) > 0: + self.table.upsert(rows=rows) + + @classmethod + def from_params( + cls, + cluster_id: str, + api_key: str, + account: str = DEFAULT_ACCOUNT, + database_name: str = DEFAULT_DATABASE_NAME, + table_name: str = DEFAULT_TABLE_NAME, + drop_exists: bool = False, + **kwargs, + ): + _try_import() + dimension = kwargs.get("dimension", 384) + table_params = TableParams( + dimension=dimension, + table_name=table_name, + drop_exists=drop_exists, + ) + return cls( + cluster_id=cluster_id, + account=account, + api_key=api_key, + database_name=database_name, + table_params=table_params, + ) + + +class BaiduVDBRetriever(Component): + """ + 向量检索组件,用于检索和query相匹配的内容 + + Examples: + + .. code-block:: python + + import appbuilder + os.environ["APPBUILDER_TOKEN"] = '...' + + segments = appbuilder.Message(["文心一言大模型", "百度在线科技有限公司"]) + vector_index = appbuilder.BaiduVDBVectorStoreIndex.from_params( + self.cluster_id, + self.api_key, + ) + vector_index.add_segments(segments) + + query = appbuilder.Message("文心一言") + time.sleep(5) + retriever = vector_index.as_retriever() + res = retriever(query) + + """ + name: str = "BaiduVectorDBRetriever" + tool_desc: Dict[str, Any] = {"description": "a retriever based on Baidu VectorDB"} + + def __init__(self, embedding, table): + super().__init__() + + self.embedding = embedding + self.table = table + + def run(self, query: Message, top_k: int = 1): + """ + 根据query进行查询 + 参数: + query (Message[str]): 需要查询的内容, + top_k (bool): 查询结果中匹配度最高的top_k个结果 + 返回: + obj (Message[Dict]): 查询到的结果,包含文本和匹配得分。 + """ + from pymochow.model.table import AnnSearch, HNSWSearchParams + from pymochow.model.enum import ReadConsistency + + query_embedding = self.embedding(query) + anns = AnnSearch( + vector_field=FIELD_VECTOR, + vector_floats=query_embedding.content, + params=HNSWSearchParams(ef=10, limit=top_k), + ) + res = self.table.search(anns=anns, read_consistency=ReadConsistency.STRONG) + rows = res.rows + docs = [] + if rows is None or len(rows) == 0: + return Message(docs) + + for row in rows: + row_data = row.get("row", {}) + docs.append({ + "text": row_data.get(FIELD_TEXT), + "meta": row_data.get(FIELD_METADATA), + "score": row.get("score") + }) + + return Message(docs)