diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index 4a2d1fb98c..92ec3940f3 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -15,6 +15,7 @@ import ( "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/tidwall/gjson" + "github.com/tidwall/sjson" ) const ( @@ -149,7 +150,14 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig ) return types.ActionContinue } - + // Default setting include_usage. + if gjson.GetBytes(body, "stream").Bool() { + var err error + newBody, err = sjson.SetBytes(newBody, "stream_options.include_usage", true) + if err != nil { + log.Errorf("set include_usage failed, err:%s", err) + } + } log.Debugf("[onHttpRequestBody] newBody=%s", newBody) body = newBody action, err := handler.OnRequestBody(ctx, apiName, body, log) diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/openai.go b/plugins/wasm-go/extensions/ai-proxy/provider/openai.go index 2def57aa62..f875dbaa40 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/openai.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/openai.go @@ -127,21 +127,14 @@ func (m *openaiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, } func (m *openaiProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) { - request := &chatCompletionRequest{} - if err := decodeChatCompletionRequest(body, request); err != nil { - return nil, err - } if m.config.responseJsonSchema != nil { + request := &chatCompletionRequest{} + if err := decodeChatCompletionRequest(body, request); err != nil { + return nil, err + } log.Debugf("[ai-proxy] set response format to %s", m.config.responseJsonSchema) request.ResponseFormat = m.config.responseJsonSchema + body, _ = json.Marshal(request) } - if request.Stream { - // For stream requests, we need to include usage in the response. - if request.StreamOptions == nil { - request.StreamOptions = &streamOptions{IncludeUsage: true} - } else if !request.StreamOptions.IncludeUsage { - request.StreamOptions.IncludeUsage = true - } - } - return json.Marshal(request) + return m.config.defaultTransformRequestBody(ctx, apiName, body, log) }