Skip to content

Commit

Permalink
add baidu vdb as retriever
Browse files Browse the repository at this point in the history
  • Loading branch information
fengjial committed Mar 13, 2024
1 parent 6e3cbcc commit a0a41ff
Showing 1 changed file with 397 additions and 0 deletions.
397 changes: 397 additions & 0 deletions appbuilder/core/components/retriever/baiduvdb_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,397 @@
# Copyright (c) 2023 Baidu, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


# -*- coding: utf-8 -*-
"""
基于Baidu VDB的retriever
"""
import importlib
import os
import random
import string
import time
from typing import Dict, Any
from appbuilder.core.component import Component, Message
from appbuilder.core.components.embeddings.component import Embedding
from appbuilder.core.constants import GATEWAY_URL
from appbuilder.utils.logger_util import logger

DEFAULT_ACCOUNT = "root"
DEFAULT_DATABASE_NAME = "AppBuilderDatabase"
DEFAULT_TABLE_NAME = "AppBuilderTable"
DEFAULT_TIMEOUT_IN_MILLS: int = 30 * 1000

DEFAULT_PARTITION = 1
DEFAULT_REPLICA = 3
DEFAULT_INDEX_TYPE = "HNSW"
DEFAULT_METRIC_TYPE = "L2"

DEFAULT_HNSW_M = 16
DEFAULT_HNSW_EF_CONSTRUCTION = 200
DEFAULT_HNSW_EF = 10

DEFAULT_BATCH_SIZE = 1000

FIELD_ID: str = "id"
FIELD_TEXT: str = "text"
FIELD_VECTOR: str = "vector"
FIELD_METADATA: str = "metadata"
INDEX_VECTOR: str = "vector_idx"

def _try_import() -> None:
try:
import pymochow
except ImportError:
raise ImportError(
"`pymochow` package not found, please run `pip install pymochow`"
)

class TableParams:
"""Baidu VectorDB table params.
See the following documentation for details:
https://cloud.baidu.com/doc/VDB/s/mlrsob0p6
Args:
dimension int: The dimension of vector.
replication int: The number of replicas in the table.
partition int: The number of partitions in the table.
index_type (Optional[str]): HNSW, FLAT... Default value is "HNSW"
metric_type (Optional[str]): L2, COSINE, IP. Default value is "L2"
drop_exists (Optional[bool]): Delete the existing Table. Default value is False.
vector_params (Optional[Dict]):
if HNSW set parameters: `M` and `efConstruction`, for example `{'M': 16, efConstruction: 200}`
default is HNSW
"""

def __init__(
self,
dimension: int,
table_name: str = DEFAULT_TABLE_NAME,
replication: int = DEFAULT_REPLICA,
partition: int = DEFAULT_PARTITION,
index_type: str = DEFAULT_INDEX_TYPE,
metric_type: str = DEFAULT_METRIC_TYPE,
drop_exists: bool = False,
vector_params: Dict = None,
):
self.dimension = dimension
self.table_name = table_name
self.replication = replication
self.partition = partition
self.index_type = index_type
self.metric_type = metric_type
self.drop_exists = drop_exists
self.vector_params = vector_params

class BaiduVDBVectorStoreIndex:
"""
Baidu VDB向量存储检索工具
"""
base_vdb_url: str = "/v1/bce/vdb/cluster/"

def __init__(
self,
cluster_id,
api_key: str,
account: str = DEFAULT_ACCOUNT,
database_name: str = DEFAULT_DATABASE_NAME,
table_params: TableParams = TableParams(dimension=384),
embedding=None,
prefix="/rpc/2.0/cloud_hub"
):

if embedding is None:
embedding = Embedding()

self.embedding = embedding
self.prefix = prefix

self._init_client(cluster_id, account, api_key)
self._create_database_if_not_exists(database_name)
self._create_table(table_params)

def _init_client(self, cluster_id, account, api_key):
"""
创建一个vdb的client
"""
import pymochow
from pymochow.configuration import Configuration
from pymochow.auth.bce_credentials import BceCredentials

gateway = os.getenv("GATEWAY_URL") if os.getenv("GATEWAY_URL") else GATEWAY_URL
endpoint = gateway + self.prefix + self.base_vdb_url + cluster_id

config = Configuration(
credentials=BceCredentials(account, api_key),
endpoint=endpoint,
connection_timeout_in_mills=DEFAULT_TIMEOUT_IN_MILLS,
)
self.vdb_client = pymochow.MochowClient(config)

def _create_database_if_not_exists(self, database_name: str) -> None:
db_list = self.vdb_client.list_databases()

if database_name in [db.database_name for db in db_list]:
self.database = self.vdb_client.database(database_name)
else:
self.database = self.vdb_client.create_database(database_name)

def _create_table(self, table_params: TableParams) -> None:
import pymochow

if table_params is None:
raise ValueError(VALUE_NONE_ERROR.format("table_params"))

try:
self.table = self.database.describe_table(table_params.table_name)
if table_params.drop_exists:
self.database.drop_table(table_params.table_name)
# wait db release resource
time.sleep(5)
self._create_table_in_db(table_params)
except pymochow.exception.ServerError:
self._create_table_in_db(table_params)

def _create_table_in_db(
self,
table_params: TableParams,
) -> None:
from pymochow.model.enum import FieldType
from pymochow.model.schema import Field, Schema, SecondaryIndex, VectorIndex
from pymochow.model.table import Partition

index_type = self._get_index_type(table_params.index_type)
metric_type = self._get_metric_type(table_params.metric_type)
vector_params = self._get_index_params(index_type, table_params)
fields = []
fields.append(
Field(
FIELD_ID,
FieldType.UINT64,
primary_key=True,
partition_key=True,
auto_increment=True,
not_null=True,
)
)
fields.append(Field(FIELD_METADATA, FieldType.STRING))
fields.append(Field(FIELD_TEXT, FieldType.STRING))
fields.append(
Field(
FIELD_VECTOR, FieldType.FLOAT_VECTOR, dimension=table_params.dimension
)
)

indexes = []
indexes.append(
VectorIndex(
index_name=INDEX_VECTOR,
index_type=index_type,
field=FIELD_VECTOR,
metric_type=metric_type,
params=vector_params,
)
)

schema = Schema(fields=fields, indexes=indexes)
self.table = self.database.create_table(
table_name=table_params.table_name,
replication=table_params.replication,
partition=Partition(partition_num=table_params.partition),
schema=Schema(fields=fields, indexes=indexes),
enable_dynamic_field=True,
)
# need wait 10s to wait proxy sync meta
time.sleep(10)

@staticmethod
def _get_index_params(index_type: Any, table_params: TableParams) -> None:
from pymochow.model.enum import IndexType
from pymochow.model.schema import HNSWParams

vector_params = (
{} if table_params.vector_params is None else table_params.vector_params
)

if index_type == IndexType.HNSW:
return HNSWParams(
m=vector_params.get("M", DEFAULT_HNSW_M),
efconstruction=vector_params.get(
"efConstruction", DEFAULT_HNSW_EF_CONSTRUCTION
),
)
return None

@staticmethod
def _get_index_type(index_type_value: str) -> Any:
from pymochow.model.enum import IndexType

index_type_value = index_type_value or IndexType.HNSW
try:
return IndexType(index_type_value)
except ValueError:
support_index_types = [d.value for d in IndexType.__members__.values()]
raise ValueError(
NOT_SUPPORT_INDEX_TYPE_ERROR.format(
index_type_value, support_index_types
)
)

@staticmethod
def _get_metric_type(metric_type_value: str) -> Any:
from pymochow.model.enum import MetricType

metric_type_value = metric_type_value or MetricType.L2
try:
return MetricType(metric_type_value.upper())
except ValueError:
support_metric_types = [d.value for d in MetricType.__members__.values()]
raise ValueError(
NOT_SUPPORT_METRIC_TYPE_ERROR.format(
metric_type_value, support_metric_types
)
)

@property
def client(self) -> Any:
"""Get client."""
return self.vdb_client

def as_retriever(self):
"""
转化为retriever
"""
return BaiduVDBRetriever(
embedding=self.embedding,
table=self.table,
)

def add_segments(self, segments: Message, metadata=""):
"""
向bes中插入数据
参数:
query (Message[str]): 需要插入的内容
返回:
"""
from pymochow.model.table import Row

segment_vectors = self.embedding.batch(segments)
segment_vectors = segment_vectors.content
vector_dims = len(segment_vectors[0])
segments = segments.content

rows = []
for segment, vector in zip(segments, segment_vectors):
row = Row(text=segment, vector=vector, metadata=metadata)
rows.append(row)
if len(rows) >= DEFAULT_BATCH_SIZE:
self.collection.upsert(rows=rows)
rows = []

if len(rows) > 0:
self.table.upsert(rows=rows)

@classmethod
def from_params(
cls,
cluster_id: str,
api_key: str,
account: str = DEFAULT_ACCOUNT,
database_name: str = DEFAULT_DATABASE_NAME,
table_name: str = DEFAULT_TABLE_NAME,
drop_exists: bool = False,
**kwargs,
):
_try_import()
dimension = kwargs.get("dimension", 384)
table_params = TableParams(
dimension=dimension,
table_name=table_name,
drop_exists=drop_exists,
)
return cls(
cluster_id=cluster_id,
account=account,
api_key=api_key,
database_name=database_name,
table_params=table_params,
)


class BaiduVDBRetriever(Component):
"""
向量检索组件,用于检索和query相匹配的内容
Examples:
.. code-block:: python
import appbuilder
os.environ["APPBUILDER_TOKEN"] = '...'
segments = appbuilder.Message(["文心一言大模型", "百度在线科技有限公司"])
vector_index = appbuilder.BaiduVDBVectorStoreIndex.from_params(
self.cluster_id,
self.api_key,
)
vector_index.add_segments(segments)
query = appbuilder.Message("文心一言")
time.sleep(5)
retriever = vector_index.as_retriever()
res = retriever(query)
"""
name: str = "BaiduVectorDBRetriever"
tool_desc: Dict[str, Any] = {"description": "a retriever based on Baidu VectorDB"}

def __init__(self, embedding, table):
super().__init__()

self.embedding = embedding
self.table = table

def run(self, query: Message, top_k: int = 1):
"""
根据query进行查询
参数:
query (Message[str]): 需要查询的内容,
top_k (bool): 查询结果中匹配度最高的top_k个结果
返回:
obj (Message[Dict]): 查询到的结果,包含文本和匹配得分。
"""
from pymochow.model.table import AnnSearch, HNSWSearchParams
from pymochow.model.enum import ReadConsistency

query_embedding = self.embedding(query)
anns = AnnSearch(
vector_field=FIELD_VECTOR,
vector_floats=query_embedding.content,
params=HNSWSearchParams(ef=10, limit=top_k),
)
res = self.table.search(anns=anns, read_consistency=ReadConsistency.STRONG)
rows = res.rows
docs = []
if rows is None or len(rows) == 0:
return Message(docs)

for row in rows:
row_data = row.get("row", {})
docs.append({
"text": row_data.get(FIELD_TEXT),
"meta": row_data.get(FIELD_METADATA),
"score": row.get("score")
})

return Message(docs)

0 comments on commit a0a41ff

Please sign in to comment.