diff --git a/service/aiproxy/common/gin.go b/service/aiproxy/common/gin.go index eb2c95f0123..e62b4eed63c 100644 --- a/service/aiproxy/common/gin.go +++ b/service/aiproxy/common/gin.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net/http" + "strings" "github.com/gin-gonic/gin" json "github.com/json-iterator/go" @@ -40,6 +41,12 @@ func (l *LimitedReader) Read(p []byte) (n int, err error) { } func GetRequestBody(req *http.Request) ([]byte, error) { + contentType := req.Header.Get("Content-Type") + if contentType == "application/x-www-form-urlencoded" || + strings.HasPrefix(contentType, "multipart/form-data") { + return nil, nil + } + requestBody := req.Context().Value(RequestBodyKey{}) if requestBody != nil { return requestBody.([]byte), nil diff --git a/service/aiproxy/controller/channel-test.go b/service/aiproxy/controller/channel-test.go index 4e4a6cfe176..450145a2cc9 100644 --- a/service/aiproxy/controller/channel-test.go +++ b/service/aiproxy/controller/channel-test.go @@ -69,8 +69,8 @@ func testSingleModel(channel *model.Channel, modelName string) (*model.ChannelTe return channel.UpdateModelTest( meta.RequestAt, - meta.OriginModelName, - meta.ActualModelName, + meta.OriginModel, + meta.ActualModel, meta.Mode, time.Since(meta.RequestAt).Seconds(), success, diff --git a/service/aiproxy/controller/relay.go b/service/aiproxy/controller/relay.go index 3edac7b44a8..edc3bef017a 100644 --- a/service/aiproxy/controller/relay.go +++ b/service/aiproxy/controller/relay.go @@ -57,7 +57,7 @@ func RelayHelper(meta *meta.Meta, c *gin.Context, relayController RelayControlle if err == nil { if err := monitor.AddRequest( c.Request.Context(), - meta.OriginModelName, + meta.OriginModel, int64(meta.Channel.ID), false, ); err != nil { @@ -68,7 +68,7 @@ func RelayHelper(meta *meta.Meta, c *gin.Context, relayController RelayControlle if shouldRetry(c, err.StatusCode) { if err := monitor.AddRequest( c.Request.Context(), - meta.OriginModelName, + meta.OriginModel, int64(meta.Channel.ID), true, ); err != nil { diff --git a/service/aiproxy/middleware/auth.go b/service/aiproxy/middleware/auth.go index 3bcde94e124..b52acd57ebe 100644 --- a/service/aiproxy/middleware/auth.go +++ b/service/aiproxy/middleware/auth.go @@ -102,8 +102,8 @@ func SetLogFieldsFromMeta(m *meta.Meta, fields logrus.Fields) { SetLogRequestIDField(fields, m.RequestID) SetLogModeField(fields, m.Mode) - SetLogModelFields(fields, m.OriginModelName) - SetLogActualModelFields(fields, m.ActualModelName) + SetLogModelFields(fields, m.OriginModel) + SetLogActualModelFields(fields, m.ActualModel) if m.IsChannelTest { SetLogIsChannelTestField(fields, true) diff --git a/service/aiproxy/relay/adaptor/ali/adaptor.go b/service/aiproxy/relay/adaptor/ali/adaptor.go index a8837295c9c..76ebfa8fa68 100644 --- a/service/aiproxy/relay/adaptor/ali/adaptor.go +++ b/service/aiproxy/relay/adaptor/ali/adaptor.go @@ -76,7 +76,7 @@ func (a *Adaptor) DoRequest(meta *meta.Meta, _ *gin.Context, req *http.Request) case relaymode.AudioTranscription: return STTDoRequest(meta, req) case relaymode.ChatCompletions: - if meta.IsChannelTest && strings.Contains(meta.ActualModelName, "-ocr") { + if meta.IsChannelTest && strings.Contains(meta.ActualModel, "-ocr") { return &http.Response{ StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(nil)), @@ -95,7 +95,7 @@ func (a *Adaptor) DoResponse(meta *meta.Meta, c *gin.Context, resp *http.Respons case relaymode.ImagesGenerations: usage, err = ImageHandler(meta, c, resp) case relaymode.ChatCompletions: - if meta.IsChannelTest && strings.Contains(meta.ActualModelName, "-ocr") { + if meta.IsChannelTest && strings.Contains(meta.ActualModel, "-ocr") { return nil, nil } usage, err = openai.DoResponse(meta, c, resp) diff --git a/service/aiproxy/relay/adaptor/ali/embeddings.go b/service/aiproxy/relay/adaptor/ali/embeddings.go index c6863d536b4..66f4569cc18 100644 --- a/service/aiproxy/relay/adaptor/ali/embeddings.go +++ b/service/aiproxy/relay/adaptor/ali/embeddings.go @@ -21,7 +21,7 @@ func ConvertEmbeddingsRequest(meta *meta.Meta, req *http.Request) (string, http. if err != nil { return "", nil, nil, err } - reqMap["model"] = meta.ActualModelName + reqMap["model"] = meta.ActualModel input, ok := reqMap["input"] if !ok { return "", nil, nil, errors.New("input is required") @@ -56,7 +56,7 @@ func embeddingResponse2OpenAI(meta *meta.Meta, response *EmbeddingResponse) *ope openAIEmbeddingResponse := openai.EmbeddingResponse{ Object: "list", Data: make([]*openai.EmbeddingResponseItem, 0, 1), - Model: meta.OriginModelName, + Model: meta.OriginModel, Usage: response.Usage, } diff --git a/service/aiproxy/relay/adaptor/ali/image.go b/service/aiproxy/relay/adaptor/ali/image.go index 21d1574851c..4a8c12e1da2 100644 --- a/service/aiproxy/relay/adaptor/ali/image.go +++ b/service/aiproxy/relay/adaptor/ali/image.go @@ -27,7 +27,7 @@ func ConvertImageRequest(meta *meta.Meta, req *http.Request) (string, http.Heade if err != nil { return "", nil, nil, err } - request.Model = meta.ActualModelName + request.Model = meta.ActualModel var imageRequest ImageRequest imageRequest.Input.Prompt = request.Prompt diff --git a/service/aiproxy/relay/adaptor/ali/rerank.go b/service/aiproxy/relay/adaptor/ali/rerank.go index ef6d5fb6fab..ab739bc8ac9 100644 --- a/service/aiproxy/relay/adaptor/ali/rerank.go +++ b/service/aiproxy/relay/adaptor/ali/rerank.go @@ -32,7 +32,7 @@ func ConvertRerankRequest(meta *meta.Meta, req *http.Request) (string, http.Head if err != nil { return "", nil, nil, err } - reqMap["model"] = meta.ActualModelName + reqMap["model"] = meta.ActualModel reqMap["input"] = map[string]any{ "query": reqMap["query"], "documents": reqMap["documents"], diff --git a/service/aiproxy/relay/adaptor/ali/stt-realtime.go b/service/aiproxy/relay/adaptor/ali/stt-realtime.go index 1d45920f7b5..46ca83892ae 100644 --- a/service/aiproxy/relay/adaptor/ali/stt-realtime.go +++ b/service/aiproxy/relay/adaptor/ali/stt-realtime.go @@ -89,7 +89,7 @@ func ConvertSTTRequest(meta *meta.Meta, request *http.Request) (string, http.Hea TaskID: uuid.New().String(), }, Payload: STTPayload{ - Model: meta.ActualModelName, + Model: meta.ActualModel, Task: "asr", TaskGroup: "audio", Function: "recognition", diff --git a/service/aiproxy/relay/adaptor/ali/tts.go b/service/aiproxy/relay/adaptor/ali/tts.go index 3c685a081c6..642e76fe8ab 100644 --- a/service/aiproxy/relay/adaptor/ali/tts.go +++ b/service/aiproxy/relay/adaptor/ali/tts.go @@ -107,7 +107,7 @@ func ConvertTTSRequest(meta *meta.Meta, req *http.Request) (string, http.Header, if ok { sampleRate = int(sampleRateI) } - request.Model = meta.ActualModelName + request.Model = meta.ActualModel if strings.HasPrefix(request.Model, "sambert-v") { voice := request.Voice diff --git a/service/aiproxy/relay/adaptor/anthropic/adaptor.go b/service/aiproxy/relay/adaptor/anthropic/adaptor.go index 73601212833..e4f035793a5 100644 --- a/service/aiproxy/relay/adaptor/anthropic/adaptor.go +++ b/service/aiproxy/relay/adaptor/anthropic/adaptor.go @@ -37,7 +37,7 @@ func (a *Adaptor) SetupRequestHeader(meta *meta.Meta, c *gin.Context, req *http. // https://x.com/alexalbert__/status/1812921642143900036 // claude-3-5-sonnet can support 8k context - if strings.HasPrefix(meta.ActualModelName, "claude-3-5-sonnet") { + if strings.HasPrefix(meta.ActualModel, "claude-3-5-sonnet") { req.Header.Set("Anthropic-Beta", "max-tokens-3-5-sonnet-2024-07-15") } diff --git a/service/aiproxy/relay/adaptor/anthropic/main.go b/service/aiproxy/relay/adaptor/anthropic/main.go index 169a6e0df0d..50fe177a9ab 100644 --- a/service/aiproxy/relay/adaptor/anthropic/main.go +++ b/service/aiproxy/relay/adaptor/anthropic/main.go @@ -44,7 +44,7 @@ func ConvertRequest(meta *meta.Meta, req *http.Request) (*Request, error) { if err != nil { return nil, err } - textRequest.Model = meta.ActualModelName + textRequest.Model = meta.ActualModel meta.Set("stream", textRequest.Stream) claudeTools := make([]Tool, 0, len(textRequest.Tools)) @@ -372,7 +372,7 @@ func Handler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model.Error }, nil } fullTextResponse := ResponseClaude2OpenAI(&claudeResponse) - fullTextResponse.Model = meta.OriginModelName + fullTextResponse.Model = meta.OriginModel usage := model.Usage{ PromptTokens: claudeResponse.Usage.InputTokens, CompletionTokens: claudeResponse.Usage.OutputTokens, diff --git a/service/aiproxy/relay/adaptor/aws/adaptor.go b/service/aiproxy/relay/adaptor/aws/adaptor.go index cb79e41cab1..2c53495b039 100644 --- a/service/aiproxy/relay/adaptor/aws/adaptor.go +++ b/service/aiproxy/relay/adaptor/aws/adaptor.go @@ -18,7 +18,7 @@ var _ adaptor.Adaptor = new(Adaptor) type Adaptor struct{} func (a *Adaptor) ConvertRequest(meta *meta.Meta, req *http.Request) (string, http.Header, io.Reader, error) { - adaptor := GetAdaptor(meta.ActualModelName) + adaptor := GetAdaptor(meta.ActualModel) if adaptor == nil { return "", nil, nil, errors.New("adaptor not found") } diff --git a/service/aiproxy/relay/adaptor/aws/claude/main.go b/service/aiproxy/relay/adaptor/aws/claude/main.go index 7c2106981df..7a9e19f1d61 100644 --- a/service/aiproxy/relay/adaptor/aws/claude/main.go +++ b/service/aiproxy/relay/adaptor/aws/claude/main.go @@ -93,7 +93,7 @@ func awsModelID(requestModel string) (string, error) { } func Handler(meta *meta.Meta, c *gin.Context) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { - awsModelID, err := awsModelID(meta.ActualModelName) + awsModelID, err := awsModelID(meta.ActualModel) if err != nil { return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil } @@ -138,7 +138,7 @@ func Handler(meta *meta.Meta, c *gin.Context) (*relaymodel.ErrorWithStatusCode, } openaiResp := anthropic.ResponseClaude2OpenAI(claudeResponse) - openaiResp.Model = meta.OriginModelName + openaiResp.Model = meta.OriginModel usage := relaymodel.Usage{ PromptTokens: claudeResponse.Usage.InputTokens, CompletionTokens: claudeResponse.Usage.OutputTokens, @@ -153,8 +153,8 @@ func Handler(meta *meta.Meta, c *gin.Context) (*relaymodel.ErrorWithStatusCode, func StreamHandler(meta *meta.Meta, c *gin.Context) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { log := middleware.GetLogger(c) createdTime := time.Now().Unix() - originModelName := meta.OriginModelName - awsModelID, err := awsModelID(meta.ActualModelName) + originModelName := meta.OriginModel + awsModelID, err := awsModelID(meta.ActualModel) if err != nil { return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil } diff --git a/service/aiproxy/relay/adaptor/aws/llama3/adapter.go b/service/aiproxy/relay/adaptor/aws/llama3/adapter.go index bca6e0d8561..ecfd2ee18cb 100644 --- a/service/aiproxy/relay/adaptor/aws/llama3/adapter.go +++ b/service/aiproxy/relay/adaptor/aws/llama3/adapter.go @@ -24,7 +24,7 @@ func (a *Adaptor) ConvertRequest(meta *meta.Meta, req *http.Request) (string, ht if err != nil { return "", nil, nil, err } - request.Model = meta.ActualModelName + request.Model = meta.ActualModel meta.Set("stream", request.Stream) llamaReq := ConvertRequest(request) meta.Set(ConvertedRequest, llamaReq) diff --git a/service/aiproxy/relay/adaptor/aws/llama3/main.go b/service/aiproxy/relay/adaptor/aws/llama3/main.go index 436aac4fdfb..d353096178f 100644 --- a/service/aiproxy/relay/adaptor/aws/llama3/main.go +++ b/service/aiproxy/relay/adaptor/aws/llama3/main.go @@ -94,7 +94,7 @@ func ConvertRequest(textRequest *relaymodel.GeneralOpenAIRequest) *Request { } func Handler(meta *meta.Meta, c *gin.Context) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { - awsModelID, err := awsModelID(meta.ActualModelName) + awsModelID, err := awsModelID(meta.ActualModel) if err != nil { return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil } @@ -132,7 +132,7 @@ func Handler(meta *meta.Meta, c *gin.Context) (*relaymodel.ErrorWithStatusCode, } openaiResp := ResponseLlama2OpenAI(&llamaResponse) - openaiResp.Model = meta.OriginModelName + openaiResp.Model = meta.OriginModel usage := relaymodel.Usage{ PromptTokens: llamaResponse.PromptTokenCount, CompletionTokens: llamaResponse.GenerationTokenCount, @@ -171,7 +171,7 @@ func StreamHandler(meta *meta.Meta, c *gin.Context) (*relaymodel.ErrorWithStatus log := middleware.GetLogger(c) createdTime := time.Now().Unix() - awsModelID, err := awsModelID(meta.ActualModelName) + awsModelID, err := awsModelID(meta.ActualModel) if err != nil { return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil } @@ -231,7 +231,7 @@ func StreamHandler(meta *meta.Meta, c *gin.Context) (*relaymodel.ErrorWithStatus } response := StreamResponseLlama2OpenAI(&llamaResp) response.ID = "chatcmpl-" + random.GetUUID() - response.Model = meta.OriginModelName + response.Model = meta.OriginModel response.Created = createdTime err = render.ObjectData(c, response) if err != nil { diff --git a/service/aiproxy/relay/adaptor/azure/constants.go b/service/aiproxy/relay/adaptor/azure/constants.go index 19bc3de87bf..257f94a9411 100644 --- a/service/aiproxy/relay/adaptor/azure/constants.go +++ b/service/aiproxy/relay/adaptor/azure/constants.go @@ -20,7 +20,7 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { if err != nil { return "", err } - model := strings.ReplaceAll(meta.ActualModelName, ".", "") + model := strings.ReplaceAll(meta.ActualModel, ".", "") switch meta.Mode { case relaymode.ImagesGenerations: // https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api diff --git a/service/aiproxy/relay/adaptor/baidu/adaptor.go b/service/aiproxy/relay/adaptor/baidu/adaptor.go index a68633c6363..0605eef7cc7 100644 --- a/service/aiproxy/relay/adaptor/baidu/adaptor.go +++ b/service/aiproxy/relay/adaptor/baidu/adaptor.go @@ -64,9 +64,9 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { pathSuffix = "text2image" } - modelEndpoint, ok := modelEndpointMap[meta.ActualModelName] + modelEndpoint, ok := modelEndpointMap[meta.ActualModel] if !ok { - modelEndpoint = strings.ToLower(meta.ActualModelName) + modelEndpoint = strings.ToLower(meta.ActualModel) } // Construct full URL diff --git a/service/aiproxy/relay/adaptor/baidu/embeddings.go b/service/aiproxy/relay/adaptor/baidu/embeddings.go index 9071c871499..c07582d47ee 100644 --- a/service/aiproxy/relay/adaptor/baidu/embeddings.go +++ b/service/aiproxy/relay/adaptor/baidu/embeddings.go @@ -41,7 +41,7 @@ func EmbeddingsHandler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*r if err != nil { return &baiduResponse.Usage, openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) } - respMap["model"] = meta.OriginModelName + respMap["model"] = meta.OriginModel respMap["object"] = "list" data, err := json.Marshal(respMap) diff --git a/service/aiproxy/relay/adaptor/baidu/main.go b/service/aiproxy/relay/adaptor/baidu/main.go index 46d691f9298..34257469170 100644 --- a/service/aiproxy/relay/adaptor/baidu/main.go +++ b/service/aiproxy/relay/adaptor/baidu/main.go @@ -46,7 +46,7 @@ func ConvertRequest(meta *meta.Meta, req *http.Request) (string, http.Header, io if err != nil { return "", nil, nil, err } - request.Model = meta.ActualModelName + request.Model = meta.ActualModel baiduRequest := ChatRequest{ Messages: request.Messages, Temperature: request.Temperature, @@ -117,7 +117,7 @@ func streamResponseBaidu2OpenAI(meta *meta.Meta, baiduResponse *ChatStreamRespon ID: baiduResponse.ID, Object: "chat.completion.chunk", Created: baiduResponse.Created, - Model: meta.OriginModelName, + Model: meta.OriginModel, Choices: []*openai.ChatCompletionsStreamResponseChoice{&choice}, Usage: baiduResponse.Usage, } @@ -185,7 +185,7 @@ func Handler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model.Usage return nil, openai.ErrorWrapperWithMessage(baiduResponse.Error.ErrorMsg, "baidu_error_"+strconv.Itoa(baiduResponse.Error.ErrorCode), http.StatusInternalServerError) } fullTextResponse := responseBaidu2OpenAI(&baiduResponse) - fullTextResponse.Model = meta.OriginModelName + fullTextResponse.Model = meta.OriginModel jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { return nil, openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError) diff --git a/service/aiproxy/relay/adaptor/baiduv2/adaptor.go b/service/aiproxy/relay/adaptor/baiduv2/adaptor.go index 9f844488d9d..d0fb3ba0006 100644 --- a/service/aiproxy/relay/adaptor/baiduv2/adaptor.go +++ b/service/aiproxy/relay/adaptor/baiduv2/adaptor.go @@ -61,11 +61,11 @@ func (a *Adaptor) SetupRequestHeader(meta *meta.Meta, _ *gin.Context, req *http. func (a *Adaptor) ConvertRequest(meta *meta.Meta, req *http.Request) (string, http.Header, io.Reader, error) { switch meta.Mode { case relaymode.ChatCompletions: - actModel := meta.ActualModelName + actModel := meta.ActualModel v2Model := toV2ModelName(actModel) if v2Model != actModel { - meta.ActualModelName = v2Model - defer func() { meta.ActualModelName = actModel }() + meta.ActualModel = v2Model + defer func() { meta.ActualModel = actModel }() } return openai.ConvertRequest(meta, req) default: diff --git a/service/aiproxy/relay/adaptor/cloudflare/adaptor.go b/service/aiproxy/relay/adaptor/cloudflare/adaptor.go index cf7e84b0704..7680155fa76 100644 --- a/service/aiproxy/relay/adaptor/cloudflare/adaptor.go +++ b/service/aiproxy/relay/adaptor/cloudflare/adaptor.go @@ -43,9 +43,9 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { return urlPrefix + "/v1/embeddings", nil default: if isAIGateWay { - return fmt.Sprintf("%s/%s", urlPrefix, meta.ActualModelName), nil + return fmt.Sprintf("%s/%s", urlPrefix, meta.ActualModel), nil } - return fmt.Sprintf("%s/run/%s", urlPrefix, meta.ActualModelName), nil + return fmt.Sprintf("%s/run/%s", urlPrefix, meta.ActualModel), nil } } diff --git a/service/aiproxy/relay/adaptor/cohere/adaptor.go b/service/aiproxy/relay/adaptor/cohere/adaptor.go index 7b1ddd3f967..df545466358 100644 --- a/service/aiproxy/relay/adaptor/cohere/adaptor.go +++ b/service/aiproxy/relay/adaptor/cohere/adaptor.go @@ -38,7 +38,7 @@ func (a *Adaptor) ConvertRequest(meta *meta.Meta, req *http.Request) (string, ht if err != nil { return "", nil, nil, err } - request.Model = meta.ActualModelName + request.Model = meta.ActualModel requestBody := ConvertRequest(request) if requestBody == nil { return "", nil, nil, errors.New("request body is nil") @@ -62,7 +62,7 @@ func (a *Adaptor) DoResponse(meta *meta.Meta, c *gin.Context, resp *http.Respons if utils.IsStreamResponse(resp) { err, usage = StreamHandler(c, resp) } else { - err, usage = Handler(c, resp, meta.InputTokens, meta.ActualModelName) + err, usage = Handler(c, resp, meta.InputTokens, meta.ActualModel) } } return diff --git a/service/aiproxy/relay/adaptor/coze/adaptor.go b/service/aiproxy/relay/adaptor/coze/adaptor.go index b1bf41a138d..528c5e62522 100644 --- a/service/aiproxy/relay/adaptor/coze/adaptor.go +++ b/service/aiproxy/relay/adaptor/coze/adaptor.go @@ -51,11 +51,11 @@ func (a *Adaptor) ConvertRequest(meta *meta.Meta, req *http.Request) (string, ht return "", nil, nil, err } request.User = userID - request.Model = meta.ActualModelName + request.Model = meta.ActualModel cozeRequest := Request{ Stream: request.Stream, User: request.User, - BotID: strings.TrimPrefix(meta.ActualModelName, "bot-"), + BotID: strings.TrimPrefix(meta.ActualModel, "bot-"), } for i, message := range request.Messages { if i == len(request.Messages)-1 { @@ -84,10 +84,10 @@ func (a *Adaptor) DoResponse(meta *meta.Meta, c *gin.Context, resp *http.Respons if utils.IsStreamResponse(resp) { err, responseText = StreamHandler(c, resp) } else { - err, responseText = Handler(c, resp, meta.InputTokens, meta.ActualModelName) + err, responseText = Handler(c, resp, meta.InputTokens, meta.ActualModel) } if responseText != nil { - usage = openai.ResponseText2Usage(*responseText, meta.ActualModelName, meta.InputTokens) + usage = openai.ResponseText2Usage(*responseText, meta.ActualModel, meta.InputTokens) } else { usage = &relaymodel.Usage{} } diff --git a/service/aiproxy/relay/adaptor/doubao/main.go b/service/aiproxy/relay/adaptor/doubao/main.go index 9d4b7eb42da..fca376cd62f 100644 --- a/service/aiproxy/relay/adaptor/doubao/main.go +++ b/service/aiproxy/relay/adaptor/doubao/main.go @@ -17,7 +17,7 @@ func GetRequestURL(meta *meta.Meta) (string, error) { } switch meta.Mode { case relaymode.ChatCompletions: - if strings.HasPrefix(meta.ActualModelName, "bot-") { + if strings.HasPrefix(meta.ActualModel, "bot-") { return u + "/api/v3/bots/chat/completions", nil } return u + "/api/v3/chat/completions", nil diff --git a/service/aiproxy/relay/adaptor/gemini/adaptor.go b/service/aiproxy/relay/adaptor/gemini/adaptor.go index 556eaf28dac..d5cb42928f4 100644 --- a/service/aiproxy/relay/adaptor/gemini/adaptor.go +++ b/service/aiproxy/relay/adaptor/gemini/adaptor.go @@ -27,10 +27,10 @@ func getRequestURL(meta *meta.Meta, action string) string { u = baseURL } version := "v1beta" - if _, ok := v1ModelMap[meta.ActualModelName]; ok { + if _, ok := v1ModelMap[meta.ActualModel]; ok { version = "v1" } - return fmt.Sprintf("%s/%s/models/%s:%s", u, version, meta.ActualModelName, action) + return fmt.Sprintf("%s/%s/models/%s:%s", u, version, meta.ActualModel, action) } func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { diff --git a/service/aiproxy/relay/adaptor/gemini/embeddings.go b/service/aiproxy/relay/adaptor/gemini/embeddings.go index 54ff391e4bd..403284e1db3 100644 --- a/service/aiproxy/relay/adaptor/gemini/embeddings.go +++ b/service/aiproxy/relay/adaptor/gemini/embeddings.go @@ -18,7 +18,7 @@ func ConvertEmbeddingRequest(meta *meta.Meta, req *http.Request) (string, http.H if err != nil { return "", nil, nil, err } - request.Model = meta.ActualModelName + request.Model = meta.ActualModel inputs := request.ParseInput() requests := make([]EmbeddingRequest, len(inputs)) diff --git a/service/aiproxy/relay/adaptor/gemini/main.go b/service/aiproxy/relay/adaptor/gemini/main.go index 77d9968da39..6922660c7f2 100644 --- a/service/aiproxy/relay/adaptor/gemini/main.go +++ b/service/aiproxy/relay/adaptor/gemini/main.go @@ -190,7 +190,7 @@ func ConvertRequest(meta *meta.Meta, req *http.Request) (string, http.Header, io return "", nil, nil, err } - textRequest.Model = meta.ActualModelName + textRequest.Model = meta.ActualModel meta.Set("stream", textRequest.Stream) systemContent, contents, err := buildContents(req.Context(), textRequest) @@ -317,7 +317,7 @@ func getToolCalls(candidate *ChatCandidate) []*model.Tool { func responseGeminiChat2OpenAI(meta *meta.Meta, response *ChatResponse) *openai.TextResponse { fullTextResponse := openai.TextResponse{ ID: "chatcmpl-" + random.GetUUID(), - Model: meta.OriginModelName, + Model: meta.OriginModel, Object: "chat.completion", Created: time.Now().Unix(), Choices: make([]*openai.TextResponseChoice, 0, len(response.Candidates)), @@ -356,7 +356,7 @@ func streamResponseGeminiChat2OpenAI(meta *meta.Meta, geminiResponse *ChatRespon response := &openai.ChatCompletionsStreamResponse{ ID: "chatcmpl-" + random.GetUUID(), Created: time.Now().Unix(), - Model: meta.OriginModelName, + Model: meta.OriginModel, Object: "chat.completion.chunk", Choices: make([]*openai.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates)), } @@ -444,7 +444,7 @@ func StreamHandler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model tokenCount, err := CountTokens(c.Request.Context(), meta, respContent) if err != nil { log.Error("count tokens failed: " + err.Error()) - usage.CompletionTokens = openai.CountTokenText(responseText.String(), meta.ActualModelName) + usage.CompletionTokens = openai.CountTokenText(responseText.String(), meta.ActualModel) } else { usage.CompletionTokens = tokenCount } @@ -466,7 +466,7 @@ func Handler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model.Usage return nil, openai.ErrorWrapperWithMessage("No candidates returned", "gemini_error", resp.StatusCode) } fullTextResponse := responseGeminiChat2OpenAI(meta, &geminiResponse) - fullTextResponse.Model = meta.OriginModelName + fullTextResponse.Model = meta.OriginModel respContent := []*ChatContent{} for _, candidate := range geminiResponse.Candidates { respContent = append(respContent, &candidate.Content) @@ -478,7 +478,7 @@ func Handler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model.Usage tokenCount, err := CountTokens(c.Request.Context(), meta, respContent) if err != nil { log.Error("count tokens failed: " + err.Error()) - usage.CompletionTokens = openai.CountTokenText(geminiResponse.GetResponseText(), meta.ActualModelName) + usage.CompletionTokens = openai.CountTokenText(geminiResponse.GetResponseText(), meta.ActualModel) } else { usage.CompletionTokens = tokenCount } diff --git a/service/aiproxy/relay/adaptor/minimax/tts.go b/service/aiproxy/relay/adaptor/minimax/tts.go index 2dd69d9ebb7..8663758b387 100644 --- a/service/aiproxy/relay/adaptor/minimax/tts.go +++ b/service/aiproxy/relay/adaptor/minimax/tts.go @@ -25,7 +25,7 @@ func ConvertTTSRequest(meta *meta.Meta, req *http.Request) (string, http.Header, return "", nil, nil, err } - reqMap["model"] = meta.ActualModelName + reqMap["model"] = meta.ActualModel reqMap["text"] = reqMap["input"] delete(reqMap, "input") diff --git a/service/aiproxy/relay/adaptor/ollama/main.go b/service/aiproxy/relay/adaptor/ollama/main.go index 0b16a837d11..f49c868d3f7 100644 --- a/service/aiproxy/relay/adaptor/ollama/main.go +++ b/service/aiproxy/relay/adaptor/ollama/main.go @@ -29,7 +29,7 @@ func ConvertRequest(meta *meta.Meta, req *http.Request) (string, http.Header, io if err != nil { return "", nil, nil, err } - request.Model = meta.ActualModelName + request.Model = meta.ActualModel ollamaRequest := ChatRequest{ Model: request.Model, @@ -180,7 +180,7 @@ func ConvertEmbeddingRequest(meta *meta.Meta, req *http.Request) (string, http.H if err != nil { return "", nil, nil, err } - request.Model = meta.ActualModelName + request.Model = meta.ActualModel data, err := json.Marshal(&EmbeddingRequest{ Model: request.Model, Input: request.ParseInput(), diff --git a/service/aiproxy/relay/adaptor/openai/adaptor.go b/service/aiproxy/relay/adaptor/openai/adaptor.go index cb2fbee34b7..1d0155ab393 100644 --- a/service/aiproxy/relay/adaptor/openai/adaptor.go +++ b/service/aiproxy/relay/adaptor/openai/adaptor.go @@ -141,7 +141,7 @@ func ConvertTextRequest(meta *meta.Meta, req *http.Request) (string, http.Header } } - reqMap["model"] = meta.ActualModelName + reqMap["model"] = meta.ActualModel jsonData, err := json.Marshal(reqMap) if err != nil { return "", nil, nil, err diff --git a/service/aiproxy/relay/adaptor/openai/embeddings.go b/service/aiproxy/relay/adaptor/openai/embeddings.go index b368fe87597..f1851aa4a8c 100644 --- a/service/aiproxy/relay/adaptor/openai/embeddings.go +++ b/service/aiproxy/relay/adaptor/openai/embeddings.go @@ -20,7 +20,7 @@ func ConvertEmbeddingsRequest(meta *meta.Meta, req *http.Request) (string, http. return "", nil, nil, err } - reqMap["model"] = meta.ActualModelName + reqMap["model"] = meta.ActualModel if meta.GetBool(MetaEmbeddingsPatchInputToSlices) { switch v := reqMap["input"].(type) { diff --git a/service/aiproxy/relay/adaptor/openai/image.go b/service/aiproxy/relay/adaptor/openai/image.go index 1614421987f..fa09364cded 100644 --- a/service/aiproxy/relay/adaptor/openai/image.go +++ b/service/aiproxy/relay/adaptor/openai/image.go @@ -22,7 +22,7 @@ func ConvertImageRequest(meta *meta.Meta, req *http.Request) (string, http.Heade } meta.Set(MetaResponseFormat, reqMap["response_format"]) - reqMap["model"] = meta.ActualModelName + reqMap["model"] = meta.ActualModel jsonData, err := json.Marshal(reqMap) if err != nil { return "", nil, nil, err diff --git a/service/aiproxy/relay/adaptor/openai/main.go b/service/aiproxy/relay/adaptor/openai/main.go index e466a8aefb7..0c91b314f4c 100644 --- a/service/aiproxy/relay/adaptor/openai/main.go +++ b/service/aiproxy/relay/adaptor/openai/main.go @@ -82,8 +82,8 @@ func StreamHandler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model log.Error("error unmarshalling stream response: " + err.Error()) continue } - if _, ok := respMap["model"]; ok && meta.OriginModelName != "" { - respMap["model"] = meta.OriginModelName + if _, ok := respMap["model"]; ok && meta.OriginModel != "" { + respMap["model"] = meta.OriginModel } err = render.ObjectData(c, respMap) if err != nil { @@ -111,7 +111,7 @@ func StreamHandler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model render.Done(c) if usage == nil || (usage.TotalTokens == 0 && responseText != "") { - usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.InputTokens) + usage = ResponseText2Usage(responseText, meta.ActualModel, meta.InputTokens) } if usage.TotalTokens != 0 && usage.PromptTokens == 0 { // some channels don't return prompt tokens & completion tokens @@ -143,7 +143,7 @@ func Handler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model.Usage if textResponse.Usage.TotalTokens == 0 || (textResponse.Usage.PromptTokens == 0 && textResponse.Usage.CompletionTokens == 0) { completionTokens := 0 for _, choice := range textResponse.Choices { - completionTokens += CountTokenText(choice.Message.StringContent(), meta.ActualModelName) + completionTokens += CountTokenText(choice.Message.StringContent(), meta.ActualModel) } textResponse.Usage = model.Usage{ PromptTokens: meta.InputTokens, @@ -158,8 +158,8 @@ func Handler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model.Usage return &textResponse.Usage, ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) } - if _, ok := respMap["model"]; ok && meta.OriginModelName != "" { - respMap["model"] = meta.OriginModelName + if _, ok := respMap["model"]; ok && meta.OriginModel != "" { + respMap["model"] = meta.OriginModel } newData, err := stdjson.Marshal(respMap) diff --git a/service/aiproxy/relay/adaptor/openai/moderations.go b/service/aiproxy/relay/adaptor/openai/moderations.go index b462a198193..14d2c4919ef 100644 --- a/service/aiproxy/relay/adaptor/openai/moderations.go +++ b/service/aiproxy/relay/adaptor/openai/moderations.go @@ -36,8 +36,8 @@ func ModerationsHandler(meta *meta.Meta, c *gin.Context, resp *http.Response) (* return nil, ErrorWrapperWithMessage(errorResp.Error.Message, errorResp.Error.Code, http.StatusBadRequest) } - if _, ok := respMap["model"]; ok && meta.OriginModelName != "" { - respMap["model"] = meta.OriginModelName + if _, ok := respMap["model"]; ok && meta.OriginModel != "" { + respMap["model"] = meta.OriginModel } usage := &model.Usage{ diff --git a/service/aiproxy/relay/adaptor/openai/rerank.go b/service/aiproxy/relay/adaptor/openai/rerank.go index f4774c7d847..653415a68f0 100644 --- a/service/aiproxy/relay/adaptor/openai/rerank.go +++ b/service/aiproxy/relay/adaptor/openai/rerank.go @@ -19,7 +19,7 @@ func ConvertRerankRequest(meta *meta.Meta, req *http.Request) (string, http.Head if err != nil { return "", nil, nil, err } - reqMap["model"] = meta.ActualModelName + reqMap["model"] = meta.ActualModel jsonData, err := json.Marshal(reqMap) if err != nil { return "", nil, nil, err diff --git a/service/aiproxy/relay/adaptor/openai/stt.go b/service/aiproxy/relay/adaptor/openai/stt.go index c8f070bd18f..ab32ced6c8e 100644 --- a/service/aiproxy/relay/adaptor/openai/stt.go +++ b/service/aiproxy/relay/adaptor/openai/stt.go @@ -32,7 +32,7 @@ func ConvertSTTRequest(meta *meta.Meta, request *http.Request) (string, http.Hea } value := values[0] if key == "model" { - err = multipartWriter.WriteField(key, meta.ActualModelName) + err = multipartWriter.WriteField(key, meta.ActualModel) if err != nil { return "", nil, nil, err } @@ -113,7 +113,7 @@ func STTHandler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model.Us if err != nil { return nil, ErrorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError) } - completionTokens := CountTokenText(text, meta.ActualModelName) + completionTokens := CountTokenText(text, meta.ActualModel) for k, v := range resp.Header { c.Writer.Header().Set(k, v[0]) diff --git a/service/aiproxy/relay/adaptor/openai/tts.go b/service/aiproxy/relay/adaptor/openai/tts.go index 08f3fa07c20..bdce9258938 100644 --- a/service/aiproxy/relay/adaptor/openai/tts.go +++ b/service/aiproxy/relay/adaptor/openai/tts.go @@ -28,7 +28,7 @@ func ConvertTTSRequest(meta *meta.Meta, req *http.Request) (string, http.Header, if err != nil { return "", nil, nil, err } - reqMap["model"] = meta.ActualModelName + reqMap["model"] = meta.ActualModel jsonData, err := json.Marshal(reqMap) if err != nil { return "", nil, nil, err diff --git a/service/aiproxy/relay/adaptor/siliconflow/image.go b/service/aiproxy/relay/adaptor/siliconflow/image.go index 08853c733a4..1992fccc6ba 100644 --- a/service/aiproxy/relay/adaptor/siliconflow/image.go +++ b/service/aiproxy/relay/adaptor/siliconflow/image.go @@ -33,7 +33,7 @@ func ConvertImageRequest(meta *meta.Meta, request *http.Request) (http.Header, i meta.Set(openai.MetaResponseFormat, reqMap["response_format"]) - reqMap["model"] = meta.ActualModelName + reqMap["model"] = meta.ActualModel reqMap["batch_size"] = reqMap["n"] delete(reqMap, "n") if _, ok := reqMap["steps"]; ok { diff --git a/service/aiproxy/relay/adaptor/vertexai/adaptor.go b/service/aiproxy/relay/adaptor/vertexai/adaptor.go index 737c330a643..075d532cc18 100644 --- a/service/aiproxy/relay/adaptor/vertexai/adaptor.go +++ b/service/aiproxy/relay/adaptor/vertexai/adaptor.go @@ -30,7 +30,7 @@ type Config struct { } func (a *Adaptor) ConvertRequest(meta *meta.Meta, request *http.Request) (string, http.Header, io.Reader, error) { - adaptor := GetAdaptor(meta.ActualModelName) + adaptor := GetAdaptor(meta.ActualModel) if adaptor == nil { return "", nil, nil, errors.New("adaptor not found") } @@ -39,9 +39,9 @@ func (a *Adaptor) ConvertRequest(meta *meta.Meta, request *http.Request) (string } func (a *Adaptor) DoResponse(meta *meta.Meta, c *gin.Context, resp *http.Response) (usage *relaymodel.Usage, err *relaymodel.ErrorWithStatusCode) { - adaptor := GetAdaptor(meta.ActualModelName) + adaptor := GetAdaptor(meta.ActualModel) if adaptor == nil { - return nil, openai.ErrorWrapperWithMessage(meta.ActualModelName+" adaptor not found", "adaptor_not_found", http.StatusInternalServerError) + return nil, openai.ErrorWrapperWithMessage(meta.ActualModel+" adaptor not found", "adaptor_not_found", http.StatusInternalServerError) } return adaptor.DoResponse(meta, c, resp) } @@ -56,7 +56,7 @@ func (a *Adaptor) GetChannelName() string { func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { var suffix string - if strings.HasPrefix(meta.ActualModelName, "gemini") { + if strings.HasPrefix(meta.ActualModel, "gemini") { if meta.GetBool("stream") { suffix = "streamGenerateContent?alt=sse" } else { @@ -81,7 +81,7 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { meta.Channel.BaseURL, config.ProjectID, config.Region, - meta.ActualModelName, + meta.ActualModel, suffix, ), nil } @@ -90,7 +90,7 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { config.Region, config.ProjectID, config.Region, - meta.ActualModelName, + meta.ActualModel, suffix, ), nil } diff --git a/service/aiproxy/relay/adaptor/xunfei/adaptor.go b/service/aiproxy/relay/adaptor/xunfei/adaptor.go index faa51f5d4ca..0309a92f258 100644 --- a/service/aiproxy/relay/adaptor/xunfei/adaptor.go +++ b/service/aiproxy/relay/adaptor/xunfei/adaptor.go @@ -23,14 +23,14 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { } func (a *Adaptor) ConvertRequest(meta *meta.Meta, req *http.Request) (string, http.Header, io.Reader, error) { - domain, err := getXunfeiDomain(meta.ActualModelName) + domain, err := getXunfeiDomain(meta.ActualModel) if err != nil { return "", nil, nil, err } - model := meta.ActualModelName - meta.ActualModelName = domain + model := meta.ActualModel + meta.ActualModel = domain defer func() { - meta.ActualModelName = model + meta.ActualModel = model }() method, h, body, err := a.Adaptor.ConvertRequest(meta, req) if err != nil { diff --git a/service/aiproxy/relay/controller/handle.go b/service/aiproxy/relay/controller/handle.go new file mode 100644 index 00000000000..e46e7c8b991 --- /dev/null +++ b/service/aiproxy/relay/controller/handle.go @@ -0,0 +1,111 @@ +package controller + +import ( + "context" + "errors" + "fmt" + "net/http" + + "github.com/gin-gonic/gin" + "github.com/labring/sealos/service/aiproxy/common" + "github.com/labring/sealos/service/aiproxy/common/conv" + "github.com/labring/sealos/service/aiproxy/middleware" + "github.com/labring/sealos/service/aiproxy/model" + "github.com/labring/sealos/service/aiproxy/relay/adaptor/openai" + "github.com/labring/sealos/service/aiproxy/relay/channeltype" + "github.com/labring/sealos/service/aiproxy/relay/meta" + relaymodel "github.com/labring/sealos/service/aiproxy/relay/model" +) + +func Handle(meta *meta.Meta, c *gin.Context, preProcess func() (*PreCheckGroupBalanceReq, error)) *relaymodel.ErrorWithStatusCode { + log := middleware.GetLogger(c) + ctx := c.Request.Context() + + // 1. Get adaptor + adaptor, ok := channeltype.GetAdaptor(meta.Channel.Type) + if !ok { + log.Errorf("invalid (%s[%d]) channel type: %d", meta.Channel.Name, meta.Channel.ID, meta.Channel.Type) + return openai.ErrorWrapperWithMessage("invalid channel error", "invalid_channel_type", http.StatusInternalServerError) + } + + // 2. Get group balance + groupRemainBalance, postGroupConsumer, err := getGroupBalance(ctx, meta) + if err != nil { + log.Errorf("get group (%s) balance failed: %v", meta.Group.ID, err) + return openai.ErrorWrapper( + fmt.Errorf("get group (%s) balance failed", meta.Group.ID), + "get_group_quota_failed", + http.StatusInternalServerError, + ) + } + + // 3. Pre-process request + preCheckReq, err := preProcess() + if err != nil { + log.Errorf("pre-process request failed: %s", err.Error()) + var detail *model.RequestDetail + body, bodyErr := common.GetRequestBody(c.Request) + if bodyErr != nil { + log.Errorf("get request body failed: %s", bodyErr.Error()) + } else { + detail = &model.RequestDetail{ + RequestBody: conv.BytesToString(body), + } + } + ConsumeWaitGroup.Add(1) + go postConsumeAmount(context.Background(), + &ConsumeWaitGroup, + nil, + http.StatusBadRequest, + nil, + meta, + 0, + 0, + err.Error(), + detail, + ) + return openai.ErrorWrapper(err, "invalid_request", http.StatusBadRequest) + } + + // 4. Pre-check balance + ok = checkGroupBalance(preCheckReq, meta, groupRemainBalance) + if !ok { + return openai.ErrorWrapper(errors.New("group balance is not enough"), "insufficient_group_balance", http.StatusForbidden) + } + + meta.InputTokens = preCheckReq.InputTokens + + // 5. Do request + usage, detail, respErr := DoHelper(adaptor, c, meta) + if respErr != nil { + ConsumeWaitGroup.Add(1) + go postConsumeAmount(context.Background(), + &ConsumeWaitGroup, + postGroupConsumer, + respErr.StatusCode, + usage, + meta, + preCheckReq.InputPrice, + preCheckReq.OutputPrice, + respErr.String(), + detail, + ) + return respErr + } + + // 6. Post consume + ConsumeWaitGroup.Add(1) + go postConsumeAmount(context.Background(), + &ConsumeWaitGroup, + postGroupConsumer, + http.StatusOK, + usage, + meta, + preCheckReq.InputPrice, + preCheckReq.OutputPrice, + "", + nil, + ) + + return nil +} diff --git a/service/aiproxy/relay/controller/helper.go b/service/aiproxy/relay/controller/helper.go index 9c457a2defa..eb156d5d7c5 100644 --- a/service/aiproxy/relay/controller/helper.go +++ b/service/aiproxy/relay/controller/helper.go @@ -24,6 +24,7 @@ import ( "github.com/labring/sealos/service/aiproxy/relay/relaymode" "github.com/labring/sealos/service/aiproxy/relay/utils" "github.com/shopspring/decimal" + log "github.com/sirupsen/logrus" ) var ConsumeWaitGroup sync.WaitGroup @@ -31,11 +32,12 @@ var ConsumeWaitGroup sync.WaitGroup type PreCheckGroupBalanceReq struct { InputTokens int MaxTokens int - Price float64 + InputPrice float64 + OutputPrice float64 } func getPreConsumedAmount(req *PreCheckGroupBalanceReq) float64 { - if req == nil || req.Price == 0 || (req.InputTokens == 0 && req.MaxTokens == 0) { + if req == nil || req.InputPrice == 0 || (req.InputTokens == 0 && req.MaxTokens == 0) { return 0 } preConsumedTokens := int64(req.InputTokens) @@ -44,15 +46,18 @@ func getPreConsumedAmount(req *PreCheckGroupBalanceReq) float64 { } return decimal. NewFromInt(preConsumedTokens). - Mul(decimal.NewFromFloat(req.Price)). + Mul(decimal.NewFromFloat(req.InputPrice)). Div(decimal.NewFromInt(billingprice.PriceUnit)). InexactFloat64() } -func preCheckGroupBalance(req *PreCheckGroupBalanceReq, meta *meta.Meta, groupRemainBalance float64) bool { +func checkGroupBalance(req *PreCheckGroupBalanceReq, meta *meta.Meta, groupRemainBalance float64) bool { if meta.IsChannelTest { return true } + if groupRemainBalance <= 0 { + return false + } preConsumedAmount := getPreConsumedAmount(req) @@ -79,7 +84,13 @@ func postConsumeAmount( content string, requestDetail *model.RequestDetail, ) { - defer consumeWaitGroup.Done() + defer func() { + consumeWaitGroup.Done() + if r := recover(); r != nil { + log.Errorf("panic in post consume amount: %v", r) + } + }() + if meta.IsChannelTest { return } @@ -94,7 +105,7 @@ func postConsumeAmount( meta.Channel.ID, 0, 0, - meta.OriginModelName, + meta.OriginModel, meta.Token.ID, meta.Token.Name, 0, @@ -133,7 +144,7 @@ func postConsumeAmount( meta.RequestAt, meta.Group.ID, meta.Token.Name, - meta.OriginModelName, + meta.OriginModel, err.Error(), amount, meta.Token.ID, @@ -154,7 +165,7 @@ func postConsumeAmount( meta.Channel.ID, promptTokens, completionTokens, - meta.OriginModelName, + meta.OriginModel, meta.Token.ID, meta.Token.Name, amount, diff --git a/service/aiproxy/relay/controller/image.go b/service/aiproxy/relay/controller/image.go index d485adc3705..2ff31c2b730 100644 --- a/service/aiproxy/relay/controller/image.go +++ b/service/aiproxy/relay/controller/image.go @@ -1,19 +1,10 @@ package controller import ( - "context" "errors" "fmt" - "net/http" "github.com/gin-gonic/gin" - "github.com/labring/sealos/service/aiproxy/common" - "github.com/labring/sealos/service/aiproxy/common/config" - "github.com/labring/sealos/service/aiproxy/common/conv" - "github.com/labring/sealos/service/aiproxy/middleware" - "github.com/labring/sealos/service/aiproxy/model" - "github.com/labring/sealos/service/aiproxy/relay/adaptor/openai" - "github.com/labring/sealos/service/aiproxy/relay/channeltype" "github.com/labring/sealos/service/aiproxy/relay/meta" relaymodel "github.com/labring/sealos/service/aiproxy/relay/model" billingprice "github.com/labring/sealos/service/aiproxy/relay/price" @@ -40,8 +31,8 @@ func getImageRequest(c *gin.Context) (*relaymodel.ImageRequest, error) { return imageRequest, nil } -func getImageCostPrice(modelName string, reqModel string, size string) (float64, error) { - imageCostPrice, ok := billingprice.GetImageSizePrice(modelName, reqModel, size) +func getImageCostPrice(model string, size string) (float64, error) { + imageCostPrice, ok := billingprice.GetImageSizePrice(model, size) if !ok { return 0, fmt.Errorf("invalid image size: %s", size) } @@ -49,100 +40,20 @@ func getImageCostPrice(modelName string, reqModel string, size string) (float64, } func RelayImageHelper(meta *meta.Meta, c *gin.Context) *relaymodel.ErrorWithStatusCode { - log := middleware.GetLogger(c) - ctx := c.Request.Context() - - adaptor, ok := channeltype.GetAdaptor(meta.Channel.Type) - if !ok { - log.Errorf("invalid (%s[%d]) channel type: %d", meta.Channel.Name, meta.Channel.ID, meta.Channel.Type) - return openai.ErrorWrapperWithMessage("invalid channel error", "invalid_channel_type", http.StatusInternalServerError) - } - - groupRemainBalance, postGroupConsumer, err := getGroupBalance(ctx, meta) - if err != nil { - log.Errorf("get group (%s) balance failed: %v", meta.Group.ID, err) - return openai.ErrorWrapper( - fmt.Errorf("get group (%s) balance failed", meta.Group.ID), - "get_group_quota_failed", - http.StatusInternalServerError, - ) - } - - imageRequest, err := getImageRequest(c) - if err != nil { - log.Errorf("get request failed: %s", err.Error()) - var detail model.RequestDetail - reqDetail, err := common.GetRequestBody(c.Request) + return Handle(meta, c, func() (*PreCheckGroupBalanceReq, error) { + imageRequest, err := getImageRequest(c) if err != nil { - log.Errorf("get request body failed: %s", err.Error()) - } else { - detail.RequestBody = conv.BytesToString(reqDetail) + return nil, err } - ConsumeWaitGroup.Add(1) - go postConsumeAmount(context.Background(), - &ConsumeWaitGroup, - nil, - http.StatusBadRequest, - nil, - meta, - 0, - 0, - err.Error(), - &detail, - ) - return openai.ErrorWrapper(err, "invalid_image_request", http.StatusBadRequest) - } - imageCostPrice, err := getImageCostPrice(meta.OriginModelName, meta.ActualModelName, imageRequest.Size) - if err != nil { - return openai.ErrorWrapper(err, "get_image_cost_price_failed", http.StatusInternalServerError) - } - - meta.InputTokens = imageRequest.N - - ok = preCheckGroupBalance(&PreCheckGroupBalanceReq{ - InputTokens: meta.InputTokens, - Price: imageCostPrice, - }, meta, groupRemainBalance) - if !ok { - return openai.ErrorWrapper(errors.New("group balance is not enough"), "insufficient_group_balance", http.StatusForbidden) - } - - // do response - usage, detail, respErr := DoHelper(adaptor, c, meta) - if respErr != nil { - if detail != nil && config.DebugEnabled { - log.Errorf("do image failed: %s\nrequest detail:\n%s\nresponse detail:\n%s", respErr, detail.RequestBody, detail.ResponseBody) - } else { - log.Errorf("do image failed: %s", respErr) + imageCostPrice, err := getImageCostPrice(meta.OriginModel, imageRequest.Size) + if err != nil { + return nil, err } - ConsumeWaitGroup.Add(1) - go postConsumeAmount(context.Background(), - &ConsumeWaitGroup, - postGroupConsumer, - respErr.StatusCode, - usage, - meta, - imageCostPrice, - 0, - respErr.String(), - detail, - ) - return respErr - } - - ConsumeWaitGroup.Add(1) - go postConsumeAmount(context.Background(), - &ConsumeWaitGroup, - postGroupConsumer, - http.StatusOK, - usage, - meta, - imageCostPrice, - 0, - imageRequest.Size, - nil, - ) - return nil + return &PreCheckGroupBalanceReq{ + InputTokens: imageRequest.N, + InputPrice: imageCostPrice, + }, nil + }) } diff --git a/service/aiproxy/relay/controller/rerank.go b/service/aiproxy/relay/controller/rerank.go index 8a9f5b783c7..737e2bc2d2e 100644 --- a/service/aiproxy/relay/controller/rerank.go +++ b/service/aiproxy/relay/controller/rerank.go @@ -1,20 +1,11 @@ package controller import ( - "context" "errors" "fmt" - "net/http" "strings" "github.com/gin-gonic/gin" - "github.com/labring/sealos/service/aiproxy/common" - "github.com/labring/sealos/service/aiproxy/common/config" - "github.com/labring/sealos/service/aiproxy/common/conv" - "github.com/labring/sealos/service/aiproxy/middleware" - "github.com/labring/sealos/service/aiproxy/model" - "github.com/labring/sealos/service/aiproxy/relay/adaptor/openai" - "github.com/labring/sealos/service/aiproxy/relay/channeltype" "github.com/labring/sealos/service/aiproxy/relay/meta" relaymodel "github.com/labring/sealos/service/aiproxy/relay/model" billingprice "github.com/labring/sealos/service/aiproxy/relay/price" @@ -22,101 +13,23 @@ import ( ) func RerankHelper(meta *meta.Meta, c *gin.Context) *relaymodel.ErrorWithStatusCode { - log := middleware.GetLogger(c) - ctx := c.Request.Context() - - adaptor, ok := channeltype.GetAdaptor(meta.Channel.Type) - if !ok { - log.Errorf("invalid (%s[%d]) channel type: %d", meta.Channel.Name, meta.Channel.ID, meta.Channel.Type) - return openai.ErrorWrapperWithMessage("invalid channel error", "invalid_channel_type", http.StatusInternalServerError) - } - - groupRemainBalance, postGroupConsumer, err := getGroupBalance(ctx, meta) - if err != nil { - log.Errorf("get group (%s) balance failed: %v", meta.Group.ID, err) - return openai.ErrorWrapper( - fmt.Errorf("get group (%s) balance failed", meta.Group.ID), - "get_group_quota_failed", - http.StatusInternalServerError, - ) - } - - rerankRequest, err := getRerankRequest(c) - if err != nil { - log.Errorf("get request failed: %s", err.Error()) - var detail model.RequestDetail - reqDetail, err := common.GetRequestBody(c.Request) - if err != nil { - log.Errorf("get request body failed: %s", err.Error()) - } else { - detail.RequestBody = conv.BytesToString(reqDetail) + return Handle(meta, c, func() (*PreCheckGroupBalanceReq, error) { + price, completionPrice, ok := billingprice.GetModelPrice(meta.OriginModel) + if !ok { + return nil, fmt.Errorf("model price not found: %s", meta.OriginModel) } - ConsumeWaitGroup.Add(1) - go postConsumeAmount(context.Background(), - &ConsumeWaitGroup, - nil, - http.StatusBadRequest, - nil, - meta, - 0, - 0, - err.Error(), - &detail, - ) - return openai.ErrorWrapper(err, "invalid_rerank_request", http.StatusBadRequest) - } - price, completionPrice, ok := billingprice.GetModelPrice(meta.OriginModelName, meta.ActualModelName) - if !ok { - return openai.ErrorWrapper(fmt.Errorf("model price not found: %s", meta.OriginModelName), "model_price_not_found", http.StatusInternalServerError) - } - - meta.InputTokens = rerankPromptTokens(rerankRequest) - - ok = preCheckGroupBalance(&PreCheckGroupBalanceReq{ - InputTokens: meta.InputTokens, - Price: price, - }, meta, groupRemainBalance) - if !ok { - return openai.ErrorWrapper(errors.New("group balance is not enough"), "insufficient_group_balance", http.StatusForbidden) - } - - usage, detail, respErr := DoHelper(adaptor, c, meta) - if respErr != nil { - if detail != nil && config.DebugEnabled { - log.Errorf("do rerank failed: %s\nrequest detail:\n%s\nresponse detail:\n%s", respErr, detail.RequestBody, detail.ResponseBody) - } else { - log.Errorf("do rerank failed: %s", respErr) + rerankRequest, err := getRerankRequest(c) + if err != nil { + return nil, err } - ConsumeWaitGroup.Add(1) - go postConsumeAmount(context.Background(), - &ConsumeWaitGroup, - postGroupConsumer, - http.StatusInternalServerError, - usage, - meta, - price, - completionPrice, - respErr.String(), - detail, - ) - return respErr - } - - ConsumeWaitGroup.Add(1) - go postConsumeAmount(context.Background(), - &ConsumeWaitGroup, - postGroupConsumer, - http.StatusOK, - usage, - meta, - price, - completionPrice, - "", - nil, - ) - return nil + return &PreCheckGroupBalanceReq{ + InputTokens: rerankPromptTokens(rerankRequest), + InputPrice: price, + OutputPrice: completionPrice, + }, nil + }) } func getRerankRequest(c *gin.Context) (*relaymodel.RerankRequest, error) { diff --git a/service/aiproxy/relay/controller/stt.go b/service/aiproxy/relay/controller/stt.go index 317df5ff755..e12d130c4e1 100644 --- a/service/aiproxy/relay/controller/stt.go +++ b/service/aiproxy/relay/controller/stt.go @@ -1,88 +1,24 @@ package controller import ( - "context" - "errors" "fmt" - "net/http" "github.com/gin-gonic/gin" - "github.com/labring/sealos/service/aiproxy/common/config" - "github.com/labring/sealos/service/aiproxy/middleware" - "github.com/labring/sealos/service/aiproxy/relay/adaptor/openai" - "github.com/labring/sealos/service/aiproxy/relay/channeltype" "github.com/labring/sealos/service/aiproxy/relay/meta" relaymodel "github.com/labring/sealos/service/aiproxy/relay/model" billingprice "github.com/labring/sealos/service/aiproxy/relay/price" ) func RelaySTTHelper(meta *meta.Meta, c *gin.Context) *relaymodel.ErrorWithStatusCode { - log := middleware.GetLogger(c) - ctx := c.Request.Context() - - adaptor, ok := channeltype.GetAdaptor(meta.Channel.Type) - if !ok { - log.Errorf("invalid (%s[%d]) channel type: %d", meta.Channel.Name, meta.Channel.ID, meta.Channel.Type) - return openai.ErrorWrapperWithMessage("invalid channel error", "invalid_channel_type", http.StatusInternalServerError) - } - - groupRemainBalance, postGroupConsumer, err := getGroupBalance(ctx, meta) - if err != nil { - log.Errorf("get group (%s) balance failed: %v", meta.Group.ID, err) - return openai.ErrorWrapper( - fmt.Errorf("get group (%s) balance failed", meta.Group.ID), - "get_group_quota_failed", - http.StatusInternalServerError, - ) - } - - price, completionPrice, ok := billingprice.GetModelPrice(meta.OriginModelName, meta.ActualModelName) - if !ok { - return openai.ErrorWrapper(fmt.Errorf("model price not found: %s", meta.OriginModelName), "model_price_not_found", http.StatusInternalServerError) - } - - ok = preCheckGroupBalance(&PreCheckGroupBalanceReq{ - InputTokens: meta.InputTokens, - Price: price, - }, meta, groupRemainBalance) - if !ok { - return openai.ErrorWrapper(errors.New("group balance is not enough"), "insufficient_group_balance", http.StatusForbidden) - } - - usage, detail, respErr := DoHelper(adaptor, c, meta) - if respErr != nil { - if detail != nil && config.DebugEnabled { - log.Errorf("do stt failed: %s\nrequest detail:\n%s\nresponse detail:\n%s", respErr, detail.RequestBody, detail.ResponseBody) - } else { - log.Errorf("do stt failed: %s", respErr) + return Handle(meta, c, func() (*PreCheckGroupBalanceReq, error) { + price, completionPrice, ok := billingprice.GetModelPrice(meta.OriginModel) + if !ok { + return nil, fmt.Errorf("model price not found: %s", meta.OriginModel) } - ConsumeWaitGroup.Add(1) - go postConsumeAmount(context.Background(), - &ConsumeWaitGroup, - postGroupConsumer, - respErr.StatusCode, - usage, - meta, - price, - completionPrice, - respErr.String(), - detail, - ) - return respErr - } - - ConsumeWaitGroup.Add(1) - go postConsumeAmount(context.Background(), - &ConsumeWaitGroup, - postGroupConsumer, - http.StatusOK, - usage, - meta, - price, - completionPrice, - "", - nil, - ) - return nil + return &PreCheckGroupBalanceReq{ + InputPrice: price, + OutputPrice: completionPrice, + }, nil + }) } diff --git a/service/aiproxy/relay/controller/text.go b/service/aiproxy/relay/controller/text.go index 8b6af6b9e6a..da3ee8d40e1 100644 --- a/service/aiproxy/relay/controller/text.go +++ b/service/aiproxy/relay/controller/text.go @@ -1,19 +1,10 @@ package controller import ( - "context" - "errors" "fmt" - "net/http" "github.com/gin-gonic/gin" - "github.com/labring/sealos/service/aiproxy/common" - "github.com/labring/sealos/service/aiproxy/common/config" - "github.com/labring/sealos/service/aiproxy/common/conv" - "github.com/labring/sealos/service/aiproxy/middleware" - "github.com/labring/sealos/service/aiproxy/model" "github.com/labring/sealos/service/aiproxy/relay/adaptor/openai" - "github.com/labring/sealos/service/aiproxy/relay/channeltype" "github.com/labring/sealos/service/aiproxy/relay/meta" relaymodel "github.com/labring/sealos/service/aiproxy/relay/model" billingprice "github.com/labring/sealos/service/aiproxy/relay/price" @@ -21,113 +12,22 @@ import ( ) func RelayTextHelper(meta *meta.Meta, c *gin.Context) *relaymodel.ErrorWithStatusCode { - log := middleware.GetLogger(c) - ctx := c.Request.Context() - - adaptor, ok := channeltype.GetAdaptor(meta.Channel.Type) - if !ok { - log.Errorf("invalid (%s[%d]) channel type: %d", meta.Channel.Name, meta.Channel.ID, meta.Channel.Type) - return openai.ErrorWrapperWithMessage("invalid channel error", "invalid_channel_type", http.StatusInternalServerError) - } - - groupRemainBalance, postGroupConsumer, err := getGroupBalance(ctx, meta) - if err != nil { - log.Errorf("get group (%s) balance failed: %v", meta.Group.ID, err) - return openai.ErrorWrapper( - fmt.Errorf("get group (%s) balance failed", meta.Group.ID), - "get_group_quota_failed", - http.StatusInternalServerError, - ) - } + return Handle(meta, c, func() (*PreCheckGroupBalanceReq, error) { + price, completionPrice, ok := billingprice.GetModelPrice(meta.OriginModel) + if !ok { + return nil, fmt.Errorf("model price not found: %s", meta.OriginModel) + } - textRequest, err := utils.UnmarshalGeneralOpenAIRequest(c.Request) - if err != nil { - log.Errorf("get request failed: %s", err.Error()) - var detail model.RequestDetail - reqDetail, err := common.GetRequestBody(c.Request) + textRequest, err := utils.UnmarshalGeneralOpenAIRequest(c.Request) if err != nil { - log.Errorf("get request body failed: %s", err.Error()) - } else { - detail.RequestBody = conv.BytesToString(reqDetail) + return nil, err } - ConsumeWaitGroup.Add(1) - go postConsumeAmount(context.Background(), - &ConsumeWaitGroup, - nil, - http.StatusBadRequest, - nil, - meta, - 0, - 0, - err.Error(), - &detail, - ) - return openai.ErrorWrapper(fmt.Errorf("get and validate text request failed: %s", err.Error()), "invalid_text_request", http.StatusBadRequest) - } - // get model price - price, completionPrice, ok := billingprice.GetModelPrice(meta.OriginModelName, meta.ActualModelName) - if !ok { - return openai.ErrorWrapper(fmt.Errorf("model price not found: %s", meta.OriginModelName), "model_price_not_found", http.StatusInternalServerError) - } - // pre-consume balance - meta.InputTokens = openai.GetPromptTokens(meta, textRequest) - - ok = preCheckGroupBalance(&PreCheckGroupBalanceReq{ - InputTokens: meta.InputTokens, - MaxTokens: textRequest.MaxTokens, - Price: price, - }, meta, groupRemainBalance) - if !ok { - ConsumeWaitGroup.Add(1) - go postConsumeAmount(context.Background(), - &ConsumeWaitGroup, - postGroupConsumer, - http.StatusForbidden, - nil, - meta, - 0, - 0, - "group balance is not enough", - nil, - ) - return openai.ErrorWrapper(errors.New("group balance is not enough"), "insufficient_group_balance", http.StatusForbidden) - } - - // do response - usage, detail, respErr := DoHelper(adaptor, c, meta) - if respErr != nil { - if detail != nil && config.DebugEnabled { - log.Errorf("do text failed: %s\nrequest detail:\n%s\nresponse detail:\n%s", respErr, detail.RequestBody, detail.ResponseBody) - } else { - log.Errorf("do text failed: %s", respErr) - } - ConsumeWaitGroup.Add(1) - go postConsumeAmount(context.Background(), - &ConsumeWaitGroup, - postGroupConsumer, - respErr.StatusCode, - usage, - meta, - price, - completionPrice, - respErr.String(), - detail, - ) - return respErr - } - // post-consume amount - ConsumeWaitGroup.Add(1) - go postConsumeAmount(context.Background(), - &ConsumeWaitGroup, - postGroupConsumer, - http.StatusOK, - usage, - meta, - price, - completionPrice, - "", - nil, - ) - return nil + return &PreCheckGroupBalanceReq{ + InputTokens: openai.GetPromptTokens(meta, textRequest), + MaxTokens: textRequest.MaxTokens, + InputPrice: price, + OutputPrice: completionPrice, + }, nil + }) } diff --git a/service/aiproxy/relay/controller/tts.go b/service/aiproxy/relay/controller/tts.go index 0d7a9c03837..56e0298c9cb 100644 --- a/service/aiproxy/relay/controller/tts.go +++ b/service/aiproxy/relay/controller/tts.go @@ -1,19 +1,10 @@ package controller import ( - "context" - "errors" "fmt" - "net/http" "github.com/gin-gonic/gin" - "github.com/labring/sealos/service/aiproxy/common" - "github.com/labring/sealos/service/aiproxy/common/config" - "github.com/labring/sealos/service/aiproxy/common/conv" - "github.com/labring/sealos/service/aiproxy/middleware" - "github.com/labring/sealos/service/aiproxy/model" "github.com/labring/sealos/service/aiproxy/relay/adaptor/openai" - "github.com/labring/sealos/service/aiproxy/relay/channeltype" "github.com/labring/sealos/service/aiproxy/relay/meta" relaymodel "github.com/labring/sealos/service/aiproxy/relay/model" billingprice "github.com/labring/sealos/service/aiproxy/relay/price" @@ -21,99 +12,21 @@ import ( ) func RelayTTSHelper(meta *meta.Meta, c *gin.Context) *relaymodel.ErrorWithStatusCode { - log := middleware.GetLogger(c) - ctx := c.Request.Context() - - adaptor, ok := channeltype.GetAdaptor(meta.Channel.Type) - if !ok { - log.Errorf("invalid (%s[%d]) channel type: %d", meta.Channel.Name, meta.Channel.ID, meta.Channel.Type) - return openai.ErrorWrapperWithMessage("invalid channel error", "invalid_channel_type", http.StatusInternalServerError) - } - - groupRemainBalance, postGroupConsumer, err := getGroupBalance(ctx, meta) - if err != nil { - log.Errorf("get group (%s) balance failed: %v", meta.Group.ID, err) - return openai.ErrorWrapper( - fmt.Errorf("get group (%s) balance failed", meta.Group.ID), - "get_group_quota_failed", - http.StatusInternalServerError, - ) - } - - price, completionPrice, ok := billingprice.GetModelPrice(meta.OriginModelName, meta.ActualModelName) - if !ok { - return openai.ErrorWrapper(fmt.Errorf("model price not found: %s", meta.OriginModelName), "model_price_not_found", http.StatusInternalServerError) - } - - ttsRequest, err := utils.UnmarshalTTSRequest(c.Request) - if err != nil { - log.Errorf("get request failed: %s", err.Error()) - var detail model.RequestDetail - reqDetail, err := common.GetRequestBody(c.Request) - if err != nil { - log.Errorf("get request body failed: %s", err.Error()) - } else { - detail.RequestBody = conv.BytesToString(reqDetail) + return Handle(meta, c, func() (*PreCheckGroupBalanceReq, error) { + price, completionPrice, ok := billingprice.GetModelPrice(meta.OriginModel) + if !ok { + return nil, fmt.Errorf("model price not found: %s", meta.OriginModel) } - ConsumeWaitGroup.Add(1) - go postConsumeAmount(context.Background(), - &ConsumeWaitGroup, - nil, - http.StatusBadRequest, - nil, - meta, - 0, - 0, - err.Error(), - &detail, - ) - return openai.ErrorWrapper(err, "invalid_tts_request", http.StatusBadRequest) - } - meta.InputTokens = openai.CountTokenText(ttsRequest.Input, meta.ActualModelName) - - ok = preCheckGroupBalance(&PreCheckGroupBalanceReq{ - InputTokens: meta.InputTokens, - Price: price, - }, meta, groupRemainBalance) - if !ok { - return openai.ErrorWrapper(errors.New("group balance is not enough"), "insufficient_group_balance", http.StatusForbidden) - } - - usage, detail, respErr := DoHelper(adaptor, c, meta) - if respErr != nil { - if detail != nil && config.DebugEnabled { - log.Errorf("do tts failed: %s\nrequest detail:\n%s\nresponse detail:\n%s", respErr, detail.RequestBody, detail.ResponseBody) - } else { - log.Errorf("do tts failed: %s", respErr) + ttsRequest, err := utils.UnmarshalTTSRequest(c.Request) + if err != nil { + return nil, err } - ConsumeWaitGroup.Add(1) - go postConsumeAmount(context.Background(), - &ConsumeWaitGroup, - postGroupConsumer, - respErr.StatusCode, - usage, - meta, - price, - completionPrice, - respErr.String(), - detail, - ) - return respErr - } - - ConsumeWaitGroup.Add(1) - go postConsumeAmount(context.Background(), - &ConsumeWaitGroup, - postGroupConsumer, - http.StatusOK, - usage, - meta, - price, - completionPrice, - "", - nil, - ) - return nil + return &PreCheckGroupBalanceReq{ + InputTokens: openai.CountTokenText(ttsRequest.Input, meta.ActualModel), + InputPrice: price, + OutputPrice: completionPrice, + }, nil + }) } diff --git a/service/aiproxy/relay/meta/meta.go b/service/aiproxy/relay/meta/meta.go index 63cc4de07cb..7519d507eec 100644 --- a/service/aiproxy/relay/meta/meta.go +++ b/service/aiproxy/relay/meta/meta.go @@ -21,14 +21,14 @@ type Meta struct { Group *model.GroupCache Token *model.TokenCache - Endpoint string - RequestAt time.Time - RequestID string - OriginModelName string - ActualModelName string - Mode int - InputTokens int - IsChannelTest bool + Endpoint string + RequestAt time.Time + RequestID string + OriginModel string + ActualModel string + Mode int + InputTokens int + IsChannelTest bool } type Option func(meta *Meta) @@ -71,10 +71,10 @@ func WithToken(token *model.TokenCache) Option { func NewMeta(channel *model.Channel, mode int, modelName string, opts ...Option) *Meta { meta := Meta{ - values: make(map[string]any), - Mode: mode, - OriginModelName: modelName, - RequestAt: time.Now(), + values: make(map[string]any), + Mode: mode, + OriginModel: modelName, + RequestAt: time.Now(), } for _, opt := range opts { @@ -94,7 +94,7 @@ func (m *Meta) Reset(channel *model.Channel) { ID: channel.ID, Type: channel.Type, } - m.ActualModelName, _ = GetMappedModelName(m.OriginModelName, channel.ModelMapping) + m.ActualModel, _ = GetMappedModelName(m.OriginModel, channel.ModelMapping) m.ClearValues() } diff --git a/service/aiproxy/relay/price/image.go b/service/aiproxy/relay/price/image.go index 19bcde3f7b2..e5171f1a5e0 100644 --- a/service/aiproxy/relay/price/image.go +++ b/service/aiproxy/relay/price/image.go @@ -8,17 +8,11 @@ import ( "github.com/labring/sealos/service/aiproxy/model" ) -func GetImageSizePrice(model string, reqModel string, size string) (float64, bool) { +func GetImageSizePrice(model string, size string) (float64, bool) { if !config.GetBillingEnabled() { return 0, false } - if price, ok := getImageSizePrice(model, size); ok { - return price, true - } - if price, ok := getImageSizePrice(reqModel, size); ok { - return price, true - } - return 0, false + return getImageSizePrice(model, size) } func getImageSizePrice(modelName string, size string) (float64, bool) { diff --git a/service/aiproxy/relay/price/model.go b/service/aiproxy/relay/price/model.go index 09cda1bd708..48da1919591 100644 --- a/service/aiproxy/relay/price/model.go +++ b/service/aiproxy/relay/price/model.go @@ -16,18 +16,10 @@ const ( // https://openai.com/pricing // 价格单位:人民币/1K tokens -func GetModelPrice(mapedName string, reqModel string) (float64, float64, bool) { +func GetModelPrice(modelName string) (float64, float64, bool) { if !config.GetBillingEnabled() { return 0, 0, true } - price, completionPrice, ok := getModelPrice(mapedName) - if !ok && reqModel != "" { - price, completionPrice, ok = getModelPrice(reqModel) - } - return price, completionPrice, ok -} - -func getModelPrice(modelName string) (float64, float64, bool) { modelConfig, ok := model.CacheGetModelConfig(modelName) if !ok { return 0, 0, false