Skip to content

Commit

Permalink
fix special charactor handle in ai-security-guard plugin (#1394)
Browse files Browse the repository at this point in the history
  • Loading branch information
rinfx authored Oct 18, 2024
1 parent 49bb5ec commit 32e5a59
Showing 1 changed file with 14 additions and 13 deletions.
27 changes: 14 additions & 13 deletions plugins/wasm-go/extensions/ai-security-guard/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func (config *AISecurityConfig) incrementCounter(metricName string, inc uint64)

func urlEncoding(rawStr string) string {
encodedStr := url.PathEscape(rawStr)
encodedStr = strings.ReplaceAll(encodedStr, "+", "%20")
encodedStr = strings.ReplaceAll(encodedStr, "+", "%2B")
encodedStr = strings.ReplaceAll(encodedStr, ":", "%3A")
encodedStr = strings.ReplaceAll(encodedStr, "=", "%3D")
encodedStr = strings.ReplaceAll(encodedStr, "&", "%26")
Expand All @@ -106,7 +106,7 @@ func getSign(params map[string]string, secret string) string {
})
canonicalStr := strings.Join(paramArray, "&")
signStr := "POST&%2F&" + urlEncoding(canonicalStr)
// proxywasm.LogInfo(signStr)
proxywasm.LogDebugf("String to sign is: %s", signStr)
return hmacSha1(signStr, secret)
}

Expand Down Expand Up @@ -196,10 +196,11 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log

func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action {
log.Debugf("checking request body...")
content := gjson.GetBytes(body, config.requestContentJsonPath).String()
content := gjson.GetBytes(body, config.requestContentJsonPath).Raw
model := gjson.GetBytes(body, "model").Raw
ctx.SetContext("requestModel", model)
if content != "" {
log.Debugf("Raw response content is: %s", content)
if len(content) > 0 {
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
randomID, _ := generateHexID(16)
params := map[string]string{
Expand All @@ -212,7 +213,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
"AccessKeyId": config.ak,
"Timestamp": timestamp,
"Service": config.requestCheckService,
"ServiceParameters": `{"content": "` + content + `"}`,
"ServiceParameters": fmt.Sprintf(`{"content": %s}`, content),
}
signature := getSign(params, config.sk+"&")
reqParams := url.Values{}
Expand Down Expand Up @@ -339,7 +340,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
if isStreamingResponse {
content = extractMessageFromStreamingBody(body, config.responseStreamContentJsonPath)
} else {
content = gjson.GetBytes(body, config.responseContentJsonPath).String()
content = gjson.GetBytes(body, config.responseContentJsonPath).Raw
}
log.Debugf("Raw response content is: %s", content)
if len(content) > 0 {
Expand All @@ -355,7 +356,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
"AccessKeyId": config.ak,
"Timestamp": timestamp,
"Service": config.responseCheckService,
"ServiceParameters": `{"content": "` + content + `"}`,
"ServiceParameters": fmt.Sprintf(`{"content": %s}`, content),
}
signature := getSign(params, config.sk+"&")
reqParams := url.Values{}
Expand Down Expand Up @@ -400,10 +401,10 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
jsonData = []byte(denyMessage)
} else if strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream") {
randomID := generateRandomID()
jsonData = []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, respAdvice.Array()[0].Get("Answer").String(), randomID, model))
jsonData = []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, denyMessage, randomID, model))
} else {
randomID := generateRandomID()
jsonData = []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, respAdvice.Array()[0].Get("Answer").String()))
jsonData = []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, denyMessage))
}
delete(hdsMap, "content-length")
hdsMap[":status"] = []string{fmt.Sprint(config.denyCode)}
Expand Down Expand Up @@ -432,10 +433,10 @@ func extractMessageFromStreamingBody(data []byte, jsonPath string) string {
strChunks := []string{}
for _, chunk := range chunks {
// Example: "choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]
jsonObj := gjson.GetBytes(chunk, jsonPath)
if jsonObj.Exists() {
strChunks = append(strChunks, jsonObj.String())
jsonRaw := gjson.GetBytes(chunk, jsonPath).Raw
if len(jsonRaw) > 2 {
strChunks = append(strChunks, jsonRaw[1:len(jsonRaw)-1])
}
}
return strings.Join(strChunks, "")
return fmt.Sprintf(`"%s"`, strings.Join(strChunks, ""))
}

0 comments on commit 32e5a59

Please sign in to comment.