From 6834c12ec413b4eeec8353fc5e2e67873482087e Mon Sep 17 00:00:00 2001 From: MiaowFISH Date: Sun, 5 Jan 2025 11:56:54 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8DGemini=E9=80=82=E9=85=8D?= =?UTF-8?q?=E5=99=A8=E3=80=81=E5=85=81=E8=AE=B8=E4=B8=B4=E6=97=B6=E7=A6=81?= =?UTF-8?q?=E7=94=A8=E9=80=82=E9=85=8D=E5=99=A8=EF=BC=8C=E8=B7=B3=E8=BF=87?= =?UTF-8?q?=E7=A9=BA=E8=BF=87=E6=BB=A4=E8=AF=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/adapters/config.ts | 14 +++++++++++--- src/adapters/gemini.ts | 8 ++++++++ src/adapters/index.ts | 1 + src/config.ts | 2 +- src/utils/toolkit.ts | 4 +++- 5 files changed, 24 insertions(+), 5 deletions(-) diff --git a/src/adapters/config.ts b/src/adapters/config.ts index bfaf613..53d2c89 100644 --- a/src/adapters/config.ts +++ b/src/adapters/config.ts @@ -1,12 +1,14 @@ import { Schema } from "koishi"; export interface LLM { + Enabled?: boolean; APIType: "OpenAI" | "Cloudflare" | "Ollama" | "Custom URL" | "Gemini"; BaseURL: string; UID?: string; APIKey: string; AIModel: string; Ability?: Array<"原生工具调用" | "识图功能" | "结构化输出">; + NUMA?: boolean; NumCtx?: number; NumBatch?: number; @@ -26,12 +28,13 @@ export interface Config { export const API: Schema = Schema.intersect([ Schema.object({ + Enabled: Schema.boolean().default(true).description("是否启用"), APIType: Schema.union(["OpenAI", "Cloudflare", "Ollama", "Custom URL", "Gemini"]) .default("OpenAI") .description("API 类型"), - BaseURL: Schema.string() - .default("https://api.openai.com") - .description("API 基础 URL, 设置为\"Custom URL\"需要填写完整的 URL"), + // BaseURL: Schema.string() + // .default("https://api.openai.com") + // .description("API 基础 URL, 设置为\"Custom URL\"需要填写完整的 URL"), APIKey: Schema.string().required().description("你的 API 令牌"), AIModel: Schema.string() .description("模型 ID"), @@ -44,16 +47,20 @@ export const API: Schema = Schema.intersect([ Schema.union([ Schema.object({ APIType: Schema.const("OpenAI"), + BaseURL: Schema.string().default("https://api.openai.com"), }), Schema.object({ APIType: Schema.const("Cloudflare"), + BaseURL: Schema.string().default("https://api.cloudflare.com/client/v4"), UID: Schema.string().required().description("Cloudflare UID"), }), Schema.object({ APIType: Schema.const("Custom URL"), + BaseURL: Schema.string().required().description("自定义 URL"), }), Schema.object({ APIType: Schema.const("Ollama"), + BaseURL: Schema.string().default("http://127.0.0.1:11434"), NUMA: Schema.boolean() .default(false) .description("是否使用 NUMA"), @@ -100,6 +107,7 @@ export const API: Schema = Schema.intersect([ }), Schema.object({ APIType: Schema.const("Gemini"), + BaseURL: Schema.string().default("https://generativelanguage.googleapis.com"), }), ]), ]); diff --git a/src/adapters/gemini.ts b/src/adapters/gemini.ts index f602ab1..9adac7d 100644 --- a/src/adapters/gemini.ts +++ b/src/adapters/gemini.ts @@ -61,7 +61,13 @@ export class GeminiAdapter extends BaseAdapter { } async chat(messages: Message[], toolsSchema?: ToolSchema[], debug = false): Promise { + const system = messages.find((message) => message.role === "system"); + if (system) { + messages = messages.filter((message) => message.role !== "system"); + + } const requestBody = { + system_instruction: convert(system), contents: messages.map(convert), generationConfig: { stopSequences: this.parameters?.Stop, @@ -98,6 +104,8 @@ export class GeminiAdapter extends BaseAdapter { } function convert(message: Message): Content { + // @ts-ignore + message.role = message.role == "assistant" ? "model" : message.role; if (typeof message.content === "string") { return { role: message.role, diff --git a/src/adapters/index.ts b/src/adapters/index.ts index 62c5bea..01ae6c1 100644 --- a/src/adapters/index.ts +++ b/src/adapters/index.ts @@ -34,6 +34,7 @@ export class AdapterSwitcher { ) { this.adapters = []; for (const adapter of adapterConfig) { + if (!adapter.Enabled) continue; this.adapters.push(getAdapter(adapter, parameters)); } } diff --git a/src/config.ts b/src/config.ts index 7b77c27..26dd84f 100644 --- a/src/config.ts +++ b/src/config.ts @@ -144,7 +144,7 @@ export const Config: Schema = Schema.object({ .description("立即回复 @ 消息的概率"), Filter: Schema.array(Schema.string()) .default(["你是", "You are", "吧", "呢"]) - .description("过滤的词汇(防止被调皮群友/机器人自己搞傻)"), + .description("过滤的词汇(防止被调皮群友/机器人自己搞傻)可以使用正则表达式"), }).description("记忆槽位设置"), API: AdapterConfig, diff --git a/src/utils/toolkit.ts b/src/utils/toolkit.ts index 4de8ef3..4000c67 100644 --- a/src/utils/toolkit.ts +++ b/src/utils/toolkit.ts @@ -4,6 +4,7 @@ import { Element, Session } from "koishi"; import { Mutex } from 'async-mutex'; import { Config } from "../config"; +import { isEmpty } from "./string"; export function isChannelAllowed(slotContains: string[], channelId: string): boolean { @@ -36,6 +37,7 @@ export function containsFilter(content: string, FilterList: string[]): boolean { //return re.test(content); for (const filter of FilterList) { + if (isEmpty(filter)) continue; let regex = new RegExp(filter, "gi"); if (regex.test(content)) return true; @@ -273,7 +275,7 @@ export function downloadFile(url, filePath, debug) { }; /** - * + * * @param date * @returns 2024年12月3日星期二17:34:00 */