Skip to content

Commit

Permalink
release: channel feature
Browse files Browse the repository at this point in the history
  • Loading branch information
zmh-program committed Dec 3, 2023
1 parent e6ca4c7 commit b2c2c2d
Show file tree
Hide file tree
Showing 11 changed files with 84 additions and 42 deletions.
2 changes: 1 addition & 1 deletion app/src/api/conversation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ export class Conversation {

if (id === -1 && this.idx === -1) {
sharingEvent.bind(({ refer, data }) => {
console.log(
console.debug(
`[conversation] load from sharing event (ref: ${refer}, length: ${data.length})`,
);

Expand Down
16 changes: 10 additions & 6 deletions app/src/components/admin/assemblies/ChannelEditor.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -192,17 +192,21 @@ function ChannelEditor({ display, id, setEnabled }: ChannelEditorProps) {
}, [edit.models]);
const enabled = useMemo(() => validator(edit), [edit]);

function close(clear?: boolean) {
if (clear) dispatch({ type: "clear" });
setEnabled(false);
}

async function post() {
const data = handler(edit);
console.debug(`[channel] preflight channel data`, data);

const resp =
id === -1 ? await createChannel(data) : await updateChannel(id, data);
toastState(toast, t, resp as ChannelCommonResponse);
toastState(toast, t, resp as ChannelCommonResponse, true);

if (resp.status) {
dispatch({ type: "clear" });
setEnabled(false);
close(true);
}
}

Expand All @@ -211,7 +215,6 @@ function ChannelEditor({ display, id, setEnabled }: ChannelEditorProps) {
else {
const resp = await getChannel(id);
toastState(toast, t, resp as ChannelCommonResponse);
console.log(resp);
if (resp.data) dispatch({ type: "set", value: resp.data });
}
}, [id]);
Expand Down Expand Up @@ -309,7 +312,8 @@ function ChannelEditor({ display, id, setEnabled }: ChannelEditorProps) {
</DropdownMenu>
<CustomAction
onPost={(model) => {
dispatch({ type: "add-model", value: model });
const models = model.split(" ");
dispatch({ type: "add-models", value: models });
}}
/>
<Button
Expand Down Expand Up @@ -406,7 +410,7 @@ function ChannelEditor({ display, id, setEnabled }: ChannelEditorProps) {
</div>
<div className={`mt-4 flex flex-row w-full h-max pr-2`}>
<div className={`grow`} />
<Button variant={`outline`} onClick={() => setEnabled(false)}>
<Button variant={`outline`} onClick={() => close()}>
{t("cancel")}
</Button>
<Button className={`ml-2`} onClick={post} disabled={!enabled}>
Expand Down
6 changes: 5 additions & 1 deletion app/src/components/admin/assemblies/ChannelTable.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import { Badge } from "@/components/ui/badge.tsx";
import { Check, Plus, RotateCw, Settings2, Trash, X } from "lucide-react";
import { Button } from "@/components/ui/button.tsx";
import OperationAction from "@/components/OperationAction.tsx";
import { useMemo, useState } from "react";
import { useEffect, useMemo, useState } from "react";
import { Channel, getChannelType, toastState } from "@/admin/channel.ts";
import { useTranslation } from "react-i18next";
import { useEffectAsync } from "@/utils/hook.ts";
Expand Down Expand Up @@ -54,6 +54,10 @@ function ChannelTable({ display, setId, setEnabled }: ChannelTableProps) {
useEffectAsync(refresh, []);
useEffectAsync(refresh, [display]);

useEffect(() => {
if (display) setId(-1);
}, [display]);

return (
display && (
<div className={`channel-table`}>
Expand Down
2 changes: 1 addition & 1 deletion app/src/conf.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import {
} from "@/utils/env.ts";
import { getMemory } from "@/utils/memory.ts";

export const version = "3.7.0";
export const version = "3.7.1";
export const dev: boolean = getDev();
export const deploy: boolean = true;
export let rest_api: string = getRestApi(deploy);
Expand Down
8 changes: 5 additions & 3 deletions app/src/i18n.ts
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ const resources = {
create: "创建渠道",
"search-model": "搜索模型",
"fill-template-models": "填入模板模型 ({{number}} 个)",
"add-custom-model": "添加自定义模型",
"add-custom-model": "添加自定义模型 (多个模型用空格分隔)",
"add-model": "添加模型",
"clear-models": "清空全部模型",
},
Expand Down Expand Up @@ -818,7 +818,8 @@ const resources = {
create: "Create Channel",
"search-model": "Search Model",
"fill-template-models": "Fill Template Models ({{number}})",
"add-custom-model": "Add Custom Model",
"add-custom-model":
"Add Custom Model (Multiple models are separated by spaces)",
"add-model": "Add Model",
"clear-models": "Clear All Models",
},
Expand Down Expand Up @@ -1241,7 +1242,8 @@ const resources = {
create: "Создать канал",
"search-model": "Поиск по имени модели",
"fill-template-models": "Заполнить шаблонные модели ({{number}})",
"add-custom-model": "Добавить пользовательскую модель",
"add-custom-model":
"Добавить пользовательскую модель (несколько моделей разделяются пробелами)",
"add-model": "Добавить модель",
"clear-models": "Очистить все модели",
},
Expand Down
3 changes: 1 addition & 2 deletions channel/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (c *Channel) GetMapper() string {

func (c *Channel) GetReflect() map[string]string {
if c.Reflect == nil {
var reflect map[string]string
reflect := make(map[string]string)
arr := strings.Split(c.GetMapper(), "\n")
for _, item := range arr {
pair := strings.Split(item, ">")
Expand Down Expand Up @@ -126,7 +126,6 @@ func (c *Channel) GetHitModels() []string {

c.HitModels = &res
}

return *c.HitModels
}

Expand Down
4 changes: 2 additions & 2 deletions channel/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func (m *Manager) Load() {
// init support models
m.Models = []string{}
for _, channel := range m.GetActiveSequence() {
for _, model := range channel.GetModels() {
for _, model := range channel.GetHitModels() {
if !utils.Contains(model, m.Models) {
m.Models = append(m.Models, model)
}
Expand All @@ -46,7 +46,7 @@ func (m *Manager) Load() {
for _, model := range m.Models {
var seq Sequence
for _, channel := range m.GetActiveSequence() {
if utils.Contains(model, channel.GetModels()) {
if channel.IsHit(model) {
seq = append(seq, channel)
}
}
Expand Down
2 changes: 1 addition & 1 deletion channel/sequence.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ func (s *Sequence) Len() int {
}

func (s *Sequence) Less(i, j int) bool {
return (*s)[i].GetPriority() < (*s)[j].GetPriority()
return (*s)[i].GetPriority() > (*s)[j].GetPriority()
}

func (s *Sequence) Swap(i, j int) {
Expand Down
4 changes: 2 additions & 2 deletions channel/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ type Channel struct {
Endpoint string `json:"endpoint" mapstructure:"endpoint"`
Mapper string `json:"mapper" mapstructure:"mapper"`
State bool `json:"state" mapstructure:"state"`
Reflect *map[string]string `json:"reflect" mapstructure:"reflect"`
HitModels *[]string `json:"hit_models" mapstructure:"hit_models"`
Reflect *map[string]string `json:"-"`
HitModels *[]string `json:"-"`
}

type Sequence []*Channel
Expand Down
72 changes: 49 additions & 23 deletions manager/transhipment.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,37 +68,58 @@ type TranshipmentStreamResponse struct {
Choices []ChoiceDelta `json:"choices"`
Usage Usage `json:"usage"`
Quota *float32 `json:"quota,omitempty"`
Error error `json:"error,omitempty"`
}

type TranshipmentErrorResponse struct {
Error TranshipmentError `json:"error"`
}

type TranshipmentError struct {
Message string `json:"message"`
Type string `json:"type"`
}

func ModelAPI(c *gin.Context) {
c.JSON(http.StatusOK, channel.ManagerInstance.GetModels())
}

func sendErrorResponse(c *gin.Context, err error, types ...string) {
var errType string
if len(types) > 0 {
errType = types[0]
} else {
errType = "chatnio_api_error"
}

c.JSON(http.StatusServiceUnavailable, TranshipmentErrorResponse{
Error: TranshipmentError{
Message: err.Error(),
Type: errType,
},
})
}

func abortWithErrorResponse(c *gin.Context, err error, types ...string) {
sendErrorResponse(c, err, types...)
c.Abort()
}

func TranshipmentAPI(c *gin.Context) {
username := utils.GetUserFromContext(c)
if username == "" {
c.AbortWithStatusJSON(403, gin.H{
"code": 403,
"message": "Access denied. Please provide correct api key.",
})
abortWithErrorResponse(c, fmt.Errorf("access denied for invalid api key"), "authentication_error")
return
}

if utils.GetAgentFromContext(c) != "api" {
c.AbortWithStatusJSON(403, gin.H{
"code": 403,
"message": "Access denied. Please provide correct api key.",
})
abortWithErrorResponse(c, fmt.Errorf("access denied for invalid agent"), "authentication_error")
return
}

var form TranshipmentForm
if err := c.ShouldBindJSON(&form); err != nil {
c.JSON(400, gin.H{
"status": false,
"error": "invalid request body",
"reason": err.Error(),
})
abortWithErrorResponse(c, fmt.Errorf("invalid request body: %s", err.Error()), "invalid_request_error")
return
}

Expand All @@ -124,11 +145,7 @@ func TranshipmentAPI(c *gin.Context) {

check, plan := auth.CanEnableModelWithSubscription(db, cache, user, form.Model)
if !check {
c.JSON(http.StatusForbidden, gin.H{
"status": false,
"error": "quota exceeded",
"reason": "not enough quota to use this model",
})
sendErrorResponse(c, fmt.Errorf("quota exceeded"), "quota_exceeded_error")
return
}

Expand Down Expand Up @@ -171,6 +188,9 @@ func sendTranshipmentResponse(c *gin.Context, form TranshipmentForm, id string,
if err != nil {
auth.RevertSubscriptionUsage(db, cache, user, form.Model)
globals.Warn(fmt.Sprintf("error from chat request api: %s (instance: %s, client: %s)", err, form.Model, c.ClientIP()))

sendErrorResponse(c, err)
return
}

CollectQuota(c, user, buffer, plan)
Expand All @@ -195,7 +215,7 @@ func sendTranshipmentResponse(c *gin.Context, form TranshipmentForm, id string,
})
}

func getStreamTranshipmentForm(id string, created int64, form TranshipmentForm, data string, buffer *utils.Buffer, end bool) TranshipmentStreamResponse {
func getStreamTranshipmentForm(id string, created int64, form TranshipmentForm, data string, buffer *utils.Buffer, end bool, err error) TranshipmentStreamResponse {
return TranshipmentStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", id),
Object: "chat.completion.chunk",
Expand All @@ -217,6 +237,7 @@ func getStreamTranshipmentForm(id string, created int64, form TranshipmentForm,
TotalTokens: utils.MultiF(end, func() int { return buffer.CountToken() }, 0),
},
Quota: utils.Multi[*float32](form.Official, nil, utils.ToPtr(buffer.GetQuota())),
Error: err,
}
}

Expand All @@ -228,27 +249,32 @@ func sendStreamTranshipmentResponse(c *gin.Context, form TranshipmentForm, id st
go func() {
buffer := utils.NewBuffer(form.Model, form.Messages)
err := channel.NewChatRequest(GetProps(form, buffer, plan), func(data string) error {
partial <- getStreamTranshipmentForm(id, created, form, buffer.Write(data), buffer, false)
partial <- getStreamTranshipmentForm(id, created, form, buffer.Write(data), buffer, false, nil)
return nil
})

admin.AnalysisRequest(form.Model, buffer, err)
if err != nil {
auth.RevertSubscriptionUsage(db, cache, user, form.Model)
partial <- getStreamTranshipmentForm(id, created, form, fmt.Sprintf("Error: %s", err.Error()), buffer, true)
CollectQuota(c, user, buffer, plan)
globals.Warn(fmt.Sprintf("error from chat request api: %s (instance: %s, client: %s)", err.Error(), form.Model, c.ClientIP()))
partial <- getStreamTranshipmentForm(id, created, form, err.Error(), buffer, true, err)
close(partial)
return
}

partial <- getStreamTranshipmentForm(id, created, form, "", buffer, true)
partial <- getStreamTranshipmentForm(id, created, form, "", buffer, true, nil)
CollectQuota(c, user, buffer, plan)
close(partial)
return
}()

c.Stream(func(w io.Writer) bool {
if resp, ok := <-partial; ok {
if resp.Error != nil {
sendErrorResponse(c, resp.Error)
return false
}

c.Render(-1, utils.NewEvent(resp))
return true
}
Expand Down
7 changes: 7 additions & 0 deletions utils/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,3 +232,10 @@ func GetError(err error) string {
}
return ""
}

func GetIndexSafe[T any](arr []T, index int) *T {
if index >= len(arr) {
return nil
}
return &arr[index]
}

0 comments on commit b2c2c2d

Please sign in to comment.