Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement apiToken failover mechanism #1256

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
095b25e
feat: implement apiToken failover mechanism
cr7258 Aug 27, 2024
4af200c
Use SetSharedData for leader election and syncing apiTokens between W…
cr7258 Aug 31, 2024
192d855
Merge branch 'main' into failover
cr7258 Sep 1, 2024
856343c
support failover for all models
cr7258 Sep 1, 2024
7d5f427
add cas retry logic
cr7258 Sep 7, 2024
ee49848
wrap getApiTokenInUse funtion
cr7258 Sep 7, 2024
1e40d82
only removed the apiToken when the number of consecutive request fail…
cr7258 Sep 25, 2024
432395b
use uuid as vmid
cr7258 Sep 25, 2024
67551f2
fix byte covert
cr7258 Sep 26, 2024
82b2284
reset shared data during initialization
cr7258 Sep 26, 2024
daa48fe
Merge branch 'main' into failover
cr7258 Sep 26, 2024
8a818ed
failover support new model
cr7258 Sep 26, 2024
0554c85
fix
cr7258 Sep 26, 2024
e3401d5
move SetApiTokensFailover to complete function
cr7258 Sep 28, 2024
0f79913
wrap failover logic into ProviderConfig
cr7258 Sep 28, 2024
bda87f1
fix
cr7258 Sep 28, 2024
263c38c
config envoy local cluster and isolate apiToken ctx between different…
cr7258 Oct 5, 2024
374d5be
update README.md
cr7258 Oct 7, 2024
fd49f2d
add description
cr7258 Oct 7, 2024
66c371b
fix nil point exception when don't set failover config
cr7258 Oct 7, 2024
2130c00
Merge branch 'main' into failover
cr7258 Oct 7, 2024
7f36c09
support github provider
cr7258 Oct 7, 2024
01b92d8
fix
cr7258 Oct 10, 2024
a11a38b
Merge branch 'main' into failover
cr7258 Oct 10, 2024
01b0eec
unified the transformation of HTTP headers and body for ai-proxy and …
cr7258 Oct 17, 2024
a180e65
fix readme
cr7258 Oct 17, 2024
a72a8a1
optimize
cr7258 Oct 17, 2024
6a62333
refine transform headers and body
cr7258 Nov 3, 2024
f1f375e
move defaultInsertHttpContextMessage to context.go
cr7258 Nov 3, 2024
0296110
fix
cr7258 Nov 5, 2024
f164854
remove get context in original protocol
cr7258 Nov 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions plugins/wasm-go/extensions/ai-proxy/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,16 @@ description: AI 代理插件配置参考

`provider`的配置字段说明如下:

| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
| -------------- | --------------- | -------- | ------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `type` | string | 必填 | - | AI 服务提供商名称 |
| `apiTokens` | array of string | 非必填 | - | 用于在访问 AI 服务时进行认证的令牌。如果配置了多个 token,插件会在请求时随机进行选择。部分服务提供商只支持配置一个 token。 |
| `timeout` | number | 非必填 | - | 访问 AI 服务的超时时间。单位为毫秒。默认值为 120000,即 2 分钟 |
| `modelMapping` | map of string | 非必填 | - | AI 模型映射表,用于将请求中的模型名称映射为服务提供商支持模型名称。<br/>1. 支持前缀匹配。例如用 "gpt-3-*" 匹配所有名称以“gpt-3-”开头的模型;<br/>2. 支持使用 "*" 为键来配置通用兜底映射关系;<br/>3. 如果映射的目标名称为空字符串 "",则表示保留原模型名称。 |
| `protocol` | string | 非必填 | - | 插件对外提供的 API 接口契约。目前支持以下取值:openai(默认值,使用 OpenAI 的接口契约)、original(使用目标服务提供商的原始接口契约) |
| `context` | object | 非必填 | - | 配置 AI 对话上下文信息 |
| `customSettings` | array of customSetting | 非必填 | - | 为AI请求指定覆盖或者填充参数 |
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|------------------| --------------- | -------- | ------ |-----------------------------------------------------------------------------------------------------------------------------------------------------------|
| `type` | string | 必填 | - | AI 服务提供商名称 |
| `apiTokens` | array of string | 非必填 | - | 用于在访问 AI 服务时进行认证的令牌。如果配置了多个 token,插件会在请求时随机进行选择。部分服务提供商只支持配置一个 token。 |
| `timeout` | number | 非必填 | - | 访问 AI 服务的超时时间。单位为毫秒。默认值为 120000,即 2 分钟 |
| `modelMapping` | map of string | 非必填 | - | AI 模型映射表,用于将请求中的模型名称映射为服务提供商支持模型名称。<br/>1. 支持前缀匹配。例如用 "gpt-3-*" 匹配所有名称以“gpt-3-”开头的模型;<br/>2. 支持使用 "*" 为键来配置通用兜底映射关系;<br/>3. 如果映射的目标名称为空字符串 "",则表示保留原模型名称。 |
| `protocol` | string | 非必填 | - | 插件对外提供的 API 接口契约。目前支持以下取值:openai(默认值,使用 OpenAI 的接口契约)、original(使用目标服务提供商的原始接口契约) |
| `context` | object | 非必填 | - | 配置 AI 对话上下文信息 |
| `customSettings` | array of customSetting | 非必填 | - | 为AI请求指定覆盖或者填充参数 |
| `failover` | object | 非必填 | - | 配置 apiToken 的 failover 策略,当 apiToken 不可用时,将其移出 apiToken 列表,待健康检测通过后重新添加回 apiToken 列表 |

`context`的配置字段说明如下:

Expand Down Expand Up @@ -75,6 +76,16 @@ custom-setting会遵循如下表格,根据`name`和协议来替换对应的字
如果启用了raw模式,custom-setting会直接用输入的`name`和`value`去更改请求中的json内容,而不对参数名称做任何限制和修改。
对于大多数协议,custom-setting都会在json内容的根路径修改或者填充参数。对于`qwen`协议,ai-proxy会在json的`parameters`子路径下做配置。对于`gemini`协议,则会在`generation_config`子路径下做配置。

`failover` 的配置字段说明如下:

| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|------------------|--------|------|-------|-----------------------------|
| enabled | bool | 非必填 | false | 是否启用 apiToken 的 failover 机制 |
| failureThreshold | int | 非必填 | 3 | 触发 failover 连续请求失败的阈值(次数) |
| successThreshold | int | 非必填 | 1 | 健康检测的成功阈值(次数) |
| healthCheckInterval | int | 非必填 | 5000 | 健康检测的间隔时间,单位毫秒 |
| healthCheckTimeout | int | 非必填 | 5000 | 健康检测的超时时间,单位毫秒 |
| healthCheckModel | string | 必填 | | 健康检测使用的模型 |

### 提供商特有配置

Expand Down
10 changes: 7 additions & 3 deletions plugins/wasm-go/extensions/ai-proxy/config/config.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package config

import (
"github.com/tidwall/gjson"

"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/provider"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)

// @Name ai-proxy
Expand Down Expand Up @@ -75,13 +75,17 @@ func (c *PluginConfig) Validate() error {
return nil
}

func (c *PluginConfig) Complete() error {
func (c *PluginConfig) Complete(log wrapper.Log) error {
if c.activeProviderConfig == nil {
c.activeProvider = nil
return nil
}
var err error
c.activeProvider, err = provider.CreateProvider(*c.activeProviderConfig)

providerConfig := c.GetProviderConfig()
err = providerConfig.SetApiTokensFailover(log, c.activeProvider)

return err
}

Expand Down
21 changes: 18 additions & 3 deletions plugins/wasm-go/extensions/ai-proxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,10 @@ func parseGlobalConfig(json gjson.Result, pluginConfig *config.PluginConfig, log
if err := pluginConfig.Validate(); err != nil {
return err
}
if err := pluginConfig.Complete(); err != nil {
if err := pluginConfig.Complete(log); err != nil {
return err
}

return nil
}

Expand All @@ -59,9 +60,10 @@ func parseOverrideRuleConfig(json gjson.Result, global config.PluginConfig, plug
if err := pluginConfig.Validate(); err != nil {
return err
}
if err := pluginConfig.Complete(); err != nil {
if err := pluginConfig.Complete(log); err != nil {
return err
}

return nil
}

Expand All @@ -88,8 +90,11 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
ctx.SetContext(ctxKeyApiName, apiName)

if handler, ok := activeProvider.(provider.RequestHeadersHandler); ok {
// Disable the route re-calculation since the plugin may modify some headers related to the chosen route.
// Disable the route re-calculation since the plugin may modify some headers related to the chosen route.
ctx.DisableReroute()
// Set the apiToken for the current request.
providerConfig.SetApiTokenInUse(ctx, log)

hasRequestBody := wrapper.HasRequestBody()
action, err := handler.OnRequestHeaders(ctx, apiName, log)
if err == nil {
Expand All @@ -101,6 +106,7 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
}
return action
}

_ = util.SendResponse(500, "ai-proxy.proc_req_headers_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to process request headers: %v", err))
return types.ActionContinue
}
Expand Down Expand Up @@ -155,15 +161,24 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo

log.Debugf("[onHttpResponseHeaders] provider=%s", activeProvider.GetProviderType())

providerConfig := pluginConfig.GetProviderConfig()
apiTokenInUse := providerConfig.GetApiTokenInUse(ctx)

status, err := proxywasm.GetHttpResponseHeader(":status")
if err != nil || status != "200" {
if err != nil {
log.Errorf("unable to load :status header from response: %v", err)
}
ctx.DontReadResponseBody()
providerConfig.OnRequestFailed(ctx, apiTokenInUse, log)

return types.ActionContinue
}

// Reset ctxApiTokenRequestFailureCount if the request is successful,
// the apiToken is removed only when the number of consecutive request failures exceeds the threshold.
providerConfig.ResetApiTokenRequestFailureCount(apiTokenInUse, log)

if handler, ok := activeProvider.(provider.ResponseHeadersHandler); ok {
apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName)
action, err := handler.OnResponseHeaders(ctx, apiName, log)
Expand Down
7 changes: 3 additions & 4 deletions plugins/wasm-go/extensions/ai-proxy/provider/ai360.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"

"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)

// ai360Provider is the provider for 360 OpenAI service.
Expand Down Expand Up @@ -49,7 +48,7 @@ func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam
_ = util.OverwriteRequestHost(ai360Domain)
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
_ = proxywasm.ReplaceHttpRequestHeader("Authorization", m.config.GetRandomToken())
_ = proxywasm.ReplaceHttpRequestHeader("Authorization", m.config.GetApiTokenInUse(ctx))
// Delay the header processing to allow changing streaming mode in OnRequestBody
return types.HeaderStopIteration, nil
}
Expand Down
2 changes: 1 addition & 1 deletion plugins/wasm-go/extensions/ai-proxy/provider/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func (m *azureProvider) GetProviderType() string {
func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
_ = util.OverwriteRequestPath(m.serviceUrl.RequestURI())
_ = util.OverwriteRequestHost(m.serviceUrl.Host)
_ = proxywasm.ReplaceHttpRequestHeader("api-key", m.config.apiTokens[0])
_ = proxywasm.ReplaceHttpRequestHeader("api-key", m.config.GetApiTokenInUse(ctx))
if apiName == ApiNameChatCompletion {
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
} else {
Expand Down
3 changes: 1 addition & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package provider
import (
"errors"
"fmt"

"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
Expand Down Expand Up @@ -49,7 +48,7 @@ func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api
}
_ = util.OverwriteRequestPath(baichuanChatCompletionPath)
_ = util.OverwriteRequestHost(baichuanDomain)
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetApiTokenInUse(ctx))
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
return types.ActionContinue, nil
}
Expand Down
8 changes: 4 additions & 4 deletions plugins/wasm-go/extensions/ai-proxy/provider/baidu.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func (b *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
return types.ActionContinue, errors.New("request model is empty")
}
// 根据模型重写requestPath
path := b.getRequestPath(request.Model)
path := b.getRequestPath(ctx, request.Model)
_ = util.OverwriteRequestPath(path)

if b.config.context == nil {
Expand Down Expand Up @@ -126,7 +126,7 @@ func (b *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
}
request.Model = mappedModel
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
path := b.getRequestPath(mappedModel)
path := b.getRequestPath(ctx, mappedModel)
_ = util.OverwriteRequestPath(path)

if b.config.context == nil {
Expand Down Expand Up @@ -226,13 +226,13 @@ type baiduTextGenRequest struct {
UserId string `json:"user_id,omitempty"`
}

func (b *baiduProvider) getRequestPath(baiduModel string) string {
func (b *baiduProvider) getRequestPath(ctx wrapper.HttpContext, baiduModel string) string {
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
suffix, ok := baiduModelToPathSuffixMap[baiduModel]
if !ok {
suffix = baiduModel
}
return fmt.Sprintf("/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/%s?access_token=%s", suffix, b.config.GetRandomToken())
return fmt.Sprintf("/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/%s?access_token=%s", suffix, b.config.GetApiTokenInUse(ctx))
}

func (b *baiduProvider) setSystemContent(request *baiduTextGenRequest, content string) {
Expand Down
110 changes: 35 additions & 75 deletions plugins/wasm-go/extensions/ai-proxy/provider/claude.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"time"

Expand Down Expand Up @@ -105,102 +106,46 @@ func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
c.config.handleRequestHeaders(c, ctx, apiName, log)
return types.ActionContinue, nil
}

_ = util.OverwriteRequestPath(claudeChatCompletionPath)
_ = util.OverwriteRequestHost(claudeDomain)
_ = proxywasm.ReplaceHttpRequestHeader("x-api-key", c.config.GetRandomToken())
func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteHttpRequestPath(headers, claudeChatCompletionPath)
util.OverwriteHttpRequestHost(headers, claudeDomain)

headers.Add("x-api-key", c.config.GetApiTokenInUse(ctx))

if c.config.claudeVersion == "" {
c.config.claudeVersion = defaultVersion
}
_ = proxywasm.AddHttpRequestHeader("anthropic-version", c.config.claudeVersion)
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")

return types.ActionContinue, nil
headers.Add("anthropic-version", c.config.claudeVersion)
headers.Del("Accept-Encoding")
headers.Del("Content-Length")
}

func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log)
}

// use original protocol
if c.config.protocol == protocolOriginal {
if c.config.context == nil {
return types.ActionContinue, nil
}

request := &claudeTextGenRequest{}
if err := json.Unmarshal(body, request); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
}

err := c.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()

if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.claude.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.claude.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
}

// use openai protocol
func (c *claudeProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}

model := request.Model
if model == "" {
return types.ActionContinue, errors.New("missing model in chat completion request")
}
ctx.SetContext(ctxKeyOriginalRequestModel, model)
mappedModel := getMappedModel(model, c.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
err := c.config.parseRequestAndMapModel(ctx, request, body, log)
if err != nil {
return nil, err
}
request.Model = mappedModel
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)

streaming := request.Stream
if streaming {
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream")
}

if c.config.context == nil {
claudeRequest := c.buildClaudeTextGenRequest(request)
return types.ActionContinue, replaceJsonRequestBody(claudeRequest, log)
}

err := c.contextCache.GetContent(func(content string, err error) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

context逻辑

defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.claude.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
insertContextMessage(request, content)
claudeRequest := c.buildClaudeTextGenRequest(request)
if err := replaceJsonRequestBody(claudeRequest, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.claude.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
claudeRequest := c.buildClaudeTextGenRequest(request)
return json.Marshal(claudeRequest)
}

func (c *claudeProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
Expand Down Expand Up @@ -369,3 +314,18 @@ func createChatCompletionResponse(ctx wrapper.HttpContext, response *claudeTextG
func (c *claudeProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) {
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
}

func (c *claudeProvider) insertHttpContextMessage(body []byte, content string, onlyOneSystemBeforeFile bool) ([]byte, error) {
request := &claudeTextGenRequest{}
if err := json.Unmarshal(body, request); err != nil {
return nil, fmt.Errorf("unable to unmarshal request: %v", err)
}

if request.System == "" {
request.System = content
} else {
request.System = content + "\n" + request.System
}

return json.Marshal(request)
}
Loading
Loading