generated from songquanpeng/gin-template
-
-
Notifications
You must be signed in to change notification settings - Fork 4.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
* feat: initial support of Dall-E * fix: fix N not timed --------- Co-authored-by: JustSong <[email protected]> Co-authored-by: JustSong <[email protected]>
- Loading branch information
1 parent
81c5901
commit b520b54
Showing
7 changed files
with
191 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,34 +1,181 @@ | ||
package controller | ||
|
||
import ( | ||
"github.com/gin-gonic/gin" | ||
"bytes" | ||
"encoding/json" | ||
"errors" | ||
"fmt" | ||
"io" | ||
"net/http" | ||
"one-api/common" | ||
"one-api/model" | ||
|
||
"github.com/gin-gonic/gin" | ||
) | ||
|
||
func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | ||
// TODO: this part is not finished | ||
req, err := http.NewRequest(c.Request.Method, c.Request.RequestURI, c.Request.Body) | ||
imageModel := "dall-e" | ||
|
||
tokenId := c.GetInt("token_id") | ||
channelType := c.GetInt("channel") | ||
userId := c.GetInt("id") | ||
consumeQuota := c.GetBool("consume_quota") | ||
group := c.GetString("group") | ||
|
||
var imageRequest ImageRequest | ||
if consumeQuota { | ||
err := common.UnmarshalBodyReusable(c, &imageRequest) | ||
if err != nil { | ||
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) | ||
} | ||
} | ||
|
||
// Prompt validation | ||
if imageRequest.Prompt == "" { | ||
return errorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest) | ||
} | ||
|
||
// Not "256x256", "512x512", or "1024x1024" | ||
if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" { | ||
return errorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024"), "invalid_field_value", http.StatusBadRequest) | ||
} | ||
|
||
// N should between 1 and 10 | ||
if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) { | ||
return errorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest) | ||
} | ||
|
||
// map model name | ||
modelMapping := c.GetString("model_mapping") | ||
isModelMapped := false | ||
if modelMapping != "" { | ||
modelMap := make(map[string]string) | ||
err := json.Unmarshal([]byte(modelMapping), &modelMap) | ||
if err != nil { | ||
return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) | ||
} | ||
if modelMap[imageModel] != "" { | ||
imageModel = modelMap[imageModel] | ||
isModelMapped = true | ||
} | ||
} | ||
|
||
baseURL := common.ChannelBaseURLs[channelType] | ||
requestURL := c.Request.URL.String() | ||
|
||
if c.GetString("base_url") != "" { | ||
baseURL = c.GetString("base_url") | ||
} | ||
|
||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) | ||
|
||
var requestBody io.Reader | ||
if isModelMapped { | ||
jsonStr, err := json.Marshal(imageRequest) | ||
if err != nil { | ||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | ||
} | ||
requestBody = bytes.NewBuffer(jsonStr) | ||
} else { | ||
requestBody = c.Request.Body | ||
} | ||
|
||
modelRatio := common.GetModelRatio(imageModel) | ||
groupRatio := common.GetGroupRatio(group) | ||
ratio := modelRatio * groupRatio | ||
userQuota, err := model.CacheGetUserQuota(userId) | ||
|
||
sizeRatio := 1.0 | ||
// Size | ||
if imageRequest.Size == "256x256" { | ||
sizeRatio = 1 | ||
} else if imageRequest.Size == "512x512" { | ||
sizeRatio = 1.125 | ||
} else if imageRequest.Size == "1024x1024" { | ||
sizeRatio = 1.25 | ||
} | ||
quota := int(ratio*sizeRatio*1000) * imageRequest.N | ||
|
||
if consumeQuota && userQuota-quota < 0 { | ||
return errorWrapper(err, "insufficient_user_quota", http.StatusForbidden) | ||
} | ||
|
||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) | ||
if err != nil { | ||
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) | ||
} | ||
req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) | ||
|
||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) | ||
req.Header.Set("Accept", c.Request.Header.Get("Accept")) | ||
|
||
client := &http.Client{} | ||
resp, err := client.Do(req) | ||
if err != nil { | ||
return errorWrapper(err, "do_request_failed", http.StatusOK) | ||
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) | ||
} | ||
|
||
err = req.Body.Close() | ||
if err != nil { | ||
return errorWrapper(err, "close_request_body_failed", http.StatusOK) | ||
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||
} | ||
err = c.Request.Body.Close() | ||
if err != nil { | ||
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) | ||
} | ||
var textResponse ImageResponse | ||
|
||
defer func() { | ||
if consumeQuota { | ||
err := model.PostConsumeTokenQuota(tokenId, quota) | ||
if err != nil { | ||
common.SysError("error consuming token remain quota: " + err.Error()) | ||
} | ||
err = model.CacheUpdateUserQuota(userId) | ||
if err != nil { | ||
common.SysError("error update user quota cache: " + err.Error()) | ||
} | ||
if quota != 0 { | ||
tokenName := c.GetString("token_name") | ||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | ||
model.RecordConsumeLog(userId, 0, 0, imageModel, tokenName, quota, logContent) | ||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota) | ||
channelId := c.GetInt("channel_id") | ||
model.UpdateChannelUsedQuota(channelId, quota) | ||
} | ||
} | ||
}() | ||
|
||
if consumeQuota { | ||
responseBody, err := io.ReadAll(resp.Body) | ||
|
||
if err != nil { | ||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) | ||
} | ||
err = resp.Body.Close() | ||
if err != nil { | ||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | ||
} | ||
err = json.Unmarshal(responseBody, &textResponse) | ||
if err != nil { | ||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) | ||
} | ||
|
||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) | ||
} | ||
|
||
for k, v := range resp.Header { | ||
c.Writer.Header().Set(k, v[0]) | ||
} | ||
c.Writer.WriteHeader(resp.StatusCode) | ||
|
||
_, err = io.Copy(c.Writer, resp.Body) | ||
if err != nil { | ||
return errorWrapper(err, "copy_response_body_failed", http.StatusOK) | ||
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) | ||
} | ||
err = resp.Body.Close() | ||
if err != nil { | ||
return errorWrapper(err, "close_response_body_failed", http.StatusOK) | ||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | ||
} | ||
return nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters