Skip to content

Commit

Permalink
bugfix: plugin will block GET request (#1428)
Browse files Browse the repository at this point in the history
  • Loading branch information
rinfx authored Oct 24, 2024
1 parent e7561c3 commit d952fa5
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions plugins/wasm-go/extensions/ai-security-guard/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,6 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log
log.Debugf("request checking is disabled")
ctx.DontReadRequestBody()
}
if !config.checkResponse {
log.Debugf("response checking is disabled")
ctx.DontReadResponseBody()
}
return types.ActionContinue
}

Expand All @@ -199,7 +195,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
content := gjson.GetBytes(body, config.requestContentJsonPath).Raw
model := gjson.GetBytes(body, "model").Raw
ctx.SetContext("requestModel", model)
log.Debugf("Raw response content is: %s", content)
log.Debugf("Raw request content is: %s", content)
if len(content) > 0 {
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
randomID, _ := generateHexID(16)
Expand Down Expand Up @@ -321,6 +317,11 @@ func reconvertHeaders(hs map[string][]string) [][2]string {
}

func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log wrapper.Log) types.Action {
if !config.checkResponse {
log.Debugf("response checking is disabled")
ctx.DontReadResponseBody()
return types.ActionContinue
}
headers, err := proxywasm.GetHttpResponseHeaders()
if err != nil {
log.Warnf("failed to get response headers: %v", err)
Expand Down Expand Up @@ -399,7 +400,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
var jsonData []byte
if config.protocolOriginal {
jsonData = []byte(denyMessage)
} else if strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream") {
} else if isStreamingResponse {
randomID := generateRandomID()
jsonData = []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, denyMessage, randomID, model))
} else {
Expand Down

0 comments on commit d952fa5

Please sign in to comment.