diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index 60773f71c0..35d06b9502 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -108,6 +108,8 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf if handler, ok := activeProvider.(provider.RequestHeadersHandler); ok { // Set the apiToken for the current request. providerConfig.SetApiTokenInUse(ctx, log) + // Set available apiTokens of current request in the context, will be used in the retryOnFailure + providerConfig.SetAvailableApiTokens(ctx, log) err := handler.OnRequestHeaders(ctx, apiName, log) if err != nil { @@ -179,6 +181,7 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo providerConfig := pluginConfig.GetProviderConfig() apiTokenInUse := providerConfig.GetApiTokenInUse(ctx) + apiTokens := providerConfig.GetAvailableApiToken(ctx) status, err := proxywasm.GetHttpResponseHeader(":status") if err != nil || status != "200" { @@ -186,7 +189,7 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo log.Errorf("unable to load :status header from response: %v", err) } ctx.DontReadResponseBody() - return providerConfig.OnRequestFailed(activeProvider, ctx, apiTokenInUse, log) + return providerConfig.OnRequestFailed(activeProvider, ctx, apiTokenInUse, apiTokens, log) } // Reset ctxApiTokenRequestFailureCount if the request is successful, diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go index 6c8259949b..9644693f5e 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go @@ -32,6 +32,8 @@ type failover struct { healthCheckModel string `required:"false" yaml:"healthCheckModel" json:"healthCheckModel"` // @Title zh-CN 本次请求使用的 apiToken ctxApiTokenInUse string + // @Title zh-CN 记录本次请求时所有可用的 apiToken + ctxAvailableApiTokensInRequest string // @Title zh-CN 记录 apiToken 请求失败的次数,key 为 apiToken,value 为失败次数 ctxApiTokenRequestFailureCount string // @Title zh-CN 记录 apiToken 健康检测成功的次数,key 为 apiToken,value 为成功次数 @@ -527,6 +529,22 @@ func (c *ProviderConfig) GetGlobalRandomToken(log wrapper.Log) string { } } +func (c *ProviderConfig) GetAvailableApiToken(ctx wrapper.HttpContext) []string { + apiTokens, _ := ctx.GetContext(c.failover.ctxAvailableApiTokensInRequest).([]string) + return apiTokens +} + +// SetAvailableApiTokens set available apiTokens of current request in the context, will be used in the retryOnFailure +func (c *ProviderConfig) SetAvailableApiTokens(ctx wrapper.HttpContext, log wrapper.Log) { + var apiTokens []string + if c.isFailoverEnabled() { + apiTokens, _, _ = getApiTokens(c.failover.ctxApiTokens) + } else { + apiTokens = c.apiTokens + } + ctx.SetContext(c.failover.ctxAvailableApiTokensInRequest, apiTokens) +} + func (c *ProviderConfig) isFailoverEnabled() bool { return c.failover.enabled } @@ -539,12 +557,12 @@ func (c *ProviderConfig) resetSharedData() { _ = proxywasm.SetSharedData(c.failover.ctxApiTokenRequestFailureCount, nil, 0) } -func (c *ProviderConfig) OnRequestFailed(activeProvider Provider, ctx wrapper.HttpContext, apiTokenInUse string, log wrapper.Log) types.Action { +func (c *ProviderConfig) OnRequestFailed(activeProvider Provider, ctx wrapper.HttpContext, apiTokenInUse string, apiTokens []string, log wrapper.Log) types.Action { if c.isFailoverEnabled() { c.handleUnavailableApiToken(ctx, apiTokenInUse, log) } if c.isRetryOnFailureEnabled() && ctx.GetContext(ctxKeyIsStreaming) != nil && !ctx.GetContext(ctxKeyIsStreaming).(bool) { - c.retryFailedRequest(activeProvider, ctx, log) + c.retryFailedRequest(activeProvider, ctx, apiTokenInUse, apiTokens, log) return types.HeaderStopAllIterationAndWatermark } return types.ActionContinue diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/retry.go b/plugins/wasm-go/extensions/ai-proxy/provider/retry.go index 033a8cd8c5..59691d855f 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/retry.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/retry.go @@ -1,11 +1,13 @@ package provider import ( + "math/rand" + "net/http" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/tidwall/gjson" - "net/http" ) const ( @@ -38,12 +40,12 @@ func (c *ProviderConfig) isRetryOnFailureEnabled() bool { return c.retryOnFailure.enabled } -func (c *ProviderConfig) retryFailedRequest(activeProvider Provider, ctx wrapper.HttpContext, log wrapper.Log) { +func (c *ProviderConfig) retryFailedRequest(activeProvider Provider, ctx wrapper.HttpContext, apiTokenInUse string, apiTokens []string, log wrapper.Log) { log.Debugf("Retry failed request: provider=%s", activeProvider.GetProviderType()) retryClient := createRetryClient(ctx) apiName, _ := ctx.GetContext(CtxKeyApiName).(ApiName) ctx.SetContext(ctxRetryCount, 1) - c.sendRetryRequest(ctx, apiName, activeProvider, retryClient, log) + c.sendRetryRequest(ctx, apiName, activeProvider, retryClient, apiTokenInUse, apiTokens, log) } func (c *ProviderConfig) transformResponseHeadersAndBody(ctx wrapper.HttpContext, activeProvider Provider, apiName ApiName, headers http.Header, body []byte, log wrapper.Log) ([][2]string, []byte) { @@ -67,7 +69,8 @@ func (c *ProviderConfig) transformResponseHeadersAndBody(ctx wrapper.HttpContext func (c *ProviderConfig) retryCall( ctx wrapper.HttpContext, log wrapper.Log, activeProvider Provider, apiName ApiName, statusCode int, responseHeaders http.Header, responseBody []byte, - retryClient *wrapper.ClusterClient[wrapper.RouteCluster]) { + retryClient *wrapper.ClusterClient[wrapper.RouteCluster], + apiTokenInUse string, apiTokens []string) { retryCount := ctx.GetContext(ctxRetryCount).(int) log.Debugf("Sent retry request: %d/%d", retryCount, c.retryOnFailure.maxRetries) @@ -76,6 +79,7 @@ func (c *ProviderConfig) retryCall( log.Debugf("Retry request succeeded") headers, body := c.transformResponseHeadersAndBody(ctx, activeProvider, apiName, responseHeaders, responseBody, log) proxywasm.SendHttpResponse(200, headers, body, -1) + return } else { log.Debugf("The retry request still failed, status: %d, responseHeaders: %v, responseBody: %s", statusCode, responseHeaders, string(responseBody)) } @@ -83,26 +87,41 @@ func (c *ProviderConfig) retryCall( retryCount++ if retryCount <= int(c.retryOnFailure.maxRetries) { ctx.SetContext(ctxRetryCount, retryCount) - c.sendRetryRequest(ctx, apiName, activeProvider, retryClient, log) + c.sendRetryRequest(ctx, apiName, activeProvider, retryClient, apiTokenInUse, apiTokens, log) } else { log.Debugf("Reached the maximum retry count: %d", c.retryOnFailure.maxRetries) proxywasm.ResumeHttpResponse() + return } } func (c *ProviderConfig) sendRetryRequest( ctx wrapper.HttpContext, apiName ApiName, activeProvider Provider, - retryClient *wrapper.ClusterClient[wrapper.RouteCluster], log wrapper.Log) { + retryClient *wrapper.ClusterClient[wrapper.RouteCluster], + apiTokenInUse string, apiTokens []string, log wrapper.Log) { + + // Remove last failed token from retry apiTokens list + apiTokens = removeApiTokenFromRetryList(apiTokens, apiTokenInUse, log) + if len(apiTokens) == 0 { + log.Debugf("No more apiTokens to retry") + proxywasm.ResumeHttpResponse() + return + } + // Set apiTokenInUse for the retry request + apiTokenInUse = GetRandomToken(apiTokens) + log.Debugf("Retry request with apiToken: %s", apiTokenInUse) + ctx.SetContext(c.failover.ctxApiTokenInUse, apiTokenInUse) requestHeaders, requestBody := c.getRetryRequestHeadersAndBody(ctx, activeProvider, apiName, log) path := getRetryPath(ctx) err := retryClient.Post(path, util.HeaderToSlice(requestHeaders), requestBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) { - c.retryCall(ctx, log, activeProvider, apiName, statusCode, responseHeaders, responseBody, retryClient) + c.retryCall(ctx, log, activeProvider, apiName, statusCode, responseHeaders, responseBody, retryClient, apiTokenInUse, apiTokens) }, uint32(c.retryOnFailure.retryTimeout)) if err != nil { log.Errorf("Failed to send retry request: %v", err) proxywasm.ResumeHttpResponse() + return } } @@ -126,9 +145,7 @@ func getRetryPath(ctx wrapper.HttpContext) string { } func (c *ProviderConfig) getRetryRequestHeadersAndBody(ctx wrapper.HttpContext, activeProvider Provider, apiName ApiName, log wrapper.Log) (http.Header, []byte) { - // The retry request may be sent with different apiToken, so the header needs to be regenerated - c.SetApiTokenInUse(ctx, log) - + // The retry request is sent with different apiToken, so the header needs to be regenerated requestHeaders := http.Header{ "Content-Type": []string{"application/json"}, } @@ -139,3 +156,27 @@ func (c *ProviderConfig) getRetryRequestHeadersAndBody(ctx wrapper.HttpContext, return requestHeaders, requestBody } + +func removeApiTokenFromRetryList(apiTokens []string, removedApiToken string, log wrapper.Log) []string { + var availableApiTokens []string + for _, s := range apiTokens { + if s != removedApiToken { + availableApiTokens = append(availableApiTokens, s) + } + } + log.Debugf("Remove apiToken %s from retry apiTokens list", removedApiToken) + log.Debugf("Available retry apiTokens: %v", availableApiTokens) + return availableApiTokens +} + +func GetRandomToken(apiTokens []string) string { + count := len(apiTokens) + switch count { + case 0: + return "" + case 1: + return apiTokens[0] + default: + return apiTokens[rand.Intn(count)] + } +}