Skip to content

Commit

Permalink
✨ feat: spport qwen-vl and tool call for qwen (lobehub#3114)
Browse files Browse the repository at this point in the history
* feat: spport qwen-vl and tool call for qwen

* fix: make transformResponseToStream a util for testability

* test: append unit test for non-streaming response

* test: update unit-test against LobeQwenAI models
  • Loading branch information
Mingholy authored Jul 17, 2024
1 parent 091898b commit 5216a85
Show file tree
Hide file tree
Showing 7 changed files with 763 additions and 87 deletions.
22 changes: 22 additions & 0 deletions src/config/modelProviders/qwen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ const Qwen: ModelProviderCard = {
description: '通义千问超大规模语言模型,支持中文、英文等不同语言输入',
displayName: 'Qwen Turbo',
enabled: true,
functionCall: true,
id: 'qwen-turbo',
tokens: 8000,
},
{
description: '通义千问超大规模语言模型增强版,支持中文、英文等不同语言输入',
displayName: 'Qwen Plus',
enabled: true,
functionCall: true,
id: 'qwen-plus',
tokens: 32_000,
},
Expand All @@ -22,13 +24,15 @@ const Qwen: ModelProviderCard = {
'通义千问千亿级别超大规模语言模型,支持中文、英文等不同语言输入,当前通义千问2.5产品版本背后的API模型',
displayName: 'Qwen Max',
enabled: true,
functionCall: true,
id: 'qwen-max',
tokens: 8000,
},
{
description:
'通义千问千亿级别超大规模语言模型,支持中文、英文等不同语言输入,扩展了上下文窗口',
displayName: 'Qwen Max LongContext',
functionCall: true,
id: 'qwen-max-longcontext',
tokens: 30_000,
},
Expand All @@ -50,6 +54,24 @@ const Qwen: ModelProviderCard = {
id: 'qwen2-72b-instruct',
tokens: 131_072,
},
{
description:
'通义千问大规模视觉语言模型增强版。大幅提升细节识别能力和文字识别能力,支持超百万像素分辨率和任意长宽比规格的图像。',
displayName: 'Qwen VL Plus',
enabled: true,
id: 'qwen-vl-plus',
tokens: 6144,
vision: true,
},
{
description:
'通义千问超大规模视觉语言模型。相比增强版,再次提升视觉推理能力和指令遵循能力,提供更高的视觉感知和认知水平。',
displayName: 'Qwen VL Max',
enabled: true,
id: 'qwen-vl-max',
tokens: 6144,
vision: true,
},
],
checkModel: 'qwen-turbo',
disableBrowserRequest: true,
Expand Down
118 changes: 114 additions & 4 deletions src/libs/agent-runtime/qwen/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import OpenAI from 'openai';
import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest';

import Qwen from '@/config/modelProviders/qwen';
import { LobeOpenAICompatibleRuntime } from '@/libs/agent-runtime';
import { ModelProvider } from '@/libs/agent-runtime';
import { AgentRuntimeErrorType } from '@/libs/agent-runtime';
Expand All @@ -17,7 +18,7 @@ const invalidErrorType = AgentRuntimeErrorType.InvalidProviderAPIKey;
// Mock the console.error to avoid polluting test output
vi.spyOn(console, 'error').mockImplementation(() => {});

let instance: LobeOpenAICompatibleRuntime;
let instance: LobeQwenAI;

beforeEach(() => {
instance = new LobeQwenAI({ apiKey: 'test' });
Expand All @@ -41,7 +42,116 @@ describe('LobeQwenAI', () => {
});
});

describe('models', () => {
it('should correctly list available models', async () => {
const instance = new LobeQwenAI({ apiKey: 'test_api_key' });
vi.spyOn(instance, 'models').mockResolvedValue(Qwen.chatModels);

const models = await instance.models();
expect(models).toEqual(Qwen.chatModels);
});
});

describe('chat', () => {
describe('Params', () => {
it('should call llms with proper options', async () => {
const mockStream = new ReadableStream();
const mockResponse = Promise.resolve(mockStream);

(instance['client'].chat.completions.create as Mock).mockResolvedValue(mockResponse);

const result = await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'qwen-turbo',
temperature: 0.6,
top_p: 0.7,
});

// Assert
expect(instance['client'].chat.completions.create).toHaveBeenCalledWith(
{
messages: [{ content: 'Hello', role: 'user' }],
model: 'qwen-turbo',
temperature: 0.6,
stream: true,
top_p: 0.7,
result_format: 'message',
},
{ headers: { Accept: '*/*' } },
);
expect(result).toBeInstanceOf(Response);
});

it('should call vlms with proper options', async () => {
const mockStream = new ReadableStream();
const mockResponse = Promise.resolve(mockStream);

(instance['client'].chat.completions.create as Mock).mockResolvedValue(mockResponse);

const result = await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'qwen-vl-plus',
temperature: 0.6,
top_p: 0.7,
});

// Assert
expect(instance['client'].chat.completions.create).toHaveBeenCalledWith(
{
messages: [{ content: 'Hello', role: 'user' }],
model: 'qwen-vl-plus',
stream: true,
},
{ headers: { Accept: '*/*' } },
);
expect(result).toBeInstanceOf(Response);
});

it('should transform non-streaming response to stream correctly', async () => {
const mockResponse: OpenAI.ChatCompletion = {
id: 'chatcmpl-fc539f49-51a8-94be-8061',
object: 'chat.completion',
created: 1719901794,
model: 'qwen-turbo',
choices: [
{
index: 0,
message: { role: 'assistant', content: 'Hello' },
finish_reason: 'stop',
logprobs: null,
},
],
};
vi.spyOn(instance['client'].chat.completions, 'create').mockResolvedValue(
mockResponse as any,
);

const result = await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'qwen-turbo',
temperature: 0.6,
stream: false,
});

const decoder = new TextDecoder();

const reader = result.body!.getReader();
expect(decoder.decode((await reader.read()).value)).toContain(
'id: chatcmpl-fc539f49-51a8-94be-8061\n',
);
expect(decoder.decode((await reader.read()).value)).toContain('event: text\n');
expect(decoder.decode((await reader.read()).value)).toContain('data: "Hello"\n\n');

expect(decoder.decode((await reader.read()).value)).toContain(
'id: chatcmpl-fc539f49-51a8-94be-8061\n',
);
expect(decoder.decode((await reader.read()).value)).toContain('event: stop\n');
expect(decoder.decode((await reader.read()).value)).toContain('');

expect((await reader.read()).done).toBe(true);
});
});

describe('Error', () => {
it('should return QwenBizError with an openai error response when OpenAI.APIError is thrown', async () => {
// Arrange
Expand Down Expand Up @@ -129,8 +239,7 @@ describe('LobeQwenAI', () => {

instance = new LobeQwenAI({
apiKey: 'test',

baseURL: 'https://api.abc.com/v1',
baseURL: defaultBaseURL,
});

vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError);
Expand All @@ -144,7 +253,8 @@ describe('LobeQwenAI', () => {
});
} catch (e) {
expect(e).toEqual({
endpoint: 'https://api.***.com/v1',
/* Desensitizing is unnecessary for a public-accessible gateway endpoint. */
endpoint: defaultBaseURL,
error: {
cause: { message: 'api is undefined' },
stack: 'abc',
Expand Down
156 changes: 128 additions & 28 deletions src/libs/agent-runtime/qwen/index.ts
Original file line number Diff line number Diff line change
@@ -1,28 +1,128 @@
import OpenAI from 'openai';

import { ModelProvider } from '../types';
import { LobeOpenAICompatibleFactory } from '../utils/openaiCompatibleFactory';

export const LobeQwenAI = LobeOpenAICompatibleFactory({
baseURL: 'https://dashscope.aliyuncs.com/compatible-mode/v1',
chatCompletion: {
handlePayload: (payload) => {
const top_p = payload.top_p;
return {
...payload,
stream: payload.stream ?? true,
top_p: top_p && top_p >= 1 ? 0.9999 : top_p,
} as OpenAI.ChatCompletionCreateParamsStreaming;
},
},
constructorOptions: {
defaultHeaders: {
'Content-Type': 'application/json',
},
},
debug: {
chatCompletion: () => process.env.DEBUG_QWEN_CHAT_COMPLETION === '1',
},

provider: ModelProvider.Qwen,
});
import { omit } from 'lodash-es';
import OpenAI, { ClientOptions } from 'openai';

import Qwen from '@/config/modelProviders/qwen';

import { LobeOpenAICompatibleRuntime, LobeRuntimeAI } from '../BaseAI';
import { AgentRuntimeErrorType } from '../error';
import { ChatCompetitionOptions, ChatStreamPayload, ModelProvider } from '../types';
import { AgentRuntimeError } from '../utils/createError';
import { debugStream } from '../utils/debugStream';
import { handleOpenAIError } from '../utils/handleOpenAIError';
import { transformResponseToStream } from '../utils/openaiCompatibleFactory';
import { StreamingResponse } from '../utils/response';
import { QwenAIStream } from '../utils/streams';

const DEFAULT_BASE_URL = 'https://dashscope.aliyuncs.com/compatible-mode/v1';

/**
* Use DashScope OpenAI compatible mode for now.
* DashScope OpenAI [compatible mode](https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-qianwen-vl-plus-api) currently supports base64 image input for vision models e.g. qwen-vl-plus.
* You can use images input either:
* 1. Use qwen-vl-* out of box with base64 image_url input;
* or
* 2. Set S3-* enviroment variables properly to store all uploaded files.
*/
export class LobeQwenAI extends LobeOpenAICompatibleRuntime implements LobeRuntimeAI {
client: OpenAI;
baseURL: string;

constructor({
apiKey,
baseURL = DEFAULT_BASE_URL,
...res
}: ClientOptions & Record<string, any> = {}) {
super();
if (!apiKey) throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidProviderAPIKey);
this.client = new OpenAI({ apiKey, baseURL, ...res });
this.baseURL = this.client.baseURL;
}

async models() {
return Qwen.chatModels;
}

async chat(payload: ChatStreamPayload, options?: ChatCompetitionOptions) {
try {
const params = this.buildCompletionParamsByModel(payload);

const response = await this.client.chat.completions.create(
params as OpenAI.ChatCompletionCreateParamsStreaming & { result_format: string },
{
headers: { Accept: '*/*' },
signal: options?.signal,
},
);

if (params.stream) {
const [prod, debug] = response.tee();

if (process.env.DEBUG_QWEN_CHAT_COMPLETION === '1') {
debugStream(debug.toReadableStream()).catch(console.error);
}

return StreamingResponse(QwenAIStream(prod, options?.callback), {
headers: options?.headers,
});
}

const stream = transformResponseToStream(response as unknown as OpenAI.ChatCompletion);

return StreamingResponse(QwenAIStream(stream, options?.callback), {
headers: options?.headers,
});
} catch (error) {
if ('status' in (error as any)) {
switch ((error as Response).status) {
case 401: {
throw AgentRuntimeError.chat({
endpoint: this.baseURL,
error: error as any,
errorType: AgentRuntimeErrorType.InvalidProviderAPIKey,
provider: ModelProvider.Qwen,
});
}

default: {
break;
}
}
}
const { errorResult, RuntimeError } = handleOpenAIError(error);
const errorType = RuntimeError || AgentRuntimeErrorType.ProviderBizError;

throw AgentRuntimeError.chat({
endpoint: this.baseURL,
error: errorResult,
errorType,
provider: ModelProvider.Qwen,
});
}
}

private buildCompletionParamsByModel(payload: ChatStreamPayload) {
const { model, top_p, stream, messages, tools } = payload;
const isVisionModel = model.startsWith('qwen-vl');

const params = {
...payload,
messages,
result_format: 'message',
stream: !!tools?.length ? false : stream ?? true,
top_p: top_p && top_p >= 1 ? 0.999 : top_p,
};

/* Qwen-vl models temporarily do not support parameters below. */
/* Notice: `top_p` imposes significant impact on the result,the default 1 or 0.999 is not a proper choice. */
return isVisionModel
? omit(
params,
'presence_penalty',
'frequency_penalty',
'temperature',
'result_format',
'top_p',
)
: params;
}
}
Loading

0 comments on commit 5216a85

Please sign in to comment.