diff --git a/plugins/wasm-go/extensions/ai-proxy/README.md b/plugins/wasm-go/extensions/ai-proxy/README.md index 9d7dbc796b..8317f653d4 100644 --- a/plugins/wasm-go/extensions/ai-proxy/README.md +++ b/plugins/wasm-go/extensions/ai-proxy/README.md @@ -148,7 +148,15 @@ Groq 所对应的 `type` 为 `groq`。它并无特有的配置字段。 #### 文心一言(Baidu) -文心一言所对应的 `type` 为 `baidu`。它并无特有的配置字段。 +文心一言所对应的 `type` 为 `baidu`。它特有的配置字段如下: + +| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | +|--------------------|-----------------|------|-----|-----------------------------------------------------------| +| `baiduAccessKeyAndSecret` | array of string | 必填 | - | Baidu 的 Access Key 和 Secret Key,中间用 `:` 分隔,用于申请 apiToken。 | +| `baiduApiTokenServiceName` | string | 必填 | - | 请求刷新百度 apiToken 服务名称。 | +| `baiduApiTokenServiceHost` | string | 非必填 | - | 请求刷新百度 apiToken 服务域名,默认是 iam.bj.baidubce.com。 | +| `baiduApiTokenServicePort` | int64 | 非必填 | - | 请求刷新百度 apiToken 服务端口,默认是 443。 | + #### 360智脑 diff --git a/plugins/wasm-go/extensions/ai-proxy/config/config.go b/plugins/wasm-go/extensions/ai-proxy/config/config.go index 48f08dd9e4..a510115f40 100644 --- a/plugins/wasm-go/extensions/ai-proxy/config/config.go +++ b/plugins/wasm-go/extensions/ai-proxy/config/config.go @@ -86,6 +86,11 @@ func (c *PluginConfig) Complete(log wrapper.Log) error { providerConfig := c.GetProviderConfig() err = providerConfig.SetApiTokensFailover(log, c.activeProvider) + if handler, ok := c.activeProvider.(provider.TickFuncHandler); ok { + tickPeriod, tickFunc := handler.GetTickFunc(log) + wrapper.RegisteTickFunc(tickPeriod, tickFunc) + } + return err } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go index 42a1bc723d..0908836290 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go @@ -1,48 +1,53 @@ package provider import ( + "crypto/hmac" + "crypto/sha256" + "encoding/hex" "encoding/json" "errors" "fmt" "net/http" + "sort" "strings" "time" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" ) -// baiduProvider is the provider for baidu ernie bot service. - +// baiduProvider is the provider for baidu service. const ( - baiduDomain = "aip.baidubce.com" - baiduChatCompletionPath = "/chat" + baiduDomain = "qianfan.baidubce.com" + baiduChatCompletionPath = "/v2/chat/completions" + baiduApiTokenDomain = "iam.bj.baidubce.com" + baiduApiTokenPort = 443 + baiduApiTokenPath = "/v1/BCE-BEARER/token" + // refresh apiToken every 1 hour + baiduApiTokenRefreshInterval = 3600 + // authorizationString expires in 30 minutes, authorizationString is used to generate apiToken + // the default expiration time of apiToken is 24 hours + baiduAuthorizationStringExpirationSeconds = 1800 + bce_prefix = "x-bce-" ) -var baiduModelToPathSuffixMap = map[string]string{ - "ERNIE-4.0-8K": "completions_pro", - "ERNIE-3.5-8K": "completions", - "ERNIE-3.5-128K": "ernie-3.5-128k", - "ERNIE-Speed-8K": "ernie_speed", - "ERNIE-Speed-128K": "ernie-speed-128k", - "ERNIE-Tiny-8K": "ernie-tiny-8k", - "ERNIE-Bot-8K": "ernie_bot_8k", - "BLOOMZ-7B": "bloomz_7b1", -} - -type baiduProviderInitializer struct { -} +type baiduProviderInitializer struct{} -func (b *baiduProviderInitializer) ValidateConfig(config ProviderConfig) error { - if config.apiTokens == nil || len(config.apiTokens) == 0 { - return errors.New("no apiToken found in provider config") +func (g *baiduProviderInitializer) ValidateConfig(config ProviderConfig) error { + if config.baiduAccessKeyAndSecret == nil || len(config.baiduAccessKeyAndSecret) == 0 { + return errors.New("no baiduAccessKeyAndSecret found in provider config") + } + if config.baiduApiTokenServiceName == "" { + return errors.New("no baiduApiTokenServiceName found in provider config") + } + if !config.failover.enabled { + config.useGlobalApiToken = true } return nil } -func (b *baiduProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { +func (g *baiduProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { return &baiduProvider{ config: config, contextCache: createContextCache(&config), @@ -54,234 +59,235 @@ type baiduProvider struct { contextCache *contextCache } -func (b *baiduProvider) GetProviderType() string { +func (g *baiduProvider) GetProviderType() string { return providerTypeBaidu } -func (b *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { +func (g *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { if apiName != ApiNameChatCompletion { return types.ActionContinue, errUnsupportedApiName } - b.config.handleRequestHeaders(b, ctx, apiName, log) - // Delay the header processing to allow changing streaming mode in OnRequestBody - return types.HeaderStopIteration, nil + g.config.handleRequestHeaders(g, ctx, apiName, log) + return types.ActionContinue, nil } -func (b *baiduProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { +func (g *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { + if apiName != ApiNameChatCompletion { + return types.ActionContinue, errUnsupportedApiName + } + return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log) +} + +func (g *baiduProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { + util.OverwriteRequestPathHeader(headers, baiduChatCompletionPath) util.OverwriteRequestHostHeader(headers, baiduDomain) - headers.Del("Accept-Encoding") + util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+g.config.GetApiTokenInUse(ctx)) headers.Del("Content-Length") } -func (b *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - if apiName != ApiNameChatCompletion { - return types.ActionContinue, errUnsupportedApiName +func (g *baiduProvider) GetApiName(path string) ApiName { + if strings.Contains(path, baiduChatCompletionPath) { + return ApiNameChatCompletion } - return b.config.handleRequestBody(b, b.contextCache, ctx, apiName, body, log) + return "" } -func (b *baiduProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) { - request := &chatCompletionRequest{} - err := b.config.parseRequestAndMapModel(ctx, request, body, log) - if err != nil { - return nil, err +func generateAuthorizationString(accessKeyAndSecret string, expirationInSeconds int) string { + c := strings.Split(accessKeyAndSecret, ":") + credentials := BceCredentials{ + AccessKeyId: c[0], + SecretAccessKey: c[1], + } + httpMethod := "GET" + path := baiduApiTokenPath + headers := map[string]string{"host": baiduApiTokenDomain} + timestamp := time.Now().Unix() + + headersToSign := make([]string, 0, len(headers)) + for k := range headers { + headersToSign = append(headersToSign, k) } - path := b.getRequestPath(ctx, request.Model) - util.OverwriteRequestPathHeader(headers, path) - baiduRequest := b.baiduTextGenRequest(request) - return json.Marshal(baiduRequest) + return sign(credentials, httpMethod, path, headers, timestamp, expirationInSeconds, headersToSign) } -func (b *baiduProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { - // 使用文心一言接口协议,跳过OnStreamingResponseBody()和OnResponseBody() - if b.config.protocol == protocolOriginal { - ctx.DontReadResponseBody() - return types.ActionContinue, nil - } - - _ = proxywasm.RemoveHttpResponseHeader("Content-Length") - return types.ActionContinue, nil +// BceCredentials holds the access key and secret key +type BceCredentials struct { + AccessKeyId string + SecretAccessKey string } -func (b *baiduProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) { - if isLastChunk || len(chunk) == 0 { - return nil, nil +// normalizeString performs URI encoding according to RFC 3986 +func normalizeString(inStr string, encodingSlash bool) string { + if inStr == "" { + return "" } - // sample event response: - // data: {"id":"as-vb0m37ti8y","object":"chat.completion","created":1709089502,"sentence_id":0,"is_end":false,"is_truncated":false,"result":"当然可以,","need_clear_history":false,"finish_reason":"normal","usage":{"prompt_tokens":5,"completion_tokens":2,"total_tokens":7}} - - // sample end event response: - // data: {"id":"as-vb0m37ti8y","object":"chat.completion","created":1709089531,"sentence_id":20,"is_end":true,"is_truncated":false,"result":"","need_clear_history":false,"finish_reason":"normal","usage":{"prompt_tokens":5,"completion_tokens":420,"total_tokens":425}} - responseBuilder := &strings.Builder{} - lines := strings.Split(string(chunk), "\n") - for _, data := range lines { - if len(data) < 6 { - // ignore blank line or wrong format - continue - } - data = data[6:] - var baiduResponse baiduTextGenStreamResponse - if err := json.Unmarshal([]byte(data), &baiduResponse); err != nil { - log.Errorf("unable to unmarshal baidu response: %v", err) - continue - } - response := b.streamResponseBaidu2OpenAI(ctx, &baiduResponse) - responseBody, err := json.Marshal(response) - if err != nil { - log.Errorf("unable to marshal response: %v", err) - return nil, err + + var result strings.Builder + for _, ch := range []byte(inStr) { + if (ch >= 'A' && ch <= 'Z') || (ch >= 'a' && ch <= 'z') || + (ch >= '0' && ch <= '9') || ch == '.' || ch == '-' || + ch == '_' || ch == '~' || (!encodingSlash && ch == '/') { + result.WriteByte(ch) + } else { + result.WriteString(fmt.Sprintf("%%%02X", ch)) } - b.appendResponse(responseBuilder, string(responseBody)) } - modifiedResponseChunk := responseBuilder.String() - log.Debugf("=== modified response chunk: %s", modifiedResponseChunk) - return []byte(modifiedResponseChunk), nil + return result.String() } -func (b *baiduProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { - baiduResponse := &baiduTextGenResponse{} - if err := json.Unmarshal(body, baiduResponse); err != nil { - return types.ActionContinue, fmt.Errorf("unable to unmarshal baidu response: %v", err) +// getCanonicalTime generates a timestamp in UTC format +func getCanonicalTime(timestamp int64) string { + if timestamp == 0 { + timestamp = time.Now().Unix() } - if baiduResponse.ErrorMsg != "" { - return types.ActionContinue, fmt.Errorf("baidu response error, error_code: %d, error_message: %s", baiduResponse.ErrorCode, baiduResponse.ErrorMsg) - } - response := b.responseBaidu2OpenAI(ctx, baiduResponse) - return types.ActionContinue, replaceJsonResponseBody(response, log) + t := time.Unix(timestamp, 0).UTC() + return t.Format("2006-01-02T15:04:05Z") } -type baiduTextGenRequest struct { - Model string `json:"model"` - Messages []chatMessage `json:"messages"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - PenaltyScore float64 `json:"penalty_score,omitempty"` - Stream bool `json:"stream,omitempty"` - System string `json:"system,omitempty"` - DisableSearch bool `json:"disable_search,omitempty"` - EnableCitation bool `json:"enable_citation,omitempty"` - MaxOutputTokens int `json:"max_output_tokens,omitempty"` - UserId string `json:"user_id,omitempty"` +// getCanonicalUri generates a canonical URI +func getCanonicalUri(path string) string { + return normalizeString(path, false) } -func (b *baiduProvider) getRequestPath(ctx wrapper.HttpContext, baiduModel string) string { - // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t - suffix, ok := baiduModelToPathSuffixMap[baiduModel] - if !ok { - suffix = baiduModel +// getCanonicalHeaders generates canonical headers +func getCanonicalHeaders(headers map[string]string, headersToSign []string) string { + if len(headers) == 0 { + return "" } - return fmt.Sprintf("/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/%s?access_token=%s", suffix, b.config.GetApiTokenInUse(ctx)) -} -func (b *baiduProvider) setSystemContent(request *baiduTextGenRequest, content string) { - request.System = content -} + // If headersToSign is not specified, use default headers + if len(headersToSign) == 0 { + headersToSign = []string{"host", "content-md5", "content-length", "content-type"} + } -func (b *baiduProvider) baiduTextGenRequest(request *chatCompletionRequest) *baiduTextGenRequest { - baiduRequest := baiduTextGenRequest{ - Messages: make([]chatMessage, 0, len(request.Messages)), - Temperature: request.Temperature, - TopP: request.TopP, - PenaltyScore: request.FrequencyPenalty, - Stream: request.Stream, - DisableSearch: false, - EnableCitation: false, - MaxOutputTokens: request.MaxTokens, - UserId: request.User, + // Convert headersToSign to a map for easier lookup + headerMap := make(map[string]bool) + for _, header := range headersToSign { + headerMap[strings.ToLower(strings.TrimSpace(header))] = true } - for _, message := range request.Messages { - if message.Role == roleSystem { - baiduRequest.System = message.StringContent() - } else { - baiduRequest.Messages = append(baiduRequest.Messages, chatMessage{ - Role: message.Role, - Content: message.Content, - }) + + // Create a slice to hold the canonical headers + var canonicalHeaders []string + for k, v := range headers { + k = strings.ToLower(strings.TrimSpace(k)) + v = strings.TrimSpace(v) + + // Add headers that start with x-bce- or are in headersToSign + if strings.HasPrefix(k, bce_prefix) || headerMap[k] { + canonicalHeaders = append(canonicalHeaders, + fmt.Sprintf("%s:%s", normalizeString(k, true), normalizeString(v, true))) } } - return &baiduRequest -} - -type baiduTextGenResponse struct { - Id string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Result string `json:"result"` - IsTruncated bool `json:"is_truncated"` - NeedClearHistory bool `json:"need_clear_history"` - Usage baiduTextGenResponseUsage `json:"usage"` - baiduTextGenResponseError -} -type baiduTextGenResponseError struct { - ErrorCode int `json:"error_code"` - ErrorMsg string `json:"error_msg"` -} + // Sort the canonical headers + sort.Strings(canonicalHeaders) -type baiduTextGenStreamResponse struct { - baiduTextGenResponse - SentenceId int `json:"sentence_id"` - IsEnd bool `json:"is_end"` + return strings.Join(canonicalHeaders, "\n") } -type baiduTextGenResponseUsage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` +// sign generates the authorization string +func sign(credentials BceCredentials, httpMethod, path string, headers map[string]string, + timestamp int64, expirationInSeconds int, + headersToSign []string) string { + + // Generate sign key + signKeyInfo := fmt.Sprintf("bce-auth-v1/%s/%s/%d", + credentials.AccessKeyId, + getCanonicalTime(timestamp), + expirationInSeconds) + + // Generate sign key using HMAC-SHA256 + h := hmac.New(sha256.New, []byte(credentials.SecretAccessKey)) + h.Write([]byte(signKeyInfo)) + signKey := hex.EncodeToString(h.Sum(nil)) + + // Generate canonical URI + canonicalUri := getCanonicalUri(path) + + // Generate canonical headers + canonicalHeaders := getCanonicalHeaders(headers, headersToSign) + + // Generate string to sign + stringToSign := strings.Join([]string{ + httpMethod, + canonicalUri, + "", + canonicalHeaders, + }, "\n") + + // Calculate final signature + h = hmac.New(sha256.New, []byte(signKey)) + h.Write([]byte(stringToSign)) + signature := hex.EncodeToString(h.Sum(nil)) + + // Generate final authorization string + if len(headersToSign) > 0 { + return fmt.Sprintf("%s/%s/%s", signKeyInfo, strings.Join(headersToSign, ";"), signature) + } + return fmt.Sprintf("%s//%s", signKeyInfo, signature) } -func (b *baiduProvider) responseBaidu2OpenAI(ctx wrapper.HttpContext, response *baiduTextGenResponse) *chatCompletionResponse { - choice := chatCompletionChoice{ - Index: 0, - Message: &chatMessage{Role: roleAssistant, Content: response.Result}, - FinishReason: finishReasonStop, - } - return &chatCompletionResponse{ - Id: response.Id, - Created: time.Now().UnixMilli() / 1000, - Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""), - SystemFingerprint: "", - Object: objectChatCompletion, - Choices: []chatCompletionChoice{choice}, - Usage: usage{ - PromptTokens: response.Usage.PromptTokens, - CompletionTokens: response.Usage.CompletionTokens, - TotalTokens: response.Usage.TotalTokens, - }, +// GetTickFunc Refresh apiToken (apiToken) periodically, the maximum apiToken expiration time is 24 hours +func (g *baiduProvider) GetTickFunc(log wrapper.Log) (tickPeriod int64, tickFunc func()) { + vmID := generateVMID() + + return baiduApiTokenRefreshInterval * 1000, func() { + // Only the Wasm VM that successfully acquires the lease will refresh the apiToken + if g.config.tryAcquireOrRenewLease(vmID, log) { + log.Debugf("Successfully acquired or renewed lease for baidu apiToken refresh task, vmID: %v", vmID) + // Get the apiToken that is about to expire, will be removed after the new apiToken is obtained + oldApiTokens, _, err := getApiTokens(g.config.failover.ctxApiTokens) + if err != nil { + log.Errorf("Get old apiToken failed: %v", err) + return + } + log.Debugf("Old apiTokens: %v", oldApiTokens) + + for _, accessKeyAndSecret := range g.config.baiduAccessKeyAndSecret { + authorizationString := generateAuthorizationString(accessKeyAndSecret, baiduAuthorizationStringExpirationSeconds) + log.Debugf("Generate authorizationString: %v", authorizationString) + g.generateNewApiToken(authorizationString, log) + } + + // remove old old apiToken + for _, token := range oldApiTokens { + log.Debugf("Remove old apiToken: %v", token) + removeApiToken(g.config.failover.ctxApiTokens, token, log) + } + } } } -func (b *baiduProvider) streamResponseBaidu2OpenAI(ctx wrapper.HttpContext, response *baiduTextGenStreamResponse) *chatCompletionResponse { - choice := chatCompletionChoice{ - Index: 0, - Message: &chatMessage{Role: roleAssistant, Content: response.Result}, - } - if response.IsEnd { - choice.FinishReason = finishReasonStop - } - return &chatCompletionResponse{ - Id: response.Id, - Created: time.Now().UnixMilli() / 1000, - Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""), - SystemFingerprint: "", - Object: objectChatCompletionChunk, - Choices: []chatCompletionChoice{choice}, - Usage: usage{ - PromptTokens: response.Usage.PromptTokens, - CompletionTokens: response.Usage.CompletionTokens, - TotalTokens: response.Usage.TotalTokens, - }, +func (g *baiduProvider) generateNewApiToken(authorizationString string, log wrapper.Log) { + client := wrapper.NewClusterClient(wrapper.FQDNCluster{ + FQDN: g.config.baiduApiTokenServiceName, + Host: g.config.baiduApiTokenServiceHost, + Port: g.config.baiduApiTokenServicePort, + }) + + headers := [][2]string{ + {"content-type", "application/json"}, + {"Authorization", authorizationString}, } -} -func (b *baiduProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) { - responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody)) -} + var apiToken string + err := client.Get(baiduApiTokenPath, headers, func(statusCode int, responseHeaders http.Header, responseBody []byte) { + if statusCode == 201 { + var response map[string]interface{} + err := json.Unmarshal(responseBody, &response) + if err != nil { + log.Errorf("Unmarshal response failed: %v", err) + } else { + apiToken = response["token"].(string) + addApiToken(g.config.failover.ctxApiTokens, apiToken, log) + } + } else { + log.Errorf("Get apiToken failed, status code: %d, response body: %s", statusCode, string(responseBody)) + } + }, 30000) -func (b *baiduProvider) GetApiName(path string) ApiName { - if strings.Contains(path, baiduChatCompletionPath) { - return ApiNameChatCompletion + if err != nil { + log.Errorf("Get apiToken failed: %v", err) } - return "" } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go index 32e92a4db4..56b03fbd72 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/failover.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/failover.go @@ -467,7 +467,7 @@ func (c *ProviderConfig) ResetApiTokenRequestFailureCount(apiTokenInUse string, log.Errorf("failed to get failureApiTokenRequestCount: %v", err) } if _, ok := failureApiTokenRequestCount[apiTokenInUse]; ok { - log.Infof("reset apiToken %s request failure count", apiTokenInUse) + log.Infof("Reset apiToken %s request failure count", apiTokenInUse) resetApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiTokenInUse, log) } } @@ -489,7 +489,7 @@ func modifyApiTokenRequestCount(key, apiToken string, op string, log wrapper.Log apiTokenRequestCountByte, err := json.Marshal(apiTokenRequestCount) if err != nil { - log.Errorf("failed to marshal apiTokenRequestCount: %v", err) + log.Errorf("Failed to marshal apiTokenRequestCount: %v", err) } if err := proxywasm.SetSharedData(key, apiTokenRequestCountByte, cas); err == nil { @@ -551,7 +551,7 @@ func (c *ProviderConfig) GetApiTokenInUse(ctx wrapper.HttpContext) string { func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext, log wrapper.Log) { var apiToken string - if c.isFailoverEnabled() { + if c.isFailoverEnabled() || c.useGlobalApiToken { // if enable apiToken failover, only use available apiToken apiToken = c.GetGlobalRandomToken(log) } else { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index f74805c912..9799974bae 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -151,6 +151,12 @@ type ResponseBodyHandler interface { OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) } +// TickFuncHandler allows the provider to execute a function periodically +// Use case: the maximum expiration time of baidu apiToken is 24 hours, need to refresh periodically +type TickFuncHandler interface { + GetTickFunc(log wrapper.Log) (tickPeriod int64, tickFunc func()) +} + type ProviderConfig struct { // @Title zh-CN ID // @Description zh-CN AI服务提供商标识 @@ -227,6 +233,17 @@ type ProviderConfig struct { // @Title zh-CN 自定义大模型参数配置 // @Description zh-CN 用于填充或者覆盖大模型调用时的参数 customSettings []CustomSetting + // @Title zh-CN Baidu 的 Access Key 和 Secret Key,中间用 : 分隔,用于申请 apiToken + baiduAccessKeyAndSecret []string `required:"false" yaml:"baiduAccessKeyAndSecret" json:"baiduAccessKeyAndSecret"` + // @Title zh-CN 请求刷新百度 apiToken 服务名称 + baiduApiTokenServiceName string `required:"false" yaml:"baiduApiTokenServiceName" json:"baiduApiTokenServiceName"` + // @Title zh-CN 请求刷新百度 apiToken 服务域名 + baiduApiTokenServiceHost string `required:"false" yaml:"baiduApiTokenServiceHost" json:"baiduApiTokenServiceHost"` + // @Title zh-CN 请求刷新百度 apiToken 服务端口 + baiduApiTokenServicePort int64 `required:"false" yaml:"baiduApiTokenServicePort" json:"baiduApiTokenServicePort"` + // @Title zh-CN 是否使用全局的 apiToken + // @Description zh-CN 如果没有启用 apiToken failover,但是 apiToken 的状态又需要在多个 Wasm VM 中同步时需要将该参数设置为 true,例如 Baidu 的 apiToken 需要定时刷新 + useGlobalApiToken bool `required:"false" yaml:"useGlobalApiToken" json:"useGlobalApiToken"` } func (c *ProviderConfig) GetId() string { @@ -321,6 +338,19 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { if failoverJson.Exists() { c.failover.FromJson(failoverJson) } + + for _, accessKeyAndSecret := range json.Get("baiduAccessKeyAndSecret").Array() { + c.baiduAccessKeyAndSecret = append(c.baiduAccessKeyAndSecret, accessKeyAndSecret.String()) + } + c.baiduApiTokenServiceName = json.Get("baiduApiTokenServiceName").String() + c.baiduApiTokenServiceHost = json.Get("baiduApiTokenServiceHost").String() + if c.baiduApiTokenServiceHost == "" { + c.baiduApiTokenServiceHost = baiduApiTokenDomain + } + c.baiduApiTokenServicePort = json.Get("baiduApiTokenServicePort").Int() + if c.baiduApiTokenServicePort == 0 { + c.baiduApiTokenServicePort = baiduApiTokenPort + } } func (c *ProviderConfig) Validate() error {