Skip to content

Commit

Permalink
refactor: relay handler
Browse files Browse the repository at this point in the history
  • Loading branch information
zijiren233 committed Jan 2, 2025
1 parent 38f9f4f commit 9d83c84
Show file tree
Hide file tree
Showing 51 changed files with 299 additions and 611 deletions.
7 changes: 7 additions & 0 deletions service/aiproxy/common/gin.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"io"
"net/http"
"strings"

"github.com/gin-gonic/gin"
json "github.com/json-iterator/go"
Expand Down Expand Up @@ -40,6 +41,12 @@ func (l *LimitedReader) Read(p []byte) (n int, err error) {
}

func GetRequestBody(req *http.Request) ([]byte, error) {
contentType := req.Header.Get("Content-Type")
if contentType == "application/x-www-form-urlencoded" ||
strings.HasPrefix(contentType, "multipart/form-data") {
return nil, nil
}

requestBody := req.Context().Value(RequestBodyKey{})
if requestBody != nil {
return requestBody.([]byte), nil
Expand Down
4 changes: 2 additions & 2 deletions service/aiproxy/controller/channel-test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ func testSingleModel(channel *model.Channel, modelName string) (*model.ChannelTe

return channel.UpdateModelTest(
meta.RequestAt,
meta.OriginModelName,
meta.ActualModelName,
meta.OriginModel,
meta.ActualModel,
meta.Mode,
time.Since(meta.RequestAt).Seconds(),
success,
Expand Down
4 changes: 2 additions & 2 deletions service/aiproxy/controller/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func RelayHelper(meta *meta.Meta, c *gin.Context, relayController RelayControlle
if err == nil {
if err := monitor.AddRequest(
c.Request.Context(),
meta.OriginModelName,
meta.OriginModel,
int64(meta.Channel.ID),
false,
); err != nil {
Expand All @@ -68,7 +68,7 @@ func RelayHelper(meta *meta.Meta, c *gin.Context, relayController RelayControlle
if shouldRetry(c, err.StatusCode) {
if err := monitor.AddRequest(
c.Request.Context(),
meta.OriginModelName,
meta.OriginModel,
int64(meta.Channel.ID),
true,
); err != nil {
Expand Down
4 changes: 2 additions & 2 deletions service/aiproxy/middleware/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ func SetLogFieldsFromMeta(m *meta.Meta, fields logrus.Fields) {
SetLogRequestIDField(fields, m.RequestID)

SetLogModeField(fields, m.Mode)
SetLogModelFields(fields, m.OriginModelName)
SetLogActualModelFields(fields, m.ActualModelName)
SetLogModelFields(fields, m.OriginModel)
SetLogActualModelFields(fields, m.ActualModel)

if m.IsChannelTest {
SetLogIsChannelTestField(fields, true)
Expand Down
4 changes: 2 additions & 2 deletions service/aiproxy/relay/adaptor/ali/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func (a *Adaptor) DoRequest(meta *meta.Meta, _ *gin.Context, req *http.Request)
case relaymode.AudioTranscription:
return STTDoRequest(meta, req)
case relaymode.ChatCompletions:
if meta.IsChannelTest && strings.Contains(meta.ActualModelName, "-ocr") {
if meta.IsChannelTest && strings.Contains(meta.ActualModel, "-ocr") {
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewReader(nil)),
Expand All @@ -95,7 +95,7 @@ func (a *Adaptor) DoResponse(meta *meta.Meta, c *gin.Context, resp *http.Respons
case relaymode.ImagesGenerations:
usage, err = ImageHandler(meta, c, resp)
case relaymode.ChatCompletions:
if meta.IsChannelTest && strings.Contains(meta.ActualModelName, "-ocr") {
if meta.IsChannelTest && strings.Contains(meta.ActualModel, "-ocr") {
return nil, nil
}
usage, err = openai.DoResponse(meta, c, resp)
Expand Down
4 changes: 2 additions & 2 deletions service/aiproxy/relay/adaptor/ali/embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func ConvertEmbeddingsRequest(meta *meta.Meta, req *http.Request) (string, http.
if err != nil {
return "", nil, nil, err
}
reqMap["model"] = meta.ActualModelName
reqMap["model"] = meta.ActualModel
input, ok := reqMap["input"]
if !ok {
return "", nil, nil, errors.New("input is required")
Expand Down Expand Up @@ -56,7 +56,7 @@ func embeddingResponse2OpenAI(meta *meta.Meta, response *EmbeddingResponse) *ope
openAIEmbeddingResponse := openai.EmbeddingResponse{
Object: "list",
Data: make([]*openai.EmbeddingResponseItem, 0, 1),
Model: meta.OriginModelName,
Model: meta.OriginModel,
Usage: response.Usage,
}

Expand Down
2 changes: 1 addition & 1 deletion service/aiproxy/relay/adaptor/ali/image.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func ConvertImageRequest(meta *meta.Meta, req *http.Request) (string, http.Heade
if err != nil {
return "", nil, nil, err
}
request.Model = meta.ActualModelName
request.Model = meta.ActualModel

var imageRequest ImageRequest
imageRequest.Input.Prompt = request.Prompt
Expand Down
2 changes: 1 addition & 1 deletion service/aiproxy/relay/adaptor/ali/rerank.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func ConvertRerankRequest(meta *meta.Meta, req *http.Request) (string, http.Head
if err != nil {
return "", nil, nil, err
}
reqMap["model"] = meta.ActualModelName
reqMap["model"] = meta.ActualModel
reqMap["input"] = map[string]any{
"query": reqMap["query"],
"documents": reqMap["documents"],
Expand Down
2 changes: 1 addition & 1 deletion service/aiproxy/relay/adaptor/ali/stt-realtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func ConvertSTTRequest(meta *meta.Meta, request *http.Request) (string, http.Hea
TaskID: uuid.New().String(),
},
Payload: STTPayload{
Model: meta.ActualModelName,
Model: meta.ActualModel,
Task: "asr",
TaskGroup: "audio",
Function: "recognition",
Expand Down
2 changes: 1 addition & 1 deletion service/aiproxy/relay/adaptor/ali/tts.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func ConvertTTSRequest(meta *meta.Meta, req *http.Request) (string, http.Header,
if ok {
sampleRate = int(sampleRateI)
}
request.Model = meta.ActualModelName
request.Model = meta.ActualModel

if strings.HasPrefix(request.Model, "sambert-v") {
voice := request.Voice
Expand Down
2 changes: 1 addition & 1 deletion service/aiproxy/relay/adaptor/anthropic/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func (a *Adaptor) SetupRequestHeader(meta *meta.Meta, c *gin.Context, req *http.

// https://x.com/alexalbert__/status/1812921642143900036
// claude-3-5-sonnet can support 8k context
if strings.HasPrefix(meta.ActualModelName, "claude-3-5-sonnet") {
if strings.HasPrefix(meta.ActualModel, "claude-3-5-sonnet") {
req.Header.Set("Anthropic-Beta", "max-tokens-3-5-sonnet-2024-07-15")
}

Expand Down
4 changes: 2 additions & 2 deletions service/aiproxy/relay/adaptor/anthropic/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func ConvertRequest(meta *meta.Meta, req *http.Request) (*Request, error) {
if err != nil {
return nil, err
}
textRequest.Model = meta.ActualModelName
textRequest.Model = meta.ActualModel
meta.Set("stream", textRequest.Stream)
claudeTools := make([]Tool, 0, len(textRequest.Tools))

Expand Down Expand Up @@ -372,7 +372,7 @@ func Handler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model.Error
}, nil
}
fullTextResponse := ResponseClaude2OpenAI(&claudeResponse)
fullTextResponse.Model = meta.OriginModelName
fullTextResponse.Model = meta.OriginModel
usage := model.Usage{
PromptTokens: claudeResponse.Usage.InputTokens,
CompletionTokens: claudeResponse.Usage.OutputTokens,
Expand Down
2 changes: 1 addition & 1 deletion service/aiproxy/relay/adaptor/aws/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ var _ adaptor.Adaptor = new(Adaptor)
type Adaptor struct{}

func (a *Adaptor) ConvertRequest(meta *meta.Meta, req *http.Request) (string, http.Header, io.Reader, error) {
adaptor := GetAdaptor(meta.ActualModelName)
adaptor := GetAdaptor(meta.ActualModel)
if adaptor == nil {
return "", nil, nil, errors.New("adaptor not found")
}
Expand Down
8 changes: 4 additions & 4 deletions service/aiproxy/relay/adaptor/aws/claude/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func awsModelID(requestModel string) (string, error) {
}

func Handler(meta *meta.Meta, c *gin.Context) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
awsModelID, err := awsModelID(meta.ActualModelName)
awsModelID, err := awsModelID(meta.ActualModel)
if err != nil {
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
}
Expand Down Expand Up @@ -138,7 +138,7 @@ func Handler(meta *meta.Meta, c *gin.Context) (*relaymodel.ErrorWithStatusCode,
}

openaiResp := anthropic.ResponseClaude2OpenAI(claudeResponse)
openaiResp.Model = meta.OriginModelName
openaiResp.Model = meta.OriginModel
usage := relaymodel.Usage{
PromptTokens: claudeResponse.Usage.InputTokens,
CompletionTokens: claudeResponse.Usage.OutputTokens,
Expand All @@ -153,8 +153,8 @@ func Handler(meta *meta.Meta, c *gin.Context) (*relaymodel.ErrorWithStatusCode,
func StreamHandler(meta *meta.Meta, c *gin.Context) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
log := middleware.GetLogger(c)
createdTime := time.Now().Unix()
originModelName := meta.OriginModelName
awsModelID, err := awsModelID(meta.ActualModelName)
originModelName := meta.OriginModel
awsModelID, err := awsModelID(meta.ActualModel)
if err != nil {
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
}
Expand Down
2 changes: 1 addition & 1 deletion service/aiproxy/relay/adaptor/aws/llama3/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func (a *Adaptor) ConvertRequest(meta *meta.Meta, req *http.Request) (string, ht
if err != nil {
return "", nil, nil, err
}
request.Model = meta.ActualModelName
request.Model = meta.ActualModel
meta.Set("stream", request.Stream)
llamaReq := ConvertRequest(request)
meta.Set(ConvertedRequest, llamaReq)
Expand Down
8 changes: 4 additions & 4 deletions service/aiproxy/relay/adaptor/aws/llama3/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func ConvertRequest(textRequest *relaymodel.GeneralOpenAIRequest) *Request {
}

func Handler(meta *meta.Meta, c *gin.Context) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
awsModelID, err := awsModelID(meta.ActualModelName)
awsModelID, err := awsModelID(meta.ActualModel)
if err != nil {
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
}
Expand Down Expand Up @@ -132,7 +132,7 @@ func Handler(meta *meta.Meta, c *gin.Context) (*relaymodel.ErrorWithStatusCode,
}

openaiResp := ResponseLlama2OpenAI(&llamaResponse)
openaiResp.Model = meta.OriginModelName
openaiResp.Model = meta.OriginModel
usage := relaymodel.Usage{
PromptTokens: llamaResponse.PromptTokenCount,
CompletionTokens: llamaResponse.GenerationTokenCount,
Expand Down Expand Up @@ -171,7 +171,7 @@ func StreamHandler(meta *meta.Meta, c *gin.Context) (*relaymodel.ErrorWithStatus
log := middleware.GetLogger(c)

createdTime := time.Now().Unix()
awsModelID, err := awsModelID(meta.ActualModelName)
awsModelID, err := awsModelID(meta.ActualModel)
if err != nil {
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
}
Expand Down Expand Up @@ -231,7 +231,7 @@ func StreamHandler(meta *meta.Meta, c *gin.Context) (*relaymodel.ErrorWithStatus
}
response := StreamResponseLlama2OpenAI(&llamaResp)
response.ID = "chatcmpl-" + random.GetUUID()
response.Model = meta.OriginModelName
response.Model = meta.OriginModel
response.Created = createdTime
err = render.ObjectData(c, response)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion service/aiproxy/relay/adaptor/azure/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
if err != nil {
return "", err
}
model := strings.ReplaceAll(meta.ActualModelName, ".", "")
model := strings.ReplaceAll(meta.ActualModel, ".", "")
switch meta.Mode {
case relaymode.ImagesGenerations:
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
Expand Down
4 changes: 2 additions & 2 deletions service/aiproxy/relay/adaptor/baidu/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
pathSuffix = "text2image"
}

modelEndpoint, ok := modelEndpointMap[meta.ActualModelName]
modelEndpoint, ok := modelEndpointMap[meta.ActualModel]
if !ok {
modelEndpoint = strings.ToLower(meta.ActualModelName)
modelEndpoint = strings.ToLower(meta.ActualModel)
}

// Construct full URL
Expand Down
2 changes: 1 addition & 1 deletion service/aiproxy/relay/adaptor/baidu/embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func EmbeddingsHandler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*r
if err != nil {
return &baiduResponse.Usage, openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
respMap["model"] = meta.OriginModelName
respMap["model"] = meta.OriginModel
respMap["object"] = "list"

data, err := json.Marshal(respMap)
Expand Down
6 changes: 3 additions & 3 deletions service/aiproxy/relay/adaptor/baidu/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func ConvertRequest(meta *meta.Meta, req *http.Request) (string, http.Header, io
if err != nil {
return "", nil, nil, err
}
request.Model = meta.ActualModelName
request.Model = meta.ActualModel
baiduRequest := ChatRequest{
Messages: request.Messages,
Temperature: request.Temperature,
Expand Down Expand Up @@ -117,7 +117,7 @@ func streamResponseBaidu2OpenAI(meta *meta.Meta, baiduResponse *ChatStreamRespon
ID: baiduResponse.ID,
Object: "chat.completion.chunk",
Created: baiduResponse.Created,
Model: meta.OriginModelName,
Model: meta.OriginModel,
Choices: []*openai.ChatCompletionsStreamResponseChoice{&choice},
Usage: baiduResponse.Usage,
}
Expand Down Expand Up @@ -185,7 +185,7 @@ func Handler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model.Usage
return nil, openai.ErrorWrapperWithMessage(baiduResponse.Error.ErrorMsg, "baidu_error_"+strconv.Itoa(baiduResponse.Error.ErrorCode), http.StatusInternalServerError)
}
fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
fullTextResponse.Model = meta.OriginModelName
fullTextResponse.Model = meta.OriginModel
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return nil, openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
Expand Down
6 changes: 3 additions & 3 deletions service/aiproxy/relay/adaptor/baiduv2/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ func (a *Adaptor) SetupRequestHeader(meta *meta.Meta, _ *gin.Context, req *http.
func (a *Adaptor) ConvertRequest(meta *meta.Meta, req *http.Request) (string, http.Header, io.Reader, error) {
switch meta.Mode {
case relaymode.ChatCompletions:
actModel := meta.ActualModelName
actModel := meta.ActualModel
v2Model := toV2ModelName(actModel)
if v2Model != actModel {
meta.ActualModelName = v2Model
defer func() { meta.ActualModelName = actModel }()
meta.ActualModel = v2Model
defer func() { meta.ActualModel = actModel }()
}
return openai.ConvertRequest(meta, req)
default:
Expand Down
4 changes: 2 additions & 2 deletions service/aiproxy/relay/adaptor/cloudflare/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return urlPrefix + "/v1/embeddings", nil
default:
if isAIGateWay {
return fmt.Sprintf("%s/%s", urlPrefix, meta.ActualModelName), nil
return fmt.Sprintf("%s/%s", urlPrefix, meta.ActualModel), nil
}
return fmt.Sprintf("%s/run/%s", urlPrefix, meta.ActualModelName), nil
return fmt.Sprintf("%s/run/%s", urlPrefix, meta.ActualModel), nil
}
}

Expand Down
4 changes: 2 additions & 2 deletions service/aiproxy/relay/adaptor/cohere/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func (a *Adaptor) ConvertRequest(meta *meta.Meta, req *http.Request) (string, ht
if err != nil {
return "", nil, nil, err
}
request.Model = meta.ActualModelName
request.Model = meta.ActualModel
requestBody := ConvertRequest(request)
if requestBody == nil {
return "", nil, nil, errors.New("request body is nil")
Expand All @@ -62,7 +62,7 @@ func (a *Adaptor) DoResponse(meta *meta.Meta, c *gin.Context, resp *http.Respons
if utils.IsStreamResponse(resp) {
err, usage = StreamHandler(c, resp)
} else {
err, usage = Handler(c, resp, meta.InputTokens, meta.ActualModelName)
err, usage = Handler(c, resp, meta.InputTokens, meta.ActualModel)
}
}
return
Expand Down
8 changes: 4 additions & 4 deletions service/aiproxy/relay/adaptor/coze/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ func (a *Adaptor) ConvertRequest(meta *meta.Meta, req *http.Request) (string, ht
return "", nil, nil, err
}
request.User = userID
request.Model = meta.ActualModelName
request.Model = meta.ActualModel
cozeRequest := Request{
Stream: request.Stream,
User: request.User,
BotID: strings.TrimPrefix(meta.ActualModelName, "bot-"),
BotID: strings.TrimPrefix(meta.ActualModel, "bot-"),
}
for i, message := range request.Messages {
if i == len(request.Messages)-1 {
Expand Down Expand Up @@ -84,10 +84,10 @@ func (a *Adaptor) DoResponse(meta *meta.Meta, c *gin.Context, resp *http.Respons
if utils.IsStreamResponse(resp) {
err, responseText = StreamHandler(c, resp)
} else {
err, responseText = Handler(c, resp, meta.InputTokens, meta.ActualModelName)
err, responseText = Handler(c, resp, meta.InputTokens, meta.ActualModel)
}
if responseText != nil {
usage = openai.ResponseText2Usage(*responseText, meta.ActualModelName, meta.InputTokens)
usage = openai.ResponseText2Usage(*responseText, meta.ActualModel, meta.InputTokens)
} else {
usage = &relaymodel.Usage{}
}
Expand Down
2 changes: 1 addition & 1 deletion service/aiproxy/relay/adaptor/doubao/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func GetRequestURL(meta *meta.Meta) (string, error) {
}
switch meta.Mode {
case relaymode.ChatCompletions:
if strings.HasPrefix(meta.ActualModelName, "bot-") {
if strings.HasPrefix(meta.ActualModel, "bot-") {
return u + "/api/v3/bots/chat/completions", nil
}
return u + "/api/v3/chat/completions", nil
Expand Down
4 changes: 2 additions & 2 deletions service/aiproxy/relay/adaptor/gemini/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ func getRequestURL(meta *meta.Meta, action string) string {
u = baseURL
}
version := "v1beta"
if _, ok := v1ModelMap[meta.ActualModelName]; ok {
if _, ok := v1ModelMap[meta.ActualModel]; ok {
version = "v1"
}
return fmt.Sprintf("%s/%s/models/%s:%s", u, version, meta.ActualModelName, action)
return fmt.Sprintf("%s/%s/models/%s:%s", u, version, meta.ActualModel, action)
}

func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
Expand Down
Loading

0 comments on commit 9d83c84

Please sign in to comment.