From aac266c497881712d91d9b2899622f31b9233bce Mon Sep 17 00:00:00 2001 From: mxdlzg Date: Sat, 27 Apr 2024 11:26:18 +0800 Subject: [PATCH] Refactor Gemini adaptor to support streaming content generation --- .github/workflows/docker-image-amd64.yml | 2 +- relay/adaptor/gemini/adaptor.go | 7 +++--- relay/adaptor/gemini/main.go | 32 ++++++++++-------------- 3 files changed, 18 insertions(+), 23 deletions(-) diff --git a/.github/workflows/docker-image-amd64.yml b/.github/workflows/docker-image-amd64.yml index 1b9983c608..be818d4e9a 100644 --- a/.github/workflows/docker-image-amd64.yml +++ b/.github/workflows/docker-image-amd64.yml @@ -29,7 +29,7 @@ jobs: - name: Save version info run: | - git describe --tags > VERSION + git describe --tags > VERSION && cat VERSION - name: Log in to Docker Hub uses: docker/login-action@v2 diff --git a/relay/adaptor/gemini/adaptor.go b/relay/adaptor/gemini/adaptor.go index 839e45d6b5..a4dcae931a 100644 --- a/relay/adaptor/gemini/adaptor.go +++ b/relay/adaptor/gemini/adaptor.go @@ -3,6 +3,9 @@ package gemini import ( "errors" "fmt" + "io" + "net/http" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" @@ -10,8 +13,6 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" ) type Adaptor struct { @@ -25,7 +26,7 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { version := helper.AssignOrDefault(meta.Config.APIVersion, config.GeminiVersion) action := "generateContent" if meta.IsStream { - action = "streamGenerateContent" + action = "streamGenerateContent?alt=sse" } return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil } diff --git a/relay/adaptor/gemini/main.go b/relay/adaptor/gemini/main.go index 8b934d305d..f1b4855179 100644 --- a/relay/adaptor/gemini/main.go +++ b/relay/adaptor/gemini/main.go @@ -232,8 +232,6 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *ChatResponse) *openai.ChatC func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) { responseText := "" - dataChan := make(chan string) - stopChan := make(chan bool) scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { @@ -247,14 +245,16 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC } return 0, nil, nil }) + dataChan := make(chan string) + stopChan := make(chan bool) go func() { for scanner.Scan() { data := scanner.Text() data = strings.TrimSpace(data) - if !strings.HasPrefix(data, "\"text\": \"") { + if !strings.HasPrefix(data, "data: ") { continue } - data = strings.TrimPrefix(data, "\"text\": \"") + data = strings.TrimPrefix(data, "data: ") data = strings.TrimSuffix(data, "\"") dataChan <- data } @@ -264,23 +264,17 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: - // this is used to prevent annoying \ related format bug - data = fmt.Sprintf("{\"content\": \"%s\"}", data) - type dummyStruct struct { - Content string `json:"content"` + var geminiResponse ChatResponse + err := json.Unmarshal([]byte(data), &geminiResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + return true } - var dummy dummyStruct - err := json.Unmarshal([]byte(data), &dummy) - responseText += dummy.Content - var choice openai.ChatCompletionsStreamResponseChoice - choice.Delta.Content = dummy.Content - response := openai.ChatCompletionsStreamResponse{ - Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), - Object: "chat.completion.chunk", - Created: helper.GetTimestamp(), - Model: "gemini-pro", - Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, + response := streamResponseGeminiChat2OpenAI(&geminiResponse) + if response == nil { + return true } + responseText += fmt.Sprintf("%v", response.Choices[0].Delta.Content) jsonResponse, err := json.Marshal(response) if err != nil { logger.SysError("error marshalling stream response: " + err.Error())