From 0cb186846a03b95dfc4dd0d3b1f25dac48ac1026 Mon Sep 17 00:00:00 2001
From: Dogtiti <499960698@qq.com>
Date: Fri, 27 Dec 2024 21:52:22 +0800
Subject: [PATCH] feature: support glm Cogview
---
app/client/platforms/glm.ts | 131 ++++++++++++++++++++++++++++++------
app/components/chat.tsx | 13 ++--
app/constant.ts | 11 +++
app/store/config.ts | 4 +-
app/typing.ts | 11 +++
app/utils.ts | 23 +++++++
6 files changed, 167 insertions(+), 26 deletions(-)
diff --git a/app/client/platforms/glm.ts b/app/client/platforms/glm.ts
index a7965947fab..8d685fec5ee 100644
--- a/app/client/platforms/glm.ts
+++ b/app/client/platforms/glm.ts
@@ -25,12 +25,103 @@ import { getMessageTextContent } from "@/app/utils";
import { RequestPayload } from "./openai";
import { fetch } from "@/app/utils/stream";
+interface BasePayload {
+ model: string;
+}
+
+interface ChatPayload extends BasePayload {
+ messages: ChatOptions["messages"];
+ stream?: boolean;
+ temperature?: number;
+ presence_penalty?: number;
+ frequency_penalty?: number;
+ top_p?: number;
+}
+
+interface ImageGenerationPayload extends BasePayload {
+ prompt: string;
+ size?: string;
+ user_id?: string;
+}
+
+interface VideoGenerationPayload extends BasePayload {
+ prompt: string;
+ duration?: number;
+ resolution?: string;
+ user_id?: string;
+}
+
+type ModelType = "chat" | "image" | "video";
+
export class ChatGLMApi implements LLMApi {
private disableListModels = true;
+ private getModelType(model: string): ModelType {
+ if (model.startsWith("cogview-")) return "image";
+ if (model.startsWith("cogvideo-")) return "video";
+ return "chat";
+ }
+
+ private getModelPath(type: ModelType): string {
+ switch (type) {
+ case "image":
+ return ChatGLM.ImagePath;
+ case "video":
+ return ChatGLM.VideoPath;
+ default:
+ return ChatGLM.ChatPath;
+ }
+ }
+
+ private createPayload(
+ messages: ChatOptions["messages"],
+ modelConfig: any,
+ options: ChatOptions,
+ ): BasePayload {
+ const modelType = this.getModelType(modelConfig.model);
+ const lastMessage = messages[messages.length - 1];
+ const prompt =
+ typeof lastMessage.content === "string"
+ ? lastMessage.content
+ : lastMessage.content.map((c) => c.text).join("\n");
+
+ switch (modelType) {
+ case "image":
+ return {
+ model: modelConfig.model,
+ prompt,
+ size: "1024x1024",
+ } as ImageGenerationPayload;
+ default:
+ return {
+ messages,
+ stream: options.config.stream,
+ model: modelConfig.model,
+ temperature: modelConfig.temperature,
+ presence_penalty: modelConfig.presence_penalty,
+ frequency_penalty: modelConfig.frequency_penalty,
+ top_p: modelConfig.top_p,
+ } as ChatPayload;
+ }
+ }
+
+ private parseResponse(modelType: ModelType, json: any): string {
+ switch (modelType) {
+ case "image": {
+ const imageUrl = json.data?.[0]?.url;
+ return imageUrl ? `` : "";
+ }
+ case "video": {
+ const videoUrl = json.data?.[0]?.url;
+ return videoUrl ? `` : "";
+ }
+ default:
+ return this.extractMessage(json);
+ }
+ }
+
path(path: string): string {
const accessStore = useAccessStore.getState();
-
let baseUrl = "";
if (accessStore.useCustomConfig) {
@@ -51,7 +142,6 @@ export class ChatGLMApi implements LLMApi {
}
console.log("[Proxy Endpoint] ", baseUrl, path);
-
return [baseUrl, path].join("/");
}
@@ -79,24 +169,16 @@ export class ChatGLMApi implements LLMApi {
},
};
- const requestPayload: RequestPayload = {
- messages,
- stream: options.config.stream,
- model: modelConfig.model,
- temperature: modelConfig.temperature,
- presence_penalty: modelConfig.presence_penalty,
- frequency_penalty: modelConfig.frequency_penalty,
- top_p: modelConfig.top_p,
- };
+ const modelType = this.getModelType(modelConfig.model);
+ const requestPayload = this.createPayload(messages, modelConfig, options);
+ const path = this.path(this.getModelPath(modelType));
- console.log("[Request] glm payload: ", requestPayload);
+ console.log(`[Request] glm ${modelType} payload: `, requestPayload);
- const shouldStream = !!options.config.stream;
const controller = new AbortController();
options.onController?.(controller);
try {
- const chatPath = this.path(ChatGLM.ChatPath);
const chatPayload = {
method: "POST",
body: JSON.stringify(requestPayload),
@@ -104,12 +186,23 @@ export class ChatGLMApi implements LLMApi {
headers: getHeaders(),
};
- // make a fetch request
const requestTimeoutId = setTimeout(
() => controller.abort(),
REQUEST_TIMEOUT_MS,
);
+ if (modelType === "image" || modelType === "video") {
+ const res = await fetch(path, chatPayload);
+ clearTimeout(requestTimeoutId);
+
+ const resJson = await res.json();
+ console.log(`[Response] glm ${modelType}:`, resJson);
+ const message = this.parseResponse(modelType, resJson);
+ options.onFinish(message, res);
+ return;
+ }
+
+ const shouldStream = !!options.config.stream;
if (shouldStream) {
const [tools, funcs] = usePluginStore
.getState()
@@ -117,7 +210,7 @@ export class ChatGLMApi implements LLMApi {
useChatStore.getState().currentSession().mask?.plugin || [],
);
return stream(
- chatPath,
+ path,
requestPayload,
getHeaders(),
tools as any,
@@ -125,7 +218,6 @@ export class ChatGLMApi implements LLMApi {
controller,
// parseSSE
(text: string, runTools: ChatMessageTool[]) => {
- // console.log("parseSSE", text, runTools);
const json = JSON.parse(text);
const choices = json.choices as Array<{
delta: {
@@ -154,7 +246,7 @@ export class ChatGLMApi implements LLMApi {
}
return choices[0]?.delta?.content;
},
- // processToolMessage, include tool_calls message and tool call results
+ // processToolMessage
(
requestPayload: RequestPayload,
toolCallMessage: any,
@@ -172,7 +264,7 @@ export class ChatGLMApi implements LLMApi {
options,
);
} else {
- const res = await fetch(chatPath, chatPayload);
+ const res = await fetch(path, chatPayload);
clearTimeout(requestTimeoutId);
const resJson = await res.json();
@@ -184,6 +276,7 @@ export class ChatGLMApi implements LLMApi {
options.onError?.(e as Error);
}
}
+
async usage() {
return {
used: 0,
diff --git a/app/components/chat.tsx b/app/components/chat.tsx
index 51fe74fe7be..f34f7d78e09 100644
--- a/app/components/chat.tsx
+++ b/app/components/chat.tsx
@@ -72,6 +72,8 @@ import {
isDalle3,
showPlugins,
safeLocalStorage,
+ getModelSizes,
+ supportsCustomSize,
} from "../utils";
import { uploadImage as uploadImageRemote } from "@/app/utils/chat";
@@ -79,7 +81,7 @@ import { uploadImage as uploadImageRemote } from "@/app/utils/chat";
import dynamic from "next/dynamic";
import { ChatControllerPool } from "../client/controller";
-import { DalleSize, DalleQuality, DalleStyle } from "../typing";
+import { DalleQuality, DalleStyle, ModelSize } from "../typing";
import { Prompt, usePromptStore } from "../store/prompt";
import Locale from "../locales";
@@ -519,10 +521,11 @@ export function ChatActions(props: {
const [showSizeSelector, setShowSizeSelector] = useState(false);
const [showQualitySelector, setShowQualitySelector] = useState(false);
const [showStyleSelector, setShowStyleSelector] = useState(false);
- const dalle3Sizes: DalleSize[] = ["1024x1024", "1792x1024", "1024x1792"];
+ const modelSizes = getModelSizes(currentModel);
const dalle3Qualitys: DalleQuality[] = ["standard", "hd"];
const dalle3Styles: DalleStyle[] = ["vivid", "natural"];
- const currentSize = session.mask.modelConfig?.size ?? "1024x1024";
+ const currentSize =
+ session.mask.modelConfig?.size ?? ("1024x1024" as ModelSize);
const currentQuality = session.mask.modelConfig?.quality ?? "standard";
const currentStyle = session.mask.modelConfig?.style ?? "vivid";
@@ -673,7 +676,7 @@ export function ChatActions(props: {
/>
)}
- {isDalle3(currentModel) && (
+ {supportsCustomSize(currentModel) && (
setShowSizeSelector(true)}
text={currentSize}
@@ -684,7 +687,7 @@ export function ChatActions(props: {
{showSizeSelector && (
({
+ items={modelSizes.map((m) => ({
title: m,
value: m,
}))}
diff --git a/app/constant.ts b/app/constant.ts
index 5759411af17..c1a73bc6593 100644
--- a/app/constant.ts
+++ b/app/constant.ts
@@ -233,6 +233,8 @@ export const XAI = {
export const ChatGLM = {
ExampleEndpoint: CHATGLM_BASE_URL,
ChatPath: "api/paas/v4/chat/completions",
+ ImagePath: "api/paas/v4/images/generations",
+ VideoPath: "api/paas/v4/videos/generations",
};
export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang
@@ -431,6 +433,15 @@ const chatglmModels = [
"glm-4-long",
"glm-4-flashx",
"glm-4-flash",
+ "glm-4v-plus",
+ "glm-4v",
+ "glm-4v-flash", // free
+ "cogview-3-plus",
+ "cogview-3",
+ "cogview-3-flash", // free
+ // 目前无法适配轮询任务
+ // "cogvideox",
+ // "cogvideox-flash", // free
];
let seq = 1000; // 内置的模型序号生成器从1000开始
diff --git a/app/store/config.ts b/app/store/config.ts
index 4256eba925d..45e21b02697 100644
--- a/app/store/config.ts
+++ b/app/store/config.ts
@@ -1,5 +1,5 @@
import { LLMModel } from "../client/api";
-import { DalleSize, DalleQuality, DalleStyle } from "../typing";
+import { DalleQuality, DalleStyle, ModelSize } from "../typing";
import { getClientConfig } from "../config/client";
import {
DEFAULT_INPUT_TEMPLATE,
@@ -78,7 +78,7 @@ export const DEFAULT_CONFIG = {
compressProviderName: "",
enableInjectSystemPrompts: true,
template: config?.template ?? DEFAULT_INPUT_TEMPLATE,
- size: "1024x1024" as DalleSize,
+ size: "1024x1024" as ModelSize,
quality: "standard" as DalleQuality,
style: "vivid" as DalleStyle,
},
diff --git a/app/typing.ts b/app/typing.ts
index 0336be75d39..ecb327936fd 100644
--- a/app/typing.ts
+++ b/app/typing.ts
@@ -11,3 +11,14 @@ export interface RequestMessage {
export type DalleSize = "1024x1024" | "1792x1024" | "1024x1792";
export type DalleQuality = "standard" | "hd";
export type DalleStyle = "vivid" | "natural";
+
+export type ModelSize =
+ | "1024x1024"
+ | "1792x1024"
+ | "1024x1792"
+ | "768x1344"
+ | "864x1152"
+ | "1344x768"
+ | "1152x864"
+ | "1440x720"
+ | "720x1440";
diff --git a/app/utils.ts b/app/utils.ts
index 962e68a101c..810dc7842b1 100644
--- a/app/utils.ts
+++ b/app/utils.ts
@@ -7,6 +7,7 @@ import { ServiceProvider } from "./constant";
import { fetch as tauriStreamFetch } from "./utils/stream";
import { VISION_MODEL_REGEXES, EXCLUDE_VISION_MODEL_REGEXES } from "./constant";
import { getClientConfig } from "./config/client";
+import { ModelSize } from "./typing";
export function trimTopic(topic: string) {
// Fix an issue where double quotes still show in the Indonesian language
@@ -271,6 +272,28 @@ export function isDalle3(model: string) {
return "dall-e-3" === model;
}
+export function getModelSizes(model: string): ModelSize[] {
+ if (isDalle3(model)) {
+ return ["1024x1024", "1792x1024", "1024x1792"];
+ }
+ if (model.toLowerCase().includes("cogview")) {
+ return [
+ "1024x1024",
+ "768x1344",
+ "864x1152",
+ "1344x768",
+ "1152x864",
+ "1440x720",
+ "720x1440",
+ ];
+ }
+ return [];
+}
+
+export function supportsCustomSize(model: string): boolean {
+ return getModelSizes(model).length > 0;
+}
+
export function showPlugins(provider: ServiceProvider, model: string) {
if (
provider == ServiceProvider.OpenAI ||