Skip to content

Commit

Permalink
update ut & readme
Browse files Browse the repository at this point in the history
  • Loading branch information
MrChengmo committed Mar 20, 2024
1 parent f6a5bea commit 9ae38bc
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 5 deletions.
4 changes: 2 additions & 2 deletions appbuilder/core/components/retriever/baidu_vdb/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ os.environ["APPBUILDER_TOKEN"] = "bce-YOURTOKEN"

| 参数名称 | 参数类型 |是否必须 | 描述 | 示例值 |
|---------|--------|--------|------------------|---------------|
| message | String || 需要检索的内容 | "中国2023人均GDP" |
| top_k | int || 返回相似度最高的top_k个内容 | 1 |
| message | String || 需要检索的内容, 类型为Message,content类型为str, 长度要求(0,1000) | "中国2023人均GDP" |
| top_k | int || 返回相似度最高的top_k个内容,top_k的数值范围(1,embedding索引数量] | 1 |


### 响应参数
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def from_params(
):
"""
从参数中实例化类。
Args:
cls: 类对象,即当前函数所属的类。
instance_id: str,实例ID。
Expand All @@ -369,13 +369,33 @@ def from_params(
table_name: str,表名,默认为AppBuilderTable。
drop_exists: bool,是否删除已存在的表,默认为False。
**kwargs: 其他参数,可选的维度参数dimension默认为384。
Returns:
类实例,包含实例ID、账户名、API密钥、数据库名、表参数等属性。
"""
_try_import()
dimension = kwargs.get("dimension", 384)

if not isinstance(instance_id, str):
raise TypeError("instance_id must be a string. but got {}".format(
type(instance_id)))
if not isinstance(api_key, str):
raise TypeError("api_key must be a string. but got {}".format(
type(api_key)))
if not isinstance(account, str):
raise TypeError("account must be a string. but got {}".format(
type(account)))
if not isinstance(database_name, str):
raise TypeError("database_name must be a string. but got {}".format(
type(database_name)))
if not isinstance(table_name, str):
raise TypeError("table_name must be a string. but got {}".format(
type(table_name)))
if not isinstance(drop_exists, bool):
raise TypeError("drop_exists must be a boolean. but got {}".format(
type(drop_exists)))

table_params = TableParams(
dimension=dimension,
table_name=table_name,
Expand Down Expand Up @@ -447,8 +467,13 @@ def run(self, query: Message, top_k: int = 1):
.format(top_k))

content = query.content
if not isinstance(content, str):
raise ValueError("Parameter `query` content is not a string, got: {}"
.format(type(content)))
if len(content) == 0:
raise ValueError("Parameter `query` content is empty")
if len(content) > 1000:
raise ValueError("Parameter `query` content is too long, max length per batch size is 1000")

query_embedding = self.embedding(query)
anns = AnnSearch(
Expand Down
20 changes: 20 additions & 0 deletions appbuilder/tests/test_vdb_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,26 @@ def test_run_parameter_query(self):
retriever.run(query)
self.assertIn("Parameter `query` content is empty", str(context.exception))

def test_run_paramter_query_type(self):
query = appbuilder.Message(content=12345)

retriever = appbuilder.BaiduVDBRetriever(
embedding="abcde",
table="abcde")

with self.assertRaises(ValueError) as context:
retriever.run(query)
self.assertIn("Parameter `query` content is not a string", str(context.exception))

def test_run_parameter_query_length(self):
query = appbuilder.Message(content="a" * 1025)
retriever = appbuilder.BaiduVDBRetriever(
embedding="abcde",
table="abcde")
with self.assertRaises(ValueError) as context:
retriever.run(query)
self.assertIn("Parameter `query` content is too long", str(context.exception))

def test_run_parameter_topk_positive(self):
query = appbuilder.Message()
retriever = appbuilder.BaiduVDBRetriever(
Expand Down

0 comments on commit 9ae38bc

Please sign in to comment.