diff --git a/plugins/wasm-go/extensions/ai-security-guard/main.go b/plugins/wasm-go/extensions/ai-security-guard/main.go index ca59f7f6a1..5b61589616 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/main.go +++ b/plugins/wasm-go/extensions/ai-security-guard/main.go @@ -187,10 +187,6 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log log.Debugf("request checking is disabled") ctx.DontReadRequestBody() } - if !config.checkResponse { - log.Debugf("response checking is disabled") - ctx.DontReadResponseBody() - } return types.ActionContinue } @@ -199,7 +195,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body [] content := gjson.GetBytes(body, config.requestContentJsonPath).Raw model := gjson.GetBytes(body, "model").Raw ctx.SetContext("requestModel", model) - log.Debugf("Raw response content is: %s", content) + log.Debugf("Raw request content is: %s", content) if len(content) > 0 { timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z") randomID, _ := generateHexID(16) @@ -321,6 +317,11 @@ func reconvertHeaders(hs map[string][]string) [][2]string { } func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log wrapper.Log) types.Action { + if !config.checkResponse { + log.Debugf("response checking is disabled") + ctx.DontReadResponseBody() + return types.ActionContinue + } headers, err := proxywasm.GetHttpResponseHeaders() if err != nil { log.Warnf("failed to get response headers: %v", err) @@ -399,7 +400,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [ var jsonData []byte if config.protocolOriginal { jsonData = []byte(denyMessage) - } else if strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream") { + } else if isStreamingResponse { randomID := generateRandomID() jsonData = []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, denyMessage, randomID, model)) } else {