Skip to content

Commit

Permalink
feature: support glm Cogview
Browse files Browse the repository at this point in the history
  • Loading branch information
Dogtiti committed Dec 28, 2024
1 parent 0c3d446 commit 0cb1868
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 26 deletions.
131 changes: 112 additions & 19 deletions app/client/platforms/glm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,103 @@ import { getMessageTextContent } from "@/app/utils";
import { RequestPayload } from "./openai";
import { fetch } from "@/app/utils/stream";

interface BasePayload {
model: string;
}

interface ChatPayload extends BasePayload {
messages: ChatOptions["messages"];
stream?: boolean;
temperature?: number;
presence_penalty?: number;
frequency_penalty?: number;
top_p?: number;
}

interface ImageGenerationPayload extends BasePayload {
prompt: string;
size?: string;
user_id?: string;
}

interface VideoGenerationPayload extends BasePayload {
prompt: string;
duration?: number;
resolution?: string;
user_id?: string;
}

type ModelType = "chat" | "image" | "video";

export class ChatGLMApi implements LLMApi {
private disableListModels = true;

private getModelType(model: string): ModelType {
if (model.startsWith("cogview-")) return "image";
if (model.startsWith("cogvideo-")) return "video";
return "chat";
}

private getModelPath(type: ModelType): string {
switch (type) {
case "image":
return ChatGLM.ImagePath;
case "video":
return ChatGLM.VideoPath;
default:
return ChatGLM.ChatPath;
}
}

private createPayload(
messages: ChatOptions["messages"],
modelConfig: any,
options: ChatOptions,
): BasePayload {
const modelType = this.getModelType(modelConfig.model);
const lastMessage = messages[messages.length - 1];
const prompt =
typeof lastMessage.content === "string"
? lastMessage.content
: lastMessage.content.map((c) => c.text).join("\n");

switch (modelType) {
case "image":
return {
model: modelConfig.model,
prompt,
size: "1024x1024",
} as ImageGenerationPayload;
default:
return {
messages,
stream: options.config.stream,
model: modelConfig.model,
temperature: modelConfig.temperature,
presence_penalty: modelConfig.presence_penalty,
frequency_penalty: modelConfig.frequency_penalty,
top_p: modelConfig.top_p,
} as ChatPayload;
}
}

private parseResponse(modelType: ModelType, json: any): string {
switch (modelType) {
case "image": {
const imageUrl = json.data?.[0]?.url;
return imageUrl ? `![Generated Image](${imageUrl})` : "";
}
case "video": {
const videoUrl = json.data?.[0]?.url;
return videoUrl ? `<video controls src="${videoUrl}"></video>` : "";
}
default:
return this.extractMessage(json);
}
}

path(path: string): string {
const accessStore = useAccessStore.getState();

let baseUrl = "";

if (accessStore.useCustomConfig) {
Expand All @@ -51,7 +142,6 @@ export class ChatGLMApi implements LLMApi {
}

console.log("[Proxy Endpoint] ", baseUrl, path);

return [baseUrl, path].join("/");
}

Expand Down Expand Up @@ -79,53 +169,55 @@ export class ChatGLMApi implements LLMApi {
},
};

const requestPayload: RequestPayload = {
messages,
stream: options.config.stream,
model: modelConfig.model,
temperature: modelConfig.temperature,
presence_penalty: modelConfig.presence_penalty,
frequency_penalty: modelConfig.frequency_penalty,
top_p: modelConfig.top_p,
};
const modelType = this.getModelType(modelConfig.model);
const requestPayload = this.createPayload(messages, modelConfig, options);
const path = this.path(this.getModelPath(modelType));

console.log("[Request] glm payload: ", requestPayload);
console.log(`[Request] glm ${modelType} payload: `, requestPayload);

const shouldStream = !!options.config.stream;
const controller = new AbortController();
options.onController?.(controller);

try {
const chatPath = this.path(ChatGLM.ChatPath);
const chatPayload = {
method: "POST",
body: JSON.stringify(requestPayload),
signal: controller.signal,
headers: getHeaders(),
};

// make a fetch request
const requestTimeoutId = setTimeout(
() => controller.abort(),
REQUEST_TIMEOUT_MS,
);

if (modelType === "image" || modelType === "video") {
const res = await fetch(path, chatPayload);
clearTimeout(requestTimeoutId);

const resJson = await res.json();
console.log(`[Response] glm ${modelType}:`, resJson);
const message = this.parseResponse(modelType, resJson);
options.onFinish(message, res);
return;
}

const shouldStream = !!options.config.stream;
if (shouldStream) {
const [tools, funcs] = usePluginStore
.getState()
.getAsTools(
useChatStore.getState().currentSession().mask?.plugin || [],
);
return stream(
chatPath,
path,
requestPayload,
getHeaders(),
tools as any,
funcs,
controller,
// parseSSE
(text: string, runTools: ChatMessageTool[]) => {
// console.log("parseSSE", text, runTools);
const json = JSON.parse(text);
const choices = json.choices as Array<{
delta: {
Expand Down Expand Up @@ -154,7 +246,7 @@ export class ChatGLMApi implements LLMApi {
}
return choices[0]?.delta?.content;
},
// processToolMessage, include tool_calls message and tool call results
// processToolMessage
(
requestPayload: RequestPayload,
toolCallMessage: any,
Expand All @@ -172,7 +264,7 @@ export class ChatGLMApi implements LLMApi {
options,
);
} else {
const res = await fetch(chatPath, chatPayload);
const res = await fetch(path, chatPayload);
clearTimeout(requestTimeoutId);

const resJson = await res.json();
Expand All @@ -184,6 +276,7 @@ export class ChatGLMApi implements LLMApi {
options.onError?.(e as Error);
}
}

async usage() {
return {
used: 0,
Expand Down
13 changes: 8 additions & 5 deletions app/components/chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,16 @@ import {
isDalle3,
showPlugins,
safeLocalStorage,
getModelSizes,
supportsCustomSize,
} from "../utils";

import { uploadImage as uploadImageRemote } from "@/app/utils/chat";

import dynamic from "next/dynamic";

import { ChatControllerPool } from "../client/controller";
import { DalleSize, DalleQuality, DalleStyle } from "../typing";
import { DalleQuality, DalleStyle, ModelSize } from "../typing";
import { Prompt, usePromptStore } from "../store/prompt";
import Locale from "../locales";

Expand Down Expand Up @@ -519,10 +521,11 @@ export function ChatActions(props: {
const [showSizeSelector, setShowSizeSelector] = useState(false);
const [showQualitySelector, setShowQualitySelector] = useState(false);
const [showStyleSelector, setShowStyleSelector] = useState(false);
const dalle3Sizes: DalleSize[] = ["1024x1024", "1792x1024", "1024x1792"];
const modelSizes = getModelSizes(currentModel);
const dalle3Qualitys: DalleQuality[] = ["standard", "hd"];
const dalle3Styles: DalleStyle[] = ["vivid", "natural"];
const currentSize = session.mask.modelConfig?.size ?? "1024x1024";
const currentSize =
session.mask.modelConfig?.size ?? ("1024x1024" as ModelSize);
const currentQuality = session.mask.modelConfig?.quality ?? "standard";
const currentStyle = session.mask.modelConfig?.style ?? "vivid";

Expand Down Expand Up @@ -673,7 +676,7 @@ export function ChatActions(props: {
/>
)}

{isDalle3(currentModel) && (
{supportsCustomSize(currentModel) && (
<ChatAction
onClick={() => setShowSizeSelector(true)}
text={currentSize}
Expand All @@ -684,7 +687,7 @@ export function ChatActions(props: {
{showSizeSelector && (
<Selector
defaultSelectedValue={currentSize}
items={dalle3Sizes.map((m) => ({
items={modelSizes.map((m) => ({
title: m,
value: m,
}))}
Expand Down
11 changes: 11 additions & 0 deletions app/constant.ts
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,8 @@ export const XAI = {
export const ChatGLM = {
ExampleEndpoint: CHATGLM_BASE_URL,
ChatPath: "api/paas/v4/chat/completions",
ImagePath: "api/paas/v4/images/generations",
VideoPath: "api/paas/v4/videos/generations",
};

export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang
Expand Down Expand Up @@ -431,6 +433,15 @@ const chatglmModels = [
"glm-4-long",
"glm-4-flashx",
"glm-4-flash",
"glm-4v-plus",
"glm-4v",
"glm-4v-flash", // free
"cogview-3-plus",
"cogview-3",
"cogview-3-flash", // free
// 目前无法适配轮询任务
// "cogvideox",
// "cogvideox-flash", // free
];

let seq = 1000; // 内置的模型序号生成器从1000开始
Expand Down
4 changes: 2 additions & 2 deletions app/store/config.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { LLMModel } from "../client/api";
import { DalleSize, DalleQuality, DalleStyle } from "../typing";
import { DalleQuality, DalleStyle, ModelSize } from "../typing";
import { getClientConfig } from "../config/client";
import {
DEFAULT_INPUT_TEMPLATE,
Expand Down Expand Up @@ -78,7 +78,7 @@ export const DEFAULT_CONFIG = {
compressProviderName: "",
enableInjectSystemPrompts: true,
template: config?.template ?? DEFAULT_INPUT_TEMPLATE,
size: "1024x1024" as DalleSize,
size: "1024x1024" as ModelSize,
quality: "standard" as DalleQuality,
style: "vivid" as DalleStyle,
},
Expand Down
11 changes: 11 additions & 0 deletions app/typing.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,14 @@ export interface RequestMessage {
export type DalleSize = "1024x1024" | "1792x1024" | "1024x1792";
export type DalleQuality = "standard" | "hd";
export type DalleStyle = "vivid" | "natural";

export type ModelSize =
| "1024x1024"
| "1792x1024"
| "1024x1792"
| "768x1344"
| "864x1152"
| "1344x768"
| "1152x864"
| "1440x720"
| "720x1440";
23 changes: 23 additions & 0 deletions app/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { ServiceProvider } from "./constant";
import { fetch as tauriStreamFetch } from "./utils/stream";
import { VISION_MODEL_REGEXES, EXCLUDE_VISION_MODEL_REGEXES } from "./constant";
import { getClientConfig } from "./config/client";
import { ModelSize } from "./typing";

export function trimTopic(topic: string) {
// Fix an issue where double quotes still show in the Indonesian language
Expand Down Expand Up @@ -271,6 +272,28 @@ export function isDalle3(model: string) {
return "dall-e-3" === model;
}

export function getModelSizes(model: string): ModelSize[] {
if (isDalle3(model)) {
return ["1024x1024", "1792x1024", "1024x1792"];
}
if (model.toLowerCase().includes("cogview")) {
return [
"1024x1024",
"768x1344",
"864x1152",
"1344x768",
"1152x864",
"1440x720",
"720x1440",
];
}
return [];
}

export function supportsCustomSize(model: string): boolean {
return getModelSizes(model).length > 0;
}

export function showPlugins(provider: ServiceProvider, model: string) {
if (
provider == ServiceProvider.OpenAI ||
Expand Down

0 comments on commit 0cb1868

Please sign in to comment.