From 0cb186846a03b95dfc4dd0d3b1f25dac48ac1026 Mon Sep 17 00:00:00 2001 From: Dogtiti <499960698@qq.com> Date: Fri, 27 Dec 2024 21:52:22 +0800 Subject: [PATCH] feature: support glm Cogview --- app/client/platforms/glm.ts | 131 ++++++++++++++++++++++++++++++------ app/components/chat.tsx | 13 ++-- app/constant.ts | 11 +++ app/store/config.ts | 4 +- app/typing.ts | 11 +++ app/utils.ts | 23 +++++++ 6 files changed, 167 insertions(+), 26 deletions(-) diff --git a/app/client/platforms/glm.ts b/app/client/platforms/glm.ts index a7965947fab..8d685fec5ee 100644 --- a/app/client/platforms/glm.ts +++ b/app/client/platforms/glm.ts @@ -25,12 +25,103 @@ import { getMessageTextContent } from "@/app/utils"; import { RequestPayload } from "./openai"; import { fetch } from "@/app/utils/stream"; +interface BasePayload { + model: string; +} + +interface ChatPayload extends BasePayload { + messages: ChatOptions["messages"]; + stream?: boolean; + temperature?: number; + presence_penalty?: number; + frequency_penalty?: number; + top_p?: number; +} + +interface ImageGenerationPayload extends BasePayload { + prompt: string; + size?: string; + user_id?: string; +} + +interface VideoGenerationPayload extends BasePayload { + prompt: string; + duration?: number; + resolution?: string; + user_id?: string; +} + +type ModelType = "chat" | "image" | "video"; + export class ChatGLMApi implements LLMApi { private disableListModels = true; + private getModelType(model: string): ModelType { + if (model.startsWith("cogview-")) return "image"; + if (model.startsWith("cogvideo-")) return "video"; + return "chat"; + } + + private getModelPath(type: ModelType): string { + switch (type) { + case "image": + return ChatGLM.ImagePath; + case "video": + return ChatGLM.VideoPath; + default: + return ChatGLM.ChatPath; + } + } + + private createPayload( + messages: ChatOptions["messages"], + modelConfig: any, + options: ChatOptions, + ): BasePayload { + const modelType = this.getModelType(modelConfig.model); + const lastMessage = messages[messages.length - 1]; + const prompt = + typeof lastMessage.content === "string" + ? lastMessage.content + : lastMessage.content.map((c) => c.text).join("\n"); + + switch (modelType) { + case "image": + return { + model: modelConfig.model, + prompt, + size: "1024x1024", + } as ImageGenerationPayload; + default: + return { + messages, + stream: options.config.stream, + model: modelConfig.model, + temperature: modelConfig.temperature, + presence_penalty: modelConfig.presence_penalty, + frequency_penalty: modelConfig.frequency_penalty, + top_p: modelConfig.top_p, + } as ChatPayload; + } + } + + private parseResponse(modelType: ModelType, json: any): string { + switch (modelType) { + case "image": { + const imageUrl = json.data?.[0]?.url; + return imageUrl ? `![Generated Image](${imageUrl})` : ""; + } + case "video": { + const videoUrl = json.data?.[0]?.url; + return videoUrl ? `` : ""; + } + default: + return this.extractMessage(json); + } + } + path(path: string): string { const accessStore = useAccessStore.getState(); - let baseUrl = ""; if (accessStore.useCustomConfig) { @@ -51,7 +142,6 @@ export class ChatGLMApi implements LLMApi { } console.log("[Proxy Endpoint] ", baseUrl, path); - return [baseUrl, path].join("/"); } @@ -79,24 +169,16 @@ export class ChatGLMApi implements LLMApi { }, }; - const requestPayload: RequestPayload = { - messages, - stream: options.config.stream, - model: modelConfig.model, - temperature: modelConfig.temperature, - presence_penalty: modelConfig.presence_penalty, - frequency_penalty: modelConfig.frequency_penalty, - top_p: modelConfig.top_p, - }; + const modelType = this.getModelType(modelConfig.model); + const requestPayload = this.createPayload(messages, modelConfig, options); + const path = this.path(this.getModelPath(modelType)); - console.log("[Request] glm payload: ", requestPayload); + console.log(`[Request] glm ${modelType} payload: `, requestPayload); - const shouldStream = !!options.config.stream; const controller = new AbortController(); options.onController?.(controller); try { - const chatPath = this.path(ChatGLM.ChatPath); const chatPayload = { method: "POST", body: JSON.stringify(requestPayload), @@ -104,12 +186,23 @@ export class ChatGLMApi implements LLMApi { headers: getHeaders(), }; - // make a fetch request const requestTimeoutId = setTimeout( () => controller.abort(), REQUEST_TIMEOUT_MS, ); + if (modelType === "image" || modelType === "video") { + const res = await fetch(path, chatPayload); + clearTimeout(requestTimeoutId); + + const resJson = await res.json(); + console.log(`[Response] glm ${modelType}:`, resJson); + const message = this.parseResponse(modelType, resJson); + options.onFinish(message, res); + return; + } + + const shouldStream = !!options.config.stream; if (shouldStream) { const [tools, funcs] = usePluginStore .getState() @@ -117,7 +210,7 @@ export class ChatGLMApi implements LLMApi { useChatStore.getState().currentSession().mask?.plugin || [], ); return stream( - chatPath, + path, requestPayload, getHeaders(), tools as any, @@ -125,7 +218,6 @@ export class ChatGLMApi implements LLMApi { controller, // parseSSE (text: string, runTools: ChatMessageTool[]) => { - // console.log("parseSSE", text, runTools); const json = JSON.parse(text); const choices = json.choices as Array<{ delta: { @@ -154,7 +246,7 @@ export class ChatGLMApi implements LLMApi { } return choices[0]?.delta?.content; }, - // processToolMessage, include tool_calls message and tool call results + // processToolMessage ( requestPayload: RequestPayload, toolCallMessage: any, @@ -172,7 +264,7 @@ export class ChatGLMApi implements LLMApi { options, ); } else { - const res = await fetch(chatPath, chatPayload); + const res = await fetch(path, chatPayload); clearTimeout(requestTimeoutId); const resJson = await res.json(); @@ -184,6 +276,7 @@ export class ChatGLMApi implements LLMApi { options.onError?.(e as Error); } } + async usage() { return { used: 0, diff --git a/app/components/chat.tsx b/app/components/chat.tsx index 51fe74fe7be..f34f7d78e09 100644 --- a/app/components/chat.tsx +++ b/app/components/chat.tsx @@ -72,6 +72,8 @@ import { isDalle3, showPlugins, safeLocalStorage, + getModelSizes, + supportsCustomSize, } from "../utils"; import { uploadImage as uploadImageRemote } from "@/app/utils/chat"; @@ -79,7 +81,7 @@ import { uploadImage as uploadImageRemote } from "@/app/utils/chat"; import dynamic from "next/dynamic"; import { ChatControllerPool } from "../client/controller"; -import { DalleSize, DalleQuality, DalleStyle } from "../typing"; +import { DalleQuality, DalleStyle, ModelSize } from "../typing"; import { Prompt, usePromptStore } from "../store/prompt"; import Locale from "../locales"; @@ -519,10 +521,11 @@ export function ChatActions(props: { const [showSizeSelector, setShowSizeSelector] = useState(false); const [showQualitySelector, setShowQualitySelector] = useState(false); const [showStyleSelector, setShowStyleSelector] = useState(false); - const dalle3Sizes: DalleSize[] = ["1024x1024", "1792x1024", "1024x1792"]; + const modelSizes = getModelSizes(currentModel); const dalle3Qualitys: DalleQuality[] = ["standard", "hd"]; const dalle3Styles: DalleStyle[] = ["vivid", "natural"]; - const currentSize = session.mask.modelConfig?.size ?? "1024x1024"; + const currentSize = + session.mask.modelConfig?.size ?? ("1024x1024" as ModelSize); const currentQuality = session.mask.modelConfig?.quality ?? "standard"; const currentStyle = session.mask.modelConfig?.style ?? "vivid"; @@ -673,7 +676,7 @@ export function ChatActions(props: { /> )} - {isDalle3(currentModel) && ( + {supportsCustomSize(currentModel) && ( setShowSizeSelector(true)} text={currentSize} @@ -684,7 +687,7 @@ export function ChatActions(props: { {showSizeSelector && ( ({ + items={modelSizes.map((m) => ({ title: m, value: m, }))} diff --git a/app/constant.ts b/app/constant.ts index 5759411af17..c1a73bc6593 100644 --- a/app/constant.ts +++ b/app/constant.ts @@ -233,6 +233,8 @@ export const XAI = { export const ChatGLM = { ExampleEndpoint: CHATGLM_BASE_URL, ChatPath: "api/paas/v4/chat/completions", + ImagePath: "api/paas/v4/images/generations", + VideoPath: "api/paas/v4/videos/generations", }; export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang @@ -431,6 +433,15 @@ const chatglmModels = [ "glm-4-long", "glm-4-flashx", "glm-4-flash", + "glm-4v-plus", + "glm-4v", + "glm-4v-flash", // free + "cogview-3-plus", + "cogview-3", + "cogview-3-flash", // free + // 目前无法适配轮询任务 + // "cogvideox", + // "cogvideox-flash", // free ]; let seq = 1000; // 内置的模型序号生成器从1000开始 diff --git a/app/store/config.ts b/app/store/config.ts index 4256eba925d..45e21b02697 100644 --- a/app/store/config.ts +++ b/app/store/config.ts @@ -1,5 +1,5 @@ import { LLMModel } from "../client/api"; -import { DalleSize, DalleQuality, DalleStyle } from "../typing"; +import { DalleQuality, DalleStyle, ModelSize } from "../typing"; import { getClientConfig } from "../config/client"; import { DEFAULT_INPUT_TEMPLATE, @@ -78,7 +78,7 @@ export const DEFAULT_CONFIG = { compressProviderName: "", enableInjectSystemPrompts: true, template: config?.template ?? DEFAULT_INPUT_TEMPLATE, - size: "1024x1024" as DalleSize, + size: "1024x1024" as ModelSize, quality: "standard" as DalleQuality, style: "vivid" as DalleStyle, }, diff --git a/app/typing.ts b/app/typing.ts index 0336be75d39..ecb327936fd 100644 --- a/app/typing.ts +++ b/app/typing.ts @@ -11,3 +11,14 @@ export interface RequestMessage { export type DalleSize = "1024x1024" | "1792x1024" | "1024x1792"; export type DalleQuality = "standard" | "hd"; export type DalleStyle = "vivid" | "natural"; + +export type ModelSize = + | "1024x1024" + | "1792x1024" + | "1024x1792" + | "768x1344" + | "864x1152" + | "1344x768" + | "1152x864" + | "1440x720" + | "720x1440"; diff --git a/app/utils.ts b/app/utils.ts index 962e68a101c..810dc7842b1 100644 --- a/app/utils.ts +++ b/app/utils.ts @@ -7,6 +7,7 @@ import { ServiceProvider } from "./constant"; import { fetch as tauriStreamFetch } from "./utils/stream"; import { VISION_MODEL_REGEXES, EXCLUDE_VISION_MODEL_REGEXES } from "./constant"; import { getClientConfig } from "./config/client"; +import { ModelSize } from "./typing"; export function trimTopic(topic: string) { // Fix an issue where double quotes still show in the Indonesian language @@ -271,6 +272,28 @@ export function isDalle3(model: string) { return "dall-e-3" === model; } +export function getModelSizes(model: string): ModelSize[] { + if (isDalle3(model)) { + return ["1024x1024", "1792x1024", "1024x1792"]; + } + if (model.toLowerCase().includes("cogview")) { + return [ + "1024x1024", + "768x1344", + "864x1152", + "1344x768", + "1152x864", + "1440x720", + "720x1440", + ]; + } + return []; +} + +export function supportsCustomSize(model: string): boolean { + return getModelSizes(model).length > 0; +} + export function showPlugins(provider: ServiceProvider, model: string) { if ( provider == ServiceProvider.OpenAI ||