Skip to content

Commit

Permalink
JS Reranker (#539)
Browse files Browse the repository at this point in the history
* JS Reranker

* JS Reranker

---------

Co-authored-by: dailin01 <[email protected]>
  • Loading branch information
wangting829 and dailin01 authored May 23, 2024
1 parent 64fe7e6 commit c552fcc
Show file tree
Hide file tree
Showing 12 changed files with 257 additions and 7 deletions.
27 changes: 27 additions & 0 deletions javascript/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -454,3 +454,30 @@ async function yiYanChatFileMain() {
}
yiYanChatFileMain();
```

### Reranker 重排序

跨语种语义表征算法模型,擅长优化语义搜索结果和语义相关顺序精排,支持中英日韩四门语言。

```ts
// node环境
import {Reranker} from "@baiducloud/qianfan";
// 直接读取 env
const client = new Reranker();

// 手动传 AK/SK
// const client = new Reranker({ QIANFAN_AK: '***', QIANFAN_SK: '***'});

// 浏览器环境,必须传入QIANFAN_BASE_URL,(proxy启动后地址), QIANFAN_CONSOLE_API_BASE_URL不传时,只能使用预置模型,传入后可以使用动态模型
import {Reranker} from "@baiducloud/qianfan";
const client = Reranker({QIANFAN_BASE_URL: 'http://172.18.184.85:8002', QIANFAN_CONSOLE_API_BASE_URL: 'http://172.18.184.85:8003'});

async function main() {
const resp = await client.reranker({
query: '上海天气',
documents: ['上海气候', '北京美食'],
});
}

main();
```
2 changes: 1 addition & 1 deletion javascript/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@baiducloud/qianfan",
"version": " 0.0.11-alpha.0",
"version": " 0.0.11-alpha.1",
"publishConfig": {
"access": "public",
"registry": "https://registry.npmjs.org/"
Expand Down
4 changes: 2 additions & 2 deletions javascript/src/DynamicModelEndpoint/index.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import {Mutex} from 'async-mutex';
import Fetch from '../Fetch/fetch';
import HttpClient from '../HttpClient';
import {SERVER_LIST_API, DEFAULT_HEADERS} from '../constant';
import {SERVER_LIST_API, DEFAULT_HEADERS, DYNAMIC_INVALID} from '../constant';
import {getTypeMap, typeModelEndpointMap} from './utils';
import {getPath} from '../utils';

Expand Down Expand Up @@ -41,7 +41,7 @@ class DynamicModelEndpoint {
const mutex = new Mutex();
const release = await mutex.acquire(); // 等待获取互斥锁
try {
if (this.isDynamicMapExpired()) {
if (!DYNAMIC_INVALID.includes(type) && this.isDynamicMapExpired()) {
await this.updateDynamicModelEndpoint(type); // 等待动态更新完成
this.dynamicMapExpireAt = Date.now() / 1000 + this.DYNAMIC_MAP_REFRESH_INTERVAL;
}
Expand Down
6 changes: 6 additions & 0 deletions javascript/src/DynamicModelEndpoint/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,19 @@ const pluginEndpoints = new Map<string, string>([
['ebpluginv2', 'erniebot/plugin'],
]);

// 重新排序向量模型
const rerankerEndpoints = new Map<string, string>([
['bce-reranker-base_v1', 'bce_reranker_base'],
]);

// 将模型 endpoints 映射添加到主映射中
typeModelEndpointMap.set(ModelType.CHAT, chatModelEndpoints);
typeModelEndpointMap.set(ModelType.COMPLETIONS, completionsModelEndpoints);
typeModelEndpointMap.set(ModelType.EMBEDDINGS, embeddingEndpoints);
typeModelEndpointMap.set(ModelType.TEXT_2_IMAGE, text2imageEndpoints);
typeModelEndpointMap.set(ModelType.IMAGE_2_TEXT, image2textEndpoints);
typeModelEndpointMap.set(ModelType.PLUGIN, pluginEndpoints);
typeModelEndpointMap.set(ModelType.RERANKER, rerankerEndpoints);

export {typeModelEndpointMap};
// 检查CHAT是否有数据的函数
Expand Down
40 changes: 40 additions & 0 deletions javascript/src/Reranker/__tests__/index.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/* eslint-disable max-len */
// Copyright (c) 2024 Baidu, Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import {Reranker, setEnvVariable} from '../../index';

// 设置环境变量
setEnvVariable('QIANFAN_BASE_URL', 'http://127.0.0.1:8866');
setEnvVariable('QIANFAN_CONSOLE_API_BASE_URL', 'http://127.0.0.1:8866');
setEnvVariable('QIANFAN_ACCESS_KEY', '123');
setEnvVariable('QIANFAN_SECRET_KEY', '456');

describe('Reranker functionality', () => {
let client;

beforeEach(() => {
client = new Reranker();
jest.clearAllMocks();
});

it('should reorder documents by query', async () => {
const resp = await client.reranker({
query: '上海天气',
documents: ['上海气候', '北京美食'],
});
console.log(resp);
expect(resp).toBeDefined();
});
});
39 changes: 39 additions & 0 deletions javascript/src/Reranker/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (c) 2024 Baidu, Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import {BaseClient} from '../Base';
import {modelInfoMap} from './utilts';
import {getPathAndBody, getUpperCaseModelAndModelMap} from '../utils';
import {RerankerBody, RerankerResp} from '../interface';
import {ModelType} from '../enum';

class Reranker extends BaseClient {

public async reranker(body: RerankerBody, model = 'bce-reranker-base_v1'): Promise<RerankerResp> {
const {modelInfoMapUppercase, modelUppercase} = getUpperCaseModelAndModelMap(model, modelInfoMap);
const type = ModelType.RERANKER;
const {AKPath, requestBody} = getPathAndBody({
model: modelUppercase,
modelInfoMap: modelInfoMapUppercase,
baseUrl: this.qianfanBaseUrl,
body,
endpoint: this.Endpoint,
type,
});
const resp = await this.sendRequest(type, model, AKPath, requestBody);
return resp as RerankerResp;
}
}

export default Reranker;
32 changes: 32 additions & 0 deletions javascript/src/Reranker/utilts.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (c) 2024 Baidu, Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import {QfLLMInfoMap} from '../interface';

/**
* 重新排序向量模型
*/
export type RerankerModel =
| 'bce-reranker-base_v1';

export const modelInfoMap: QfLLMInfoMap = {
'bce-reranker-base_v1': {
endpoint: '/reranker/bce_reranker_base',
required_keys: ['query', 'documents'],
optional_keys: [
'user_id',
'top_n',
],
},
};
4 changes: 3 additions & 1 deletion javascript/src/constant.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,6 @@ export const RETRY_CODE = [
APIErrorCode.QPSLimitReached,
];

export const SERVER_LIST_API = '/wenxinworkshop/service/list';
export const SERVER_LIST_API = '/wenxinworkshop/service/list';

export const DYNAMIC_INVALID = ['reranker'];
1 change: 1 addition & 0 deletions javascript/src/enum.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ export enum ModelType {
TEXT_2_IMAGE = 'text2image',
IMAGE_2_TEXT = 'image2text',
PLUGIN = 'plugin',
RERANKER = 'reranker'
}

export enum APIErrorCode {
Expand Down
3 changes: 2 additions & 1 deletion javascript/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import Completions from './Completions';
import Embedding from './Embedding';
import Plugin from './Plugin';
import {Text2Image, Image2Text} from './Images';
import Reranker from './Reranker';
import {setEnvVariable} from './utils';

export {ChatCompletion, Completions, Embedding, Plugin, Text2Image, Image2Text, setEnvVariable};
export {ChatCompletion, Completions, Embedding, Plugin, Text2Image, Image2Text, Reranker, setEnvVariable};
67 changes: 65 additions & 2 deletions javascript/src/interface.ts
Original file line number Diff line number Diff line change
Expand Up @@ -624,8 +624,71 @@ export interface Image2TextBody extends baseReq{
stop?: string[];
}

export interface RerankerBody extends baseReq{
/**
* 查询文本
* 长度不超过1600个字符,token数若超过400做截断
*/
query: string;
/**
* 需要重排序的文本
*(1)不能为空List,List的每个成员不能为空字符串
*(2)文本数量不超过64
*(3)每条document文本长度不超过4096个字符,token数若超过1024做截断
*/
documents: string[];
/**
* 返回的最相关文本的数量
* 默认为document的数量
*/
top_n?: number;
}

type RerankerData = {
/**
* 文本内容
*/
document: string;
/**
* 相似性得分
*/
relevance_score: number;
/**
* 序号
*/
index: number;
}
type UsageType = {
/**
* 问题tokens数(包含历史QA)
*/
prompt_tokens: number;
/**
* tokens总数
*/
total_tokens: number;
}

export interface RerankerResp {
/**
* 本轮对话的id
*/
id: string;
/**
* 回包类型, 固定值“rerank_list”
*/
object: string;
/**
* 时间戳
*/
created: number;
/**
* 重排序结果,按相似性得分倒序
*/
results: RerankerData[];
usage: UsageType;
}

export type ReqBody = ChatBody | CompletionBody | EmbeddingBody | PluginsBody | Text2ImageBody;
export type Resp = RespBase | ChatResp | EmbeddingResp | PluginsResp | Text2ImageResp;
export type ReqBody = ChatBody | CompletionBody | EmbeddingBody | PluginsBody | Text2ImageBody | RerankerBody;
export type Resp = RespBase | ChatResp | EmbeddingResp | PluginsResp | Text2ImageResp | RerankerResp;
export type AsyncIterableType = AsyncIterable<ChatResp | RespBase | PluginsResp>;
39 changes: 39 additions & 0 deletions javascript/src/test/reranker.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (c) 2024 Baidu, Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import {Reranker, setEnvVariable} from '../index';

// 修改env文件
// setEnvVariable('QIANFAN_AK','***');
// setEnvVariable('QIANFAN_SK','***');

// 直接读取env
const client = new Reranker();

// 手动传AK/SK 测试
// const client = new Embedding({ QIANFAN_AK: '***', QIANFAN_SK: '***'});
// 手动传ACCESS_KEY/ SECRET_KEY测试
// const client = new Embedding({ QIANFAN_ACCESS_KEY: '***', QIANFAN_SECRET_KEY: '***' });

// AK/SK 测试
async function main() {
const resp = await client.reranker({
query: '上海天气',
documents: ['上海气候', '北京美食'],
});
console.log('返回结果');
console.log(resp);
}

main();

0 comments on commit c552fcc

Please sign in to comment.