Skip to content

Commit

Permalink
fix token billing error
Browse files Browse the repository at this point in the history
  • Loading branch information
zmh-program committed Dec 26, 2023
1 parent 6487a48 commit 13c1295
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 13 deletions.
6 changes: 2 additions & 4 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@ RUN wget https://nodejs.org/dist/v16.14.0/node-v16.14.0-linux-x64.tar.xz && \

ENV PATH=$PATH:/usr/local/go/bin:/usr/local/node/bin

# Install npm
RUN npm install -g pnpm


# Copy source code
COPY . .
Expand All @@ -40,7 +37,8 @@ RUN go install && \
go build .

# Build frontend
RUN cd /app && \
RUN npm install -g pnpm && \
cd /app && \
pnpm install && \
pnpm run build && \
rm -rf node_modules
Expand Down
12 changes: 8 additions & 4 deletions manager/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,17 @@ import (
const defaultMessage = "Sorry, I don't understand. Please try again."
const defaultQuotaMessage = "You don't have enough quota or you don't have permission to use this model. please [buy](/buy) or [subscribe](/subscribe) to get more."

func CollectQuota(c *gin.Context, user *auth.User, buffer *utils.Buffer, uncountable bool) {
func CollectQuota(c *gin.Context, user *auth.User, buffer *utils.Buffer, uncountable bool, err error) {
db := utils.GetDBFromContext(c)
quota := buffer.GetQuota()
if buffer.IsEmpty() || buffer.GetCharge().IsBillingType(globals.TimesBilling) {
if buffer.IsEmpty() {
return
} else if buffer.GetCharge().IsBillingType(globals.TimesBilling) && err != nil {
// billing type is times, but error occurred
return
}

// collect quota for tokens billing (though error occurred) or times billing
if !uncountable && quota > 0 && user != nil {
user.UseQuota(db, quota)
}
Expand Down Expand Up @@ -115,15 +119,15 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve
globals.Warn(fmt.Sprintf("caught error from chat handler: %s (instance: %s, client: %s)", err, model, conn.GetCtx().ClientIP()))

auth.RevertSubscriptionUsage(db, cache, user, model)
CollectQuota(conn.GetCtx(), user, buffer, plan)
CollectQuota(conn.GetCtx(), user, buffer, plan, err)
conn.Send(globals.ChatSegmentResponse{
Message: err.Error(),
End: true,
})
return err.Error()
}

CollectQuota(conn.GetCtx(), user, buffer, plan)
CollectQuota(conn.GetCtx(), user, buffer, plan, err)

if buffer.IsEmpty() {
conn.Send(globals.ChatSegmentResponse{
Expand Down
4 changes: 2 additions & 2 deletions manager/completions.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ func NativeChatHandler(c *gin.Context, user *auth.User, model string, message []
admin.AnalysisRequest(model, buffer, err)
if err != nil {
auth.RevertSubscriptionUsage(db, cache, user, model)
CollectQuota(c, user, buffer, plan)
CollectQuota(c, user, buffer, plan, err)
return err.Error(), 0
}

CollectQuota(c, user, buffer, plan)
CollectQuota(c, user, buffer, plan, err)

SaveCacheData(c, &CacheProps{
Message: segment,
Expand Down
4 changes: 2 additions & 2 deletions manager/transhipment.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ func sendTranshipmentResponse(c *gin.Context, form TranshipmentForm, id string,
return
}

CollectQuota(c, user, buffer, plan)
CollectQuota(c, user, buffer, plan, err)
c.JSON(http.StatusOK, TranshipmentResponse{
Id: fmt.Sprintf("chatcmpl-%s", id),
Object: "chat.completion",
Expand Down Expand Up @@ -266,7 +266,7 @@ func sendStreamTranshipmentResponse(c *gin.Context, form TranshipmentForm, id st
}

partial <- getStreamTranshipmentForm(id, created, form, "", buffer, true, nil)
CollectQuota(c, user, buffer, plan)
CollectQuota(c, user, buffer, plan, err)
close(partial)
return
}()
Expand Down
6 changes: 5 additions & 1 deletion middleware/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,11 @@ func AuthMiddleware() gin.HandlerFunc {
instance := ProcessAuthorization(c)

if viper.GetBool("serve_static") {
path = strings.TrimPrefix(path, "/api")
if !strings.HasPrefix(path, "/api") {
return
} else {
path = strings.TrimPrefix(path, "/api")
}
}

db := utils.GetDBFromContext(c)
Expand Down

0 comments on commit 13c1295

Please sign in to comment.