From 4f7bfbdb584960bf95e4cd73a78b40da910bacfc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BE=84=E6=BD=AD?= Date: Wed, 31 Jul 2024 17:48:38 +0800 Subject: [PATCH 01/71] fix bugs --- .../extensions/ai-cache/cache/cache.go | 93 ++++++++ .../extensions/ai-cache/config/config.go | 69 ++++++ .../ai-cache/embedding/dashscope.go | 190 +++++++++++++++ .../extensions/ai-cache/embedding/provider.go | 76 ++++++ plugins/wasm-go/extensions/ai-cache/go.mod | 1 - plugins/wasm-go/extensions/ai-cache/main.go | 221 ++++-------------- .../extensions/ai-cache/util/cachelogic.go | 211 +++++++++++++++++ .../ai-cache/vectorDatabase/dashvector.go | 153 ++++++++++++ .../ai-cache/vectorDatabase/provider.go | 124 ++++++++++ 9 files changed, 957 insertions(+), 181 deletions(-) create mode 100644 plugins/wasm-go/extensions/ai-cache/cache/cache.go create mode 100644 plugins/wasm-go/extensions/ai-cache/config/config.go create mode 100644 plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go create mode 100644 plugins/wasm-go/extensions/ai-cache/embedding/provider.go create mode 100644 plugins/wasm-go/extensions/ai-cache/util/cachelogic.go create mode 100644 plugins/wasm-go/extensions/ai-cache/vectorDatabase/dashvector.go create mode 100644 plugins/wasm-go/extensions/ai-cache/vectorDatabase/provider.go diff --git a/plugins/wasm-go/extensions/ai-cache/cache/cache.go b/plugins/wasm-go/extensions/ai-cache/cache/cache.go new file mode 100644 index 0000000000..e4d1f1e81f --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/cache/cache.go @@ -0,0 +1,93 @@ +// TODO: 在这里写缓存的具体逻辑, 将textEmbeddingPrvider和vectorStoreProvider作为逻辑中的一个函数调用 +package cache + +import ( + "errors" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" +) + +type RedisConfig struct { + // @Title zh-CN redis 服务名称 + // @Description zh-CN 带服务类型的完整 FQDN 名称,例如 my-redis.dns、redis.my-ns.svc.cluster.local + RedisServiceName string `required:"true" yaml:"serviceName" json:"serviceName"` + // @Title zh-CN redis 服务端口 + // @Description zh-CN 默认值为6379 + RedisServicePort int `required:"false" yaml:"servicePort" json:"servicePort"` + // @Title zh-CN 用户名 + // @Description zh-CN 登陆 redis 的用户名,非必填 + RedisUsername string `required:"false" yaml:"username" json:"username"` + // @Title zh-CN 密码 + // @Description zh-CN 登陆 redis 的密码,非必填,可以只填密码 + RedisPassword string `required:"false" yaml:"password" json:"password"` + // @Title zh-CN 请求超时 + // @Description zh-CN 请求 redis 的超时时间,单位为毫秒。默认值是1000,即1秒 + RedisTimeout uint32 `required:"false" yaml:"timeout" json:"timeout"` +} + +func CreateProvider(cf RedisConfig) (Provider, error) { + rp := redisProvider{ + config: cf, + client: wrapper.NewRedisClusterClient(wrapper.FQDNCluster{ + FQDN: cf.RedisServiceName, + Port: int64(cf.RedisServicePort)}), + } + err := rp.Init(cf.RedisUsername, cf.RedisPassword, cf.RedisTimeout) + return &rp, err +} + +func (c *RedisConfig) FromJson(json gjson.Result) { + c.RedisUsername = json.Get("username").String() + c.RedisPassword = json.Get("password").String() + c.RedisTimeout = uint32(json.Get("timeout").Int()) + c.RedisServiceName = json.Get("serviceName").String() + c.RedisServicePort = int(json.Get("servicePort").Int()) +} + +func (c *RedisConfig) Validate() error { + if len(c.RedisServiceName) == 0 { + return errors.New("serviceName is required") + } + if c.RedisTimeout <= 0 { + return errors.New("timeout must be greater than 0") + } + if c.RedisServicePort <= 0 { + c.RedisServicePort = 6379 + } + if len(c.RedisUsername) == 0 { + c.RedisUsername = "" + } + if len(c.RedisPassword) == 0 { + c.RedisPassword = "" + } + return nil +} + +type Provider interface { + GetProviderType() string + Init(username string, password string, timeout uint32) error + Get(key string, cb wrapper.RedisResponseCallback) + Set(key string, value string, cb wrapper.RedisResponseCallback) +} + +type redisProvider struct { + config RedisConfig + client wrapper.RedisClient +} + +func (rp *redisProvider) GetProviderType() string { + return "redis" +} + +func (rp *redisProvider) Init(username string, password string, timeout uint32) error { + return rp.client.Init(rp.config.RedisUsername, rp.config.RedisPassword, int64(rp.config.RedisTimeout)) +} + +func (rp *redisProvider) Get(key string, cb wrapper.RedisResponseCallback) { + rp.client.Get(key, cb) +} + +func (rp *redisProvider) Set(key string, value string, cb wrapper.RedisResponseCallback) { + rp.client.Set(key, value, cb) +} diff --git a/plugins/wasm-go/extensions/ai-cache/config/config.go b/plugins/wasm-go/extensions/ai-cache/config/config.go new file mode 100644 index 0000000000..378d956fc3 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/config/config.go @@ -0,0 +1,69 @@ +package config + +import ( + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/cache" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/embedding" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/vectorDatabase" + "github.com/tidwall/gjson" +) + +type KVExtractor struct { + // @Title zh-CN 从请求 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 + RequestBody string `required:"false" yaml:"requestBody" json:"requestBody"` + // @Title zh-CN 从响应 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 + ResponseBody string `required:"false" yaml:"responseBody" json:"responseBody"` +} + +type PluginConfig struct { + EmbeddingProviderConfig embedding.ProviderConfig `required:"true" yaml:"embeddingProvider" json:"embeddingProvider"` + VectorDatabaseProviderConfig vectorDatabase.ProviderConfig `required:"true" yaml:"vectorBaseProvider" json:"vectorBaseProvider"` + CacheKeyFrom KVExtractor `required:"true" yaml:"cacheKeyFrom" json:"cacheKeyFrom"` + CacheValueFrom KVExtractor `required:"true" yaml:"cacheValueFrom" json:"cacheValueFrom"` + CacheStreamValueFrom KVExtractor `required:"true" yaml:"cacheStreamValueFrom" json:"cacheStreamValueFrom"` + + CacheKeyPrefix string `required:"false" yaml:"cacheKeyPrefix" json:"cacheKeyPrefix"` + RedisConfig cache.RedisConfig `required:"true" yaml:"redisConfig" json:"redisConfig"` + // 现在只支持RedisClient作为cacheClient + redisProvider cache.Provider `yaml:"-"` + embeddingProvider embedding.Provider `yaml:"-"` + vectorDatabaseProvider vectorDatabase.Provider `yaml:"-"` +} + +func (c *PluginConfig) FromJson(json gjson.Result) { + c.EmbeddingProviderConfig.FromJson(json.Get("embeddingProvider")) + c.VectorDatabaseProviderConfig.FromJson(json.Get("vectorBaseProvider")) + c.RedisConfig.FromJson(json.Get("redis")) +} + +func (c *PluginConfig) Validate() error { + if err := c.RedisConfig.Validate(); err != nil { + return err + } + if err := c.EmbeddingProviderConfig.Validate(); err != nil { + return err + } + if err := c.VectorDatabaseProviderConfig.Validate(); err != nil { + return err + } + return nil +} + +func (c *PluginConfig) Complete() error { + var err error + c.embeddingProvider, err = embedding.CreateProvider(c.EmbeddingProviderConfig) + c.vectorDatabaseProvider, err = vectorDatabase.CreateProvider(c.VectorDatabaseProviderConfig) + c.redisProvider, err = cache.CreateProvider(c.RedisConfig) + return err +} + +func (c *PluginConfig) GetEmbeddingProvider() embedding.Provider { + return c.embeddingProvider +} + +func (c *PluginConfig) GetVectorDatabaseProvider() vectorDatabase.Provider { + return c.vectorDatabaseProvider +} + +func (c *PluginConfig) GetCacheProvider() cache.Provider { + return c.redisProvider +} diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go b/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go new file mode 100644 index 0000000000..f979e1d4ae --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go @@ -0,0 +1,190 @@ +package embedding + +import ( + "encoding/json" + "errors" + "net/http" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" +) + +const ( + dashScopeDomain = "dashscope.aliyuncs.com" + dashScopePort = 443 +) + +type dashScopeProviderInitializer struct { +} + +func (d *dashScopeProviderInitializer) ValidateConfig(config ProviderConfig) error { + if len(config.DashScopeKey) == 0 { + return errors.New("DashScopeKey is required") + } + if len(config.ServiceName) == 0 { + return errors.New("ServiceName is required") + } + return nil +} + +func (d *dashScopeProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + return &DSProvider{ + config: config, + client: wrapper.NewClusterClient(wrapper.DnsCluster{ + ServiceName: config.ServiceName, + Port: dashScopePort, + Domain: dashScopeDomain, + }), + }, nil +} + +func (d *DSProvider) GetProviderType() string { + return providerTypeDashScope +} + +type Embedding struct { + Embedding []float64 `json:"embedding"` + TextIndex int `json:"text_index"` +} + +type Input struct { + Texts []string `json:"texts"` +} + +type Params struct { + TextType string `json:"text_type"` +} + +type Response struct { + RequestID string `json:"request_id"` + Output Output `json:"output"` + Usage Usage `json:"usage"` +} + +type Output struct { + Embeddings []Embedding `json:"embeddings"` +} + +type Usage struct { + TotalTokens int `json:"total_tokens"` +} + +// EmbeddingRequest 定义请求的数据结构 +type EmbeddingRequest struct { + Model string `json:"model"` + Input Input `json:"input"` + Parameters Params `json:"parameters"` +} + +// Document 定义每个文档的结构 +type Document struct { + // ID string `json:"id"` + Vector []float64 `json:"vector"` + Fields map[string]string `json:"fields"` +} + +type DSProvider struct { + config ProviderConfig + client wrapper.HttpClient +} + +func (d *DSProvider) constructParameters(texts []string, log wrapper.Log) (string, [][2]string, []byte, error) { + const ( + endpoint = "/api/v1/services/embeddings/text-embedding/text-embedding" + modelName = "text-embedding-v1" + contentType = "application/json" + ) + + // 构造请求数据 + data := EmbeddingRequest{ + Model: modelName, + Input: Input{ + Texts: texts, + }, + Parameters: Params{ + TextType: "query", + }, + } + + // 序列化请求体并处理错误 + requestBody, err := json.Marshal(data) + if err != nil { + log.Errorf("Failed to marshal request data: %v", err) + return "", nil, nil, err + } + + // 检查 DashScopeKey 是否为空 + if d.config.DashScopeKey == "" { + err := errors.New("DashScopeKey is empty") + log.Errorf("Failed to construct headers: %v", err) + return "", nil, nil, err + } + + // 设置请求头 + headers := [][2]string{ + {"Authorization", "Bearer " + d.config.DashScopeKey}, + {"Content-Type", contentType}, + } + + return endpoint, headers, requestBody, err +} + +// Result 定义查询结果的结构 +type Result struct { + ID string `json:"id"` + Vector []float64 `json:"vector,omitempty"` // omitempty 使得如果 vector 是空,它将不会被序列化 + Fields map[string]interface{} `json:"fields"` + Score float64 `json:"score"` +} + +func (d *DSProvider) parseTextEmbedding(responseBody []byte) (*Response, error) { + var resp Response + err := json.Unmarshal(responseBody, &resp) + if err != nil { + return nil, err + } + return &resp, nil +} + +func (d *DSProvider) GetEmbedding( + queryString string, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(emb []float64, statusCode int, responseHeaders http.Header, responseBody []byte)) error { + + // 构建参数并处理错误 + Emb_url, Emb_headers, Emb_requestBody, err := d.constructParameters([]string{queryString}, log) + if err != nil { + log.Errorf("Failed to construct parameters: %v", err) + return err + } + + // 发起 POST 请求 + d.client.Post(Emb_url, Emb_headers, Emb_requestBody, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + defer proxywasm.ResumeHttpRequest() // 确保 HTTP 请求被恢复 + + // 日志记录响应 + log.Infof("Get embedding response: %d, %s", statusCode, responseBody) + + // 解析响应 + resp, err := d.parseTextEmbedding(responseBody) + if err != nil { + log.Errorf("Failed to parse response: %v", err) + callback(nil, statusCode, responseHeaders, responseBody) + return + } + + // 检查是否存在嵌入结果 + if len(resp.Output.Embeddings) == 0 { + log.Errorf("No embedding found in response") + callback(nil, statusCode, responseHeaders, responseBody) + return + } + + // 调用回调函数 + callback(resp.Output.Embeddings[0].Embedding, statusCode, responseHeaders, responseBody) + }) + + return nil +} diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/provider.go b/plugins/wasm-go/extensions/ai-cache/embedding/provider.go new file mode 100644 index 0000000000..ea0f58398c --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/embedding/provider.go @@ -0,0 +1,76 @@ +package embedding + +import ( + "errors" + "net/http" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" +) + +const ( + providerTypeDashScope = "dashscope" +) + +type providerInitializer interface { + ValidateConfig(ProviderConfig) error + CreateProvider(ProviderConfig) (Provider, error) +} + +var ( + providerInitializers = map[string]providerInitializer{ + providerTypeDashScope: &dashScopeProviderInitializer{}, + } +) + +type ProviderConfig struct { + // @Title zh-CN 文本特征提取服务提供者类型 + // @Description zh-CN 文本特征提取服务提供者类型,例如 DashScope + typ string `json:"TextEmbeddingProviderType"` + // @Title zh-CN DashScope 阿里云大模型服务名 + // @Description zh-CN 调用阿里云的大模型服务 + ServiceName string `require:"true" yaml:"DashScopeServiceName" jaon:"DashScopeServiceName"` + Client wrapper.HttpClient `yaml:"-"` + DashScopeKey string `require:"true" yaml:"DashScopeKey" jaon:"DashScopeKey"` + DashScopeTimeout uint32 `require:"true" yaml:"DashScopeTimeout" jaon:"DashScopeTimeout"` + QueryEmbeddingKey string `require:"true" yaml:"QueryEmbeddingKey" jaon:"QueryEmbeddingKey"` +} + +func (c *ProviderConfig) FromJson(json gjson.Result) { + c.typ = json.Get("TextEmbeddingProviderType").String() + c.ServiceName = json.Get("DashScopeServiceName").String() + c.DashScopeKey = json.Get("DashScopeKey").String() + c.DashScopeTimeout = uint32(json.Get("DashScopeTimeout").Int()) + c.QueryEmbeddingKey = json.Get("QueryEmbeddingKey").String() +} + +func (c *ProviderConfig) Validate() error { + if len(c.DashScopeKey) == 0 { + return errors.New("DashScopeKey is required") + } + if len(c.ServiceName) == 0 { + return errors.New("DashScopeServiceName is required") + } + return nil +} + +func (c *ProviderConfig) GetProviderType() string { + return c.typ +} + +func CreateProvider(pc ProviderConfig) (Provider, error) { + initializer, has := providerInitializers[pc.typ] + if !has { + return nil, errors.New("unknown provider type: " + pc.typ) + } + return initializer.CreateProvider(pc) +} + +type Provider interface { + GetProviderType() string + GetEmbedding( + text string, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(emb []float64, statusCode int, responseHeaders http.Header, responseBody []byte)) error +} diff --git a/plugins/wasm-go/extensions/ai-cache/go.mod b/plugins/wasm-go/extensions/ai-cache/go.mod index c9630cfb8a..02d62a48e7 100644 --- a/plugins/wasm-go/extensions/ai-cache/go.mod +++ b/plugins/wasm-go/extensions/ai-cache/go.mod @@ -4,7 +4,6 @@ module github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache go 1.19 -replace github.com/alibaba/higress/plugins/wasm-go => ../.. require ( github.com/alibaba/higress/plugins/wasm-go v1.3.6-0.20240528060522-53bccf89f441 diff --git a/plugins/wasm-go/extensions/ai-cache/main.go b/plugins/wasm-go/extensions/ai-cache/main.go index dc5df1a6a8..2334c1a777 100644 --- a/plugins/wasm-go/extensions/ai-cache/main.go +++ b/plugins/wasm-go/extensions/ai-cache/main.go @@ -1,32 +1,33 @@ -// File generated by hgctl. Modify as required. -// See: https://higress.io/zh-cn/docs/user/wasm-go#2-%E7%BC%96%E5%86%99-maingo-%E6%96%87%E4%BB%B6 - +// 这个文件中主要将OnHttpRequestHeaders、OnHttpRequestBody、OnHttpResponseHeaders、OnHttpResponseBody这四个函数实现 +// 其中的缓存思路调用cache.go中的逻辑,然后cache.go中的逻辑会调用textEmbeddingProvider和vectorStoreProvider中的逻辑(实例) package main import ( - "errors" - "fmt" "strings" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/config" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/tidwall/gjson" - "github.com/tidwall/resp" ) const ( + pluginName = "ai-cache" CacheKeyContextKey = "cacheKey" CacheContentContextKey = "cacheContent" PartialMessageContextKey = "partialMessage" ToolCallsContextKey = "toolCalls" StreamContextKey = "stream" - DefaultCacheKeyPrefix = "higress-ai-cache:" + CacheKeyPrefix = "higressAiCache" + DefaultCacheKeyPrefix = "higressAiCache" + QueryEmbeddingKey = "queryEmbedding" ) func main() { wrapper.SetCtx( - "ai-cache", + pluginName, wrapper.ParseConfigBy(parseConfig), wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders), wrapper.ProcessRequestBodyBy(onHttpRequestBody), @@ -35,143 +36,20 @@ func main() { ) } -// @Name ai-cache -// @Category protocol -// @Phase AUTHN -// @Priority 10 -// @Title zh-CN AI Cache -// @Description zh-CN 大模型结果缓存 -// @IconUrl -// @Version 0.1.0 -// -// @Contact.name johnlanni -// @Contact.url -// @Contact.email -// -// @Example -// redis: -// serviceName: my-redis.dns -// timeout: 2000 -// cacheKeyFrom: -// requestBody: "messages.@reverse.0.content" -// cacheValueFrom: -// responseBody: "choices.0.message.content" -// cacheStreamValueFrom: -// responseBody: "choices.0.delta.content" -// returnResponseTemplate: | -// {"id":"from-cache","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}} -// returnStreamResponseTemplate: | -// data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}} -// -// data:[DONE] -// -// @End - -type RedisInfo struct { - // @Title zh-CN redis 服务名称 - // @Description zh-CN 带服务类型的完整 FQDN 名称,例如 my-redis.dns、redis.my-ns.svc.cluster.local - ServiceName string `required:"true" yaml:"serviceName" json:"serviceName"` - // @Title zh-CN redis 服务端口 - // @Description zh-CN 默认值为6379 - ServicePort int `required:"false" yaml:"servicePort" json:"servicePort"` - // @Title zh-CN 用户名 - // @Description zh-CN 登陆 redis 的用户名,非必填 - Username string `required:"false" yaml:"username" json:"username"` - // @Title zh-CN 密码 - // @Description zh-CN 登陆 redis 的密码,非必填,可以只填密码 - Password string `required:"false" yaml:"password" json:"password"` - // @Title zh-CN 请求超时 - // @Description zh-CN 请求 redis 的超时时间,单位为毫秒。默认值是1000,即1秒 - Timeout int `required:"false" yaml:"timeout" json:"timeout"` -} - -type KVExtractor struct { - // @Title zh-CN 从请求 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 - RequestBody string `required:"false" yaml:"requestBody" json:"requestBody"` - // @Title zh-CN 从响应 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 - ResponseBody string `required:"false" yaml:"responseBody" json:"responseBody"` -} - -type PluginConfig struct { - // @Title zh-CN Redis 地址信息 - // @Description zh-CN 用于存储缓存结果的 Redis 地址 - RedisInfo RedisInfo `required:"true" yaml:"redis" json:"redis"` - // @Title zh-CN 缓存 key 的来源 - // @Description zh-CN 往 redis 里存时,使用的 key 的提取方式 - CacheKeyFrom KVExtractor `required:"true" yaml:"cacheKeyFrom" json:"cacheKeyFrom"` - // @Title zh-CN 缓存 value 的来源 - // @Description zh-CN 往 redis 里存时,使用的 value 的提取方式 - CacheValueFrom KVExtractor `required:"true" yaml:"cacheValueFrom" json:"cacheValueFrom"` - // @Title zh-CN 流式响应下,缓存 value 的来源 - // @Description zh-CN 往 redis 里存时,使用的 value 的提取方式 - CacheStreamValueFrom KVExtractor `required:"true" yaml:"cacheStreamValueFrom" json:"cacheStreamValueFrom"` - // @Title zh-CN 返回 HTTP 响应的模版 - // @Description zh-CN 用 %s 标记需要被 cache value 替换的部分 - ReturnResponseTemplate string `required:"true" yaml:"returnResponseTemplate" json:"returnResponseTemplate"` - // @Title zh-CN 返回流式 HTTP 响应的模版 - // @Description zh-CN 用 %s 标记需要被 cache value 替换的部分 - ReturnStreamResponseTemplate string `required:"true" yaml:"returnStreamResponseTemplate" json:"returnStreamResponseTemplate"` - // @Title zh-CN 缓存的过期时间 - // @Description zh-CN 单位是秒,默认值为0,即永不过期 - CacheTTL int `required:"false" yaml:"cacheTTL" json:"cacheTTL"` - // @Title zh-CN Redis缓存Key的前缀 - // @Description zh-CN 默认值是"higress-ai-cache:" - CacheKeyPrefix string `required:"false" yaml:"cacheKeyPrefix" json:"cacheKeyPrefix"` - redisClient wrapper.RedisClient `yaml:"-" json:"-"` +func parseConfig(json gjson.Result, config *config.PluginConfig, log wrapper.Log) error { + config.EmbeddingProviderConfig.FromJson(json.Get("embeddingProvider")) + config.VectorDatabaseProviderConfig.FromJson(json.Get("vectorBaseProvider")) + if err := config.Validate(); err != nil { + return err + } + return nil } -func parseConfig(json gjson.Result, c *PluginConfig, log wrapper.Log) error { - c.RedisInfo.ServiceName = json.Get("redis.serviceName").String() - if c.RedisInfo.ServiceName == "" { - return errors.New("redis service name must not by empty") - } - c.RedisInfo.ServicePort = int(json.Get("redis.servicePort").Int()) - if c.RedisInfo.ServicePort == 0 { - if strings.HasSuffix(c.RedisInfo.ServiceName, ".static") { - // use default logic port which is 80 for static service - c.RedisInfo.ServicePort = 80 - } else { - c.RedisInfo.ServicePort = 6379 - } - } - c.RedisInfo.Username = json.Get("redis.username").String() - c.RedisInfo.Password = json.Get("redis.password").String() - c.RedisInfo.Timeout = int(json.Get("redis.timeout").Int()) - if c.RedisInfo.Timeout == 0 { - c.RedisInfo.Timeout = 1000 - } - c.CacheKeyFrom.RequestBody = json.Get("cacheKeyFrom.requestBody").String() - if c.CacheKeyFrom.RequestBody == "" { - c.CacheKeyFrom.RequestBody = "messages.@reverse.0.content" - } - c.CacheValueFrom.ResponseBody = json.Get("cacheValueFrom.responseBody").String() - if c.CacheValueFrom.ResponseBody == "" { - c.CacheValueFrom.ResponseBody = "choices.0.message.content" - } - c.CacheStreamValueFrom.ResponseBody = json.Get("cacheStreamValueFrom.responseBody").String() - if c.CacheStreamValueFrom.ResponseBody == "" { - c.CacheStreamValueFrom.ResponseBody = "choices.0.delta.content" - } - c.ReturnResponseTemplate = json.Get("returnResponseTemplate").String() - if c.ReturnResponseTemplate == "" { - c.ReturnResponseTemplate = `{"id":"from-cache","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` - } - c.ReturnStreamResponseTemplate = json.Get("returnStreamResponseTemplate").String() - if c.ReturnStreamResponseTemplate == "" { - c.ReturnStreamResponseTemplate = `data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + "\n\ndata:[DONE]\n\n" - } - c.CacheKeyPrefix = json.Get("cacheKeyPrefix").String() - if c.CacheKeyPrefix == "" { - c.CacheKeyPrefix = DefaultCacheKeyPrefix - } - c.redisClient = wrapper.NewRedisClusterClient(wrapper.FQDNCluster{ - FQDN: c.RedisInfo.ServiceName, - Port: int64(c.RedisInfo.ServicePort), - }) - return c.redisClient.Init(c.RedisInfo.Username, c.RedisInfo.Password, int64(c.RedisInfo.Timeout)) +func TrimQuote(source string) string { + return strings.Trim(source, `"`) } -func onHttpRequestHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action { +func onHttpRequestHeaders(ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log) types.Action { contentType, _ := proxywasm.GetHttpRequestHeader("content-type") // The request does not have a body. if contentType == "" { @@ -188,11 +66,8 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrap return types.HeaderStopIteration } -func TrimQuote(source string) string { - return strings.Trim(source, `"`) -} +func onHttpRequestBody(ctx wrapper.HttpContext, config config.PluginConfig, body []byte, log wrapper.Log) types.Action { -func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte, log wrapper.Log) types.Action { bodyJson := gjson.ParseBytes(body) // TODO: It may be necessary to support stream mode determination for different LLM providers. stream := false @@ -202,39 +77,21 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte } else if ctx.GetContext(StreamContextKey) != nil { stream = true } - key := TrimQuote(bodyJson.Get(config.CacheKeyFrom.RequestBody).Raw) + // key := TrimQuote(bodyJson.Get(config.CacheKeyFrom.RequestBody).Raw) + key := bodyJson.Get(config.CacheKeyFrom.RequestBody).String() if key == "" { - log.Debug("parse key from request body failed") - return types.ActionContinue - } - ctx.SetContext(CacheKeyContextKey, key) - err := config.redisClient.Get(config.CacheKeyPrefix+key, func(response resp.Value) { - if err := response.Error(); err != nil { - log.Errorf("redis get key:%s failed, err:%v", key, err) - proxywasm.ResumeHttpRequest() - return - } - if response.IsNull() { - log.Debugf("cache miss, key:%s", key) - proxywasm.ResumeHttpRequest() - return - } - log.Debugf("cache hit, key:%s", key) - ctx.SetContext(CacheKeyContextKey, nil) - if !stream { - proxywasm.SendHttpResponseWithDetail(200, "ai-cache.hit", [][2]string{{"content-type", "application/json; charset=utf-8"}}, []byte(fmt.Sprintf(config.ReturnResponseTemplate, response.String())), -1) - } else { - proxywasm.SendHttpResponseWithDetail(200, "ai-cache.hit", [][2]string{{"content-type", "text/event-stream; charset=utf-8"}}, []byte(fmt.Sprintf(config.ReturnStreamResponseTemplate, response.String())), -1) - } - }) - if err != nil { - log.Error("redis access failed") + log.Debug("[onHttpRquestBody] parse key from request body failed") return types.ActionContinue } + + queryString := config.CacheKeyPrefix + key + + util.RedisSearchHandler(queryString, ctx, config, log, stream, true) + return types.ActionPause } -func processSSEMessage(ctx wrapper.HttpContext, config PluginConfig, sseMessage string, log wrapper.Log) string { +func processSSEMessage(ctx wrapper.HttpContext, config config.PluginConfig, sseMessage string, log wrapper.Log) string { subMessages := strings.Split(sseMessage, "\n") var message string for _, msg := range subMessages { @@ -244,7 +101,7 @@ func processSSEMessage(ctx wrapper.HttpContext, config PluginConfig, sseMessage } } if len(message) < 6 { - log.Errorf("invalid message:%s", message) + log.Warnf("invalid message:%s", message) return "" } // skip the prefix "data:" @@ -265,11 +122,11 @@ func processSSEMessage(ctx wrapper.HttpContext, config PluginConfig, sseMessage ctx.SetContext(ToolCallsContextKey, struct{}{}) return "" } - log.Debugf("unknown message:%s", bodyJson) + log.Warnf("unknown message:%s", bodyJson) return "" } -func onHttpResponseHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action { +func onHttpResponseHeaders(ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log) types.Action { contentType, _ := proxywasm.GetHttpResponseHeader("content-type") if strings.Contains(contentType, "text/event-stream") { ctx.SetContext(StreamContextKey, struct{}{}) @@ -277,12 +134,14 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config PluginConfig, log wra return types.ActionContinue } -func onHttpResponseBody(ctx wrapper.HttpContext, config PluginConfig, chunk []byte, isLastChunk bool, log wrapper.Log) []byte { +func onHttpResponseBody(ctx wrapper.HttpContext, config config.PluginConfig, chunk []byte, isLastChunk bool, log wrapper.Log) []byte { + // log.Infof("I am here") if ctx.GetContext(ToolCallsContextKey) != nil { // we should not cache tool call result return chunk } keyI := ctx.GetContext(CacheKeyContextKey) + // log.Infof("I am here 2: %v", keyI) if keyI == nil { return chunk } @@ -363,9 +222,11 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config PluginConfig, chunk []by value = tempContentI.(string) } } - config.redisClient.Set(config.CacheKeyPrefix+key, value, nil) - if config.CacheTTL != 0 { - config.redisClient.Expire(config.CacheKeyPrefix+key, config.CacheTTL, nil) - } + log.Infof("[onHttpResponseBody] Setting cache to redis, key:%s, value:%s", key, value) + config.GetCacheProvider().Set(config.CacheKeyPrefix+key, value, nil) + // TODO: 要不要加个Expire方法 + // if config.RedisConfig.RedisTimeout != 0 { + // config.GetCacheProvider().Expire(config.CacheKeyPrefix+key, config.RedisConfig.RedisTimeout, nil) + // } return chunk } diff --git a/plugins/wasm-go/extensions/ai-cache/util/cachelogic.go b/plugins/wasm-go/extensions/ai-cache/util/cachelogic.go new file mode 100644 index 0000000000..95b8b6e1cf --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/util/cachelogic.go @@ -0,0 +1,211 @@ +// TODO: 在这里写缓存的具体逻辑, 将textEmbeddingPrvider和vectorStoreProvider作为逻辑中的一个函数调用 +package util + +import ( + "net/http" + + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/config" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/vectorDatabase" + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/tidwall/resp" +) + +// ===================== 以下是主要逻辑 ===================== +// 主handler函数,根据key从redis中获取value ,如果不命中,则首先调用文本向量化接口向量化query,然后调用向量搜索接口搜索最相似的出现过的key,最后再次调用redis获取结果 +// 可以把所有handler单独提取为文件,这里为了方便读者复制就和主逻辑放在一个文件中了 +// +// 1. query 进来和 redis 中存的 key 匹配 (redisSearchHandler) ,若完全一致则直接返回 (handleCacheHit) +// 2. 否则请求 text_embdding 接口将 query 转换为 query_embedding (fetchAndProcessEmbeddings) +// 3. 用 query_embedding 和向量数据库中的向量做 ANN search,返回最接近的 key ,并用阈值过滤 (performQueryAndRespond) +// 4. 若返回结果为空或大于阈值,舍去,本轮 cache 未命中, 最后将 query_embedding 存入向量数据库 (uploadQueryEmbedding) +// 5. 若小于阈值,则再次调用 redis对 most similar key 做匹配。 (redisSearchHandler) +// 7. 在 response 阶段请求 redis 新增key/LLM返回结果 + +func RedisSearchHandler(key string, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, stream bool, ifUseEmbedding bool) { + activeCacheProvider := config.GetCacheProvider() + activeCacheProvider.Get(config.CacheKeyPrefix+key, func(response resp.Value) { + if err := response.Error(); err == nil && !response.IsNull() { + log.Warnf("cache hit, key:%s", key) + HandleCacheHit(key, response, stream, ctx, config, log) + } else { + log.Warnf("cache miss, key:%s", key) + if ifUseEmbedding { + HandleCacheMiss(key, err, response, ctx, config, log, key, stream) + } else { + proxywasm.ResumeHttpRequest() + return + } + } + }) +} + +func HandleCacheHit(key string, response resp.Value, stream bool, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log) { + activeEmbeddingProvider := config.GetEmbeddingProvider() + activeVectorDatabaseProvider := config.GetVectorDatabaseProvider() + + activeEmbeddingProvider.GetEmbedding(key, ctx, log, + func(textEmbedding []float64, stateCode int, responseHeaders http.Header, responseBody []byte) { + activeVectorDatabaseProvider.QueryEmbedding(textEmbedding, ctx, log, + func(queryResponse vectorDatabase.QueryResponse, ctx wrapper.HttpContext, log wrapper.Log) { + if len(queryResponse.Output) < 1 { + log.Warnf("query response is empty") + activeVectorDatabaseProvider.UploadEmbedding(textEmbedding, key, ctx, log, + func(ctx wrapper.HttpContext, log wrapper.Log) { + proxywasm.ResumeHttpRequest() + }) + return + } + mostSimilarKey := queryResponse.Output[0].Fields["query"].(string) + log.Infof("most similar key:%s", mostSimilarKey) + mostSimilarScore := queryResponse.Output[0].Score + if mostSimilarScore < 0.1 { + // ctx.SetContext(config.CacheKeyContextKey, nil) + // RedisSearchHandler(mostSimilarKey, ctx, config, log, stream, false) + } else { + log.Infof("the most similar key's score is too high, key:%s, score:%f", mostSimilarKey, mostSimilarScore) + activeVectorDatabaseProvider.UploadEmbedding(textEmbedding, key, ctx, log, + func(ctx wrapper.HttpContext, log wrapper.Log) { + proxywasm.ResumeHttpRequest() + }) + proxywasm.ResumeHttpRequest() + return + } + + }, + ) + }) +} + +func HandleCacheMiss(key string, err error, response resp.Value, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, queryString string, stream bool) { + proxywasm.ResumeHttpRequest() +} + +// // 简单处理缓存命中的情况, 从redis中获取到value后,直接返回 +// func HandleCacheHit(key string, response resp.Value, stream bool, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log) { +// log.Warnf("cache hit, key:%s", key) +// ctx.SetContext(config.CacheKeyContextKey, nil) +// if !stream { +// proxywasm.SendHttpResponse(200, [][2]string{{"content-type", "application/json; charset=utf-8"}}, []byte(fmt.Sprintf(config.ReturnResponseTemplate, response.String())), -1) +// } else { +// proxywasm.SendHttpResponse(200, [][2]string{{"content-type", "text/event-stream; charset=utf-8"}}, []byte(fmt.Sprintf(config.ReturnStreamResponseTemplate, response.String())), -1) +// } +// } + +// // 处理缓存未命中的情况,调用fetchAndProcessEmbeddings函数向量化query +// func HandleCacheMiss(key string, err error, response resp.Value, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, queryString string, stream bool) { +// if err != nil { +// log.Warnf("redis get key:%s failed, err:%v", key, err) +// } +// if response.IsNull() { +// log.Warnf("cache miss, key:%s", key) +// } +// FetchAndProcessEmbeddings(key, ctx, config, log, queryString, stream) +// } + +// // 调用文本向量化接口向量化query, 向量化成功后调用processFetchedEmbeddings函数处理向量化结果 +// func FetchAndProcessEmbeddings(key string, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, queryString string, stream bool) { +// Emb_url, Emb_requestBody, Emb_headers := ConstructTextEmbeddingParameters(&config, log, []string{queryString}) +// config.DashVectorInfo.DashScopeClient.Post( +// Emb_url, +// Emb_headers, +// Emb_requestBody, +// func(statusCode int, responseHeaders http.Header, responseBody []byte) { +// // log.Infof("statusCode:%d, responseBody:%s", statusCode, string(responseBody)) +// log.Infof("Successfully fetched embeddings for key: %s", key) +// if statusCode != 200 { +// log.Errorf("Failed to fetch embeddings, statusCode: %d, responseBody: %s", statusCode, string(responseBody)) +// ctx.SetContext(QueryEmbeddingKey, nil) +// proxywasm.ResumeHttpRequest() +// } else { +// processFetchedEmbeddings(key, responseBody, ctx, config, log, stream) +// } +// }, +// 10000) +// } + +// // 先将向量化的结果存入上下文ctx变量,其次发起向量搜索请求 +// func ProcessFetchedEmbeddings(key string, responseBody []byte, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, stream bool) { +// text_embedding_raw, _ := ParseTextEmbedding(responseBody) +// text_embedding := text_embedding_raw.Output.Embeddings[0].Embedding +// // ctx.SetContext(CacheKeyContextKey, text_embedding) +// ctx.SetContext(QueryEmbeddingKey, text_embedding) +// ctx.SetContext(CacheKeyContextKey, key) +// PerformQueryAndRespond(key, text_embedding, ctx, config, log, stream) +// } + +// // 调用向量搜索接口搜索最相似的key,搜索成功后调用redisSearchHandler函数获取最相似的key的结果 +// func PerformQueryAndRespond(key string, text_embedding []float64, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, stream bool) { +// vector_url, vector_request, vector_headers, err := ConstructEmbeddingQueryParameters(config, text_embedding) +// if err != nil { +// log.Errorf("Failed to perform query, err: %v", err) +// proxywasm.ResumeHttpRequest() +// return +// } +// config.DashVectorInfo.DashVectorClient.Post( +// vector_url, +// vector_headers, +// vector_request, +// func(statusCode int, responseHeaders http.Header, responseBody []byte) { +// log.Infof("statusCode:%d, responseBody:%s", statusCode, string(responseBody)) +// query_resp, err_query := ParseQueryResponse(responseBody) +// if err_query != nil { +// log.Errorf("Failed to parse response: %v", err) +// proxywasm.ResumeHttpRequest() +// return +// } +// if len(query_resp.Output) < 1 { +// log.Warnf("query response is empty") +// UploadQueryEmbedding(ctx, config, log, key, text_embedding) +// return +// } +// most_similar_key := query_resp.Output[0].Fields["query"].(string) +// log.Infof("most similar key:%s", most_similar_key) +// most_similar_score := query_resp.Output[0].Score +// if most_similar_score < 0.1 { +// ctx.SetContext(CacheKeyContextKey, nil) +// RedisSearchHandler(most_similar_key, ctx, config, log, stream, false) +// } else { +// log.Infof("the most similar key's score is too high, key:%s, score:%f", most_similar_key, most_similar_score) +// UploadQueryEmbedding(ctx, config, log, key, text_embedding) +// proxywasm.ResumeHttpRequest() +// return +// } +// }, +// 100000) +// } + +// // 未命中cache,则将新的query embedding和对应的key存入向量数据库 +// func UploadQueryEmbedding(ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, key string, text_embedding []float64) error { +// vector_url, vector_body, err := ConsturctEmbeddingInsertParameters(&config, log, text_embedding, key) +// if err != nil { +// log.Errorf("Failed to construct embedding insert parameters: %v", err) +// proxywasm.ResumeHttpRequest() +// return nil +// } +// err = config.DashVectorInfo.DashVectorClient.Post( +// vector_url, +// [][2]string{ +// {"Content-Type", "application/json"}, +// {"dashvector-auth-token", config.DashVectorInfo.DashVectorKey}, +// }, +// vector_body, +// func(statusCode int, responseHeaders http.Header, responseBody []byte) { +// if statusCode != 200 { +// log.Errorf("Failed to upload query embedding: %s", responseBody) +// } else { +// log.Infof("Successfully uploaded query embedding for key: %s", key) +// } +// proxywasm.ResumeHttpRequest() +// }, +// 10000, +// ) +// if err != nil { +// log.Errorf("Failed to upload query embedding: %v", err) +// proxywasm.ResumeHttpRequest() +// return nil +// } +// return nil +// } + +// // ===================== 以上是主要逻辑 ===================== diff --git a/plugins/wasm-go/extensions/ai-cache/vectorDatabase/dashvector.go b/plugins/wasm-go/extensions/ai-cache/vectorDatabase/dashvector.go new file mode 100644 index 0000000000..52afbad67b --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/vectorDatabase/dashvector.go @@ -0,0 +1,153 @@ +package vectorDatabase + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" +) + +type dashVectorProviderInitializer struct { +} + +func (d *dashVectorProviderInitializer) ValidateConfig(config ProviderConfig) error { + if len(config.DashVectorKey) == 0 { + return errors.New("DashVectorKey is required") + } + if len(config.DashVectorAuthApiEnd) == 0 { + return errors.New("DashVectorEnd is required") + } + if len(config.DashVectorCollection) == 0 { + return errors.New("DashVectorCollection is required") + } + if len(config.DashVectorServiceName) == 0 { + return errors.New("DashVectorServiceName is required") + } + return nil +} + +func (d *dashVectorProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + return &DvProvider{config: config}, nil +} + +type DvProvider struct { + config ProviderConfig +} + +func (d *DvProvider) GetProviderType() string { + return providerTypeDashVector +} + +type EmbeddingRequest struct { + Model string `json:"model"` + Input Input `json:"input"` + Parameters Params `json:"parameters"` +} + +type Params struct { + TextType string `json:"text_type"` +} + +type Input struct { + Texts []string `json:"texts"` +} + +func (d *DvProvider) ConstructEmbeddingQueryParameters(vector []float64) (string, []byte, [][2]string, error) { + url := fmt.Sprintf("/v1/collections/%s/query", d.config.DashVectorCollection) + + requestData := QueryRequest{ + Vector: vector, + TopK: d.config.DashVectorTopK, + IncludeVector: false, + } + + requestBody, err := json.Marshal(requestData) + if err != nil { + return "", nil, nil, err + } + + header := [][2]string{ + {"Content-Type", "application/json"}, + {"dashvector-auth-token", d.config.DashVectorKey}, + } + + return url, requestBody, header, nil +} + +func (d *DvProvider) ParseQueryResponse(responseBody []byte) (QueryResponse, error) { + var queryResp QueryResponse + err := json.Unmarshal(responseBody, &queryResp) + if err != nil { + return QueryResponse{}, err + } + return queryResp, nil +} + +func (d *DvProvider) QueryEmbedding(queryEmb []float64, + ctx wrapper.HttpContext, log wrapper.Log, + callback func(query_resp QueryResponse, ctx wrapper.HttpContext, log wrapper.Log)) { + url, body, headers, err := d.ConstructEmbeddingQueryParameters(queryEmb) + if err != nil { + log.Infof("Failed to construct embedding query parameters: %v", err) + } + d.config.DashVectorClient.Post( + url, + headers, + body, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + log.Infof("Query embedding response: %d, %s", statusCode, responseBody) + query_resp, err_query := d.ParseQueryResponse(responseBody) + if err_query != nil { + log.Infof("Failed to parse response: %v", err_query) + } + callback(query_resp, ctx, log) + }, + d.config.DashVectorTimeout) +} + +type Document struct { + Vector []float64 `json:"vector"` + Fields map[string]string `json:"fields"` +} + +type InsertRequest struct { + Docs []Document `json:"docs"` +} + +func (d *DvProvider) ConstructEmbeddingUploadParameters(emb []float64, query_string string) (string, []byte, [][2]string, error) { + url := "/v1/collections/" + d.config.DashVectorCollection + "/docs" + + doc := Document{ + Vector: emb, + Fields: map[string]string{ + "query": query_string, + }, + } + + requestBody, err := json.Marshal(InsertRequest{Docs: []Document{doc}}) + if err != nil { + return "", nil, nil, err + } + + header := [][2]string{ + {"Content-Type", "application/json"}, + {"dashvector-auth-token", d.config.DashVectorKey}, + } + + return url, requestBody, header, err +} + +func (d *DvProvider) UploadEmbedding(query_emb []float64, queryString string, ctx wrapper.HttpContext, log wrapper.Log, callback func(ctx wrapper.HttpContext, log wrapper.Log)) { + url, body, headers, _ := d.ConstructEmbeddingUploadParameters(query_emb, queryString) + d.config.DashVectorClient.Post( + url, + headers, + body, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + log.Infof("statusCode:%d, responseBody:%s", statusCode, string(responseBody)) + callback(ctx, log) + }, + d.config.DashVectorTimeout) +} diff --git a/plugins/wasm-go/extensions/ai-cache/vectorDatabase/provider.go b/plugins/wasm-go/extensions/ai-cache/vectorDatabase/provider.go new file mode 100644 index 0000000000..9d810d2d76 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/vectorDatabase/provider.go @@ -0,0 +1,124 @@ +package vectorDatabase + +import ( + "errors" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" +) + +const ( + providerTypeDashVector = "dashvector" +) + +type providerInitializer interface { + ValidateConfig(ProviderConfig) error + CreateProvider(ProviderConfig) (Provider, error) +} + +var ( + providerInitializers = map[string]providerInitializer{ + providerTypeDashVector: &dashVectorProviderInitializer{}, + } +) + +type ProviderConfig struct { + // @Title zh-CN 向量存储服务提供者类型 + // @Description zh-CN 向量存储服务提供者类型,例如 DashVector、Milvus + typ string `json:"vectorStoreProviderType"` + // @Title zh-CN DashVector 阿里云向量搜索引擎 + // @Description zh-CN 调用阿里云的向量搜索引擎 + DashVectorServiceName string `require:"true" yaml:"DashVectorServiceName" jaon:"DashVectorServiceName"` + // @Title zh-CN DashVector Key + // @Description zh-CN 阿里云向量搜索引擎的 key + DashVectorKey string `require:"true" yaml:"DashVectorKey" jaon:"DashVectorKey"` + // @Title zh-CN DashVector AuthApiEnd + // @Description zh-CN 阿里云向量搜索引擎的 AuthApiEnd + DashVectorAuthApiEnd string `require:"true" yaml:"DashVectorEnd" jaon:"DashVectorEnd"` + // @Title zh-CN DashVector Collection + // @Description zh-CN 指定使用阿里云搜索引擎中的哪个向量集合 + DashVectorCollection string `require:"true" yaml:"DashVectorCollection" jaon:"DashVectorCollection"` + // @Title zh-CN DashVector Client + // @Description zh-CN 阿里云向量搜索引擎的 Client + DashVectorTopK int `require:"true" yaml:"DashVectorTopK" jaon:"DashVectorTopK"` + DashVectorTimeout uint32 `require:"true" yaml:"DashVectorTimeout" jaon:"DashVectorTimeout"` + DashVectorClient wrapper.HttpClient `yaml:"-" json:"-"` +} + +func (c *ProviderConfig) FromJson(json gjson.Result) { + c.typ = json.Get("vectorStoreProviderType").String() + c.DashVectorServiceName = json.Get("DashVectorServiceName").String() + c.DashVectorKey = json.Get("DashVectorKey").String() + c.DashVectorAuthApiEnd = json.Get("DashVectorEnd").String() + c.DashVectorCollection = json.Get("DashVectorCollection").String() + c.DashVectorTopK = int(json.Get("DashVectorTopK").Int()) + if c.DashVectorTopK == 0 { + c.DashVectorTopK = 1 + } + c.DashVectorTimeout = uint32(json.Get("DashVectorTimeout").Int()) + if c.DashVectorTimeout == 0 { + c.DashVectorTimeout = 10000 + } +} + +func (c *ProviderConfig) Validate() error { + if len(c.DashVectorKey) == 0 { + return errors.New("DashVectorKey is required") + } + if len(c.DashVectorServiceName) == 0 { + return errors.New("DashVectorServiceName is required") + } + if len(c.DashVectorAuthApiEnd) == 0 { + return errors.New("DashVectorAuthApiEnd is required") + } + if len(c.DashVectorCollection) == 0 { + return errors.New("DashVectorCollection is required") + } + return nil +} + +type Provider interface { + GetProviderType() string + QueryEmbedding( + emb []float64, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(query_resp QueryResponse, ctx wrapper.HttpContext, log wrapper.Log)) + UploadEmbedding( + query_emb []float64, + queryString string, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(ctx wrapper.HttpContext, log wrapper.Log)) +} + +// QueryResponse 定义查询响应的结构 +type QueryResponse struct { + Code int `json:"code"` + RequestID string `json:"request_id"` + Message string `json:"message"` + Output []Result `json:"output"` +} + +// QueryRequest 定义查询请求的结构 +type QueryRequest struct { + Vector []float64 `json:"vector"` + TopK int `json:"topk"` + IncludeVector bool `json:"include_vector"` +} + +// Result 定义查询结果的结构 +type Result struct { + ID string `json:"id"` + Vector []float64 `json:"vector,omitempty"` // omitempty 使得如果 vector 是空,它将不会被序列化 + Fields map[string]interface{} `json:"fields"` + Score float64 `json:"score"` +} + +func CreateProvider(pc ProviderConfig) (Provider, error) { + initializer, has := providerInitializers[pc.typ] + if !has { + return nil, errors.New("unknown provider type: " + pc.typ) + } + return initializer.CreateProvider(pc) +} From 0f9e816701f0362459d5af267de0d57da431c1a9 Mon Sep 17 00:00:00 2001 From: suchun <2594405419@qq.com> Date: Thu, 1 Aug 2024 15:29:19 +0100 Subject: [PATCH 02/71] fix bugs --- plugins/wasm-go/extensions/ai-cache/go.mod | 4 ++-- plugins/wasm-go/extensions/ai-cache/go.sum | 10 +++------- plugins/wasm-go/go.mod | 2 +- plugins/wasm-go/go.sum | 3 +++ 4 files changed, 9 insertions(+), 10 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-cache/go.mod b/plugins/wasm-go/extensions/ai-cache/go.mod index 02d62a48e7..65c863b4e4 100644 --- a/plugins/wasm-go/extensions/ai-cache/go.mod +++ b/plugins/wasm-go/extensions/ai-cache/go.mod @@ -4,13 +4,13 @@ module github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache go 1.19 +replace github.com/alibaba/higress/plugins/wasm-go => ../.. require ( github.com/alibaba/higress/plugins/wasm-go v1.3.6-0.20240528060522-53bccf89f441 github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f - github.com/tidwall/gjson v1.14.3 + github.com/tidwall/gjson v1.17.3 github.com/tidwall/resp v0.1.1 - github.com/tidwall/sjson v1.2.5 ) require ( diff --git a/plugins/wasm-go/extensions/ai-cache/go.sum b/plugins/wasm-go/extensions/ai-cache/go.sum index 8246b4de5e..042eae70f2 100644 --- a/plugins/wasm-go/extensions/ai-cache/go.sum +++ b/plugins/wasm-go/extensions/ai-cache/go.sum @@ -3,22 +3,18 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew= -github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc h1:t2AT8zb6N/59Y78lyRWedVoVWHNRSCBh0oWCC+bluTQ= -github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg= github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo= github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw= -github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94= +github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE= github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0= -github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= -github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/plugins/wasm-go/go.mod b/plugins/wasm-go/go.mod index 999721f3f6..6373ff646e 100644 --- a/plugins/wasm-go/go.mod +++ b/plugins/wasm-go/go.mod @@ -7,7 +7,7 @@ require ( github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f github.com/stretchr/testify v1.8.4 - github.com/tidwall/gjson v1.14.3 + github.com/tidwall/gjson v1.17.3 github.com/tidwall/resp v0.1.1 ) diff --git a/plugins/wasm-go/go.sum b/plugins/wasm-go/go.sum index 5b23dc2c4a..f396d4d7d9 100644 --- a/plugins/wasm-go/go.sum +++ b/plugins/wasm-go/go.sum @@ -10,6 +10,7 @@ github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240318034951-d5306e367c43 h1 github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240318034951-d5306e367c43/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc h1:t2AT8zb6N/59Y78lyRWedVoVWHNRSCBh0oWCC+bluTQ= github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg= github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo= github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= @@ -19,6 +20,8 @@ github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcU github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw= github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94= +github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= From ff1bce6d859015aef193af6f8d97ba415eea535a Mon Sep 17 00:00:00 2001 From: suchun <2594405419@qq.com> Date: Mon, 12 Aug 2024 03:59:16 +0100 Subject: [PATCH 03/71] fix bugs --- .../extensions/ai-cache/cache/cache.go | 17 ++- .../extensions/ai-cache/config/config.go | 61 ++++++++- .../ai-cache/embedding/dashscope.go | 2 +- .../extensions/ai-cache/embedding/provider.go | 17 ++- plugins/wasm-go/extensions/ai-cache/main.go | 20 ++- .../extensions/ai-cache/util/cachelogic.go | 129 +++++++++++++----- .../ai-cache/vectorDatabase/dashvector.go | 34 +++-- .../ai-cache/vectorDatabase/provider.go | 4 +- 8 files changed, 226 insertions(+), 58 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-cache/cache/cache.go b/plugins/wasm-go/extensions/ai-cache/cache/cache.go index e4d1f1e81f..1085937d2d 100644 --- a/plugins/wasm-go/extensions/ai-cache/cache/cache.go +++ b/plugins/wasm-go/extensions/ai-cache/cache/cache.go @@ -26,13 +26,24 @@ type RedisConfig struct { RedisTimeout uint32 `required:"false" yaml:"timeout" json:"timeout"` } -func CreateProvider(cf RedisConfig) (Provider, error) { +func CreateProvider(cf RedisConfig, log wrapper.Log) (Provider, error) { rp := redisProvider{ config: cf, client: wrapper.NewRedisClusterClient(wrapper.FQDNCluster{ FQDN: cf.RedisServiceName, + Host: "redis", Port: int64(cf.RedisServicePort)}), + // client: wrapper.NewRedisClusterClient(wrapper.DnsCluster{ + // ServiceName: cf.RedisServiceName, + // Port: int64(cf.RedisServicePort)}), } + // FQDN := wrapper.FQDNCluster{ + // FQDN: cf.RedisServiceName, + // Host: "redis", + // Port: int64(cf.RedisServicePort)} + // log.Debugf("test:%s", FQDN.ClusterName()) + // log.Debugf("test:%d", cf.RedisServicePort) + // log.Debugf("test:%s", proxywasm.RedisInit(FQDN.ClusterName(), "", "", 100)) err := rp.Init(cf.RedisUsername, cf.RedisPassword, cf.RedisTimeout) return &rp, err } @@ -43,6 +54,9 @@ func (c *RedisConfig) FromJson(json gjson.Result) { c.RedisTimeout = uint32(json.Get("timeout").Int()) c.RedisServiceName = json.Get("serviceName").String() c.RedisServicePort = int(json.Get("servicePort").Int()) + if c.RedisServicePort == 0 { + c.RedisServicePort = 6379 + } } func (c *RedisConfig) Validate() error { @@ -56,6 +70,7 @@ func (c *RedisConfig) Validate() error { c.RedisServicePort = 6379 } if len(c.RedisUsername) == 0 { + // return errors.New("redis.username is required") c.RedisUsername = "" } if len(c.RedisPassword) == 0 { diff --git a/plugins/wasm-go/extensions/ai-cache/config/config.go b/plugins/wasm-go/extensions/ai-cache/config/config.go index 378d956fc3..d31f6e2b70 100644 --- a/plugins/wasm-go/extensions/ai-cache/config/config.go +++ b/plugins/wasm-go/extensions/ai-cache/config/config.go @@ -4,6 +4,7 @@ import ( "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/cache" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/embedding" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/vectorDatabase" + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/tidwall/gjson" ) @@ -20,9 +21,23 @@ type PluginConfig struct { CacheKeyFrom KVExtractor `required:"true" yaml:"cacheKeyFrom" json:"cacheKeyFrom"` CacheValueFrom KVExtractor `required:"true" yaml:"cacheValueFrom" json:"cacheValueFrom"` CacheStreamValueFrom KVExtractor `required:"true" yaml:"cacheStreamValueFrom" json:"cacheStreamValueFrom"` + // @Title zh-CN 返回 HTTP 响应的模版 + // @Description zh-CN 用 %s 标记需要被 cache value 替换的部分 + ReturnResponseTemplate string `required:"true" yaml:"returnResponseTemplate" json:"returnResponseTemplate"` + // @Title zh-CN 返回流式 HTTP 响应的模版 + // @Description zh-CN 用 %s 标记需要被 cache value 替换的部分 + ReturnTestResponseTemplate string `required:"true" yaml:"returnTestResponseTemplate" json:"returnTestResponseTemplate"` - CacheKeyPrefix string `required:"false" yaml:"cacheKeyPrefix" json:"cacheKeyPrefix"` - RedisConfig cache.RedisConfig `required:"true" yaml:"redisConfig" json:"redisConfig"` + CacheKeyPrefix string `required:"false" yaml:"cacheKeyPrefix" json:"cacheKeyPrefix"` + + ReturnStreamResponseTemplate string `required:"true" yaml:"returnStreamResponseTemplate" json:"returnStreamResponseTemplate"` + // @Title zh-CN 缓存的过期时间 + // @Description zh-CN 单位是秒,默认值为0,即永不过期 + CacheTTL int `required:"false" yaml:"cacheTTL" json:"cacheTTL"` + // @Title zh-CN Redis缓存Key的前缀 + // @Description zh-CN 默认值是"higress-ai-cache:" + + RedisConfig cache.RedisConfig `required:"true" yaml:"redisConfig" json:"redisConfig"` // 现在只支持RedisClient作为cacheClient redisProvider cache.Provider `yaml:"-"` embeddingProvider embedding.Provider `yaml:"-"` @@ -33,6 +48,33 @@ func (c *PluginConfig) FromJson(json gjson.Result) { c.EmbeddingProviderConfig.FromJson(json.Get("embeddingProvider")) c.VectorDatabaseProviderConfig.FromJson(json.Get("vectorBaseProvider")) c.RedisConfig.FromJson(json.Get("redis")) + if c.CacheKeyFrom.RequestBody == "" { + c.CacheKeyFrom.RequestBody = "messages.@reverse.0.content" + } + c.CacheKeyFrom.RequestBody = json.Get("cacheKeyFrom.requestBody").String() + if c.CacheKeyFrom.RequestBody == "" { + c.CacheKeyFrom.RequestBody = "messages.@reverse.0.content" + } + c.CacheValueFrom.ResponseBody = json.Get("cacheValueFrom.responseBody").String() + if c.CacheValueFrom.ResponseBody == "" { + c.CacheValueFrom.ResponseBody = "choices.0.message.content" + } + c.CacheStreamValueFrom.ResponseBody = json.Get("cacheStreamValueFrom.responseBody").String() + if c.CacheStreamValueFrom.ResponseBody == "" { + c.CacheStreamValueFrom.ResponseBody = "choices.0.delta.content" + } + c.ReturnResponseTemplate = json.Get("returnResponseTemplate").String() + if c.ReturnResponseTemplate == "" { + c.ReturnResponseTemplate = `{"id":"from-cache","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + } + c.ReturnStreamResponseTemplate = json.Get("returnStreamResponseTemplate").String() + if c.ReturnStreamResponseTemplate == "" { + c.ReturnStreamResponseTemplate = `data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + "\n\ndata:[DONE]\n\n" + } + c.ReturnTestResponseTemplate = json.Get("returnTestResponseTemplate").String() + if c.ReturnTestResponseTemplate == "" { + c.ReturnTestResponseTemplate = `{"id":"random-generate","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + } } func (c *PluginConfig) Validate() error { @@ -48,12 +90,21 @@ func (c *PluginConfig) Validate() error { return nil } -func (c *PluginConfig) Complete() error { +func (c *PluginConfig) Complete(log wrapper.Log) error { var err error c.embeddingProvider, err = embedding.CreateProvider(c.EmbeddingProviderConfig) + if err != nil { + return err + } c.vectorDatabaseProvider, err = vectorDatabase.CreateProvider(c.VectorDatabaseProviderConfig) - c.redisProvider, err = cache.CreateProvider(c.RedisConfig) - return err + if err != nil { + return err + } + c.redisProvider, err = cache.CreateProvider(c.RedisConfig, log) + if err != nil { + return err + } + return nil } func (c *PluginConfig) GetEmbeddingProvider() embedding.Provider { diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go b/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go index f979e1d4ae..99f387e21d 100644 --- a/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go +++ b/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go @@ -184,7 +184,7 @@ func (d *DSProvider) GetEmbedding( // 调用回调函数 callback(resp.Output.Embeddings[0].Embedding, statusCode, responseHeaders, responseBody) - }) + }, d.config.DashScopeTimeout) return nil } diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/provider.go b/plugins/wasm-go/extensions/ai-cache/embedding/provider.go index ea0f58398c..4cf9ebd47f 100644 --- a/plugins/wasm-go/extensions/ai-cache/embedding/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/embedding/provider.go @@ -9,7 +9,15 @@ import ( ) const ( - providerTypeDashScope = "dashscope" + providerTypeDashScope = "dashscope" + CacheKeyContextKey = "cacheKey" + CacheContentContextKey = "cacheContent" + PartialMessageContextKey = "partialMessage" + ToolCallsContextKey = "toolCalls" + StreamContextKey = "stream" + CacheKeyPrefix = "higressAiCache" + DefaultCacheKeyPrefix = "higressAiCache" + QueryEmbeddingKey = "queryEmbedding" ) type providerInitializer interface { @@ -32,8 +40,8 @@ type ProviderConfig struct { ServiceName string `require:"true" yaml:"DashScopeServiceName" jaon:"DashScopeServiceName"` Client wrapper.HttpClient `yaml:"-"` DashScopeKey string `require:"true" yaml:"DashScopeKey" jaon:"DashScopeKey"` - DashScopeTimeout uint32 `require:"true" yaml:"DashScopeTimeout" jaon:"DashScopeTimeout"` - QueryEmbeddingKey string `require:"true" yaml:"QueryEmbeddingKey" jaon:"QueryEmbeddingKey"` + DashScopeTimeout uint32 `require:"false" yaml:"DashScopeTimeout" jaon:"DashScopeTimeout"` + QueryEmbeddingKey string `require:"false" yaml:"QueryEmbeddingKey" jaon:"QueryEmbeddingKey"` } func (c *ProviderConfig) FromJson(json gjson.Result) { @@ -41,6 +49,9 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { c.ServiceName = json.Get("DashScopeServiceName").String() c.DashScopeKey = json.Get("DashScopeKey").String() c.DashScopeTimeout = uint32(json.Get("DashScopeTimeout").Int()) + if c.DashScopeTimeout == 0 { + c.DashScopeTimeout = 1000 + } c.QueryEmbeddingKey = json.Get("QueryEmbeddingKey").String() } diff --git a/plugins/wasm-go/extensions/ai-cache/main.go b/plugins/wasm-go/extensions/ai-cache/main.go index 2334c1a777..8e018c8b4f 100644 --- a/plugins/wasm-go/extensions/ai-cache/main.go +++ b/plugins/wasm-go/extensions/ai-cache/main.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/config" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/embedding" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" @@ -37,11 +38,17 @@ func main() { } func parseConfig(json gjson.Result, config *config.PluginConfig, log wrapper.Log) error { - config.EmbeddingProviderConfig.FromJson(json.Get("embeddingProvider")) - config.VectorDatabaseProviderConfig.FromJson(json.Get("vectorBaseProvider")) + // config.EmbeddingProviderConfig.FromJson(json.Get("embeddingProvider")) + // config.VectorDatabaseProviderConfig.FromJson(json.Get("vectorBaseProvider")) + // config.RedisConfig.FromJson(json.Get("redis")) + config.FromJson(json) if err := config.Validate(); err != nil { return err } + if err := config.Complete(log); err != nil { + log.Errorf("complete config failed:%v", err) + return err + } return nil } @@ -79,6 +86,8 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config config.PluginConfig, body } // key := TrimQuote(bodyJson.Get(config.CacheKeyFrom.RequestBody).Raw) key := bodyJson.Get(config.CacheKeyFrom.RequestBody).String() + ctx.SetContext(CacheKeyContextKey, key) + log.Debugf("[onHttpRequestBody] key:%s", key) if key == "" { log.Debug("[onHttpRquestBody] parse key from request body failed") return types.ActionContinue @@ -136,6 +145,7 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config config.PluginConfig, func onHttpResponseBody(ctx wrapper.HttpContext, config config.PluginConfig, chunk []byte, isLastChunk bool, log wrapper.Log) []byte { // log.Infof("I am here") + log.Debugf("[onHttpResponseBody] i am here") if ctx.GetContext(ToolCallsContextKey) != nil { // we should not cache tool call result return chunk @@ -199,6 +209,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config config.PluginConfig, chu return chunk } } else { + log.Infof("[onHttpResponseBody] stream mode") if len(chunk) > 0 { var lastMessage []byte partialMessageI := ctx.GetContext(PartialMessageContextKey) @@ -208,7 +219,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config config.PluginConfig, chu lastMessage = chunk } if !strings.HasSuffix(string(lastMessage), "\n\n") { - log.Warnf("invalid lastMessage:%s", lastMessage) + log.Warnf("[onHttpResponseBody] invalid lastMessage:%s", lastMessage) return chunk } // remove the last \n\n @@ -217,13 +228,14 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config config.PluginConfig, chu } else { tempContentI := ctx.GetContext(CacheContentContextKey) if tempContentI == nil { + log.Warnf("[onHttpResponseBody] no content in tempContentI") return chunk } value = tempContentI.(string) } } log.Infof("[onHttpResponseBody] Setting cache to redis, key:%s, value:%s", key, value) - config.GetCacheProvider().Set(config.CacheKeyPrefix+key, value, nil) + config.GetCacheProvider().Set(embedding.CacheKeyPrefix+key, value, nil) // TODO: 要不要加个Expire方法 // if config.RedisConfig.RedisTimeout != 0 { // config.GetCacheProvider().Expire(config.CacheKeyPrefix+key, config.RedisConfig.RedisTimeout, nil) diff --git a/plugins/wasm-go/extensions/ai-cache/util/cachelogic.go b/plugins/wasm-go/extensions/ai-cache/util/cachelogic.go index 95b8b6e1cf..0b817c4121 100644 --- a/plugins/wasm-go/extensions/ai-cache/util/cachelogic.go +++ b/plugins/wasm-go/extensions/ai-cache/util/cachelogic.go @@ -2,9 +2,11 @@ package util import ( + "fmt" "net/http" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/config" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/embedding" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/vectorDatabase" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" @@ -24,7 +26,8 @@ import ( func RedisSearchHandler(key string, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, stream bool, ifUseEmbedding bool) { activeCacheProvider := config.GetCacheProvider() - activeCacheProvider.Get(config.CacheKeyPrefix+key, func(response resp.Value) { + log.Debugf("activeCacheProvider:%v", activeCacheProvider) + activeCacheProvider.Get(embedding.CacheKeyPrefix+key, func(response resp.Value) { if err := response.Error(); err == nil && !response.IsNull() { log.Warnf("cache hit, key:%s", key) HandleCacheHit(key, response, stream, ctx, config, log) @@ -41,44 +44,102 @@ func RedisSearchHandler(key string, ctx wrapper.HttpContext, config config.Plugi } func HandleCacheHit(key string, response resp.Value, stream bool, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log) { - activeEmbeddingProvider := config.GetEmbeddingProvider() - activeVectorDatabaseProvider := config.GetVectorDatabaseProvider() + ctx.SetContext(embedding.CacheKeyContextKey, nil) + if !stream { + proxywasm.SendHttpResponse(200, [][2]string{{"content-type", "application/json; charset=utf-8"}}, []byte(fmt.Sprintf(config.ReturnResponseTemplate, "[Test, this is cache]"+response.String())), -1) + } else { + proxywasm.SendHttpResponse(200, [][2]string{{"content-type", "text/event-stream; charset=utf-8"}}, []byte(fmt.Sprintf(config.ReturnStreamResponseTemplate, "[Test, this is cache]"+response.String())), -1) + } +} - activeEmbeddingProvider.GetEmbedding(key, ctx, log, - func(textEmbedding []float64, stateCode int, responseHeaders http.Header, responseBody []byte) { - activeVectorDatabaseProvider.QueryEmbedding(textEmbedding, ctx, log, - func(queryResponse vectorDatabase.QueryResponse, ctx wrapper.HttpContext, log wrapper.Log) { - if len(queryResponse.Output) < 1 { - log.Warnf("query response is empty") - activeVectorDatabaseProvider.UploadEmbedding(textEmbedding, key, ctx, log, - func(ctx wrapper.HttpContext, log wrapper.Log) { - proxywasm.ResumeHttpRequest() - }) - return - } - mostSimilarKey := queryResponse.Output[0].Fields["query"].(string) - log.Infof("most similar key:%s", mostSimilarKey) - mostSimilarScore := queryResponse.Output[0].Score - if mostSimilarScore < 0.1 { - // ctx.SetContext(config.CacheKeyContextKey, nil) - // RedisSearchHandler(mostSimilarKey, ctx, config, log, stream, false) - } else { - log.Infof("the most similar key's score is too high, key:%s, score:%f", mostSimilarKey, mostSimilarScore) - activeVectorDatabaseProvider.UploadEmbedding(textEmbedding, key, ctx, log, - func(ctx wrapper.HttpContext, log wrapper.Log) { - proxywasm.ResumeHttpRequest() - }) - proxywasm.ResumeHttpRequest() - return - } +func HandleCacheMiss(key string, err error, response resp.Value, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, queryString string, stream bool) { + if err != nil { + log.Warnf("redis get key:%s failed, err:%v", key, err) + } + if response.IsNull() { + log.Warnf("cache miss, key:%s", key) + } + FetchAndProcessEmbeddings(key, ctx, config, log, queryString, stream) +} - }, - ) +func FetchAndProcessEmbeddings(key string, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, queryString string, stream bool) { + activeEmbeddingProvider := config.GetEmbeddingProvider() + activeEmbeddingProvider.GetEmbedding(queryString, ctx, log, + func(emb []float64, statusCode int, responseHeaders http.Header, responseBody []byte) { + if statusCode != 200 { + log.Errorf("Failed to fetch embeddings, statusCode: %d, responseBody: %s", statusCode, string(responseBody)) + } else { + log.Debugf("Successfully fetched embeddings for key: %s", key) + QueryVectorDB(key, emb, ctx, config, log, stream) + } }) } -func HandleCacheMiss(key string, err error, response resp.Value, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, queryString string, stream bool) { - proxywasm.ResumeHttpRequest() +func QueryVectorDB(key string, text_embedding []float64, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, stream bool) { + log.Debugf("QueryVectorDB key:%s", key) + activeVectorDatabaseProvider := config.GetVectorDatabaseProvider() + log.Debugf("activeVectorDatabaseProvider:%v", activeVectorDatabaseProvider) + activeVectorDatabaseProvider.QueryEmbedding(text_embedding, ctx, log, + func(query_resp vectorDatabase.QueryResponse, ctx wrapper.HttpContext, log wrapper.Log) { + if len(query_resp.Output) < 1 { + log.Warnf("query response is empty") + activeVectorDatabaseProvider.UploadEmbedding(text_embedding, key, ctx, log, + func(ctx wrapper.HttpContext, log wrapper.Log) { + proxywasm.ResumeHttpRequest() + }) + return + } + mostSimilarKey := query_resp.Output[0].Fields["query"].(string) + log.Infof("most similar key:%s", mostSimilarKey) + mostSimilarScore := query_resp.Output[0].Score + if mostSimilarScore < 2000 { + log.Infof("accept most similar key:%s, score:%f", mostSimilarKey, mostSimilarScore) + // ctx.SetContext(embedding.CacheKeyContextKey, nil) + RedisSearchHandler(mostSimilarKey, ctx, config, log, stream, false) + } else { + log.Infof("the most similar key's score is too high, key:%s, score:%f", mostSimilarKey, mostSimilarScore) + activeVectorDatabaseProvider.UploadEmbedding(text_embedding, key, ctx, log, + func(ctx wrapper.HttpContext, log wrapper.Log) { + proxywasm.ResumeHttpRequest() + }) + proxywasm.ResumeHttpRequest() + return + } + }, + ) + // activeVectorDatabaseProvider.QueryEmbedding(text_embedding, ctx, log, + // func(query_resp vectorDatabase.QueryResponse, ctx wrapper.HttpContext, log wrapper.Log) { + // if len(query_resp.Output) < 1 { + // log.Warnf("query response is empty") + // // UploadQueryEmbedding(ctx, config, log, key, text_embedding) + // activeVectorDatabaseProvider.UploadEmbedding(text_embedding, key, ctx, log, + // func(ctx wrapper.HttpContext, log wrapper.Log) { + // // proxywasm.ResumeHttpRequest() + // log.Debugf("I am in the 117 line") + // }) + // return + // } + // most_similar_key := query_resp.Output[0].Fields["query"].(string) + // log.Infof("most similar key:%s", most_similar_key) + // most_similar_score := query_resp.Output[0].Score + // if most_similar_score < 0.1 { + // // ctx.SetContext(CacheKeyContextKey, nil) + // // RedisSearchHandler(most_similar_key, ctx, config, log, stream, false) + // } else { + // log.Infof("the most similar key's score is too high, key:%s, score:%f", most_similar_key, most_similar_score) + // // UploadQueryEmbedding(ctx, config, log, key, text_embedding) + // activeVectorDatabaseProvider.UploadEmbedding(text_embedding, key, ctx, log, + // func(ctx wrapper.HttpContext, log wrapper.Log) { + // proxywasm.ResumeHttpRequest() + // }) + // proxywasm.ResumeHttpRequest() + // return + // } + // }) + // ctx.SetContext(embedding.CacheKeyContextKey, text_embedding) + // ctx.SetContext(embedding.QueryEmbeddingKey, text_embedding) + // ctx.SetContext(embedding.CacheKeyContextKey, key) + // PerformQueryAndRespond(key, text_embedding, ctx, config, log, stream) } // // 简单处理缓存命中的情况, 从redis中获取到value后,直接返回 diff --git a/plugins/wasm-go/extensions/ai-cache/vectorDatabase/dashvector.go b/plugins/wasm-go/extensions/ai-cache/vectorDatabase/dashvector.go index 52afbad67b..27bf69e65e 100644 --- a/plugins/wasm-go/extensions/ai-cache/vectorDatabase/dashvector.go +++ b/plugins/wasm-go/extensions/ai-cache/vectorDatabase/dashvector.go @@ -9,6 +9,10 @@ import ( "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" ) +const ( + dashVectorPort = 443 +) + type dashVectorProviderInitializer struct { } @@ -29,11 +33,19 @@ func (d *dashVectorProviderInitializer) ValidateConfig(config ProviderConfig) er } func (d *dashVectorProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { - return &DvProvider{config: config}, nil + return &DvProvider{ + config: config, + client: wrapper.NewClusterClient(wrapper.DnsCluster{ + ServiceName: config.DashVectorServiceName, + Port: dashVectorPort, + Domain: config.DashVectorAuthApiEnd, + }), + }, nil } type DvProvider struct { config ProviderConfig + client wrapper.HttpClient } func (d *DvProvider) GetProviderType() string { @@ -85,17 +97,19 @@ func (d *DvProvider) ParseQueryResponse(responseBody []byte) (QueryResponse, err return queryResp, nil } -func (d *DvProvider) QueryEmbedding(queryEmb []float64, - ctx wrapper.HttpContext, log wrapper.Log, +func (d *DvProvider) QueryEmbedding( + queryEmb []float64, + ctx wrapper.HttpContext, + log wrapper.Log, callback func(query_resp QueryResponse, ctx wrapper.HttpContext, log wrapper.Log)) { + + // 构造请求参数 url, body, headers, err := d.ConstructEmbeddingQueryParameters(queryEmb) if err != nil { log.Infof("Failed to construct embedding query parameters: %v", err) } - d.config.DashVectorClient.Post( - url, - headers, - body, + + err = d.client.Post(url, headers, body, func(statusCode int, responseHeaders http.Header, responseBody []byte) { log.Infof("Query embedding response: %d, %s", statusCode, responseBody) query_resp, err_query := d.ParseQueryResponse(responseBody) @@ -105,6 +119,10 @@ func (d *DvProvider) QueryEmbedding(queryEmb []float64, callback(query_resp, ctx, log) }, d.config.DashVectorTimeout) + if err != nil { + log.Infof("Failed to query embedding: %v", err) + } + } type Document struct { @@ -141,7 +159,7 @@ func (d *DvProvider) ConstructEmbeddingUploadParameters(emb []float64, query_str func (d *DvProvider) UploadEmbedding(query_emb []float64, queryString string, ctx wrapper.HttpContext, log wrapper.Log, callback func(ctx wrapper.HttpContext, log wrapper.Log)) { url, body, headers, _ := d.ConstructEmbeddingUploadParameters(query_emb, queryString) - d.config.DashVectorClient.Post( + d.client.Post( url, headers, body, diff --git a/plugins/wasm-go/extensions/ai-cache/vectorDatabase/provider.go b/plugins/wasm-go/extensions/ai-cache/vectorDatabase/provider.go index 9d810d2d76..e4b78dbc93 100644 --- a/plugins/wasm-go/extensions/ai-cache/vectorDatabase/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/vectorDatabase/provider.go @@ -40,8 +40,8 @@ type ProviderConfig struct { DashVectorCollection string `require:"true" yaml:"DashVectorCollection" jaon:"DashVectorCollection"` // @Title zh-CN DashVector Client // @Description zh-CN 阿里云向量搜索引擎的 Client - DashVectorTopK int `require:"true" yaml:"DashVectorTopK" jaon:"DashVectorTopK"` - DashVectorTimeout uint32 `require:"true" yaml:"DashVectorTimeout" jaon:"DashVectorTimeout"` + DashVectorTopK int `require:"false" yaml:"DashVectorTopK" jaon:"DashVectorTopK"` + DashVectorTimeout uint32 `require:"false" yaml:"DashVectorTimeout" jaon:"DashVectorTimeout"` DashVectorClient wrapper.HttpClient `yaml:"-" json:"-"` } From 1e9d42e379d51bf85bbdd3e7ff7a3be5a0881458 Mon Sep 17 00:00:00 2001 From: Async Date: Thu, 15 Aug 2024 16:28:48 +0800 Subject: [PATCH 04/71] init MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit update update: 注意在使用http协议的时候不要用tls update: add lobechat add: makefile for ai-proxy fix bugs fix bugs fix: redis connection fix: dashvector and dashscope cluster fix: change vdb collection feat: add chroma logic docs: 增加 api 说明 update: no callback version fix: change to callback fix: finish chrome remove: key update: gitignore --- docker-compose-test/docker-compose.yml | 93 +++++ docker-compose-test/envoy.yaml | 231 +++++++++++++ .../wasm-go/extensions/ai-cache/.gitignore | 2 +- .../extensions/ai-cache/cache/cache.go | 109 ++++++ .../extensions/ai-cache/config/config.go | 120 +++++++ .../ai-cache/embedding/dashscope.go | 200 +++++++++++ .../extensions/ai-cache/embedding/provider.go | 87 +++++ .../extensions/ai-cache/embedding/weaviate.go | 27 ++ plugins/wasm-go/extensions/ai-cache/go.mod | 7 +- plugins/wasm-go/extensions/ai-cache/go.sum | 17 +- plugins/wasm-go/extensions/ai-cache/main.go | 275 +++++---------- .../extensions/ai-cache/util/cachelogic.go | 325 ++++++++++++++++++ .../ai-cache/vectorDatabase/chroma.go | 184 ++++++++++ .../ai-cache/vectorDatabase/dashvector.go | 210 +++++++++++ .../ai-cache/vectorDatabase/provider.go | 141 ++++++++ .../ai-cache/vectorDatabase/weaviate.go | 172 +++++++++ plugins/wasm-go/extensions/ai-proxy/Makefile | 4 + plugins/wasm-go/extensions/ai-proxy/go.mod | 4 +- plugins/wasm-go/extensions/ai-proxy/go.sum | 4 +- plugins/wasm-go/extensions/ai-proxy/main.go | 3 +- .../extensions/request-block/Dockerfile | 2 + .../wasm-go/extensions/request-block/Makefile | 4 + .../wasm-go/extensions/request-block/main.go | 2 + plugins/wasm-go/go.mod | 2 +- plugins/wasm-go/go.sum | 8 + 25 files changed, 2033 insertions(+), 200 deletions(-) create mode 100644 docker-compose-test/docker-compose.yml create mode 100644 docker-compose-test/envoy.yaml create mode 100644 plugins/wasm-go/extensions/ai-cache/cache/cache.go create mode 100644 plugins/wasm-go/extensions/ai-cache/config/config.go create mode 100644 plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go create mode 100644 plugins/wasm-go/extensions/ai-cache/embedding/provider.go create mode 100644 plugins/wasm-go/extensions/ai-cache/embedding/weaviate.go create mode 100644 plugins/wasm-go/extensions/ai-cache/util/cachelogic.go create mode 100644 plugins/wasm-go/extensions/ai-cache/vectorDatabase/chroma.go create mode 100644 plugins/wasm-go/extensions/ai-cache/vectorDatabase/dashvector.go create mode 100644 plugins/wasm-go/extensions/ai-cache/vectorDatabase/provider.go create mode 100644 plugins/wasm-go/extensions/ai-cache/vectorDatabase/weaviate.go create mode 100644 plugins/wasm-go/extensions/ai-proxy/Makefile create mode 100644 plugins/wasm-go/extensions/request-block/Dockerfile create mode 100644 plugins/wasm-go/extensions/request-block/Makefile diff --git a/docker-compose-test/docker-compose.yml b/docker-compose-test/docker-compose.yml new file mode 100644 index 0000000000..3b96146349 --- /dev/null +++ b/docker-compose-test/docker-compose.yml @@ -0,0 +1,93 @@ +version: '3.7' +services: + envoy: + # image: higress-registry.cn-hangzhou.cr.aliyuncs.com/higress/gateway:v1.4.0-rc.1 + image: higress-registry.cn-hangzhou.cr.aliyuncs.com/higress/gateway:1.4.2 + entrypoint: /usr/local/bin/envoy + # 注意这里对wasm开启了debug级别日志,正式部署时则默认info级别 + command: -c /etc/envoy/envoy.yaml --component-log-level wasm:debug + depends_on: + - httpbin + - redis + - chroma + networks: + - wasmtest + ports: + - "10000:10000" + - "9901:9901" + volumes: + - ./envoy.yaml:/etc/envoy/envoy.yaml + # 注意默认没有这两个 wasm 的时候,docker 会创建文件夹,这样会出错,需要有 wasm 文件之后 down 然后重新 up + - ./ai-cache.wasm:/etc/envoy/ai-cache.wasm + - ./ai-proxy.wasm:/etc/envoy/ai-proxy.wasm + + chroma: + image: chromadb/chroma + ports: + - "8001:8000" + volumes: + - chroma-data:/chroma/chroma + + redis: + image: redis:latest + networks: + - wasmtest + ports: + - "6379:6379" + + httpbin: + image: kennethreitz/httpbin:latest + networks: + - wasmtest + ports: + - "12345:80" + + lobechat: + # docker hub 如果访问不了,可以改用这个地址:registry.cn-hangzhou.aliyuncs.com/2456868764/lobe-chat:v1.1.3 + image: lobehub/lobe-chat + environment: + - CODE=admin + - OPENAI_API_KEY=unused + - OPENAI_PROXY_URL=http://envoy:10000/v1 + networks: + - wasmtest + ports: + - "3210:3210/tcp" + + # weaviate: + # command: + # - --host + # - 0.0.0.0 + # - --port + # - '8080' + # - --scheme + # - http + # image: cr.weaviate.io/semitechnologies/weaviate:1.26.1 + # ports: + # - 8081:8080 + # - 50051:50051 + # volumes: + # - weaviate_data:/var/lib/weaviate + # restart: on-failure:0 + # networks: + # - wasmtest + # environment: + # QUERY_DEFAULTS_LIMIT: 25 + # AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: 'true' + # PERSISTENCE_DATA_PATH: '/var/lib/weaviate' + # DEFAULT_VECTORIZER_MODULE: 'none' + # ENABLE_API_BASED_MODULES: 'true' + # CLUSTER_HOSTNAME: 'node1' + # TRANSFORMERS_INFERENCE_API: http://t2v-transformers:8080 # Set the inference API endpoint + + # t2v-transformers: # Set the name of the inference container + # image: cr.weaviate.io/semitechnologies/transformers-inference:sentence-transformers-multi-qa-MiniLM-L6-cos-v1 + # environment: + # ENABLE_CUDA: 0 # Set to 1 to enable +volumes: + weaviate_data: {} + chroma-data: + driver: local + +networks: + wasmtest: {} \ No newline at end of file diff --git a/docker-compose-test/envoy.yaml b/docker-compose-test/envoy.yaml new file mode 100644 index 0000000000..dc08f2e846 --- /dev/null +++ b/docker-compose-test/envoy.yaml @@ -0,0 +1,231 @@ +admin: + address: + socket_address: + protocol: TCP + address: 0.0.0.0 + port_value: 9901 +static_resources: + listeners: + - name: listener_0 + address: + socket_address: + protocol: TCP + address: 0.0.0.0 + port_value: 10000 + filter_chains: + - filters: + # httpbin + - name: envoy.filters.network.http_connection_manager + typed_config: + "@type": type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager + scheme_header_transformation: + scheme_to_overwrite: https + stat_prefix: ingress_http + route_config: + name: local_route + virtual_hosts: + - name: local_service + domains: ["*"] + routes: + # - match: + # prefix: "/" + # route: + # cluster: httpbin + - match: + prefix: "/" + route: + cluster: llm + timeout: 300s + + http_filters: + # ai-cache + - name: ai-cache + typed_config: + "@type": type.googleapis.com/udpa.type.v1.TypedStruct + type_url: type.googleapis.com/envoy.extensions.filters.http.wasm.v3.Wasm + value: + config: + name: ai-cache + vm_config: + runtime: envoy.wasm.runtime.v8 + code: + local: + filename: /etc/envoy/ai-cache.wasm + configuration: + "@type": "type.googleapis.com/google.protobuf.StringValue" + value: | + { + "embeddingProvider": { + "TextEmbeddingProviderType": "dashscope", + "ServiceName": "text-embedding-v2", + "DashScopeKey": "sk-your-key", + "DashScopeServiceName": "dashscope" + }, + "vectorBaseProvider": { + "vectorStoreProviderType": "chroma", + "ChromaServiceName": "chroma", + "ChromaCollectionID": "0294deb1-8ef5-4582-b21c-75f23093db2c" + }, + "cacheKeyFrom": { + "requestBody": "" + }, + "cacheValueFrom": { + "responseBody": "" + }, + "cacheStreamValueFrom": { + "responseBody": "" + }, + "returnResponseTemplate": "", + "returnTestResponseTemplate": "", + "ReturnStreamResponseTemplate": "", + "redis": { + "serviceName": "redis_cluster", + "timeout": 2000 + } + } + # 上面的配置中 redis 的配置名字是 redis,而不是 golang tag 中的 redisConfig + + # llm-proxy + - name: llm-proxy + typed_config: + "@type": type.googleapis.com/udpa.type.v1.TypedStruct + type_url: type.googleapis.com/envoy.extensions.filters.http.wasm.v3.Wasm + value: + config: + name: llm + vm_config: + runtime: envoy.wasm.runtime.v8 + code: + local: + filename: /etc/envoy/ai-proxy.wasm + configuration: + "@type": "type.googleapis.com/google.protobuf.StringValue" + value: | # 插件配置 + { + "provider": { + "type": "openai", + "apiTokens": [ + "YOUR_API_TOKEN" + ], + "openaiCustomUrl": "172.17.0.1:8000/v1/chat/completions" + } + } + + + - name: envoy.filters.http.router + + clusters: + - name: httpbin + connect_timeout: 30s + type: LOGICAL_DNS + # Comment out the following line to test on v6 networks + dns_lookup_family: V4_ONLY + lb_policy: ROUND_ROBIN + load_assignment: + cluster_name: httpbin + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: httpbin + port_value: 80 + # - name: redis_cluster + # connect_timeout: 30s + # type: STRICT_DNS + # lb_policy: ROUND_ROBIN + # load_assignment: + # cluster_name: redis + # endpoints: + # - lb_endpoints: + # - endpoint: + # address: + # socket_address: + # address: 172.17.0.1 + # port_value: 6379 + - name: outbound|6379||redis_cluster + connect_timeout: 1s + type: strict_dns + lb_policy: ROUND_ROBIN + load_assignment: + cluster_name: outbound|6379||redis_cluster + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: 172.17.0.1 + port_value: 6379 + typed_extension_protocol_options: + envoy.filters.network.redis_proxy: + "@type": type.googleapis.com/envoy.extensions.filters.network.redis_proxy.v3.RedisProtocolOptions + # chroma + - name: outbound|8001||chroma.dns + connect_timeout: 30s + type: LOGICAL_DNS + dns_lookup_family: V4_ONLY + lb_policy: ROUND_ROBIN + load_assignment: + cluster_name: outbound|8001||chroma.dns + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: 172.17.0.1 # 本地 API 服务地址,这里是 docker0 + port_value: 8001 + # llm + - name: llm + connect_timeout: 30s + type: LOGICAL_DNS + dns_lookup_family: V4_ONLY + lb_policy: ROUND_ROBIN + load_assignment: + cluster_name: llm + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: 172.17.0.1 # 本地 API 服务地址,这里是 docker0 + port_value: 8000 + # dashvector + - name: outbound|443||dashvector.dns + connect_timeout: 30s + type: LOGICAL_DNS + dns_lookup_family: V4_ONLY + lb_policy: ROUND_ROBIN + load_assignment: + cluster_name: outbound|443||dashvector.dns + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: vrs-cn-0dw3vnaqs0002z.dashvector.cn-hangzhou.aliyuncs.com + port_value: 443 + transport_socket: + name: envoy.transport_sockets.tls + typed_config: + "@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext + "sni": "vrs-cn-0dw3vnaqs0002z.dashvector.cn-hangzhou.aliyuncs.com" + # dashscope + - name: outbound|443||dashscope.dns + connect_timeout: 30s + type: LOGICAL_DNS + dns_lookup_family: V4_ONLY + lb_policy: ROUND_ROBIN + load_assignment: + cluster_name: outbound|443||dashscope.dns + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: dashscope.aliyuncs.com + port_value: 443 + transport_socket: + name: envoy.transport_sockets.tls + typed_config: + "@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext + "sni": "dashscope.aliyuncs.com" \ No newline at end of file diff --git a/plugins/wasm-go/extensions/ai-cache/.gitignore b/plugins/wasm-go/extensions/ai-cache/.gitignore index 47db8eedba..8a34bf52ad 100644 --- a/plugins/wasm-go/extensions/ai-cache/.gitignore +++ b/plugins/wasm-go/extensions/ai-cache/.gitignore @@ -1,5 +1,5 @@ # File generated by hgctl. Modify as required. - +docker-compose-test/ * !/.gitignore diff --git a/plugins/wasm-go/extensions/ai-cache/cache/cache.go b/plugins/wasm-go/extensions/ai-cache/cache/cache.go new file mode 100644 index 0000000000..8f095aaa24 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/cache/cache.go @@ -0,0 +1,109 @@ +// TODO: 在这里写缓存的具体逻辑, 将textEmbeddingPrvider和vectorStoreProvider作为逻辑中的一个函数调用 +package cache + +import ( + "errors" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" +) + +type RedisConfig struct { + // @Title zh-CN redis 服务名称 + // @Description zh-CN 带服务类型的完整 FQDN 名称,例如 my-redis.dns、redis.my-ns.svc.cluster.local + RedisServiceName string `required:"true" yaml:"serviceName" json:"serviceName"` + // @Title zh-CN redis 服务端口 + // @Description zh-CN 默认值为6379 + RedisServicePort int `required:"false" yaml:"servicePort" json:"servicePort"` + // @Title zh-CN 用户名 + // @Description zh-CN 登陆 redis 的用户名,非必填 + RedisUsername string `required:"false" yaml:"username" json:"username"` + // @Title zh-CN 密码 + // @Description zh-CN 登陆 redis 的密码,非必填,可以只填密码 + RedisPassword string `required:"false" yaml:"password" json:"password"` + // @Title zh-CN 请求超时 + // @Description zh-CN 请求 redis 的超时时间,单位为毫秒。默认值是1000,即1秒 + RedisTimeout uint32 `required:"false" yaml:"timeout" json:"timeout"` +} + +func CreateProvider(cf RedisConfig, log wrapper.Log) (Provider, error) { + log.Warnf("redis config: %v", cf) + rp := redisProvider{ + config: cf, + client: wrapper.NewRedisClusterClient(wrapper.FQDNCluster{ + FQDN: cf.RedisServiceName, + Host: "redis", + Port: int64(cf.RedisServicePort)}), + // client: wrapper.NewRedisClusterClient(wrapper.DnsCluster{ + // ServiceName: cf.RedisServiceName, + // Port: int64(cf.RedisServicePort)}), + } + // FQDN := wrapper.FQDNCluster{ + // FQDN: cf.RedisServiceName, + // Host: "redis", + // Port: int64(cf.RedisServicePort)} + // log.Debugf("test:%s", FQDN.ClusterName()) + // log.Debugf("test:%d", cf.RedisServicePort) + // log.Debugf("test:%s", proxywasm.RedisInit(FQDN.ClusterName(), "", "", 100)) + err := rp.Init(cf.RedisUsername, cf.RedisPassword, cf.RedisTimeout) + return &rp, err +} + +func (c *RedisConfig) FromJson(json gjson.Result) { + c.RedisUsername = json.Get("username").String() + c.RedisPassword = json.Get("password").String() + c.RedisTimeout = uint32(json.Get("timeout").Int()) + c.RedisServiceName = json.Get("serviceName").String() + c.RedisServicePort = int(json.Get("servicePort").Int()) + if c.RedisServicePort == 0 { + c.RedisServicePort = 6379 + } +} + +func (c *RedisConfig) Validate() error { + if len(c.RedisServiceName) == 0 { + return errors.New("serviceName is required") + } + if c.RedisTimeout <= 0 { + return errors.New("timeout must be greater than 0") + } + if c.RedisServicePort <= 0 { + c.RedisServicePort = 6379 + } + if len(c.RedisUsername) == 0 { + // return errors.New("redis.username is required") + c.RedisUsername = "" + } + if len(c.RedisPassword) == 0 { + c.RedisPassword = "" + } + return nil +} + +type Provider interface { + GetProviderType() string + Init(username string, password string, timeout uint32) error + Get(key string, cb wrapper.RedisResponseCallback) + Set(key string, value string, cb wrapper.RedisResponseCallback) +} + +type redisProvider struct { + config RedisConfig + client wrapper.RedisClient +} + +func (rp *redisProvider) GetProviderType() string { + return "redis" +} + +func (rp *redisProvider) Init(username string, password string, timeout uint32) error { + return rp.client.Init(rp.config.RedisUsername, rp.config.RedisPassword, int64(rp.config.RedisTimeout)) +} + +func (rp *redisProvider) Get(key string, cb wrapper.RedisResponseCallback) { + rp.client.Get(key, cb) +} + +func (rp *redisProvider) Set(key string, value string, cb wrapper.RedisResponseCallback) { + rp.client.Set(key, value, cb) +} diff --git a/plugins/wasm-go/extensions/ai-cache/config/config.go b/plugins/wasm-go/extensions/ai-cache/config/config.go new file mode 100644 index 0000000000..d31f6e2b70 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/config/config.go @@ -0,0 +1,120 @@ +package config + +import ( + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/cache" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/embedding" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/vectorDatabase" + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" +) + +type KVExtractor struct { + // @Title zh-CN 从请求 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 + RequestBody string `required:"false" yaml:"requestBody" json:"requestBody"` + // @Title zh-CN 从响应 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 + ResponseBody string `required:"false" yaml:"responseBody" json:"responseBody"` +} + +type PluginConfig struct { + EmbeddingProviderConfig embedding.ProviderConfig `required:"true" yaml:"embeddingProvider" json:"embeddingProvider"` + VectorDatabaseProviderConfig vectorDatabase.ProviderConfig `required:"true" yaml:"vectorBaseProvider" json:"vectorBaseProvider"` + CacheKeyFrom KVExtractor `required:"true" yaml:"cacheKeyFrom" json:"cacheKeyFrom"` + CacheValueFrom KVExtractor `required:"true" yaml:"cacheValueFrom" json:"cacheValueFrom"` + CacheStreamValueFrom KVExtractor `required:"true" yaml:"cacheStreamValueFrom" json:"cacheStreamValueFrom"` + // @Title zh-CN 返回 HTTP 响应的模版 + // @Description zh-CN 用 %s 标记需要被 cache value 替换的部分 + ReturnResponseTemplate string `required:"true" yaml:"returnResponseTemplate" json:"returnResponseTemplate"` + // @Title zh-CN 返回流式 HTTP 响应的模版 + // @Description zh-CN 用 %s 标记需要被 cache value 替换的部分 + ReturnTestResponseTemplate string `required:"true" yaml:"returnTestResponseTemplate" json:"returnTestResponseTemplate"` + + CacheKeyPrefix string `required:"false" yaml:"cacheKeyPrefix" json:"cacheKeyPrefix"` + + ReturnStreamResponseTemplate string `required:"true" yaml:"returnStreamResponseTemplate" json:"returnStreamResponseTemplate"` + // @Title zh-CN 缓存的过期时间 + // @Description zh-CN 单位是秒,默认值为0,即永不过期 + CacheTTL int `required:"false" yaml:"cacheTTL" json:"cacheTTL"` + // @Title zh-CN Redis缓存Key的前缀 + // @Description zh-CN 默认值是"higress-ai-cache:" + + RedisConfig cache.RedisConfig `required:"true" yaml:"redisConfig" json:"redisConfig"` + // 现在只支持RedisClient作为cacheClient + redisProvider cache.Provider `yaml:"-"` + embeddingProvider embedding.Provider `yaml:"-"` + vectorDatabaseProvider vectorDatabase.Provider `yaml:"-"` +} + +func (c *PluginConfig) FromJson(json gjson.Result) { + c.EmbeddingProviderConfig.FromJson(json.Get("embeddingProvider")) + c.VectorDatabaseProviderConfig.FromJson(json.Get("vectorBaseProvider")) + c.RedisConfig.FromJson(json.Get("redis")) + if c.CacheKeyFrom.RequestBody == "" { + c.CacheKeyFrom.RequestBody = "messages.@reverse.0.content" + } + c.CacheKeyFrom.RequestBody = json.Get("cacheKeyFrom.requestBody").String() + if c.CacheKeyFrom.RequestBody == "" { + c.CacheKeyFrom.RequestBody = "messages.@reverse.0.content" + } + c.CacheValueFrom.ResponseBody = json.Get("cacheValueFrom.responseBody").String() + if c.CacheValueFrom.ResponseBody == "" { + c.CacheValueFrom.ResponseBody = "choices.0.message.content" + } + c.CacheStreamValueFrom.ResponseBody = json.Get("cacheStreamValueFrom.responseBody").String() + if c.CacheStreamValueFrom.ResponseBody == "" { + c.CacheStreamValueFrom.ResponseBody = "choices.0.delta.content" + } + c.ReturnResponseTemplate = json.Get("returnResponseTemplate").String() + if c.ReturnResponseTemplate == "" { + c.ReturnResponseTemplate = `{"id":"from-cache","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + } + c.ReturnStreamResponseTemplate = json.Get("returnStreamResponseTemplate").String() + if c.ReturnStreamResponseTemplate == "" { + c.ReturnStreamResponseTemplate = `data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + "\n\ndata:[DONE]\n\n" + } + c.ReturnTestResponseTemplate = json.Get("returnTestResponseTemplate").String() + if c.ReturnTestResponseTemplate == "" { + c.ReturnTestResponseTemplate = `{"id":"random-generate","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + } +} + +func (c *PluginConfig) Validate() error { + if err := c.RedisConfig.Validate(); err != nil { + return err + } + if err := c.EmbeddingProviderConfig.Validate(); err != nil { + return err + } + if err := c.VectorDatabaseProviderConfig.Validate(); err != nil { + return err + } + return nil +} + +func (c *PluginConfig) Complete(log wrapper.Log) error { + var err error + c.embeddingProvider, err = embedding.CreateProvider(c.EmbeddingProviderConfig) + if err != nil { + return err + } + c.vectorDatabaseProvider, err = vectorDatabase.CreateProvider(c.VectorDatabaseProviderConfig) + if err != nil { + return err + } + c.redisProvider, err = cache.CreateProvider(c.RedisConfig, log) + if err != nil { + return err + } + return nil +} + +func (c *PluginConfig) GetEmbeddingProvider() embedding.Provider { + return c.embeddingProvider +} + +func (c *PluginConfig) GetVectorDatabaseProvider() vectorDatabase.Provider { + return c.vectorDatabaseProvider +} + +func (c *PluginConfig) GetCacheProvider() cache.Provider { + return c.redisProvider +} diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go b/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go new file mode 100644 index 0000000000..c7be8b8a22 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go @@ -0,0 +1,200 @@ +package embedding + +import ( + "encoding/json" + "errors" + "net/http" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" +) + +const ( + dashScopeDomain = "dashscope.aliyuncs.com" + dashScopePort = 443 +) + +type dashScopeProviderInitializer struct { +} + +func (d *dashScopeProviderInitializer) ValidateConfig(config ProviderConfig) error { + if len(config.DashScopeKey) == 0 { + return errors.New("DashScopeKey is required") + } + if len(config.ServiceName) == 0 { + return errors.New("ServiceName is required") + } + return nil +} + +func (d *dashScopeProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + return &DSProvider{ + config: config, + client: wrapper.NewClusterClient(wrapper.DnsCluster{ + ServiceName: config.ServiceName, + Port: dashScopePort, + Domain: dashScopeDomain, + }), + }, nil +} + +func (d *DSProvider) GetProviderType() string { + return providerTypeDashScope +} + +type Embedding struct { + Embedding []float64 `json:"embedding"` + TextIndex int `json:"text_index"` +} + +type Input struct { + Texts []string `json:"texts"` +} + +type Params struct { + TextType string `json:"text_type"` +} + +type Response struct { + RequestID string `json:"request_id"` + Output Output `json:"output"` + Usage Usage `json:"usage"` +} + +type Output struct { + Embeddings []Embedding `json:"embeddings"` +} + +type Usage struct { + TotalTokens int `json:"total_tokens"` +} + +// EmbeddingRequest 定义请求的数据结构 +type EmbeddingRequest struct { + Model string `json:"model"` + Input Input `json:"input"` + Parameters Params `json:"parameters"` +} + +// Document 定义每个文档的结构 +type Document struct { + // ID string `json:"id"` + Vector []float64 `json:"vector"` + Fields map[string]string `json:"fields"` +} + +type DSProvider struct { + config ProviderConfig + client wrapper.HttpClient +} + +func (d *DSProvider) constructParameters(texts []string, log wrapper.Log) (string, [][2]string, []byte, error) { + const ( + endpoint = "/api/v1/services/embeddings/text-embedding/text-embedding" + modelName = "text-embedding-v1" + contentType = "application/json" + ) + + // 构造请求数据 + data := EmbeddingRequest{ + Model: modelName, + Input: Input{ + Texts: texts, + }, + Parameters: Params{ + TextType: "query", + }, + } + + // 序列化请求体并处理错误 + requestBody, err := json.Marshal(data) + if err != nil { + log.Errorf("Failed to marshal request data: %v", err) + return "", nil, nil, err + } + + // 检查 DashScopeKey 是否为空 + if d.config.DashScopeKey == "" { + err := errors.New("DashScopeKey is empty") + log.Errorf("Failed to construct headers: %v", err) + return "", nil, nil, err + } + + // 设置请求头 + headers := [][2]string{ + {"Authorization", "Bearer " + d.config.DashScopeKey}, + {"Content-Type", contentType}, + } + + return endpoint, headers, requestBody, err +} + +// Result 定义查询结果的结构 +type Result struct { + ID string `json:"id"` + Vector []float64 `json:"vector,omitempty"` // omitempty 使得如果 vector 是空,它将不会被序列化 + Fields map[string]interface{} `json:"fields"` + Score float64 `json:"score"` +} + +// 返回指针防止拷贝 Embedding +func (d *DSProvider) parseTextEmbedding(responseBody []byte) (*Response, error) { + var resp Response + err := json.Unmarshal(responseBody, &resp) + if err != nil { + return nil, err + } + return &resp, nil +} + +func (d *DSProvider) GetEmbedding( + queryString string, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(emb []float64, statusCode int, responseHeaders http.Header, responseBody []byte)) error { + // 构建参数并处理错误 + Emb_url, Emb_headers, Emb_requestBody, err := d.constructParameters([]string{queryString}, log) + if err != nil { + log.Errorf("Failed to construct parameters: %v", err) + return err + } + + var resp *Response + // 发起 POST 请求 + d.client.Post(Emb_url, Emb_headers, Emb_requestBody, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + if statusCode != http.StatusOK { + log.Errorf("Failed to fetch embeddings, statusCode: %d, responseBody: %s", statusCode, string(responseBody)) + err = errors.New("failed to get embedding") + return + } + + // 日志记录响应 + log.Infof("Get embedding response: %d, %s", statusCode, responseBody) + + // 解析响应 + resp, err = d.parseTextEmbedding(responseBody) + if err != nil { + log.Errorf("Failed to parse response: %v", err) + return + } + + // 检查是否存在嵌入结果 + if len(resp.Output.Embeddings) == 0 { + log.Errorf("No embedding found in response") + err = errors.New("no embedding found in response") + return + } + + // 回调函数 + callback(resp.Output.Embeddings[0].Embedding, statusCode, responseHeaders, responseBody) + + // proxywasm.ResumeHttpRequest() // 后续还有其他的 http 请求,所以先不能恢复 + }, d.config.DashScopeTimeout) + // if err != nil { + // log.Errorf("Failed to call client.Post: %v", err) + // return nil, err + // } + // // 这里因为 d.client.Post 是异步的,所以会出现 resp 为 nil 的情况,需要等待回调函数完成 + // return resp.Output.Embeddings[0].Embedding, nil + return nil +} diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/provider.go b/plugins/wasm-go/extensions/ai-cache/embedding/provider.go new file mode 100644 index 0000000000..46cffb4ff0 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/embedding/provider.go @@ -0,0 +1,87 @@ +package embedding + +import ( + "errors" + "net/http" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" +) + +const ( + providerTypeDashScope = "dashscope" + CacheKeyContextKey = "cacheKey" + CacheContentContextKey = "cacheContent" + PartialMessageContextKey = "partialMessage" + ToolCallsContextKey = "toolCalls" + StreamContextKey = "stream" + CacheKeyPrefix = "higressAiCache" + DefaultCacheKeyPrefix = "higressAiCache" + QueryEmbeddingKey = "queryEmbedding" +) + +type providerInitializer interface { + ValidateConfig(ProviderConfig) error + CreateProvider(ProviderConfig) (Provider, error) +} + +var ( + providerInitializers = map[string]providerInitializer{ + providerTypeDashScope: &dashScopeProviderInitializer{}, + } +) + +type ProviderConfig struct { + // @Title zh-CN 文本特征提取服务提供者类型 + // @Description zh-CN 文本特征提取服务提供者类型,例如 DashScope + typ string `json:"TextEmbeddingProviderType"` + // @Title zh-CN DashScope 阿里云大模型服务名 + // @Description zh-CN 调用阿里云的大模型服务 + ServiceName string `require:"true" yaml:"DashScopeServiceName" json:"DashScopeServiceName"` + Client wrapper.HttpClient `yaml:"-"` + DashScopeKey string `require:"true" yaml:"DashScopeKey" json:"DashScopeKey"` + DashScopeTimeout uint32 `require:"false" yaml:"DashScopeTimeout" json:"DashScopeTimeout"` + QueryEmbeddingKey string `require:"false" yaml:"QueryEmbeddingKey" json:"QueryEmbeddingKey"` +} + +func (c *ProviderConfig) FromJson(json gjson.Result) { + c.typ = json.Get("TextEmbeddingProviderType").String() + c.ServiceName = json.Get("DashScopeServiceName").String() + c.DashScopeKey = json.Get("DashScopeKey").String() + c.DashScopeTimeout = uint32(json.Get("DashScopeTimeout").Int()) + if c.DashScopeTimeout == 0 { + c.DashScopeTimeout = 1000 + } + c.QueryEmbeddingKey = json.Get("QueryEmbeddingKey").String() +} + +func (c *ProviderConfig) Validate() error { + if len(c.DashScopeKey) == 0 { + return errors.New("DashScopeKey is required") + } + if len(c.ServiceName) == 0 { + return errors.New("DashScopeServiceName is required") + } + return nil +} + +func (c *ProviderConfig) GetProviderType() string { + return c.typ +} + +func CreateProvider(pc ProviderConfig) (Provider, error) { + initializer, has := providerInitializers[pc.typ] + if !has { + return nil, errors.New("unknown provider type: " + pc.typ) + } + return initializer.CreateProvider(pc) +} + +type Provider interface { + GetProviderType() string + GetEmbedding( + queryString string, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(emb []float64, statusCode int, responseHeaders http.Header, responseBody []byte)) error +} diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/weaviate.go b/plugins/wasm-go/extensions/ai-cache/embedding/weaviate.go new file mode 100644 index 0000000000..b26d9cea8d --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/embedding/weaviate.go @@ -0,0 +1,27 @@ +package embedding + +// import ( +// "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" +// ) + +// const ( +// weaviateURL = "172.17.0.1:8081" +// ) + +// type weaviateProviderInitializer struct { +// } + +// func (d *weaviateProviderInitializer) ValidateConfig(config ProviderConfig) error { +// return nil +// } + +// func (d *weaviateProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { +// return &DSProvider{ +// config: config, +// client: wrapper.NewClusterClient(wrapper.DnsCluster{ +// ServiceName: config.ServiceName, +// Port: dashScopePort, +// Domain: dashScopeDomain, +// }), +// }, nil +// } diff --git a/plugins/wasm-go/extensions/ai-cache/go.mod b/plugins/wasm-go/extensions/ai-cache/go.mod index c9630cfb8a..bf2a5948dd 100644 --- a/plugins/wasm-go/extensions/ai-cache/go.mod +++ b/plugins/wasm-go/extensions/ai-cache/go.mod @@ -9,15 +9,16 @@ replace github.com/alibaba/higress/plugins/wasm-go => ../.. require ( github.com/alibaba/higress/plugins/wasm-go v1.3.6-0.20240528060522-53bccf89f441 github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f - github.com/tidwall/gjson v1.14.3 + github.com/tidwall/gjson v1.17.3 github.com/tidwall/resp v0.1.1 - github.com/tidwall/sjson v1.2.5 +// github.com/weaviate/weaviate-go-client/v4 v4.15.1 ) require ( - github.com/google/uuid v1.3.0 // indirect + github.com/google/uuid v1.6.0 // indirect github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect github.com/magefile/mage v1.14.0 // indirect + github.com/stretchr/testify v1.9.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect ) diff --git a/plugins/wasm-go/extensions/ai-cache/go.sum b/plugins/wasm-go/extensions/ai-cache/go.sum index 8246b4de5e..7ada0c8b70 100644 --- a/plugins/wasm-go/extensions/ai-cache/go.sum +++ b/plugins/wasm-go/extensions/ai-cache/go.sum @@ -1,24 +1,21 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= -github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew= -github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc h1:t2AT8zb6N/59Y78lyRWedVoVWHNRSCBh0oWCC+bluTQ= -github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg= github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo= github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw= -github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94= +github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE= github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0= -github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= -github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/plugins/wasm-go/extensions/ai-cache/main.go b/plugins/wasm-go/extensions/ai-cache/main.go index dc5df1a6a8..201cf91866 100644 --- a/plugins/wasm-go/extensions/ai-cache/main.go +++ b/plugins/wasm-go/extensions/ai-cache/main.go @@ -1,32 +1,59 @@ -// File generated by hgctl. Modify as required. -// See: https://higress.io/zh-cn/docs/user/wasm-go#2-%E7%BC%96%E5%86%99-maingo-%E6%96%87%E4%BB%B6 - +// 这个文件中主要将OnHttpRequestHeaders、OnHttpRequestBody、OnHttpResponseHeaders、OnHttpResponseBody这四个函数实现 +// 其中的缓存思路调用cache.go中的逻辑,然后cache.go中的逻辑会调用textEmbeddingProvider和vectorStoreProvider中的逻辑(实例) package main import ( - "errors" - "fmt" "strings" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/config" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/embedding" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/tidwall/gjson" - "github.com/tidwall/resp" + // "github.com/weaviate/weaviate-go-client/v4/weaviate" ) const ( + pluginName = "ai-cache" CacheKeyContextKey = "cacheKey" CacheContentContextKey = "cacheContent" PartialMessageContextKey = "partialMessage" ToolCallsContextKey = "toolCalls" StreamContextKey = "stream" - DefaultCacheKeyPrefix = "higress-ai-cache:" + CacheKeyPrefix = "higressAiCache" + DefaultCacheKeyPrefix = "higressAiCache" + QueryEmbeddingKey = "queryEmbedding" ) +// // Create the client +// func CreateClient() { +// cfg := weaviate.Config{ +// Host: "172.17.0.1:8081", +// Scheme: "http", +// Headers: nil, +// } + +// client, err := weaviate.NewClient(cfg) +// if err != nil { +// fmt.Println(err) +// } + +// // Check the connection +// live, err := client.Misc().LiveChecker().Do(context.Background()) +// if err != nil { +// panic(err) +// } +// fmt.Printf("%v", live) + +// } + func main() { + // CreateClient() + wrapper.SetCtx( - "ai-cache", + pluginName, wrapper.ParseConfigBy(parseConfig), wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders), wrapper.ProcessRequestBodyBy(onHttpRequestBody), @@ -35,143 +62,42 @@ func main() { ) } -// @Name ai-cache -// @Category protocol -// @Phase AUTHN -// @Priority 10 -// @Title zh-CN AI Cache -// @Description zh-CN 大模型结果缓存 -// @IconUrl -// @Version 0.1.0 -// -// @Contact.name johnlanni -// @Contact.url -// @Contact.email -// -// @Example -// redis: -// serviceName: my-redis.dns -// timeout: 2000 -// cacheKeyFrom: -// requestBody: "messages.@reverse.0.content" -// cacheValueFrom: -// responseBody: "choices.0.message.content" -// cacheStreamValueFrom: -// responseBody: "choices.0.delta.content" -// returnResponseTemplate: | -// {"id":"from-cache","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}} -// returnStreamResponseTemplate: | -// data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}} -// -// data:[DONE] -// -// @End - -type RedisInfo struct { - // @Title zh-CN redis 服务名称 - // @Description zh-CN 带服务类型的完整 FQDN 名称,例如 my-redis.dns、redis.my-ns.svc.cluster.local - ServiceName string `required:"true" yaml:"serviceName" json:"serviceName"` - // @Title zh-CN redis 服务端口 - // @Description zh-CN 默认值为6379 - ServicePort int `required:"false" yaml:"servicePort" json:"servicePort"` - // @Title zh-CN 用户名 - // @Description zh-CN 登陆 redis 的用户名,非必填 - Username string `required:"false" yaml:"username" json:"username"` - // @Title zh-CN 密码 - // @Description zh-CN 登陆 redis 的密码,非必填,可以只填密码 - Password string `required:"false" yaml:"password" json:"password"` - // @Title zh-CN 请求超时 - // @Description zh-CN 请求 redis 的超时时间,单位为毫秒。默认值是1000,即1秒 - Timeout int `required:"false" yaml:"timeout" json:"timeout"` -} - -type KVExtractor struct { - // @Title zh-CN 从请求 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 - RequestBody string `required:"false" yaml:"requestBody" json:"requestBody"` - // @Title zh-CN 从响应 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 - ResponseBody string `required:"false" yaml:"responseBody" json:"responseBody"` +func parseConfig(json gjson.Result, config *config.PluginConfig, log wrapper.Log) error { + // config.EmbeddingProviderConfig.FromJson(json.Get("embeddingProvider")) + // config.VectorDatabaseProviderConfig.FromJson(json.Get("vectorBaseProvider")) + // config.RedisConfig.FromJson(json.Get("redis")) + config.FromJson(json) + if err := config.Validate(); err != nil { + return err + } + // 注意,在 parseConfig 阶段初始化 client 会出错,比如 docker compose 中的 redis 就无法使用 + if err := config.Complete(log); err != nil { + log.Errorf("complete config failed:%v", err) + return err + } + return nil } -type PluginConfig struct { - // @Title zh-CN Redis 地址信息 - // @Description zh-CN 用于存储缓存结果的 Redis 地址 - RedisInfo RedisInfo `required:"true" yaml:"redis" json:"redis"` - // @Title zh-CN 缓存 key 的来源 - // @Description zh-CN 往 redis 里存时,使用的 key 的提取方式 - CacheKeyFrom KVExtractor `required:"true" yaml:"cacheKeyFrom" json:"cacheKeyFrom"` - // @Title zh-CN 缓存 value 的来源 - // @Description zh-CN 往 redis 里存时,使用的 value 的提取方式 - CacheValueFrom KVExtractor `required:"true" yaml:"cacheValueFrom" json:"cacheValueFrom"` - // @Title zh-CN 流式响应下,缓存 value 的来源 - // @Description zh-CN 往 redis 里存时,使用的 value 的提取方式 - CacheStreamValueFrom KVExtractor `required:"true" yaml:"cacheStreamValueFrom" json:"cacheStreamValueFrom"` - // @Title zh-CN 返回 HTTP 响应的模版 - // @Description zh-CN 用 %s 标记需要被 cache value 替换的部分 - ReturnResponseTemplate string `required:"true" yaml:"returnResponseTemplate" json:"returnResponseTemplate"` - // @Title zh-CN 返回流式 HTTP 响应的模版 - // @Description zh-CN 用 %s 标记需要被 cache value 替换的部分 - ReturnStreamResponseTemplate string `required:"true" yaml:"returnStreamResponseTemplate" json:"returnStreamResponseTemplate"` - // @Title zh-CN 缓存的过期时间 - // @Description zh-CN 单位是秒,默认值为0,即永不过期 - CacheTTL int `required:"false" yaml:"cacheTTL" json:"cacheTTL"` - // @Title zh-CN Redis缓存Key的前缀 - // @Description zh-CN 默认值是"higress-ai-cache:" - CacheKeyPrefix string `required:"false" yaml:"cacheKeyPrefix" json:"cacheKeyPrefix"` - redisClient wrapper.RedisClient `yaml:"-" json:"-"` +func TrimQuote(source string) string { + return strings.Trim(source, `"`) } -func parseConfig(json gjson.Result, c *PluginConfig, log wrapper.Log) error { - c.RedisInfo.ServiceName = json.Get("redis.serviceName").String() - if c.RedisInfo.ServiceName == "" { - return errors.New("redis service name must not by empty") - } - c.RedisInfo.ServicePort = int(json.Get("redis.servicePort").Int()) - if c.RedisInfo.ServicePort == 0 { - if strings.HasSuffix(c.RedisInfo.ServiceName, ".static") { - // use default logic port which is 80 for static service - c.RedisInfo.ServicePort = 80 - } else { - c.RedisInfo.ServicePort = 6379 - } - } - c.RedisInfo.Username = json.Get("redis.username").String() - c.RedisInfo.Password = json.Get("redis.password").String() - c.RedisInfo.Timeout = int(json.Get("redis.timeout").Int()) - if c.RedisInfo.Timeout == 0 { - c.RedisInfo.Timeout = 1000 - } - c.CacheKeyFrom.RequestBody = json.Get("cacheKeyFrom.requestBody").String() - if c.CacheKeyFrom.RequestBody == "" { - c.CacheKeyFrom.RequestBody = "messages.@reverse.0.content" - } - c.CacheValueFrom.ResponseBody = json.Get("cacheValueFrom.responseBody").String() - if c.CacheValueFrom.ResponseBody == "" { - c.CacheValueFrom.ResponseBody = "choices.0.message.content" - } - c.CacheStreamValueFrom.ResponseBody = json.Get("cacheStreamValueFrom.responseBody").String() - if c.CacheStreamValueFrom.ResponseBody == "" { - c.CacheStreamValueFrom.ResponseBody = "choices.0.delta.content" - } - c.ReturnResponseTemplate = json.Get("returnResponseTemplate").String() - if c.ReturnResponseTemplate == "" { - c.ReturnResponseTemplate = `{"id":"from-cache","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` - } - c.ReturnStreamResponseTemplate = json.Get("returnStreamResponseTemplate").String() - if c.ReturnStreamResponseTemplate == "" { - c.ReturnStreamResponseTemplate = `data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + "\n\ndata:[DONE]\n\n" - } - c.CacheKeyPrefix = json.Get("cacheKeyPrefix").String() - if c.CacheKeyPrefix == "" { - c.CacheKeyPrefix = DefaultCacheKeyPrefix - } - c.redisClient = wrapper.NewRedisClusterClient(wrapper.FQDNCluster{ - FQDN: c.RedisInfo.ServiceName, - Port: int64(c.RedisInfo.ServicePort), - }) - return c.redisClient.Init(c.RedisInfo.Username, c.RedisInfo.Password, int64(c.RedisInfo.Timeout)) -} +func onHttpRequestHeaders(ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log) types.Action { + // 这段代码是为了测试,在 parseConfig 阶段初始化 client 会出错,比如 docker compose 中的 redis 就无法使用 + // 但是在 onHttpRequestHeaders 中可以连接到 redis、 + // 修复需要修改 envoy + // ---------------------------------------------------------------------------- + // if err := config.Complete(log); err != nil { + // log.Errorf("complete config failed:%v", err) + // } + // activeCacheProvider := config.GetCacheProvider() + // if err := activeCacheProvider.Init("", "", 2000); err != nil { + // log.Errorf("init redis failed:%v", err) + // } + // activeCacheProvider.Set("test", "test", func(response resp.Value) {}) + // log.Warnf("redis init success") + // ---------------------------------------------------------------------------- -func onHttpRequestHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action { contentType, _ := proxywasm.GetHttpRequestHeader("content-type") // The request does not have a body. if contentType == "" { @@ -188,11 +114,8 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrap return types.HeaderStopIteration } -func TrimQuote(source string) string { - return strings.Trim(source, `"`) -} +func onHttpRequestBody(ctx wrapper.HttpContext, config config.PluginConfig, body []byte, log wrapper.Log) types.Action { -func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte, log wrapper.Log) types.Action { bodyJson := gjson.ParseBytes(body) // TODO: It may be necessary to support stream mode determination for different LLM providers. stream := false @@ -202,39 +125,24 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte } else if ctx.GetContext(StreamContextKey) != nil { stream = true } - key := TrimQuote(bodyJson.Get(config.CacheKeyFrom.RequestBody).Raw) - if key == "" { - log.Debug("parse key from request body failed") - return types.ActionContinue - } + // key := TrimQuote(bodyJson.Get(config.CacheKeyFrom.RequestBody).Raw) + key := bodyJson.Get(config.CacheKeyFrom.RequestBody).String() ctx.SetContext(CacheKeyContextKey, key) - err := config.redisClient.Get(config.CacheKeyPrefix+key, func(response resp.Value) { - if err := response.Error(); err != nil { - log.Errorf("redis get key:%s failed, err:%v", key, err) - proxywasm.ResumeHttpRequest() - return - } - if response.IsNull() { - log.Debugf("cache miss, key:%s", key) - proxywasm.ResumeHttpRequest() - return - } - log.Debugf("cache hit, key:%s", key) - ctx.SetContext(CacheKeyContextKey, nil) - if !stream { - proxywasm.SendHttpResponseWithDetail(200, "ai-cache.hit", [][2]string{{"content-type", "application/json; charset=utf-8"}}, []byte(fmt.Sprintf(config.ReturnResponseTemplate, response.String())), -1) - } else { - proxywasm.SendHttpResponseWithDetail(200, "ai-cache.hit", [][2]string{{"content-type", "text/event-stream; charset=utf-8"}}, []byte(fmt.Sprintf(config.ReturnStreamResponseTemplate, response.String())), -1) - } - }) - if err != nil { - log.Error("redis access failed") + log.Debugf("[onHttpRequestBody] key:%s", key) + if key == "" { + log.Debug("[onHttpRquestBody] parse key from request body failed") return types.ActionContinue } + + queryString := config.CacheKeyPrefix + key + + util.RedisSearchHandler(queryString, ctx, config, log, stream, true) + + // 需要等待异步回调完成,返回 Pause 状态,可以被 ResumeHttpRequest 恢复 return types.ActionPause } -func processSSEMessage(ctx wrapper.HttpContext, config PluginConfig, sseMessage string, log wrapper.Log) string { +func processSSEMessage(ctx wrapper.HttpContext, config config.PluginConfig, sseMessage string, log wrapper.Log) string { subMessages := strings.Split(sseMessage, "\n") var message string for _, msg := range subMessages { @@ -244,7 +152,7 @@ func processSSEMessage(ctx wrapper.HttpContext, config PluginConfig, sseMessage } } if len(message) < 6 { - log.Errorf("invalid message:%s", message) + log.Warnf("invalid message:%s", message) return "" } // skip the prefix "data:" @@ -265,11 +173,11 @@ func processSSEMessage(ctx wrapper.HttpContext, config PluginConfig, sseMessage ctx.SetContext(ToolCallsContextKey, struct{}{}) return "" } - log.Debugf("unknown message:%s", bodyJson) + log.Warnf("unknown message:%s", bodyJson) return "" } -func onHttpResponseHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action { +func onHttpResponseHeaders(ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log) types.Action { contentType, _ := proxywasm.GetHttpResponseHeader("content-type") if strings.Contains(contentType, "text/event-stream") { ctx.SetContext(StreamContextKey, struct{}{}) @@ -277,12 +185,15 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config PluginConfig, log wra return types.ActionContinue } -func onHttpResponseBody(ctx wrapper.HttpContext, config PluginConfig, chunk []byte, isLastChunk bool, log wrapper.Log) []byte { +func onHttpResponseBody(ctx wrapper.HttpContext, config config.PluginConfig, chunk []byte, isLastChunk bool, log wrapper.Log) []byte { + // log.Infof("I am here") + log.Debugf("[onHttpResponseBody] i am here") if ctx.GetContext(ToolCallsContextKey) != nil { // we should not cache tool call result return chunk } keyI := ctx.GetContext(CacheKeyContextKey) + // log.Infof("I am here 2: %v", keyI) if keyI == nil { return chunk } @@ -340,6 +251,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config PluginConfig, chunk []by return chunk } } else { + log.Infof("[onHttpResponseBody] stream mode") if len(chunk) > 0 { var lastMessage []byte partialMessageI := ctx.GetContext(PartialMessageContextKey) @@ -349,7 +261,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config PluginConfig, chunk []by lastMessage = chunk } if !strings.HasSuffix(string(lastMessage), "\n\n") { - log.Warnf("invalid lastMessage:%s", lastMessage) + log.Warnf("[onHttpResponseBody] invalid lastMessage:%s", lastMessage) return chunk } // remove the last \n\n @@ -358,14 +270,17 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config PluginConfig, chunk []by } else { tempContentI := ctx.GetContext(CacheContentContextKey) if tempContentI == nil { + log.Warnf("[onHttpResponseBody] no content in tempContentI") return chunk } value = tempContentI.(string) } } - config.redisClient.Set(config.CacheKeyPrefix+key, value, nil) - if config.CacheTTL != 0 { - config.redisClient.Expire(config.CacheKeyPrefix+key, config.CacheTTL, nil) - } + log.Infof("[onHttpResponseBody] Setting cache to redis, key:%s, value:%s", key, value) + config.GetCacheProvider().Set(embedding.CacheKeyPrefix+key, value, nil) + // TODO: 要不要加个Expire方法 + // if config.RedisConfig.RedisTimeout != 0 { + // config.GetCacheProvider().Expire(config.CacheKeyPrefix+key, config.RedisConfig.RedisTimeout, nil) + // } return chunk } diff --git a/plugins/wasm-go/extensions/ai-cache/util/cachelogic.go b/plugins/wasm-go/extensions/ai-cache/util/cachelogic.go new file mode 100644 index 0000000000..1135bbc07a --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/util/cachelogic.go @@ -0,0 +1,325 @@ +// TODO: 在这里写缓存的具体逻辑, 将textEmbeddingPrvider和vectorStoreProvider作为逻辑中的一个函数调用 +package util + +import ( + "fmt" + "net/http" + + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/config" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/embedding" + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/tidwall/resp" +) + +// ===================== 以下是主要逻辑 ===================== +// 主handler函数,根据key从redis中获取value ,如果不命中,则首先调用文本向量化接口向量化query,然后调用向量搜索接口搜索最相似的出现过的key,最后再次调用redis获取结果 +// 可以把所有handler单独提取为文件,这里为了方便读者复制就和主逻辑放在一个文件中了 +// +// 1. query 进来和 redis 中存的 key 匹配 (redisSearchHandler) ,若完全一致则直接返回 (handleCacheHit) +// 2. 否则请求 text_embdding 接口将 query 转换为 query_embedding (fetchAndProcessEmbeddings) +// 3. 用 query_embedding 和向量数据库中的向量做 ANN search,返回最接近的 key ,并用阈值过滤 (performQueryAndRespond) +// 4. 若返回结果为空或大于阈值,舍去,本轮 cache 未命中, 最后将 query_embedding 存入向量数据库 (uploadQueryEmbedding) +// 5. 若小于阈值,则再次调用 redis对 most similar key 做匹配。 (redisSearchHandler) +// 7. 在 response 阶段请求 redis 新增key/LLM返回结果 + +func RedisSearchHandler(key string, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, stream bool, ifUseEmbedding bool) { + activeCacheProvider := config.GetCacheProvider() + log.Debugf("activeCacheProvider:%v", activeCacheProvider) + activeCacheProvider.Get(embedding.CacheKeyPrefix+key, func(response resp.Value) { + if err := response.Error(); err == nil && !response.IsNull() { + log.Warnf("cache hit, key:%s", key) + HandleCacheHit(key, response, stream, ctx, config, log) + } else { + log.Warnf("cache miss, key:%s", key) + if ifUseEmbedding { + HandleCacheMiss(key, err, response, ctx, config, log, key, stream) + } else { + proxywasm.ResumeHttpRequest() + return + } + } + }) +} + +func HandleCacheHit(key string, response resp.Value, stream bool, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log) { + ctx.SetContext(embedding.CacheKeyContextKey, nil) + if !stream { + proxywasm.SendHttpResponse(200, [][2]string{{"content-type", "application/json; charset=utf-8"}}, []byte(fmt.Sprintf(config.ReturnResponseTemplate, "[Test, this is cache]"+response.String())), -1) + } else { + proxywasm.SendHttpResponse(200, [][2]string{{"content-type", "text/event-stream; charset=utf-8"}}, []byte(fmt.Sprintf(config.ReturnStreamResponseTemplate, "[Test, this is cache]"+response.String())), -1) + } +} + +func HandleCacheMiss(key string, err error, response resp.Value, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, queryString string, stream bool) { + if err != nil { + log.Warnf("redis get key:%s failed, err:%v", key, err) + } + if response.IsNull() { + log.Warnf("cache miss, key:%s", key) + } + FetchAndProcessEmbeddings(key, ctx, config, log, queryString, stream) +} + +func FetchAndProcessEmbeddings(key string, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, queryString string, stream bool) { + activeEmbeddingProvider := config.GetEmbeddingProvider() + activeEmbeddingProvider.GetEmbedding(queryString, ctx, log, + func(emb []float64, statusCode int, responseHeaders http.Header, responseBody []byte) { + if statusCode != 200 { + log.Errorf("Failed to fetch embeddings, statusCode: %d, responseBody: %s", statusCode, string(responseBody)) + } else { + log.Debugf("Successfully fetched embeddings for key: %s", key) + QueryVectorDB(key, emb, ctx, config, log, stream) + } + }) +} + +func QueryVectorDB(key string, text_embedding []float64, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, stream bool) { + log.Debugf("QueryVectorDB key: %s", key) + activeVectorDatabaseProvider := config.GetVectorDatabaseProvider() + log.Debugf("activeVectorDatabaseProvider: %+v", activeVectorDatabaseProvider) + activeVectorDatabaseProvider.QueryEmbedding(text_embedding, ctx, log, + func(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log) { + resp, err := activeVectorDatabaseProvider.ParseQueryResponse(responseBody, ctx, log) + if err != nil { + log.Errorf("Failed to query vector database, err: %v", err) + proxywasm.ResumeHttpRequest() + return + } + + if len(resp.MostSimilarData) == 0 { + log.Warnf("Failed to query vector database, no most similar key found") + activeVectorDatabaseProvider.UploadEmbedding(text_embedding, key, ctx, log, + func(ctx wrapper.HttpContext, log wrapper.Log) { + proxywasm.ResumeHttpRequest() + }) + return + } + + log.Infof("most similar key: %s", resp.MostSimilarData) + if resp.Score < activeVectorDatabaseProvider.GetThreshold() { + log.Infof("accept most similar key: %s, score: %f", resp.MostSimilarData, resp.Score) + // ctx.SetContext(embedding.CacheKeyContextKey, nil) + RedisSearchHandler(resp.MostSimilarData, ctx, config, log, stream, false) + } else { + log.Infof("the most similar key's score is too high, key: %s, score: %f", resp.MostSimilarData, resp.Score) + activeVectorDatabaseProvider.UploadEmbedding(text_embedding, key, ctx, log, + func(ctx wrapper.HttpContext, log wrapper.Log) { + proxywasm.ResumeHttpRequest() + }) + return + } + }, + ) + + // resp, err := activeVectorDatabaseProvider.QueryEmbedding(text_embedding, ctx, log) + // if err != nil { + // log.Errorf("Failed to query vector database, err: %v", err) + // activeVectorDatabaseProvider.UploadEmbedding(text_embedding, key, ctx, log) + // proxywasm.ResumeHttpRequest() + // return + // } + + // log.Infof("most similar key: %s", resp.MostSimilarData) + // if resp.Score < activeVectorDatabaseProvider.GetThreshold() { + // log.Infof("accept most similar key: %s, score: %f", resp.MostSimilarData, resp.Score) + // // ctx.SetContext(embedding.CacheKeyContextKey, nil) + // RedisSearchHandler(resp.MostSimilarData, ctx, config, log, stream, false) + // } else { + // log.Infof("the most similar key's score is too high, key: %s, score: %f", resp.MostSimilarData, resp.Score) + // activeVectorDatabaseProvider.UploadEmbedding(text_embedding, key, ctx, log) + // proxywasm.ResumeHttpRequest() + // return + // } + + // activeVectorDatabaseProvider.QueryEmbedding(text_embedding, ctx, log, + // func(query_resp vectorDatabase.QueryResponse, ctx wrapper.HttpContext, log wrapper.Log) { + // if len(query_resp.Output) < 1 { // 向量库不存在查询向量 + // log.Warnf("query response is empty") + // activeVectorDatabaseProvider.UploadEmbedding(text_embedding, key, ctx, log, + // func(ctx wrapper.HttpContext, log wrapper.Log) { + // proxywasm.ResumeHttpRequest() + // }) + // return + // } + // mostSimilarKey := query_resp.Output[0].Fields["query"].(string) + // log.Infof("most similar key:%s", mostSimilarKey) + // mostSimilarScore := query_resp.Output[0].Score + // if mostSimilarScore < 2000 { // 向量库存在满足相似度的向量 + // log.Infof("accept most similar key:%s, score:%f", mostSimilarKey, mostSimilarScore) + // // ctx.SetContext(embedding.CacheKeyContextKey, nil) + // RedisSearchHandler(mostSimilarKey, ctx, config, log, stream, false) + // } else { // 向量库不存在满足相似度的向量 + // log.Infof("the most similar key's score is too high, key:%s, score:%f", mostSimilarKey, mostSimilarScore) + // activeVectorDatabaseProvider.UploadEmbedding(text_embedding, key, ctx, log, + // func(ctx wrapper.HttpContext, log wrapper.Log) { + // proxywasm.ResumeHttpRequest() + // }) + // proxywasm.ResumeHttpRequest() + // return + // } + // }, + // ) + // activeVectorDatabaseProvider.QueryEmbedding(text_embedding, ctx, log, + // func(query_resp vectorDatabase.QueryResponse, ctx wrapper.HttpContext, log wrapper.Log) { + // if len(query_resp.Output) < 1 { + // log.Warnf("query response is empty") + // // UploadQueryEmbedding(ctx, config, log, key, text_embedding) + // activeVectorDatabaseProvider.UploadEmbedding(text_embedding, key, ctx, log, + // func(ctx wrapper.HttpContext, log wrapper.Log) { + // // proxywasm.ResumeHttpRequest() + // log.Debugf("I am in the 117 line") + // }) + // return + // } + // most_similar_key := query_resp.Output[0].Fields["query"].(string) + // log.Infof("most similar key:%s", most_similar_key) + // most_similar_score := query_resp.Output[0].Score + // if most_similar_score < 0.1 { + // // ctx.SetContext(CacheKeyContextKey, nil) + // // RedisSearchHandler(most_similar_key, ctx, config, log, stream, false) + // } else { + // log.Infof("the most similar key's score is too high, key:%s, score:%f", most_similar_key, most_similar_score) + // // UploadQueryEmbedding(ctx, config, log, key, text_embedding) + // activeVectorDatabaseProvider.UploadEmbedding(text_embedding, key, ctx, log, + // func(ctx wrapper.HttpContext, log wrapper.Log) { + // proxywasm.ResumeHttpRequest() + // }) + // proxywasm.ResumeHttpRequest() + // return + // } + // }) + // ctx.SetContext(embedding.CacheKeyContextKey, text_embedding) + // ctx.SetContext(embedding.QueryEmbeddingKey, text_embedding) + // ctx.SetContext(embedding.CacheKeyContextKey, key) + // PerformQueryAndRespond(key, text_embedding, ctx, config, log, stream) +} + +// // 简单处理缓存命中的情况, 从redis中获取到value后,直接返回 +// func HandleCacheHit(key string, response resp.Value, stream bool, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log) { +// log.Warnf("cache hit, key:%s", key) +// ctx.SetContext(config.CacheKeyContextKey, nil) +// if !stream { +// proxywasm.SendHttpResponse(200, [][2]string{{"content-type", "application/json; charset=utf-8"}}, []byte(fmt.Sprintf(config.ReturnResponseTemplate, response.String())), -1) +// } else { +// proxywasm.SendHttpResponse(200, [][2]string{{"content-type", "text/event-stream; charset=utf-8"}}, []byte(fmt.Sprintf(config.ReturnStreamResponseTemplate, response.String())), -1) +// } +// } + +// // 处理缓存未命中的情况,调用fetchAndProcessEmbeddings函数向量化query +// func HandleCacheMiss(key string, err error, response resp.Value, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, queryString string, stream bool) { +// if err != nil { +// log.Warnf("redis get key:%s failed, err:%v", key, err) +// } +// if response.IsNull() { +// log.Warnf("cache miss, key:%s", key) +// } +// FetchAndProcessEmbeddings(key, ctx, config, log, queryString, stream) +// } + +// // 调用文本向量化接口向量化query, 向量化成功后调用processFetchedEmbeddings函数处理向量化结果 +// func FetchAndProcessEmbeddings(key string, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, queryString string, stream bool) { +// Emb_url, Emb_requestBody, Emb_headers := ConstructTextEmbeddingParameters(&config, log, []string{queryString}) +// config.DashVectorInfo.DashScopeClient.Post( +// Emb_url, +// Emb_headers, +// Emb_requestBody, +// func(statusCode int, responseHeaders http.Header, responseBody []byte) { +// // log.Infof("statusCode:%d, responseBody:%s", statusCode, string(responseBody)) +// log.Infof("Successfully fetched embeddings for key: %s", key) +// if statusCode != 200 { +// log.Errorf("Failed to fetch embeddings, statusCode: %d, responseBody: %s", statusCode, string(responseBody)) +// ctx.SetContext(QueryEmbeddingKey, nil) +// proxywasm.ResumeHttpRequest() +// } else { +// processFetchedEmbeddings(key, responseBody, ctx, config, log, stream) +// } +// }, +// 10000) +// } + +// // 先将向量化的结果存入上下文ctx变量,其次发起向量搜索请求 +// func ProcessFetchedEmbeddings(key string, responseBody []byte, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, stream bool) { +// text_embedding_raw, _ := ParseTextEmbedding(responseBody) +// text_embedding := text_embedding_raw.Output.Embeddings[0].Embedding +// // ctx.SetContext(CacheKeyContextKey, text_embedding) +// ctx.SetContext(QueryEmbeddingKey, text_embedding) +// ctx.SetContext(CacheKeyContextKey, key) +// PerformQueryAndRespond(key, text_embedding, ctx, config, log, stream) +// } + +// // 调用向量搜索接口搜索最相似的key,搜索成功后调用redisSearchHandler函数获取最相似的key的结果 +// func PerformQueryAndRespond(key string, text_embedding []float64, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, stream bool) { +// vector_url, vector_request, vector_headers, err := ConstructEmbeddingQueryParameters(config, text_embedding) +// if err != nil { +// log.Errorf("Failed to perform query, err: %v", err) +// proxywasm.ResumeHttpRequest() +// return +// } +// config.DashVectorInfo.DashVectorClient.Post( +// vector_url, +// vector_headers, +// vector_request, +// func(statusCode int, responseHeaders http.Header, responseBody []byte) { +// log.Infof("statusCode:%d, responseBody:%s", statusCode, string(responseBody)) +// query_resp, err_query := ParseQueryResponse(responseBody) +// if err_query != nil { +// log.Errorf("Failed to parse response: %v", err) +// proxywasm.ResumeHttpRequest() +// return +// } +// if len(query_resp.Output) < 1 { +// log.Warnf("query response is empty") +// UploadQueryEmbedding(ctx, config, log, key, text_embedding) +// return +// } +// most_similar_key := query_resp.Output[0].Fields["query"].(string) +// log.Infof("most similar key:%s", most_similar_key) +// most_similar_score := query_resp.Output[0].Score +// if most_similar_score < 0.1 { +// ctx.SetContext(CacheKeyContextKey, nil) +// RedisSearchHandler(most_similar_key, ctx, config, log, stream, false) +// } else { +// log.Infof("the most similar key's score is too high, key:%s, score:%f", most_similar_key, most_similar_score) +// UploadQueryEmbedding(ctx, config, log, key, text_embedding) +// proxywasm.ResumeHttpRequest() +// return +// } +// }, +// 100000) +// } + +// // 未命中cache,则将新的query embedding和对应的key存入向量数据库 +// func UploadQueryEmbedding(ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, key string, text_embedding []float64) error { +// vector_url, vector_body, err := ConsturctEmbeddingInsertParameters(&config, log, text_embedding, key) +// if err != nil { +// log.Errorf("Failed to construct embedding insert parameters: %v", err) +// proxywasm.ResumeHttpRequest() +// return nil +// } +// err = config.DashVectorInfo.DashVectorClient.Post( +// vector_url, +// [][2]string{ +// {"Content-Type", "application/json"}, +// {"dashvector-auth-token", config.DashVectorInfo.DashVectorKey}, +// }, +// vector_body, +// func(statusCode int, responseHeaders http.Header, responseBody []byte) { +// if statusCode != 200 { +// log.Errorf("Failed to upload query embedding: %s", responseBody) +// } else { +// log.Infof("Successfully uploaded query embedding for key: %s", key) +// } +// proxywasm.ResumeHttpRequest() +// }, +// 10000, +// ) +// if err != nil { +// log.Errorf("Failed to upload query embedding: %v", err) +// proxywasm.ResumeHttpRequest() +// return nil +// } +// return nil +// } + +// // ===================== 以上是主要逻辑 ===================== diff --git a/plugins/wasm-go/extensions/ai-cache/vectorDatabase/chroma.go b/plugins/wasm-go/extensions/ai-cache/vectorDatabase/chroma.go new file mode 100644 index 0000000000..2b345d51d9 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/vectorDatabase/chroma.go @@ -0,0 +1,184 @@ +package vectorDatabase + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" +) + +type chromaProviderInitializer struct{} + +const chromaPort = 8001 + +func (c *chromaProviderInitializer) ValidateConfig(config ProviderConfig) error { + if len(config.ChromaCollectionID) == 0 { + return errors.New("ChromaCollectionID is required") + } + if len(config.ChromaServiceName) == 0 { + return errors.New("ChromaServiceName is required") + } + return nil +} + +func (c *chromaProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + return &ChromaProvider{ + config: config, + client: wrapper.NewClusterClient(wrapper.DnsCluster{ + ServiceName: config.ChromaServiceName, + Port: chromaPort, + Domain: config.ChromaServiceName, + }), + }, nil +} + +type ChromaProvider struct { + config ProviderConfig + client wrapper.HttpClient +} + +func (c *ChromaProvider) GetProviderType() string { + return providerTypeChroma +} + +func (d *ChromaProvider) GetThreshold() float64 { + return d.config.ChromaDistanceThreshold +} + +func (d *ChromaProvider) QueryEmbedding( + emb []float64, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log)) { + // 最小需要填写的参数为 collection_id, embeddings 和 ids + // 下面是一个例子 + // { + // "where": {}, 用于 metadata 过滤,可选参数 + // "where_document": {}, 用于 document 过滤,可选参数 + // "query_embeddings": [ + // [1.1, 2.3, 3.2] + // ], + // "n_results": 5, + // "include": [ + // "metadatas", + // "distances" + // ] + // } + requestBody, err := json.Marshal(ChromaQueryRequest{ + QueryEmbeddings: []ChromaEmbedding{emb}, + NResults: d.config.ChromaNResult, + Include: []string{"distances"}, + }) + + if err != nil { + log.Errorf("[Chroma] Failed to marshal query embedding request body: %v", err) + return + } + + d.client.Post( + fmt.Sprintf("/api/v1/collections/%s/query", d.config.ChromaCollectionID), + [][2]string{ + {"Content-Type", "application/json"}, + }, + requestBody, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + log.Infof("Query embedding response: %d, %s", statusCode, responseBody) + callback(responseBody, ctx, log) + }, + d.config.ChromaTimeout, + ) +} + +func (d *ChromaProvider) UploadEmbedding( + query_emb []float64, + queryString string, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(ctx wrapper.HttpContext, log wrapper.Log)) { + // 最小需要填写的参数为 collection_id, embeddings 和 ids + // 下面是一个例子 + // { + // "embeddings": [ + // [1.1, 2.3, 3.2] + // ], + // "ids": [ + // "你吃了吗?" + // ] + // } + requestBody, err := json.Marshal(ChromaInsertRequest{ + Embeddings: []ChromaEmbedding{query_emb}, + IDs: []string{queryString}, // queryString 指的是用户查询的问题 + }) + + if err != nil { + log.Errorf("[Chroma] Failed to marshal upload embedding request body: %v", err) + return + } + + d.client.Post( + fmt.Sprintf("/api/v1/collections/%s/add", d.config.ChromaCollectionID), + [][2]string{ + {"Content-Type", "application/json"}, + }, + requestBody, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + log.Infof("[Chroma] statusCode:%d, responseBody:%s", statusCode, string(responseBody)) + callback(ctx, log) + }, + d.config.ChromaTimeout, + ) +} + +// ChromaEmbedding represents the embedding vector for a data point. +type ChromaEmbedding []float64 + +// ChromaMetadataMap is a map from key to value for metadata. +type ChromaMetadataMap map[string]string + +// Dataset represents the entire dataset containing multiple data points. +type ChromaInsertRequest struct { + Embeddings []ChromaEmbedding `json:"embeddings"` + Metadatas []ChromaMetadataMap `json:"metadatas,omitempty"` // Optional metadata map array + Documents []string `json:"documents,omitempty"` // Optional document array + IDs []string `json:"ids"` +} + +// ChromaQueryRequest represents the query request structure. +type ChromaQueryRequest struct { + Where map[string]string `json:"where,omitempty"` // Optional where filter + WhereDocument map[string]string `json:"where_document,omitempty"` // Optional where_document filter + QueryEmbeddings []ChromaEmbedding `json:"query_embeddings"` + NResults int `json:"n_results"` + Include []string `json:"include"` +} + +// ChromaQueryResponse represents the search result structure. +type ChromaQueryResponse struct { + Ids [][]string `json:"ids"` // 每一个 embedding 相似的 key 可能会有多个,然后会有多个 embedding,所以是一个二维数组 + Distances [][]float64 `json:"distances"` // 与 Ids 一一对应 + Metadatas []ChromaMetadataMap `json:"metadatas,omitempty"` // Optional, can be null + Embeddings []ChromaEmbedding `json:"embeddings,omitempty"` // Optional, can be null + Documents []string `json:"documents,omitempty"` // Optional, can be null + Uris []string `json:"uris,omitempty"` // Optional, can be null + Data []interface{} `json:"data,omitempty"` // Optional, can be null + Included []string `json:"included"` +} + +func (d *ChromaProvider) ParseQueryResponse(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log) (QueryEmbeddingResult, error) { + var queryResp ChromaQueryResponse + err := json.Unmarshal(responseBody, &queryResp) + if err != nil { + return QueryEmbeddingResult{}, err + } + log.Infof("[Chroma] queryResp: %+v", queryResp) + log.Infof("[Chroma] queryResp Ids len: %d", len(queryResp.Ids)) + if len(queryResp.Ids) == 1 && len(queryResp.Ids[0]) == 0 { + return QueryEmbeddingResult{}, nil + } + return QueryEmbeddingResult{ + MostSimilarData: queryResp.Ids[0][0], + Score: queryResp.Distances[0][0], + }, nil +} diff --git a/plugins/wasm-go/extensions/ai-cache/vectorDatabase/dashvector.go b/plugins/wasm-go/extensions/ai-cache/vectorDatabase/dashvector.go new file mode 100644 index 0000000000..2200656c83 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/vectorDatabase/dashvector.go @@ -0,0 +1,210 @@ +package vectorDatabase + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" +) + +const ( + dashVectorPort = 443 + threshold = 2000 +) + +type dashVectorProviderInitializer struct { +} + +func (d *dashVectorProviderInitializer) ValidateConfig(config ProviderConfig) error { + if len(config.DashVectorKey) == 0 { + return errors.New("DashVectorKey is required") + } + if len(config.DashVectorAuthApiEnd) == 0 { + return errors.New("DashVectorEnd is required") + } + if len(config.DashVectorCollection) == 0 { + return errors.New("DashVectorCollection is required") + } + if len(config.DashVectorServiceName) == 0 { + return errors.New("DashVectorServiceName is required") + } + return nil +} + +func (d *dashVectorProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + return &DvProvider{ + config: config, + client: wrapper.NewClusterClient(wrapper.DnsCluster{ + ServiceName: config.DashVectorServiceName, + Port: dashVectorPort, + Domain: config.DashVectorAuthApiEnd, + }), + }, nil +} + +type DvProvider struct { + config ProviderConfig + client wrapper.HttpClient +} + +func (d *DvProvider) GetProviderType() string { + return providerTypeDashVector +} + +// type embeddingRequest struct { +// Model string `json:"model"` +// Input input `json:"input"` +// Parameters params `json:"parameters"` +// } + +// type params struct { +// TextType string `json:"text_type"` +// } + +// type input struct { +// Texts []string `json:"texts"` +// } + +// queryResponse 定义查询响应的结构 +type queryResponse struct { + Code int `json:"code"` + RequestID string `json:"request_id"` + Message string `json:"message"` + Output []result `json:"output"` +} + +// queryRequest 定义查询请求的结构 +type queryRequest struct { + Vector []float64 `json:"vector"` + TopK int `json:"topk"` + IncludeVector bool `json:"include_vector"` +} + +// result 定义查询结果的结构 +type result struct { + ID string `json:"id"` + Vector []float64 `json:"vector,omitempty"` // omitempty 使得如果 vector 是空,它将不会被序列化 + Fields map[string]interface{} `json:"fields"` + Score float64 `json:"score"` +} + +func (d *DvProvider) constructEmbeddingQueryParameters(vector []float64) (string, []byte, [][2]string, error) { + url := fmt.Sprintf("/v1/collections/%s/query", d.config.DashVectorCollection) + + requestData := queryRequest{ + Vector: vector, + TopK: d.config.DashVectorTopK, + IncludeVector: false, + } + + requestBody, err := json.Marshal(requestData) + if err != nil { + return "", nil, nil, err + } + + header := [][2]string{ + {"Content-Type", "application/json"}, + {"dashvector-auth-token", d.config.DashVectorKey}, + } + + return url, requestBody, header, nil +} + +func (d *DvProvider) parseQueryResponse(responseBody []byte) (queryResponse, error) { + var queryResp queryResponse + err := json.Unmarshal(responseBody, &queryResp) + if err != nil { + return queryResponse{}, err + } + return queryResp, nil +} + +func (d *DvProvider) GetThreshold() float64 { + return threshold +} +func (d *DvProvider) QueryEmbedding( + emb []float64, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log)) { + // 构造请求参数 + url, body, headers, err := d.constructEmbeddingQueryParameters(emb) + if err != nil { + log.Infof("Failed to construct embedding query parameters: %v", err) + } + + d.client.Post(url, headers, body, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + if statusCode != http.StatusOK { + log.Infof("Failed to query embedding: %d", statusCode) + return + } + log.Infof("Query embedding response: %d, %s", statusCode, responseBody) + callback(responseBody, ctx, log) + }, + d.config.DashVectorTimeout) + if err != nil { + log.Infof("Failed to query embedding: %v", err) + } +} + +func (d *DvProvider) ParseQueryResponse(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log) (QueryEmbeddingResult, error) { + resp, err := d.parseQueryResponse(responseBody) + if err != nil { + return QueryEmbeddingResult{}, err + } + if len(resp.Output) == 0 { + return QueryEmbeddingResult{}, nil + } + return QueryEmbeddingResult{ + MostSimilarData: resp.Output[0].Fields["query"].(string), + Score: resp.Output[0].Score, + }, nil +} + +type document struct { + Vector []float64 `json:"vector"` + Fields map[string]string `json:"fields"` +} + +type insertRequest struct { + Docs []document `json:"docs"` +} + +func (d *DvProvider) constructEmbeddingUploadParameters(emb []float64, query_string string) (string, []byte, [][2]string, error) { + url := "/v1/collections/" + d.config.DashVectorCollection + "/docs" + + doc := document{ + Vector: emb, + Fields: map[string]string{ + "query": query_string, + }, + } + + requestBody, err := json.Marshal(insertRequest{Docs: []document{doc}}) + if err != nil { + return "", nil, nil, err + } + + header := [][2]string{ + {"Content-Type", "application/json"}, + {"dashvector-auth-token", d.config.DashVectorKey}, + } + + return url, requestBody, header, err +} + +func (d *DvProvider) UploadEmbedding(query_emb []float64, queryString string, ctx wrapper.HttpContext, log wrapper.Log, callback func(ctx wrapper.HttpContext, log wrapper.Log)) { + url, body, headers, _ := d.constructEmbeddingUploadParameters(query_emb, queryString) + d.client.Post( + url, + headers, + body, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + log.Infof("statusCode:%d, responseBody:%s", statusCode, string(responseBody)) + callback(ctx, log) + }, + d.config.DashVectorTimeout) +} diff --git a/plugins/wasm-go/extensions/ai-cache/vectorDatabase/provider.go b/plugins/wasm-go/extensions/ai-cache/vectorDatabase/provider.go new file mode 100644 index 0000000000..0177efe76a --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/vectorDatabase/provider.go @@ -0,0 +1,141 @@ +package vectorDatabase + +import ( + "errors" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" +) + +const ( + providerTypeDashVector = "dashvector" + providerTypeChroma = "chroma" +) + +type providerInitializer interface { + ValidateConfig(ProviderConfig) error + CreateProvider(ProviderConfig) (Provider, error) +} + +var ( + providerInitializers = map[string]providerInitializer{ + providerTypeDashVector: &dashVectorProviderInitializer{}, + providerTypeChroma: &chromaProviderInitializer{}, + } +) + +type Provider interface { + GetProviderType() string + QueryEmbedding( + emb []float64, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log)) + UploadEmbedding( + query_emb []float64, + queryString string, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(ctx wrapper.HttpContext, log wrapper.Log)) + GetThreshold() float64 + ParseQueryResponse(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log) (QueryEmbeddingResult, error) +} + +// 定义通用的查询结果的结构体 +type QueryEmbeddingResult struct { + MostSimilarData string // 相似的文本 + Score float64 // 文本的向量相似度或距离等度量 +} + +type ProviderConfig struct { + // @Title zh-CN 向量存储服务提供者类型 + // @Description zh-CN 向量存储服务提供者类型,例如 DashVector、Milvus + typ string `json:"vectorStoreProviderType"` + // @Title zh-CN DashVector 阿里云向量搜索引擎 + // @Description zh-CN 调用阿里云的向量搜索引擎 + DashVectorServiceName string `require:"true" yaml:"DashVectorServiceName" json:"DashVectorServiceName"` + // @Title zh-CN DashVector Key + // @Description zh-CN 阿里云向量搜索引擎的 key + DashVectorKey string `require:"true" yaml:"DashVectorKey" json:"DashVectorKey"` + // @Title zh-CN DashVector AuthApiEnd + // @Description zh-CN 阿里云向量搜索引擎的 AuthApiEnd + DashVectorAuthApiEnd string `require:"true" yaml:"DashVectorEnd" json:"DashVectorEnd"` + // @Title zh-CN DashVector Collection + // @Description zh-CN 指定使用阿里云搜索引擎中的哪个向量集合 + DashVectorCollection string `require:"true" yaml:"DashVectorCollection" json:"DashVectorCollection"` + // @Title zh-CN DashVector Client + // @Description zh-CN 阿里云向量搜索引擎的 Client + DashVectorTopK int `require:"false" yaml:"DashVectorTopK" json:"DashVectorTopK"` + DashVectorTimeout uint32 `require:"false" yaml:"DashVectorTimeout" json:"DashVectorTimeout"` + DashVectorClient wrapper.HttpClient `yaml:"-" json:"-"` + + // @Title zh-CN Chroma 的上游服务名称 + // @Description zh-CN Chroma 服务所对应的网关内上游服务名称 + ChromaServiceName string `require:"true" yaml:"ChromaServiceName" json:"ChromaServiceName"` + // @Title zh-CN Chroma Collection ID + // @Description zh-CN Chroma Collection 的 ID + ChromaCollectionID string `require:"false" yaml:"ChromaCollectionID" json:"ChromaCollectionID"` + // @Title zh-CN Chroma 距离阈值 + // @Description zh-CN Chroma 距离阈值,默认为 2000 + ChromaDistanceThreshold float64 `require:"false" yaml:"ChromaDistanceThreshold" json:"ChromaDistanceThreshold"` + // @Title zh-CN Chroma 搜索返回结果数量 + // @Description zh-CN Chroma 搜索返回结果数量,默认为 1 + ChromaNResult int `require:"false" yaml:"ChromaNResult" json:"ChromaNResult"` + // @Title zh-CN Chroma 超时设置 + // @Description zh-CN Chroma 超时设置,默认为 10 秒 + ChromaTimeout uint32 `require:"false" yaml:"ChromaTimeout" json:"ChromaTimeout"` +} + +func (c *ProviderConfig) FromJson(json gjson.Result) { + c.typ = json.Get("vectorStoreProviderType").String() + // DashVector + c.DashVectorServiceName = json.Get("DashVectorServiceName").String() + c.DashVectorKey = json.Get("DashVectorKey").String() + c.DashVectorAuthApiEnd = json.Get("DashVectorEnd").String() + c.DashVectorCollection = json.Get("DashVectorCollection").String() + c.DashVectorTopK = int(json.Get("DashVectorTopK").Int()) + if c.DashVectorTopK == 0 { + c.DashVectorTopK = 1 + } + c.DashVectorTimeout = uint32(json.Get("DashVectorTimeout").Int()) + if c.DashVectorTimeout == 0 { + c.DashVectorTimeout = 10000 + } + // Chroma + c.ChromaCollectionID = json.Get("ChromaCollectionID").String() + c.ChromaServiceName = json.Get("ChromaServiceName").String() + c.ChromaDistanceThreshold = json.Get("ChromaDistanceThreshold").Float() + if c.ChromaDistanceThreshold == 0 { + c.ChromaDistanceThreshold = 2000 + } + c.ChromaNResult = int(json.Get("ChromaNResult").Int()) + if c.ChromaNResult == 0 { + c.ChromaNResult = 1 + } + c.ChromaTimeout = uint32(json.Get("ChromaTimeout").Int()) + if c.ChromaTimeout == 0 { + c.ChromaTimeout = 10000 + } +} + +func (c *ProviderConfig) Validate() error { + if c.typ == "" { + return errors.New("[ai-cache] missing type in provider config") + } + initializer, has := providerInitializers[c.typ] + if !has { + return errors.New("unknown provider type: " + c.typ) + } + if err := initializer.ValidateConfig(*c); err != nil { + return err + } + return nil +} + +func CreateProvider(pc ProviderConfig) (Provider, error) { + initializer, has := providerInitializers[pc.typ] + if !has { + return nil, errors.New("unknown provider type: " + pc.typ) + } + return initializer.CreateProvider(pc) +} diff --git a/plugins/wasm-go/extensions/ai-cache/vectorDatabase/weaviate.go b/plugins/wasm-go/extensions/ai-cache/vectorDatabase/weaviate.go new file mode 100644 index 0000000000..8bed5d098b --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/vectorDatabase/weaviate.go @@ -0,0 +1,172 @@ +// package vectorDatabase + +// import ( +// "encoding/json" +// "errors" +// "fmt" +// "net/http" + +// "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" +// ) + +// const ( +// dashVectorPort = 443 +// ) + +// type dashVectorProviderInitializer struct { +// } + +// func (d *dashVectorProviderInitializer) ValidateConfig(config ProviderConfig) error { +// if len(config.DashVectorKey) == 0 { +// return errors.New("DashVectorKey is required") +// } +// if len(config.DashVectorAuthApiEnd) == 0 { +// return errors.New("DashVectorEnd is required") +// } +// if len(config.DashVectorCollection) == 0 { +// return errors.New("DashVectorCollection is required") +// } +// if len(config.DashVectorServiceName) == 0 { +// return errors.New("DashVectorServiceName is required") +// } +// return nil +// } + +// func (d *dashVectorProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { +// return &DvProvider{ +// config: config, +// client: wrapper.NewClusterClient(wrapper.DnsCluster{ +// ServiceName: config.DashVectorServiceName, +// Port: dashVectorPort, +// Domain: config.DashVectorAuthApiEnd, +// }), +// }, nil +// } + +// type DvProvider struct { +// config ProviderConfig +// client wrapper.HttpClient +// } + +// func (d *DvProvider) GetProviderType() string { +// return providerTypeDashVector +// } + +// type EmbeddingRequest struct { +// Model string `json:"model"` +// Input Input `json:"input"` +// Parameters Params `json:"parameters"` +// } + +// type Params struct { +// TextType string `json:"text_type"` +// } + +// type Input struct { +// Texts []string `json:"texts"` +// } + +// func (d *DvProvider) ConstructEmbeddingQueryParameters(vector []float64) (string, []byte, [][2]string, error) { +// url := fmt.Sprintf("/v1/collections/%s/query", d.config.DashVectorCollection) + +// requestData := QueryRequest{ +// Vector: vector, +// TopK: d.config.DashVectorTopK, +// IncludeVector: false, +// } + +// requestBody, err := json.Marshal(requestData) +// if err != nil { +// return "", nil, nil, err +// } + +// header := [][2]string{ +// {"Content-Type", "application/json"}, +// {"dashvector-auth-token", d.config.DashVectorKey}, +// } + +// return url, requestBody, header, nil +// } + +// func (d *DvProvider) ParseQueryResponse(responseBody []byte) (QueryResponse, error) { +// var queryResp QueryResponse +// err := json.Unmarshal(responseBody, &queryResp) +// if err != nil { +// return QueryResponse{}, err +// } +// return queryResp, nil +// } + +// func (d *DvProvider) QueryEmbedding( +// queryEmb []float64, +// ctx wrapper.HttpContext, +// log wrapper.Log, +// callback func(query_resp QueryResponse, ctx wrapper.HttpContext, log wrapper.Log)) { + +// // 构造请求参数 +// url, body, headers, err := d.ConstructEmbeddingQueryParameters(queryEmb) +// if err != nil { +// log.Infof("Failed to construct embedding query parameters: %v", err) +// } + +// err = d.client.Post(url, headers, body, +// func(statusCode int, responseHeaders http.Header, responseBody []byte) { +// log.Infof("Query embedding response: %d, %s", statusCode, responseBody) +// query_resp, err_query := d.ParseQueryResponse(responseBody) +// if err_query != nil { +// log.Infof("Failed to parse response: %v", err_query) +// } +// callback(query_resp, ctx, log) +// }, +// d.config.DashVectorTimeout) +// if err != nil { +// log.Infof("Failed to query embedding: %v", err) +// } + +// } + +// type Document struct { +// Vector []float64 `json:"vector"` +// Fields map[string]string `json:"fields"` +// } + +// type InsertRequest struct { +// Docs []Document `json:"docs"` +// } + +// func (d *DvProvider) ConstructEmbeddingUploadParameters(emb []float64, query_string string) (string, []byte, [][2]string, error) { +// url := "/v1/collections/" + d.config.DashVectorCollection + "/docs" + +// doc := Document{ +// Vector: emb, +// Fields: map[string]string{ +// "query": query_string, +// }, +// } + +// requestBody, err := json.Marshal(InsertRequest{Docs: []Document{doc}}) +// if err != nil { +// return "", nil, nil, err +// } + +// header := [][2]string{ +// {"Content-Type", "application/json"}, +// {"dashvector-auth-token", d.config.DashVectorKey}, +// } + +// return url, requestBody, header, err +// } + +// func (d *DvProvider) UploadEmbedding(query_emb []float64, queryString string, ctx wrapper.HttpContext, log wrapper.Log, callback func(ctx wrapper.HttpContext, log wrapper.Log)) { +// url, body, headers, _ := d.ConstructEmbeddingUploadParameters(query_emb, queryString) +// d.client.Post( +// url, +// headers, +// body, +// func(statusCode int, responseHeaders http.Header, responseBody []byte) { +// log.Infof("statusCode:%d, responseBody:%s", statusCode, string(responseBody)) +// callback(ctx, log) +// }, +// d.config.DashVectorTimeout) +// } +package vectorDatabase diff --git a/plugins/wasm-go/extensions/ai-proxy/Makefile b/plugins/wasm-go/extensions/ai-proxy/Makefile new file mode 100644 index 0000000000..e5c7fa8de9 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/Makefile @@ -0,0 +1,4 @@ +.DEFAULT: +build: + tinygo build -o ai-proxy.wasm -scheduler=none -target=wasi -gc=custom -tags='custommalloc nottinygc_finalizer proxy_wasm_version_0_2_100' ./main.go + mv ai-proxy.wasm ../../../../docker-compose-test/ \ No newline at end of file diff --git a/plugins/wasm-go/extensions/ai-proxy/go.mod b/plugins/wasm-go/extensions/ai-proxy/go.mod index e2c671d989..a5457b90f8 100644 --- a/plugins/wasm-go/extensions/ai-proxy/go.mod +++ b/plugins/wasm-go/extensions/ai-proxy/go.mod @@ -10,12 +10,12 @@ require ( github.com/alibaba/higress/plugins/wasm-go v0.0.0 github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f github.com/stretchr/testify v1.8.4 - github.com/tidwall/gjson v1.14.3 + github.com/tidwall/gjson v1.17.3 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect - github.com/google/uuid v1.3.0 // indirect + github.com/google/uuid v1.3.0 github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect github.com/magefile/mage v1.14.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/plugins/wasm-go/extensions/ai-proxy/go.sum b/plugins/wasm-go/extensions/ai-proxy/go.sum index e5b8b79175..b2d63b5f4b 100644 --- a/plugins/wasm-go/extensions/ai-proxy/go.sum +++ b/plugins/wasm-go/extensions/ai-proxy/go.sum @@ -13,8 +13,8 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw= -github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94= +github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index f09e0d4afd..20e227ff2d 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -66,7 +66,8 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf apiName := getOpenAiApiName(path.Path) if apiName == "" { log.Debugf("[onHttpRequestHeader] unsupported path: %s", path.Path) - _ = util.SendResponse(404, "ai-proxy.unknown_api", util.MimeTypeTextPlain, "API not found: "+path.Path) + // _ = util.SendResponse(404, "ai-proxy.unknown_api", util.MimeTypeTextPlain, "API not found: "+path.Path) + log.Debugf("[onHttpRequestHeader] no send response") return types.ActionContinue } ctx.SetContext(ctxKeyApiName, apiName) diff --git a/plugins/wasm-go/extensions/request-block/Dockerfile b/plugins/wasm-go/extensions/request-block/Dockerfile new file mode 100644 index 0000000000..9b084e0596 --- /dev/null +++ b/plugins/wasm-go/extensions/request-block/Dockerfile @@ -0,0 +1,2 @@ +FROM scratch +COPY main.wasm plugin.wasm \ No newline at end of file diff --git a/plugins/wasm-go/extensions/request-block/Makefile b/plugins/wasm-go/extensions/request-block/Makefile new file mode 100644 index 0000000000..1210d6ec34 --- /dev/null +++ b/plugins/wasm-go/extensions/request-block/Makefile @@ -0,0 +1,4 @@ +.DEFAULT: +build: + tinygo build -o main.wasm -scheduler=none -target=wasi -gc=custom -tags='custommalloc nottinygc_finalizer' ./main.go + mv main.wasm ../../../../docker-compose-test/ \ No newline at end of file diff --git a/plugins/wasm-go/extensions/request-block/main.go b/plugins/wasm-go/extensions/request-block/main.go index 224d4b26d6..2a43b4df72 100644 --- a/plugins/wasm-go/extensions/request-block/main.go +++ b/plugins/wasm-go/extensions/request-block/main.go @@ -177,7 +177,9 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config RequestBlockConfig, lo } func onHttpRequestBody(ctx wrapper.HttpContext, config RequestBlockConfig, body []byte, log wrapper.Log) types.Action { + log.Infof("My request-block body: %s\n", string(body)) bodyStr := string(body) + if !config.caseSensitive { bodyStr = strings.ToLower(bodyStr) } diff --git a/plugins/wasm-go/go.mod b/plugins/wasm-go/go.mod index 999721f3f6..6373ff646e 100644 --- a/plugins/wasm-go/go.mod +++ b/plugins/wasm-go/go.mod @@ -7,7 +7,7 @@ require ( github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f github.com/stretchr/testify v1.8.4 - github.com/tidwall/gjson v1.14.3 + github.com/tidwall/gjson v1.17.3 github.com/tidwall/resp v0.1.1 ) diff --git a/plugins/wasm-go/go.sum b/plugins/wasm-go/go.sum index e726b100a5..f396d4d7d9 100644 --- a/plugins/wasm-go/go.sum +++ b/plugins/wasm-go/go.sum @@ -4,6 +4,12 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a h1:luYRvxLTE1xYxrXYj7nmjd1U0HHh8pUPiKfdZ0MhCGE= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240318034951-d5306e367c43 h1:dCw7F/9ciw4NZN7w68wQRaygZ2zGOWMTIEoRvP1tlWs= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240318034951-d5306e367c43/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc h1:t2AT8zb6N/59Y78lyRWedVoVWHNRSCBh0oWCC+bluTQ= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg= github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo= @@ -14,6 +20,8 @@ github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcU github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw= github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94= +github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= From 27b2f7183b5b2d4e68fa115958b6a2b3773fedcc Mon Sep 17 00:00:00 2001 From: suchun Date: Sat, 24 Aug 2024 00:29:22 +0000 Subject: [PATCH 05/71] alter some errors --- .../extensions/ai-cache/cache/cache.go | 17 +- .../extensions/ai-cache/config/config.go | 28 +- plugins/wasm-go/extensions/ai-cache/core.go | 102 +++++ .../ai-cache/embedding/dashscope.go | 93 +---- .../extensions/ai-cache/embedding/provider.go | 49 +-- plugins/wasm-go/extensions/ai-cache/go.mod | 2 +- plugins/wasm-go/extensions/ai-cache/main.go | 33 +- .../extensions/ai-cache/util/cachelogic.go | 358 ------------------ .../{vectorDatabase => vector}/chroma.go | 2 +- .../{vectorDatabase => vector}/dashvector.go | 2 +- .../{vectorDatabase => vector}/provider.go | 2 +- .../{vectorDatabase => vector}/weaviate.go | 3 +- 12 files changed, 165 insertions(+), 526 deletions(-) create mode 100644 plugins/wasm-go/extensions/ai-cache/core.go delete mode 100644 plugins/wasm-go/extensions/ai-cache/util/cachelogic.go rename plugins/wasm-go/extensions/ai-cache/{vectorDatabase => vector}/chroma.go (99%) rename plugins/wasm-go/extensions/ai-cache/{vectorDatabase => vector}/dashvector.go (99%) rename plugins/wasm-go/extensions/ai-cache/{vectorDatabase => vector}/provider.go (99%) rename plugins/wasm-go/extensions/ai-cache/{vectorDatabase => vector}/weaviate.go (98%) diff --git a/plugins/wasm-go/extensions/ai-cache/cache/cache.go b/plugins/wasm-go/extensions/ai-cache/cache/cache.go index 626c0f0397..f07c42cf64 100644 --- a/plugins/wasm-go/extensions/ai-cache/cache/cache.go +++ b/plugins/wasm-go/extensions/ai-cache/cache/cache.go @@ -24,30 +24,19 @@ type RedisConfig struct { // @Title zh-CN 请求超时 // @Description zh-CN 请求 redis 的超时时间,单位为毫秒。默认值是1000,即1秒 RedisTimeout uint32 `required:"false" yaml:"timeout" json:"timeout"` + + RedisHost string `required:"false" yaml:"host" json:"host"` } func CreateProvider(cf RedisConfig, log wrapper.Log) (Provider, error) { -<<<<<<< HEAD -======= log.Warnf("redis config: %v", cf) ->>>>>>> origin/feat/chroma rp := redisProvider{ config: cf, client: wrapper.NewRedisClusterClient(wrapper.FQDNCluster{ FQDN: cf.RedisServiceName, - Host: "redis", + Host: cf.RedisHost, Port: int64(cf.RedisServicePort)}), - // client: wrapper.NewRedisClusterClient(wrapper.DnsCluster{ - // ServiceName: cf.RedisServiceName, - // Port: int64(cf.RedisServicePort)}), } - // FQDN := wrapper.FQDNCluster{ - // FQDN: cf.RedisServiceName, - // Host: "redis", - // Port: int64(cf.RedisServicePort)} - // log.Debugf("test:%s", FQDN.ClusterName()) - // log.Debugf("test:%d", cf.RedisServicePort) - // log.Debugf("test:%s", proxywasm.RedisInit(FQDN.ClusterName(), "", "", 100)) err := rp.Init(cf.RedisUsername, cf.RedisPassword, cf.RedisTimeout) return &rp, err } diff --git a/plugins/wasm-go/extensions/ai-cache/config/config.go b/plugins/wasm-go/extensions/ai-cache/config/config.go index d31f6e2b70..7e1948306c 100644 --- a/plugins/wasm-go/extensions/ai-cache/config/config.go +++ b/plugins/wasm-go/extensions/ai-cache/config/config.go @@ -3,7 +3,7 @@ package config import ( "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/cache" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/embedding" - "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/vectorDatabase" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/vector" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/tidwall/gjson" ) @@ -16,11 +16,11 @@ type KVExtractor struct { } type PluginConfig struct { - EmbeddingProviderConfig embedding.ProviderConfig `required:"true" yaml:"embeddingProvider" json:"embeddingProvider"` - VectorDatabaseProviderConfig vectorDatabase.ProviderConfig `required:"true" yaml:"vectorBaseProvider" json:"vectorBaseProvider"` - CacheKeyFrom KVExtractor `required:"true" yaml:"cacheKeyFrom" json:"cacheKeyFrom"` - CacheValueFrom KVExtractor `required:"true" yaml:"cacheValueFrom" json:"cacheValueFrom"` - CacheStreamValueFrom KVExtractor `required:"true" yaml:"cacheStreamValueFrom" json:"cacheStreamValueFrom"` + EmbeddingProviderConfig embedding.ProviderConfig `required:"true" yaml:"embeddingProvider" json:"embeddingProvider"` + vectorProviderConfig vector.ProviderConfig `required:"true" yaml:"vectorBaseProvider" json:"vectorBaseProvider"` + CacheKeyFrom KVExtractor `required:"true" yaml:"cacheKeyFrom" json:"cacheKeyFrom"` + CacheValueFrom KVExtractor `required:"true" yaml:"cacheValueFrom" json:"cacheValueFrom"` + CacheStreamValueFrom KVExtractor `required:"true" yaml:"cacheStreamValueFrom" json:"cacheStreamValueFrom"` // @Title zh-CN 返回 HTTP 响应的模版 // @Description zh-CN 用 %s 标记需要被 cache value 替换的部分 ReturnResponseTemplate string `required:"true" yaml:"returnResponseTemplate" json:"returnResponseTemplate"` @@ -39,14 +39,14 @@ type PluginConfig struct { RedisConfig cache.RedisConfig `required:"true" yaml:"redisConfig" json:"redisConfig"` // 现在只支持RedisClient作为cacheClient - redisProvider cache.Provider `yaml:"-"` - embeddingProvider embedding.Provider `yaml:"-"` - vectorDatabaseProvider vectorDatabase.Provider `yaml:"-"` + redisProvider cache.Provider `yaml:"-"` + embeddingProvider embedding.Provider `yaml:"-"` + vectorProvider vector.Provider `yaml:"-"` } func (c *PluginConfig) FromJson(json gjson.Result) { c.EmbeddingProviderConfig.FromJson(json.Get("embeddingProvider")) - c.VectorDatabaseProviderConfig.FromJson(json.Get("vectorBaseProvider")) + c.vectorProviderConfig.FromJson(json.Get("vectorProvider")) c.RedisConfig.FromJson(json.Get("redis")) if c.CacheKeyFrom.RequestBody == "" { c.CacheKeyFrom.RequestBody = "messages.@reverse.0.content" @@ -84,7 +84,7 @@ func (c *PluginConfig) Validate() error { if err := c.EmbeddingProviderConfig.Validate(); err != nil { return err } - if err := c.VectorDatabaseProviderConfig.Validate(); err != nil { + if err := c.vectorProviderConfig.Validate(); err != nil { return err } return nil @@ -96,7 +96,7 @@ func (c *PluginConfig) Complete(log wrapper.Log) error { if err != nil { return err } - c.vectorDatabaseProvider, err = vectorDatabase.CreateProvider(c.VectorDatabaseProviderConfig) + c.vectorProvider, err = vector.CreateProvider(c.vectorProviderConfig) if err != nil { return err } @@ -111,8 +111,8 @@ func (c *PluginConfig) GetEmbeddingProvider() embedding.Provider { return c.embeddingProvider } -func (c *PluginConfig) GetVectorDatabaseProvider() vectorDatabase.Provider { - return c.vectorDatabaseProvider +func (c *PluginConfig) GetvectorProvider() vector.Provider { + return c.vectorProvider } func (c *PluginConfig) GetCacheProvider() cache.Provider { diff --git a/plugins/wasm-go/extensions/ai-cache/core.go b/plugins/wasm-go/extensions/ai-cache/core.go new file mode 100644 index 0000000000..accd11cad6 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/core.go @@ -0,0 +1,102 @@ +package main + +import ( + "fmt" + "net/http" + + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/config" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/embedding" + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/tidwall/resp" +) + +func RedisSearchHandler(key string, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, stream bool, ifUseEmbedding bool) { + activeCacheProvider := config.GetCacheProvider() + log.Debugf("activeCacheProvider:%v", activeCacheProvider) + activeCacheProvider.Get(embedding.CacheKeyPrefix+key, func(response resp.Value) { + if err := response.Error(); err == nil && !response.IsNull() { + log.Warnf("cache hit, key:%s", key) + HandleCacheHit(key, response, stream, ctx, config, log) + } else { + log.Warnf("cache miss, key:%s", key) + if ifUseEmbedding { + HandleCacheMiss(key, err, response, ctx, config, log, key, stream) + } else { + proxywasm.ResumeHttpRequest() + return + } + } + }) +} + +func HandleCacheHit(key string, response resp.Value, stream bool, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log) { + ctx.SetContext(embedding.CacheKeyContextKey, nil) + if !stream { + proxywasm.SendHttpResponse(200, [][2]string{{"content-type", "application/json; charset=utf-8"}}, []byte(fmt.Sprintf(config.ReturnResponseTemplate, "[Test, this is cache]"+response.String())), -1) + } else { + proxywasm.SendHttpResponse(200, [][2]string{{"content-type", "text/event-stream; charset=utf-8"}}, []byte(fmt.Sprintf(config.ReturnStreamResponseTemplate, "[Test, this is cache]"+response.String())), -1) + } +} + +func HandleCacheMiss(key string, err error, response resp.Value, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, queryString string, stream bool) { + if err != nil { + log.Warnf("redis get key:%s failed, err:%v", key, err) + } + if response.IsNull() { + log.Warnf("cache miss, key:%s", key) + } + FetchAndProcessEmbeddings(key, ctx, config, log, queryString, stream) +} + +func FetchAndProcessEmbeddings(key string, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, queryString string, stream bool) { + activeEmbeddingProvider := config.GetEmbeddingProvider() + activeEmbeddingProvider.GetEmbedding(queryString, ctx, log, + func(emb []float64, statusCode int, responseHeaders http.Header, responseBody []byte) { + if statusCode != 200 { + log.Errorf("Failed to fetch embeddings, statusCode: %d, responseBody: %s", statusCode, string(responseBody)) + } else { + log.Debugf("Successfully fetched embeddings for key: %s", key) + QueryVectorDB(key, emb, ctx, config, log, stream) + } + }) +} + +func QueryVectorDB(key string, text_embedding []float64, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, stream bool) { + log.Debugf("QueryVectorDB key: %s", key) + activeVectorDatabaseProvider := config.GetvectorProvider() + log.Debugf("activeVectorDatabaseProvider: %+v", activeVectorDatabaseProvider) + activeVectorDatabaseProvider.QueryEmbedding(text_embedding, ctx, log, + func(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log) { + resp, err := activeVectorDatabaseProvider.ParseQueryResponse(responseBody, ctx, log) + if err != nil { + log.Errorf("Failed to query vector database, err: %v", err) + proxywasm.ResumeHttpRequest() + return + } + + if len(resp.MostSimilarData) == 0 { + log.Warnf("Failed to query vector database, no most similar key found") + activeVectorDatabaseProvider.UploadEmbedding(text_embedding, key, ctx, log, + func(ctx wrapper.HttpContext, log wrapper.Log) { + proxywasm.ResumeHttpRequest() + }) + return + } + + log.Infof("most similar key: %s", resp.MostSimilarData) + if resp.Score < activeVectorDatabaseProvider.GetThreshold() { + log.Infof("accept most similar key: %s, score: %f", resp.MostSimilarData, resp.Score) + // ctx.SetContext(embedding.CacheKeyContextKey, nil) + RedisSearchHandler(resp.MostSimilarData, ctx, config, log, stream, false) + } else { + log.Infof("the most similar key's score is too high, key: %s, score: %f", resp.MostSimilarData, resp.Score) + activeVectorDatabaseProvider.UploadEmbedding(text_embedding, key, ctx, log, + func(ctx wrapper.HttpContext, log wrapper.Log) { + proxywasm.ResumeHttpRequest() + }) + return + } + }, + ) +} diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go b/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go index f843bb0e7a..90974b0662 100644 --- a/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go +++ b/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go @@ -6,37 +6,38 @@ import ( "net/http" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" -<<<<<<< HEAD - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" -======= ->>>>>>> origin/feat/chroma ) const ( - dashScopeDomain = "dashscope.aliyuncs.com" - dashScopePort = 443 + domain = "dashscope.aliyuncs.com" + port = 443 + modelName = "text-embedding-v1" + endpoint = "/api/v1/services/embeddings/text-embedding/text-embedding" ) type dashScopeProviderInitializer struct { } func (d *dashScopeProviderInitializer) ValidateConfig(config ProviderConfig) error { - if len(config.DashScopeKey) == 0 { + if config.apiKey == "" { return errors.New("DashScopeKey is required") } - if len(config.ServiceName) == 0 { - return errors.New("ServiceName is required") - } return nil } -func (d *dashScopeProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { +func (d *dashScopeProviderInitializer) CreateProvider(c ProviderConfig) (Provider, error) { + if c.servicePort == 0 { + c.servicePort = port + } + if c.serviceHost == "" { + c.serviceHost = domain + } return &DSProvider{ - config: config, + config: c, client: wrapper.NewClusterClient(wrapper.DnsCluster{ - ServiceName: config.ServiceName, - Port: dashScopePort, - Domain: dashScopeDomain, + ServiceName: c.serviceName, + Port: c.servicePort, + Domain: c.serviceHost, }), }, nil } @@ -72,16 +73,13 @@ type Usage struct { TotalTokens int `json:"total_tokens"` } -// EmbeddingRequest 定义请求的数据结构 type EmbeddingRequest struct { Model string `json:"model"` Input Input `json:"input"` Parameters Params `json:"parameters"` } -// Document 定义每个文档的结构 type Document struct { - // ID string `json:"id"` Vector []float64 `json:"vector"` Fields map[string]string `json:"fields"` } @@ -92,13 +90,7 @@ type DSProvider struct { } func (d *DSProvider) constructParameters(texts []string, log wrapper.Log) (string, [][2]string, []byte, error) { - const ( - endpoint = "/api/v1/services/embeddings/text-embedding/text-embedding" - modelName = "text-embedding-v1" - contentType = "application/json" - ) - // 构造请求数据 data := EmbeddingRequest{ Model: modelName, Input: Input{ @@ -117,7 +109,7 @@ func (d *DSProvider) constructParameters(texts []string, log wrapper.Log) (strin } // 检查 DashScopeKey 是否为空 - if d.config.DashScopeKey == "" { + if d.config.apiKey == "" { err := errors.New("DashScopeKey is empty") log.Errorf("Failed to construct headers: %v", err) return "", nil, nil, err @@ -125,8 +117,8 @@ func (d *DSProvider) constructParameters(texts []string, log wrapper.Log) (strin // 设置请求头 headers := [][2]string{ - {"Authorization", "Bearer " + d.config.DashScopeKey}, - {"Content-Type", contentType}, + {"Authorization", "Bearer " + d.config.apiKey}, + {"Content-Type", "application/json"}, } return endpoint, headers, requestBody, err @@ -140,10 +132,7 @@ type Result struct { Score float64 `json:"score"` } -<<<<<<< HEAD -======= // 返回指针防止拷贝 Embedding ->>>>>>> origin/feat/chroma func (d *DSProvider) parseTextEmbedding(responseBody []byte) (*Response, error) { var resp Response err := json.Unmarshal(responseBody, &resp) @@ -158,25 +147,13 @@ func (d *DSProvider) GetEmbedding( ctx wrapper.HttpContext, log wrapper.Log, callback func(emb []float64, statusCode int, responseHeaders http.Header, responseBody []byte)) error { -<<<<<<< HEAD - -======= ->>>>>>> origin/feat/chroma - // 构建参数并处理错误 Emb_url, Emb_headers, Emb_requestBody, err := d.constructParameters([]string{queryString}, log) if err != nil { log.Errorf("Failed to construct parameters: %v", err) return err } -<<<<<<< HEAD - // 发起 POST 请求 - d.client.Post(Emb_url, Emb_headers, Emb_requestBody, - func(statusCode int, responseHeaders http.Header, responseBody []byte) { - defer proxywasm.ResumeHttpRequest() // 确保 HTTP 请求被恢复 -======= var resp *Response - // 发起 POST 请求 d.client.Post(Emb_url, Emb_headers, Emb_requestBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) { if statusCode != http.StatusOK { @@ -184,53 +161,23 @@ func (d *DSProvider) GetEmbedding( err = errors.New("failed to get embedding") return } ->>>>>>> origin/feat/chroma - // 日志记录响应 log.Infof("Get embedding response: %d, %s", statusCode, responseBody) - // 解析响应 -<<<<<<< HEAD - resp, err := d.parseTextEmbedding(responseBody) - if err != nil { - log.Errorf("Failed to parse response: %v", err) - callback(nil, statusCode, responseHeaders, responseBody) -======= resp, err = d.parseTextEmbedding(responseBody) if err != nil { log.Errorf("Failed to parse response: %v", err) ->>>>>>> origin/feat/chroma return } - // 检查是否存在嵌入结果 if len(resp.Output.Embeddings) == 0 { log.Errorf("No embedding found in response") -<<<<<<< HEAD - callback(nil, statusCode, responseHeaders, responseBody) - return - } - - // 调用回调函数 - callback(resp.Output.Embeddings[0].Embedding, statusCode, responseHeaders, responseBody) - }, d.config.DashScopeTimeout) - -======= err = errors.New("no embedding found in response") return } - // 回调函数 callback(resp.Output.Embeddings[0].Embedding, statusCode, responseHeaders, responseBody) - // proxywasm.ResumeHttpRequest() // 后续还有其他的 http 请求,所以先不能恢复 - }, d.config.DashScopeTimeout) - // if err != nil { - // log.Errorf("Failed to call client.Post: %v", err) - // return nil, err - // } - // // 这里因为 d.client.Post 是异步的,所以会出现 resp 为 nil 的情况,需要等待回调函数完成 - // return resp.Output.Embeddings[0].Embedding, nil ->>>>>>> origin/feat/chroma + }, d.config.timeout) return nil } diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/provider.go b/plugins/wasm-go/extensions/ai-cache/embedding/provider.go index 7ab2b01b73..1b307c8ad9 100644 --- a/plugins/wasm-go/extensions/ai-cache/embedding/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/embedding/provider.go @@ -17,7 +17,7 @@ const ( StreamContextKey = "stream" CacheKeyPrefix = "higressAiCache" DefaultCacheKeyPrefix = "higressAiCache" - QueryEmbeddingKey = "queryEmbedding" + queryEmbeddingKey = "queryEmbedding" ) type providerInitializer interface { @@ -34,41 +34,32 @@ var ( type ProviderConfig struct { // @Title zh-CN 文本特征提取服务提供者类型 // @Description zh-CN 文本特征提取服务提供者类型,例如 DashScope - typ string `json:"TextEmbeddingProviderType"` + typ string `json:"type"` // @Title zh-CN DashScope 阿里云大模型服务名 // @Description zh-CN 调用阿里云的大模型服务 -<<<<<<< HEAD - ServiceName string `require:"true" yaml:"DashScopeServiceName" jaon:"DashScopeServiceName"` - Client wrapper.HttpClient `yaml:"-"` - DashScopeKey string `require:"true" yaml:"DashScopeKey" jaon:"DashScopeKey"` - DashScopeTimeout uint32 `require:"false" yaml:"DashScopeTimeout" jaon:"DashScopeTimeout"` - QueryEmbeddingKey string `require:"false" yaml:"QueryEmbeddingKey" jaon:"QueryEmbeddingKey"` -======= - ServiceName string `require:"true" yaml:"DashScopeServiceName" json:"DashScopeServiceName"` - Client wrapper.HttpClient `yaml:"-"` - DashScopeKey string `require:"true" yaml:"DashScopeKey" json:"DashScopeKey"` - DashScopeTimeout uint32 `require:"false" yaml:"DashScopeTimeout" json:"DashScopeTimeout"` - QueryEmbeddingKey string `require:"false" yaml:"QueryEmbeddingKey" json:"QueryEmbeddingKey"` ->>>>>>> origin/feat/chroma + serviceName string `require:"true" yaml:"serviceName" json:"serviceName"` + serviceHost string `require:"false" yaml:"serviceHost" json:"serviceHost"` + servicePort int64 `require:"false" yaml:"servicePort" json:"servicePort"` + apiKey string `require:"false" yaml:"apiKey" json:"apiKey"` + timeout uint32 `require:"false" yaml:"timeout" json:"timeout"` + client wrapper.HttpClient `yaml:"-"` } func (c *ProviderConfig) FromJson(json gjson.Result) { - c.typ = json.Get("TextEmbeddingProviderType").String() - c.ServiceName = json.Get("DashScopeServiceName").String() - c.DashScopeKey = json.Get("DashScopeKey").String() - c.DashScopeTimeout = uint32(json.Get("DashScopeTimeout").Int()) - if c.DashScopeTimeout == 0 { - c.DashScopeTimeout = 1000 + c.typ = json.Get("type").String() + c.serviceName = json.Get("serviceName").String() + c.serviceHost = json.Get("serviceHost").String() + c.servicePort = json.Get("servicePort").Int() + c.apiKey = json.Get("apiKey").String() + c.timeout = uint32(json.Get("timeout").Int()) + if c.timeout == 0 { + c.timeout = 1000 } - c.QueryEmbeddingKey = json.Get("QueryEmbeddingKey").String() } func (c *ProviderConfig) Validate() error { - if len(c.DashScopeKey) == 0 { - return errors.New("DashScopeKey is required") - } - if len(c.ServiceName) == 0 { - return errors.New("DashScopeServiceName is required") + if len(c.serviceName) == 0 { + return errors.New("serviceName is required") } return nil } @@ -88,11 +79,7 @@ func CreateProvider(pc ProviderConfig) (Provider, error) { type Provider interface { GetProviderType() string GetEmbedding( -<<<<<<< HEAD - text string, -======= queryString string, ->>>>>>> origin/feat/chroma ctx wrapper.HttpContext, log wrapper.Log, callback func(emb []float64, statusCode int, responseHeaders http.Header, responseBody []byte)) error diff --git a/plugins/wasm-go/extensions/ai-cache/go.mod b/plugins/wasm-go/extensions/ai-cache/go.mod index bf2a5948dd..e4aae265e0 100644 --- a/plugins/wasm-go/extensions/ai-cache/go.mod +++ b/plugins/wasm-go/extensions/ai-cache/go.mod @@ -7,7 +7,7 @@ go 1.19 replace github.com/alibaba/higress/plugins/wasm-go => ../.. require ( - github.com/alibaba/higress/plugins/wasm-go v1.3.6-0.20240528060522-53bccf89f441 + github.com/alibaba/higress/plugins/wasm-go v1.4.2 github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f github.com/tidwall/gjson v1.17.3 github.com/tidwall/resp v0.1.1 diff --git a/plugins/wasm-go/extensions/ai-cache/main.go b/plugins/wasm-go/extensions/ai-cache/main.go index f498256025..d58350c604 100644 --- a/plugins/wasm-go/extensions/ai-cache/main.go +++ b/plugins/wasm-go/extensions/ai-cache/main.go @@ -7,7 +7,6 @@ import ( "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/config" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/embedding" - "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/util" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" @@ -26,28 +25,6 @@ const ( QueryEmbeddingKey = "queryEmbedding" ) -// // Create the client -// func CreateClient() { -// cfg := weaviate.Config{ -// Host: "172.17.0.1:8081", -// Scheme: "http", -// Headers: nil, -// } - -// client, err := weaviate.NewClient(cfg) -// if err != nil { -// fmt.Println(err) -// } - -// // Check the connection -// live, err := client.Misc().LiveChecker().Do(context.Background()) -// if err != nil { -// panic(err) -// } -// fmt.Printf("%v", live) - -// } - func main() { // CreateClient() @@ -135,7 +112,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config config.PluginConfig, body queryString := config.CacheKeyPrefix + key - util.RedisSearchHandler(queryString, ctx, config, log, stream, true) + RedisSearchHandler(queryString, ctx, config, log, stream, true) // 需要等待异步回调完成,返回 Pause 状态,可以被 ResumeHttpRequest 恢复 return types.ActionPause @@ -185,14 +162,13 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config config.PluginConfig, } func onHttpResponseBody(ctx wrapper.HttpContext, config config.PluginConfig, chunk []byte, isLastChunk bool, log wrapper.Log) []byte { - // log.Infof("I am here") - log.Debugf("[onHttpResponseBody] i am here") + log.Infof("[onHttpResponseBody] chunk:%s", string(chunk)) + log.Infof("[onHttpResponseBody] isLastChunk:%v", isLastChunk) if ctx.GetContext(ToolCallsContextKey) != nil { // we should not cache tool call result return chunk } keyI := ctx.GetContext(CacheKeyContextKey) - // log.Infof("I am here 2: %v", keyI) if keyI == nil { return chunk } @@ -278,8 +254,5 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config config.PluginConfig, chu log.Infof("[onHttpResponseBody] Setting cache to redis, key:%s, value:%s", key, value) config.GetCacheProvider().Set(embedding.CacheKeyPrefix+key, value, nil) // TODO: 要不要加个Expire方法 - // if config.RedisConfig.RedisTimeout != 0 { - // config.GetCacheProvider().Expire(config.CacheKeyPrefix+key, config.RedisConfig.RedisTimeout, nil) - // } return chunk } diff --git a/plugins/wasm-go/extensions/ai-cache/util/cachelogic.go b/plugins/wasm-go/extensions/ai-cache/util/cachelogic.go deleted file mode 100644 index c532aec971..0000000000 --- a/plugins/wasm-go/extensions/ai-cache/util/cachelogic.go +++ /dev/null @@ -1,358 +0,0 @@ -// TODO: 在这里写缓存的具体逻辑, 将textEmbeddingPrvider和vectorStoreProvider作为逻辑中的一个函数调用 -package util - -import ( - "fmt" - "net/http" - - "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/config" - "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/embedding" -<<<<<<< HEAD - "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/vectorDatabase" -======= ->>>>>>> origin/feat/chroma - "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" - "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" - "github.com/tidwall/resp" -) - -// ===================== 以下是主要逻辑 ===================== -// 主handler函数,根据key从redis中获取value ,如果不命中,则首先调用文本向量化接口向量化query,然后调用向量搜索接口搜索最相似的出现过的key,最后再次调用redis获取结果 -// 可以把所有handler单独提取为文件,这里为了方便读者复制就和主逻辑放在一个文件中了 -// -// 1. query 进来和 redis 中存的 key 匹配 (redisSearchHandler) ,若完全一致则直接返回 (handleCacheHit) -// 2. 否则请求 text_embdding 接口将 query 转换为 query_embedding (fetchAndProcessEmbeddings) -// 3. 用 query_embedding 和向量数据库中的向量做 ANN search,返回最接近的 key ,并用阈值过滤 (performQueryAndRespond) -// 4. 若返回结果为空或大于阈值,舍去,本轮 cache 未命中, 最后将 query_embedding 存入向量数据库 (uploadQueryEmbedding) -// 5. 若小于阈值,则再次调用 redis对 most similar key 做匹配。 (redisSearchHandler) -// 7. 在 response 阶段请求 redis 新增key/LLM返回结果 - -func RedisSearchHandler(key string, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, stream bool, ifUseEmbedding bool) { - activeCacheProvider := config.GetCacheProvider() - log.Debugf("activeCacheProvider:%v", activeCacheProvider) - activeCacheProvider.Get(embedding.CacheKeyPrefix+key, func(response resp.Value) { - if err := response.Error(); err == nil && !response.IsNull() { - log.Warnf("cache hit, key:%s", key) - HandleCacheHit(key, response, stream, ctx, config, log) - } else { - log.Warnf("cache miss, key:%s", key) - if ifUseEmbedding { - HandleCacheMiss(key, err, response, ctx, config, log, key, stream) - } else { - proxywasm.ResumeHttpRequest() - return - } - } - }) -} - -func HandleCacheHit(key string, response resp.Value, stream bool, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log) { - ctx.SetContext(embedding.CacheKeyContextKey, nil) - if !stream { - proxywasm.SendHttpResponse(200, [][2]string{{"content-type", "application/json; charset=utf-8"}}, []byte(fmt.Sprintf(config.ReturnResponseTemplate, "[Test, this is cache]"+response.String())), -1) - } else { - proxywasm.SendHttpResponse(200, [][2]string{{"content-type", "text/event-stream; charset=utf-8"}}, []byte(fmt.Sprintf(config.ReturnStreamResponseTemplate, "[Test, this is cache]"+response.String())), -1) - } -} - -func HandleCacheMiss(key string, err error, response resp.Value, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, queryString string, stream bool) { - if err != nil { - log.Warnf("redis get key:%s failed, err:%v", key, err) - } - if response.IsNull() { - log.Warnf("cache miss, key:%s", key) - } - FetchAndProcessEmbeddings(key, ctx, config, log, queryString, stream) -} - -func FetchAndProcessEmbeddings(key string, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, queryString string, stream bool) { - activeEmbeddingProvider := config.GetEmbeddingProvider() - activeEmbeddingProvider.GetEmbedding(queryString, ctx, log, - func(emb []float64, statusCode int, responseHeaders http.Header, responseBody []byte) { - if statusCode != 200 { - log.Errorf("Failed to fetch embeddings, statusCode: %d, responseBody: %s", statusCode, string(responseBody)) - } else { - log.Debugf("Successfully fetched embeddings for key: %s", key) - QueryVectorDB(key, emb, ctx, config, log, stream) - } - }) -} - -func QueryVectorDB(key string, text_embedding []float64, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, stream bool) { -<<<<<<< HEAD - log.Debugf("QueryVectorDB key:%s", key) - activeVectorDatabaseProvider := config.GetVectorDatabaseProvider() - log.Debugf("activeVectorDatabaseProvider:%v", activeVectorDatabaseProvider) - activeVectorDatabaseProvider.QueryEmbedding(text_embedding, ctx, log, - func(query_resp vectorDatabase.QueryResponse, ctx wrapper.HttpContext, log wrapper.Log) { - if len(query_resp.Output) < 1 { - log.Warnf("query response is empty") -======= - log.Debugf("QueryVectorDB key: %s", key) - activeVectorDatabaseProvider := config.GetVectorDatabaseProvider() - log.Debugf("activeVectorDatabaseProvider: %+v", activeVectorDatabaseProvider) - activeVectorDatabaseProvider.QueryEmbedding(text_embedding, ctx, log, - func(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log) { - resp, err := activeVectorDatabaseProvider.ParseQueryResponse(responseBody, ctx, log) - if err != nil { - log.Errorf("Failed to query vector database, err: %v", err) - proxywasm.ResumeHttpRequest() - return - } - - if len(resp.MostSimilarData) == 0 { - log.Warnf("Failed to query vector database, no most similar key found") ->>>>>>> origin/feat/chroma - activeVectorDatabaseProvider.UploadEmbedding(text_embedding, key, ctx, log, - func(ctx wrapper.HttpContext, log wrapper.Log) { - proxywasm.ResumeHttpRequest() - }) - return - } -<<<<<<< HEAD - mostSimilarKey := query_resp.Output[0].Fields["query"].(string) - log.Infof("most similar key:%s", mostSimilarKey) - mostSimilarScore := query_resp.Output[0].Score - if mostSimilarScore < 2000 { - log.Infof("accept most similar key:%s, score:%f", mostSimilarKey, mostSimilarScore) - // ctx.SetContext(embedding.CacheKeyContextKey, nil) - RedisSearchHandler(mostSimilarKey, ctx, config, log, stream, false) - } else { - log.Infof("the most similar key's score is too high, key:%s, score:%f", mostSimilarKey, mostSimilarScore) -======= - - log.Infof("most similar key: %s", resp.MostSimilarData) - if resp.Score < activeVectorDatabaseProvider.GetThreshold() { - log.Infof("accept most similar key: %s, score: %f", resp.MostSimilarData, resp.Score) - // ctx.SetContext(embedding.CacheKeyContextKey, nil) - RedisSearchHandler(resp.MostSimilarData, ctx, config, log, stream, false) - } else { - log.Infof("the most similar key's score is too high, key: %s, score: %f", resp.MostSimilarData, resp.Score) ->>>>>>> origin/feat/chroma - activeVectorDatabaseProvider.UploadEmbedding(text_embedding, key, ctx, log, - func(ctx wrapper.HttpContext, log wrapper.Log) { - proxywasm.ResumeHttpRequest() - }) -<<<<<<< HEAD - proxywasm.ResumeHttpRequest() -======= ->>>>>>> origin/feat/chroma - return - } - }, - ) -<<<<<<< HEAD -======= - - // resp, err := activeVectorDatabaseProvider.QueryEmbedding(text_embedding, ctx, log) - // if err != nil { - // log.Errorf("Failed to query vector database, err: %v", err) - // activeVectorDatabaseProvider.UploadEmbedding(text_embedding, key, ctx, log) - // proxywasm.ResumeHttpRequest() - // return - // } - - // log.Infof("most similar key: %s", resp.MostSimilarData) - // if resp.Score < activeVectorDatabaseProvider.GetThreshold() { - // log.Infof("accept most similar key: %s, score: %f", resp.MostSimilarData, resp.Score) - // // ctx.SetContext(embedding.CacheKeyContextKey, nil) - // RedisSearchHandler(resp.MostSimilarData, ctx, config, log, stream, false) - // } else { - // log.Infof("the most similar key's score is too high, key: %s, score: %f", resp.MostSimilarData, resp.Score) - // activeVectorDatabaseProvider.UploadEmbedding(text_embedding, key, ctx, log) - // proxywasm.ResumeHttpRequest() - // return - // } - - // activeVectorDatabaseProvider.QueryEmbedding(text_embedding, ctx, log, - // func(query_resp vectorDatabase.QueryResponse, ctx wrapper.HttpContext, log wrapper.Log) { - // if len(query_resp.Output) < 1 { // 向量库不存在查询向量 - // log.Warnf("query response is empty") - // activeVectorDatabaseProvider.UploadEmbedding(text_embedding, key, ctx, log, - // func(ctx wrapper.HttpContext, log wrapper.Log) { - // proxywasm.ResumeHttpRequest() - // }) - // return - // } - // mostSimilarKey := query_resp.Output[0].Fields["query"].(string) - // log.Infof("most similar key:%s", mostSimilarKey) - // mostSimilarScore := query_resp.Output[0].Score - // if mostSimilarScore < 2000 { // 向量库存在满足相似度的向量 - // log.Infof("accept most similar key:%s, score:%f", mostSimilarKey, mostSimilarScore) - // // ctx.SetContext(embedding.CacheKeyContextKey, nil) - // RedisSearchHandler(mostSimilarKey, ctx, config, log, stream, false) - // } else { // 向量库不存在满足相似度的向量 - // log.Infof("the most similar key's score is too high, key:%s, score:%f", mostSimilarKey, mostSimilarScore) - // activeVectorDatabaseProvider.UploadEmbedding(text_embedding, key, ctx, log, - // func(ctx wrapper.HttpContext, log wrapper.Log) { - // proxywasm.ResumeHttpRequest() - // }) - // proxywasm.ResumeHttpRequest() - // return - // } - // }, - // ) ->>>>>>> origin/feat/chroma - // activeVectorDatabaseProvider.QueryEmbedding(text_embedding, ctx, log, - // func(query_resp vectorDatabase.QueryResponse, ctx wrapper.HttpContext, log wrapper.Log) { - // if len(query_resp.Output) < 1 { - // log.Warnf("query response is empty") - // // UploadQueryEmbedding(ctx, config, log, key, text_embedding) - // activeVectorDatabaseProvider.UploadEmbedding(text_embedding, key, ctx, log, - // func(ctx wrapper.HttpContext, log wrapper.Log) { - // // proxywasm.ResumeHttpRequest() - // log.Debugf("I am in the 117 line") - // }) - // return - // } - // most_similar_key := query_resp.Output[0].Fields["query"].(string) - // log.Infof("most similar key:%s", most_similar_key) - // most_similar_score := query_resp.Output[0].Score - // if most_similar_score < 0.1 { - // // ctx.SetContext(CacheKeyContextKey, nil) - // // RedisSearchHandler(most_similar_key, ctx, config, log, stream, false) - // } else { - // log.Infof("the most similar key's score is too high, key:%s, score:%f", most_similar_key, most_similar_score) - // // UploadQueryEmbedding(ctx, config, log, key, text_embedding) - // activeVectorDatabaseProvider.UploadEmbedding(text_embedding, key, ctx, log, - // func(ctx wrapper.HttpContext, log wrapper.Log) { - // proxywasm.ResumeHttpRequest() - // }) - // proxywasm.ResumeHttpRequest() - // return - // } - // }) - // ctx.SetContext(embedding.CacheKeyContextKey, text_embedding) - // ctx.SetContext(embedding.QueryEmbeddingKey, text_embedding) - // ctx.SetContext(embedding.CacheKeyContextKey, key) - // PerformQueryAndRespond(key, text_embedding, ctx, config, log, stream) -} - -// // 简单处理缓存命中的情况, 从redis中获取到value后,直接返回 -// func HandleCacheHit(key string, response resp.Value, stream bool, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log) { -// log.Warnf("cache hit, key:%s", key) -// ctx.SetContext(config.CacheKeyContextKey, nil) -// if !stream { -// proxywasm.SendHttpResponse(200, [][2]string{{"content-type", "application/json; charset=utf-8"}}, []byte(fmt.Sprintf(config.ReturnResponseTemplate, response.String())), -1) -// } else { -// proxywasm.SendHttpResponse(200, [][2]string{{"content-type", "text/event-stream; charset=utf-8"}}, []byte(fmt.Sprintf(config.ReturnStreamResponseTemplate, response.String())), -1) -// } -// } - -// // 处理缓存未命中的情况,调用fetchAndProcessEmbeddings函数向量化query -// func HandleCacheMiss(key string, err error, response resp.Value, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, queryString string, stream bool) { -// if err != nil { -// log.Warnf("redis get key:%s failed, err:%v", key, err) -// } -// if response.IsNull() { -// log.Warnf("cache miss, key:%s", key) -// } -// FetchAndProcessEmbeddings(key, ctx, config, log, queryString, stream) -// } - -// // 调用文本向量化接口向量化query, 向量化成功后调用processFetchedEmbeddings函数处理向量化结果 -// func FetchAndProcessEmbeddings(key string, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, queryString string, stream bool) { -// Emb_url, Emb_requestBody, Emb_headers := ConstructTextEmbeddingParameters(&config, log, []string{queryString}) -// config.DashVectorInfo.DashScopeClient.Post( -// Emb_url, -// Emb_headers, -// Emb_requestBody, -// func(statusCode int, responseHeaders http.Header, responseBody []byte) { -// // log.Infof("statusCode:%d, responseBody:%s", statusCode, string(responseBody)) -// log.Infof("Successfully fetched embeddings for key: %s", key) -// if statusCode != 200 { -// log.Errorf("Failed to fetch embeddings, statusCode: %d, responseBody: %s", statusCode, string(responseBody)) -// ctx.SetContext(QueryEmbeddingKey, nil) -// proxywasm.ResumeHttpRequest() -// } else { -// processFetchedEmbeddings(key, responseBody, ctx, config, log, stream) -// } -// }, -// 10000) -// } - -// // 先将向量化的结果存入上下文ctx变量,其次发起向量搜索请求 -// func ProcessFetchedEmbeddings(key string, responseBody []byte, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, stream bool) { -// text_embedding_raw, _ := ParseTextEmbedding(responseBody) -// text_embedding := text_embedding_raw.Output.Embeddings[0].Embedding -// // ctx.SetContext(CacheKeyContextKey, text_embedding) -// ctx.SetContext(QueryEmbeddingKey, text_embedding) -// ctx.SetContext(CacheKeyContextKey, key) -// PerformQueryAndRespond(key, text_embedding, ctx, config, log, stream) -// } - -// // 调用向量搜索接口搜索最相似的key,搜索成功后调用redisSearchHandler函数获取最相似的key的结果 -// func PerformQueryAndRespond(key string, text_embedding []float64, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, stream bool) { -// vector_url, vector_request, vector_headers, err := ConstructEmbeddingQueryParameters(config, text_embedding) -// if err != nil { -// log.Errorf("Failed to perform query, err: %v", err) -// proxywasm.ResumeHttpRequest() -// return -// } -// config.DashVectorInfo.DashVectorClient.Post( -// vector_url, -// vector_headers, -// vector_request, -// func(statusCode int, responseHeaders http.Header, responseBody []byte) { -// log.Infof("statusCode:%d, responseBody:%s", statusCode, string(responseBody)) -// query_resp, err_query := ParseQueryResponse(responseBody) -// if err_query != nil { -// log.Errorf("Failed to parse response: %v", err) -// proxywasm.ResumeHttpRequest() -// return -// } -// if len(query_resp.Output) < 1 { -// log.Warnf("query response is empty") -// UploadQueryEmbedding(ctx, config, log, key, text_embedding) -// return -// } -// most_similar_key := query_resp.Output[0].Fields["query"].(string) -// log.Infof("most similar key:%s", most_similar_key) -// most_similar_score := query_resp.Output[0].Score -// if most_similar_score < 0.1 { -// ctx.SetContext(CacheKeyContextKey, nil) -// RedisSearchHandler(most_similar_key, ctx, config, log, stream, false) -// } else { -// log.Infof("the most similar key's score is too high, key:%s, score:%f", most_similar_key, most_similar_score) -// UploadQueryEmbedding(ctx, config, log, key, text_embedding) -// proxywasm.ResumeHttpRequest() -// return -// } -// }, -// 100000) -// } - -// // 未命中cache,则将新的query embedding和对应的key存入向量数据库 -// func UploadQueryEmbedding(ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, key string, text_embedding []float64) error { -// vector_url, vector_body, err := ConsturctEmbeddingInsertParameters(&config, log, text_embedding, key) -// if err != nil { -// log.Errorf("Failed to construct embedding insert parameters: %v", err) -// proxywasm.ResumeHttpRequest() -// return nil -// } -// err = config.DashVectorInfo.DashVectorClient.Post( -// vector_url, -// [][2]string{ -// {"Content-Type", "application/json"}, -// {"dashvector-auth-token", config.DashVectorInfo.DashVectorKey}, -// }, -// vector_body, -// func(statusCode int, responseHeaders http.Header, responseBody []byte) { -// if statusCode != 200 { -// log.Errorf("Failed to upload query embedding: %s", responseBody) -// } else { -// log.Infof("Successfully uploaded query embedding for key: %s", key) -// } -// proxywasm.ResumeHttpRequest() -// }, -// 10000, -// ) -// if err != nil { -// log.Errorf("Failed to upload query embedding: %v", err) -// proxywasm.ResumeHttpRequest() -// return nil -// } -// return nil -// } - -// // ===================== 以上是主要逻辑 ===================== diff --git a/plugins/wasm-go/extensions/ai-cache/vectorDatabase/chroma.go b/plugins/wasm-go/extensions/ai-cache/vector/chroma.go similarity index 99% rename from plugins/wasm-go/extensions/ai-cache/vectorDatabase/chroma.go rename to plugins/wasm-go/extensions/ai-cache/vector/chroma.go index 2b345d51d9..e49108bbdf 100644 --- a/plugins/wasm-go/extensions/ai-cache/vectorDatabase/chroma.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/chroma.go @@ -1,4 +1,4 @@ -package vectorDatabase +package vector import ( "encoding/json" diff --git a/plugins/wasm-go/extensions/ai-cache/vectorDatabase/dashvector.go b/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go similarity index 99% rename from plugins/wasm-go/extensions/ai-cache/vectorDatabase/dashvector.go rename to plugins/wasm-go/extensions/ai-cache/vector/dashvector.go index 2200656c83..58ded82db8 100644 --- a/plugins/wasm-go/extensions/ai-cache/vectorDatabase/dashvector.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go @@ -1,4 +1,4 @@ -package vectorDatabase +package vector import ( "encoding/json" diff --git a/plugins/wasm-go/extensions/ai-cache/vectorDatabase/provider.go b/plugins/wasm-go/extensions/ai-cache/vector/provider.go similarity index 99% rename from plugins/wasm-go/extensions/ai-cache/vectorDatabase/provider.go rename to plugins/wasm-go/extensions/ai-cache/vector/provider.go index 0177efe76a..8bf1952149 100644 --- a/plugins/wasm-go/extensions/ai-cache/vectorDatabase/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/provider.go @@ -1,4 +1,4 @@ -package vectorDatabase +package vector import ( "errors" diff --git a/plugins/wasm-go/extensions/ai-cache/vectorDatabase/weaviate.go b/plugins/wasm-go/extensions/ai-cache/vector/weaviate.go similarity index 98% rename from plugins/wasm-go/extensions/ai-cache/vectorDatabase/weaviate.go rename to plugins/wasm-go/extensions/ai-cache/vector/weaviate.go index 8bed5d098b..0b361a6598 100644 --- a/plugins/wasm-go/extensions/ai-cache/vectorDatabase/weaviate.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/weaviate.go @@ -1,4 +1,4 @@ -// package vectorDatabase +package vector // import ( // "encoding/json" @@ -169,4 +169,3 @@ // }, // d.config.DashVectorTimeout) // } -package vectorDatabase From 130f2ee3402dbbb44c19b83b1c4d7ef8027552b8 Mon Sep 17 00:00:00 2001 From: Async Date: Sat, 24 Aug 2024 03:54:45 +0000 Subject: [PATCH 06/71] fix: embedding error --- docker-compose-test/envoy.yaml | 8 ++++---- plugins/wasm-go/extensions/ai-cache/core.go | 1 + .../wasm-go/extensions/ai-cache/embedding/dashscope.go | 3 +++ 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/docker-compose-test/envoy.yaml b/docker-compose-test/envoy.yaml index dc08f2e846..b7d81a78ae 100644 --- a/docker-compose-test/envoy.yaml +++ b/docker-compose-test/envoy.yaml @@ -56,12 +56,12 @@ static_resources: value: | { "embeddingProvider": { - "TextEmbeddingProviderType": "dashscope", - "ServiceName": "text-embedding-v2", - "DashScopeKey": "sk-your-key", + "type": "dashscope", + "serviceName": "dashscope", + "apiKey": "sk-your-key", "DashScopeServiceName": "dashscope" }, - "vectorBaseProvider": { + "vectorProvider": { "vectorStoreProviderType": "chroma", "ChromaServiceName": "chroma", "ChromaCollectionID": "0294deb1-8ef5-4582-b21c-75f23093db2c" diff --git a/plugins/wasm-go/extensions/ai-cache/core.go b/plugins/wasm-go/extensions/ai-cache/core.go index accd11cad6..54f2c668fc 100644 --- a/plugins/wasm-go/extensions/ai-cache/core.go +++ b/plugins/wasm-go/extensions/ai-cache/core.go @@ -55,6 +55,7 @@ func FetchAndProcessEmbeddings(key string, ctx wrapper.HttpContext, config confi func(emb []float64, statusCode int, responseHeaders http.Header, responseBody []byte) { if statusCode != 200 { log.Errorf("Failed to fetch embeddings, statusCode: %d, responseBody: %s", statusCode, string(responseBody)) + proxywasm.ResumeHttpRequest() // 当取 Embedding 失败了也继续流程 } else { log.Debugf("Successfully fetched embeddings for key: %s", key) QueryVectorDB(key, emb, ctx, config, log, stream) diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go b/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go index 90974b0662..fc311260cc 100644 --- a/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go +++ b/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go @@ -159,6 +159,7 @@ func (d *DSProvider) GetEmbedding( if statusCode != http.StatusOK { log.Errorf("Failed to fetch embeddings, statusCode: %d, responseBody: %s", statusCode, string(responseBody)) err = errors.New("failed to get embedding") + callback(nil, statusCode, responseHeaders, responseBody) return } @@ -167,12 +168,14 @@ func (d *DSProvider) GetEmbedding( resp, err = d.parseTextEmbedding(responseBody) if err != nil { log.Errorf("Failed to parse response: %v", err) + callback(nil, statusCode, responseHeaders, responseBody) return } if len(resp.Output.Embeddings) == 0 { log.Errorf("No embedding found in response") err = errors.New("no embedding found in response") + callback(nil, statusCode, responseHeaders, responseBody) return } From 56314d719e56070e776580b6f0c3c7b64cd1ed04 Mon Sep 17 00:00:00 2001 From: suchun Date: Sat, 24 Aug 2024 23:17:39 +0000 Subject: [PATCH 07/71] fix bugs && update interface design --- plugins/wasm-go/extensions/ai-cache/core.go | 36 +- .../ai-cache/embedding/dashscope.go | 10 +- .../extensions/ai-cache/embedding/provider.go | 2 +- .../extensions/ai-cache/vector/chroma.go | 364 +++++++++--------- .../extensions/ai-cache/vector/dashvector.go | 78 ++-- .../extensions/ai-cache/vector/provider.go | 126 +++--- 6 files changed, 312 insertions(+), 304 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-cache/core.go b/plugins/wasm-go/extensions/ai-cache/core.go index 54f2c668fc..0317b07e1e 100644 --- a/plugins/wasm-go/extensions/ai-cache/core.go +++ b/plugins/wasm-go/extensions/ai-cache/core.go @@ -6,6 +6,7 @@ import ( "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/config" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/embedding" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/vector" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/tidwall/resp" @@ -65,34 +66,29 @@ func FetchAndProcessEmbeddings(key string, ctx wrapper.HttpContext, config confi func QueryVectorDB(key string, text_embedding []float64, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, stream bool) { log.Debugf("QueryVectorDB key: %s", key) - activeVectorDatabaseProvider := config.GetvectorProvider() - log.Debugf("activeVectorDatabaseProvider: %+v", activeVectorDatabaseProvider) - activeVectorDatabaseProvider.QueryEmbedding(text_embedding, ctx, log, - func(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log) { - resp, err := activeVectorDatabaseProvider.ParseQueryResponse(responseBody, ctx, log) - if err != nil { - log.Errorf("Failed to query vector database, err: %v", err) - proxywasm.ResumeHttpRequest() - return - } - - if len(resp.MostSimilarData) == 0 { - log.Warnf("Failed to query vector database, no most similar key found") - activeVectorDatabaseProvider.UploadEmbedding(text_embedding, key, ctx, log, + activeVectorProvider := config.GetvectorProvider() + log.Debugf("activeVectorProvider: %+v", activeVectorProvider) + activeVectorProvider.QueryEmbedding(text_embedding, ctx, log, + func(results []vector.QueryEmbeddingResult, ctx wrapper.HttpContext, log wrapper.Log) { + // The baisc logic is to compare the similarity of the embedding with the most similar key in the database + if len(results) == 0 { + log.Warnf("Failed to query vector database, no similar key found") + activeVectorProvider.UploadEmbedding(text_embedding, key, ctx, log, func(ctx wrapper.HttpContext, log wrapper.Log) { proxywasm.ResumeHttpRequest() }) return } - log.Infof("most similar key: %s", resp.MostSimilarData) - if resp.Score < activeVectorDatabaseProvider.GetThreshold() { - log.Infof("accept most similar key: %s, score: %f", resp.MostSimilarData, resp.Score) + mostSimilarData := results[0] + log.Infof("most similar key: %s", mostSimilarData.Text) + if mostSimilarData.Score < activeVectorProvider.GetThreshold() { + log.Infof("accept most similar key: %s, score: %f", mostSimilarData.Text, mostSimilarData.Score) // ctx.SetContext(embedding.CacheKeyContextKey, nil) - RedisSearchHandler(resp.MostSimilarData, ctx, config, log, stream, false) + RedisSearchHandler(mostSimilarData.Text, ctx, config, log, stream, false) } else { - log.Infof("the most similar key's score is too high, key: %s, score: %f", resp.MostSimilarData, resp.Score) - activeVectorDatabaseProvider.UploadEmbedding(text_embedding, key, ctx, log, + log.Infof("the most similar key's score is too high, key: %s, score: %f", mostSimilarData.Text, mostSimilarData.Score) + activeVectorProvider.UploadEmbedding(text_embedding, key, ctx, log, func(ctx wrapper.HttpContext, log wrapper.Log) { proxywasm.ResumeHttpRequest() }) diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go b/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go index fc311260cc..e836a8bd97 100644 --- a/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go +++ b/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go @@ -146,7 +146,7 @@ func (d *DSProvider) GetEmbedding( queryString string, ctx wrapper.HttpContext, log wrapper.Log, - callback func(emb []float64, statusCode int, responseHeaders http.Header, responseBody []byte)) error { + callback func(emb []float64)) error { Emb_url, Emb_headers, Emb_requestBody, err := d.constructParameters([]string{queryString}, log) if err != nil { log.Errorf("Failed to construct parameters: %v", err) @@ -159,7 +159,7 @@ func (d *DSProvider) GetEmbedding( if statusCode != http.StatusOK { log.Errorf("Failed to fetch embeddings, statusCode: %d, responseBody: %s", statusCode, string(responseBody)) err = errors.New("failed to get embedding") - callback(nil, statusCode, responseHeaders, responseBody) + callback(nil) return } @@ -168,18 +168,18 @@ func (d *DSProvider) GetEmbedding( resp, err = d.parseTextEmbedding(responseBody) if err != nil { log.Errorf("Failed to parse response: %v", err) - callback(nil, statusCode, responseHeaders, responseBody) + callback(nil) return } if len(resp.Output.Embeddings) == 0 { log.Errorf("No embedding found in response") err = errors.New("no embedding found in response") - callback(nil, statusCode, responseHeaders, responseBody) + callback(nil) return } - callback(resp.Output.Embeddings[0].Embedding, statusCode, responseHeaders, responseBody) + callback(resp.Output.Embeddings[0].Embedding) }, d.config.timeout) return nil diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/provider.go b/plugins/wasm-go/extensions/ai-cache/embedding/provider.go index 1b307c8ad9..0c45c9cc83 100644 --- a/plugins/wasm-go/extensions/ai-cache/embedding/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/embedding/provider.go @@ -82,5 +82,5 @@ type Provider interface { queryString string, ctx wrapper.HttpContext, log wrapper.Log, - callback func(emb []float64, statusCode int, responseHeaders http.Header, responseBody []byte)) error + callback func(emb []float64)) error } diff --git a/plugins/wasm-go/extensions/ai-cache/vector/chroma.go b/plugins/wasm-go/extensions/ai-cache/vector/chroma.go index e49108bbdf..b3c3fc42d6 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/chroma.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/chroma.go @@ -1,184 +1,184 @@ package vector -import ( - "encoding/json" - "errors" - "fmt" - "net/http" - - "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" -) - -type chromaProviderInitializer struct{} - -const chromaPort = 8001 - -func (c *chromaProviderInitializer) ValidateConfig(config ProviderConfig) error { - if len(config.ChromaCollectionID) == 0 { - return errors.New("ChromaCollectionID is required") - } - if len(config.ChromaServiceName) == 0 { - return errors.New("ChromaServiceName is required") - } - return nil -} - -func (c *chromaProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { - return &ChromaProvider{ - config: config, - client: wrapper.NewClusterClient(wrapper.DnsCluster{ - ServiceName: config.ChromaServiceName, - Port: chromaPort, - Domain: config.ChromaServiceName, - }), - }, nil -} - -type ChromaProvider struct { - config ProviderConfig - client wrapper.HttpClient -} - -func (c *ChromaProvider) GetProviderType() string { - return providerTypeChroma -} - -func (d *ChromaProvider) GetThreshold() float64 { - return d.config.ChromaDistanceThreshold -} - -func (d *ChromaProvider) QueryEmbedding( - emb []float64, - ctx wrapper.HttpContext, - log wrapper.Log, - callback func(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log)) { - // 最小需要填写的参数为 collection_id, embeddings 和 ids - // 下面是一个例子 - // { - // "where": {}, 用于 metadata 过滤,可选参数 - // "where_document": {}, 用于 document 过滤,可选参数 - // "query_embeddings": [ - // [1.1, 2.3, 3.2] - // ], - // "n_results": 5, - // "include": [ - // "metadatas", - // "distances" - // ] - // } - requestBody, err := json.Marshal(ChromaQueryRequest{ - QueryEmbeddings: []ChromaEmbedding{emb}, - NResults: d.config.ChromaNResult, - Include: []string{"distances"}, - }) - - if err != nil { - log.Errorf("[Chroma] Failed to marshal query embedding request body: %v", err) - return - } - - d.client.Post( - fmt.Sprintf("/api/v1/collections/%s/query", d.config.ChromaCollectionID), - [][2]string{ - {"Content-Type", "application/json"}, - }, - requestBody, - func(statusCode int, responseHeaders http.Header, responseBody []byte) { - log.Infof("Query embedding response: %d, %s", statusCode, responseBody) - callback(responseBody, ctx, log) - }, - d.config.ChromaTimeout, - ) -} - -func (d *ChromaProvider) UploadEmbedding( - query_emb []float64, - queryString string, - ctx wrapper.HttpContext, - log wrapper.Log, - callback func(ctx wrapper.HttpContext, log wrapper.Log)) { - // 最小需要填写的参数为 collection_id, embeddings 和 ids - // 下面是一个例子 - // { - // "embeddings": [ - // [1.1, 2.3, 3.2] - // ], - // "ids": [ - // "你吃了吗?" - // ] - // } - requestBody, err := json.Marshal(ChromaInsertRequest{ - Embeddings: []ChromaEmbedding{query_emb}, - IDs: []string{queryString}, // queryString 指的是用户查询的问题 - }) - - if err != nil { - log.Errorf("[Chroma] Failed to marshal upload embedding request body: %v", err) - return - } - - d.client.Post( - fmt.Sprintf("/api/v1/collections/%s/add", d.config.ChromaCollectionID), - [][2]string{ - {"Content-Type", "application/json"}, - }, - requestBody, - func(statusCode int, responseHeaders http.Header, responseBody []byte) { - log.Infof("[Chroma] statusCode:%d, responseBody:%s", statusCode, string(responseBody)) - callback(ctx, log) - }, - d.config.ChromaTimeout, - ) -} - -// ChromaEmbedding represents the embedding vector for a data point. -type ChromaEmbedding []float64 - -// ChromaMetadataMap is a map from key to value for metadata. -type ChromaMetadataMap map[string]string - -// Dataset represents the entire dataset containing multiple data points. -type ChromaInsertRequest struct { - Embeddings []ChromaEmbedding `json:"embeddings"` - Metadatas []ChromaMetadataMap `json:"metadatas,omitempty"` // Optional metadata map array - Documents []string `json:"documents,omitempty"` // Optional document array - IDs []string `json:"ids"` -} - -// ChromaQueryRequest represents the query request structure. -type ChromaQueryRequest struct { - Where map[string]string `json:"where,omitempty"` // Optional where filter - WhereDocument map[string]string `json:"where_document,omitempty"` // Optional where_document filter - QueryEmbeddings []ChromaEmbedding `json:"query_embeddings"` - NResults int `json:"n_results"` - Include []string `json:"include"` -} - -// ChromaQueryResponse represents the search result structure. -type ChromaQueryResponse struct { - Ids [][]string `json:"ids"` // 每一个 embedding 相似的 key 可能会有多个,然后会有多个 embedding,所以是一个二维数组 - Distances [][]float64 `json:"distances"` // 与 Ids 一一对应 - Metadatas []ChromaMetadataMap `json:"metadatas,omitempty"` // Optional, can be null - Embeddings []ChromaEmbedding `json:"embeddings,omitempty"` // Optional, can be null - Documents []string `json:"documents,omitempty"` // Optional, can be null - Uris []string `json:"uris,omitempty"` // Optional, can be null - Data []interface{} `json:"data,omitempty"` // Optional, can be null - Included []string `json:"included"` -} - -func (d *ChromaProvider) ParseQueryResponse(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log) (QueryEmbeddingResult, error) { - var queryResp ChromaQueryResponse - err := json.Unmarshal(responseBody, &queryResp) - if err != nil { - return QueryEmbeddingResult{}, err - } - log.Infof("[Chroma] queryResp: %+v", queryResp) - log.Infof("[Chroma] queryResp Ids len: %d", len(queryResp.Ids)) - if len(queryResp.Ids) == 1 && len(queryResp.Ids[0]) == 0 { - return QueryEmbeddingResult{}, nil - } - return QueryEmbeddingResult{ - MostSimilarData: queryResp.Ids[0][0], - Score: queryResp.Distances[0][0], - }, nil -} +// import ( +// "encoding/json" +// "errors" +// "fmt" +// "net/http" + +// "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" +// ) + +// type chromaProviderInitializer struct{} + +// const chromaPort = 8001 + +// func (c *chromaProviderInitializer) ValidateConfig(config ProviderConfig) error { +// if len(config.ChromaCollectionID) == 0 { +// return errors.New("ChromaCollectionID is required") +// } +// if len(config.ChromaServiceName) == 0 { +// return errors.New("ChromaServiceName is required") +// } +// return nil +// } + +// func (c *chromaProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { +// return &ChromaProvider{ +// config: config, +// client: wrapper.NewClusterClient(wrapper.DnsCluster{ +// ServiceName: config.ChromaServiceName, +// Port: chromaPort, +// Domain: config.ChromaServiceName, +// }), +// }, nil +// } + +// type ChromaProvider struct { +// config ProviderConfig +// client wrapper.HttpClient +// } + +// func (c *ChromaProvider) GetProviderType() string { +// return providerTypeChroma +// } + +// func (d *ChromaProvider) GetThreshold() float64 { +// return d.config.ChromaDistanceThreshold +// } + +// func (d *ChromaProvider) QueryEmbedding( +// emb []float64, +// ctx wrapper.HttpContext, +// log wrapper.Log, +// callback func(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log)) { +// // 最小需要填写的参数为 collection_id, embeddings 和 ids +// // 下面是一个例子 +// // { +// // "where": {}, 用于 metadata 过滤,可选参数 +// // "where_document": {}, 用于 document 过滤,可选参数 +// // "query_embeddings": [ +// // [1.1, 2.3, 3.2] +// // ], +// // "n_results": 5, +// // "include": [ +// // "metadatas", +// // "distances" +// // ] +// // } +// requestBody, err := json.Marshal(ChromaQueryRequest{ +// QueryEmbeddings: []ChromaEmbedding{emb}, +// NResults: d.config.ChromaNResult, +// Include: []string{"distances"}, +// }) + +// if err != nil { +// log.Errorf("[Chroma] Failed to marshal query embedding request body: %v", err) +// return +// } + +// d.client.Post( +// fmt.Sprintf("/api/v1/collections/%s/query", d.config.ChromaCollectionID), +// [][2]string{ +// {"Content-Type", "application/json"}, +// }, +// requestBody, +// func(statusCode int, responseHeaders http.Header, responseBody []byte) { +// log.Infof("Query embedding response: %d, %s", statusCode, responseBody) +// callback(responseBody, ctx, log) +// }, +// d.config.ChromaTimeout, +// ) +// } + +// func (d *ChromaProvider) UploadEmbedding( +// query_emb []float64, +// queryString string, +// ctx wrapper.HttpContext, +// log wrapper.Log, +// callback func(ctx wrapper.HttpContext, log wrapper.Log)) { +// // 最小需要填写的参数为 collection_id, embeddings 和 ids +// // 下面是一个例子 +// // { +// // "embeddings": [ +// // [1.1, 2.3, 3.2] +// // ], +// // "ids": [ +// // "你吃了吗?" +// // ] +// // } +// requestBody, err := json.Marshal(ChromaInsertRequest{ +// Embeddings: []ChromaEmbedding{query_emb}, +// IDs: []string{queryString}, // queryString 指的是用户查询的问题 +// }) + +// if err != nil { +// log.Errorf("[Chroma] Failed to marshal upload embedding request body: %v", err) +// return +// } + +// d.client.Post( +// fmt.Sprintf("/api/v1/collections/%s/add", d.config.ChromaCollectionID), +// [][2]string{ +// {"Content-Type", "application/json"}, +// }, +// requestBody, +// func(statusCode int, responseHeaders http.Header, responseBody []byte) { +// log.Infof("[Chroma] statusCode:%d, responseBody:%s", statusCode, string(responseBody)) +// callback(ctx, log) +// }, +// d.config.ChromaTimeout, +// ) +// } + +// // ChromaEmbedding represents the embedding vector for a data point. +// type ChromaEmbedding []float64 + +// // ChromaMetadataMap is a map from key to value for metadata. +// type ChromaMetadataMap map[string]string + +// // Dataset represents the entire dataset containing multiple data points. +// type ChromaInsertRequest struct { +// Embeddings []ChromaEmbedding `json:"embeddings"` +// Metadatas []ChromaMetadataMap `json:"metadatas,omitempty"` // Optional metadata map array +// Documents []string `json:"documents,omitempty"` // Optional document array +// IDs []string `json:"ids"` +// } + +// // ChromaQueryRequest represents the query request structure. +// type ChromaQueryRequest struct { +// Where map[string]string `json:"where,omitempty"` // Optional where filter +// WhereDocument map[string]string `json:"where_document,omitempty"` // Optional where_document filter +// QueryEmbeddings []ChromaEmbedding `json:"query_embeddings"` +// NResults int `json:"n_results"` +// Include []string `json:"include"` +// } + +// // ChromaQueryResponse represents the search result structure. +// type ChromaQueryResponse struct { +// Ids [][]string `json:"ids"` // 每一个 embedding 相似的 key 可能会有多个,然后会有多个 embedding,所以是一个二维数组 +// Distances [][]float64 `json:"distances"` // 与 Ids 一一对应 +// Metadatas []ChromaMetadataMap `json:"metadatas,omitempty"` // Optional, can be null +// Embeddings []ChromaEmbedding `json:"embeddings,omitempty"` // Optional, can be null +// Documents []string `json:"documents,omitempty"` // Optional, can be null +// Uris []string `json:"uris,omitempty"` // Optional, can be null +// Data []interface{} `json:"data,omitempty"` // Optional, can be null +// Included []string `json:"included"` +// } + +// func (d *ChromaProvider) ParseQueryResponse(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log) (QueryEmbeddingResult, error) { +// var queryResp ChromaQueryResponse +// err := json.Unmarshal(responseBody, &queryResp) +// if err != nil { +// return QueryEmbeddingResult{}, err +// } +// log.Infof("[Chroma] queryResp: %+v", queryResp) +// log.Infof("[Chroma] queryResp Ids len: %d", len(queryResp.Ids)) +// if len(queryResp.Ids) == 1 && len(queryResp.Ids[0]) == 0 { +// return QueryEmbeddingResult{}, nil +// } +// return QueryEmbeddingResult{ +// MostSimilarData: queryResp.Ids[0][0], +// Score: queryResp.Distances[0][0], +// }, nil +// } diff --git a/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go b/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go index 58ded82db8..250a653919 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go @@ -10,25 +10,24 @@ import ( ) const ( - dashVectorPort = 443 - threshold = 2000 + threshold = 2000 ) type dashVectorProviderInitializer struct { } func (d *dashVectorProviderInitializer) ValidateConfig(config ProviderConfig) error { - if len(config.DashVectorKey) == 0 { - return errors.New("DashVectorKey is required") + if len(config.apiKey) == 0 { + return errors.New("[DashVector] apiKey is required") } - if len(config.DashVectorAuthApiEnd) == 0 { - return errors.New("DashVectorEnd is required") + if len(config.collectionID) == 0 { + return errors.New("[DashVector] collectionID is required") } - if len(config.DashVectorCollection) == 0 { - return errors.New("DashVectorCollection is required") + if len(config.serviceName) == 0 { + return errors.New("[DashVector] serviceName is required") } - if len(config.DashVectorServiceName) == 0 { - return errors.New("DashVectorServiceName is required") + if len(config.serviceHost) == 0 { + return errors.New("[DashVector] endPoint is required") } return nil } @@ -37,9 +36,9 @@ func (d *dashVectorProviderInitializer) CreateProvider(config ProviderConfig) (P return &DvProvider{ config: config, client: wrapper.NewClusterClient(wrapper.DnsCluster{ - ServiceName: config.DashVectorServiceName, - Port: dashVectorPort, - Domain: config.DashVectorAuthApiEnd, + ServiceName: config.serviceName, + Port: config.servicePort, + Domain: config.serviceHost, }), }, nil } @@ -91,11 +90,11 @@ type result struct { } func (d *DvProvider) constructEmbeddingQueryParameters(vector []float64) (string, []byte, [][2]string, error) { - url := fmt.Sprintf("/v1/collections/%s/query", d.config.DashVectorCollection) + url := fmt.Sprintf("/v1/collections/%s/query", d.config.collectionID) requestData := queryRequest{ Vector: vector, - TopK: d.config.DashVectorTopK, + TopK: d.config.topK, IncludeVector: false, } @@ -106,7 +105,7 @@ func (d *DvProvider) constructEmbeddingQueryParameters(vector []float64) (string header := [][2]string{ {"Content-Type", "application/json"}, - {"dashvector-auth-token", d.config.DashVectorKey}, + {"dashvector-auth-token", d.config.apiKey}, } return url, requestBody, header, nil @@ -128,7 +127,7 @@ func (d *DvProvider) QueryEmbedding( emb []float64, ctx wrapper.HttpContext, log wrapper.Log, - callback func(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log)) { + callback func(results []QueryEmbeddingResult, ctx wrapper.HttpContext, log wrapper.Log)) { // 构造请求参数 url, body, headers, err := d.constructEmbeddingQueryParameters(emb) if err != nil { @@ -141,27 +140,42 @@ func (d *DvProvider) QueryEmbedding( log.Infof("Failed to query embedding: %d", statusCode) return } - log.Infof("Query embedding response: %d, %s", statusCode, responseBody) - callback(responseBody, ctx, log) + log.Debugf("Query embedding response: %d, %s", statusCode, responseBody) + results, err := d.ParseQueryResponse(responseBody, ctx, log) + // TODO: 如果解析失败,应该如何处理? + if err != nil { + log.Infof("Failed to parse query response: %v", err) + return + } + callback(results, ctx, log) }, - d.config.DashVectorTimeout) + d.config.timeout) if err != nil { log.Infof("Failed to query embedding: %v", err) } } -func (d *DvProvider) ParseQueryResponse(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log) (QueryEmbeddingResult, error) { +func (d *DvProvider) ParseQueryResponse(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]QueryEmbeddingResult, error) { resp, err := d.parseQueryResponse(responseBody) if err != nil { - return QueryEmbeddingResult{}, err + return nil, err } if len(resp.Output) == 0 { - return QueryEmbeddingResult{}, nil + return nil, nil } - return QueryEmbeddingResult{ - MostSimilarData: resp.Output[0].Fields["query"].(string), - Score: resp.Output[0].Score, - }, nil + + results := make([]QueryEmbeddingResult, 0, len(resp.Output)) + + for _, output := range resp.Output { + result := QueryEmbeddingResult{ + Text: output.Fields["query"].(string), + Embedding: output.Vector, + Score: output.Score, + } + results = append(results, result) + } + + return results, nil } type document struct { @@ -174,7 +188,7 @@ type insertRequest struct { } func (d *DvProvider) constructEmbeddingUploadParameters(emb []float64, query_string string) (string, []byte, [][2]string, error) { - url := "/v1/collections/" + d.config.DashVectorCollection + "/docs" + url := "/v1/collections/" + d.config.collectionID + "/docs" doc := document{ Vector: emb, @@ -190,14 +204,14 @@ func (d *DvProvider) constructEmbeddingUploadParameters(emb []float64, query_str header := [][2]string{ {"Content-Type", "application/json"}, - {"dashvector-auth-token", d.config.DashVectorKey}, + {"dashvector-auth-token", d.config.apiKey}, } return url, requestBody, header, err } -func (d *DvProvider) UploadEmbedding(query_emb []float64, queryString string, ctx wrapper.HttpContext, log wrapper.Log, callback func(ctx wrapper.HttpContext, log wrapper.Log)) { - url, body, headers, _ := d.constructEmbeddingUploadParameters(query_emb, queryString) +func (d *DvProvider) UploadEmbedding(queryEmb []float64, queryString string, ctx wrapper.HttpContext, log wrapper.Log, callback func(ctx wrapper.HttpContext, log wrapper.Log)) { + url, body, headers, _ := d.constructEmbeddingUploadParameters(queryEmb, queryString) d.client.Post( url, headers, @@ -206,5 +220,5 @@ func (d *DvProvider) UploadEmbedding(query_emb []float64, queryString string, ct log.Infof("statusCode:%d, responseBody:%s", statusCode, string(responseBody)) callback(ctx, log) }, - d.config.DashVectorTimeout) + d.config.timeout) } diff --git a/plugins/wasm-go/extensions/ai-cache/vector/provider.go b/plugins/wasm-go/extensions/ai-cache/vector/provider.go index 8bf1952149..a37eac06fa 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/provider.go @@ -20,102 +20,100 @@ type providerInitializer interface { var ( providerInitializers = map[string]providerInitializer{ providerTypeDashVector: &dashVectorProviderInitializer{}, - providerTypeChroma: &chromaProviderInitializer{}, + // providerTypeChroma: &chromaProviderInitializer{}, } ) +// 定义通用的查询结果的结构体 +type QueryEmbeddingResult struct { + Text string // 相似的文本 + Embedding []float64 // 相似文本的向量 + Score float64 // 文本的向量相似度或距离等度量 +} + type Provider interface { GetProviderType() string + // TODO: 考虑失败的场景 QueryEmbedding( emb []float64, ctx wrapper.HttpContext, log wrapper.Log, - callback func(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log)) + callback func(results []QueryEmbeddingResult, ctx wrapper.HttpContext, log wrapper.Log)) + // TODO: 考虑失败的场景 UploadEmbedding( - query_emb []float64, + queryEmb []float64, queryString string, ctx wrapper.HttpContext, log wrapper.Log, callback func(ctx wrapper.HttpContext, log wrapper.Log)) GetThreshold() float64 - ParseQueryResponse(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log) (QueryEmbeddingResult, error) -} - -// 定义通用的查询结果的结构体 -type QueryEmbeddingResult struct { - MostSimilarData string // 相似的文本 - Score float64 // 文本的向量相似度或距离等度量 + // ParseQueryResponse(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log) (QueryEmbeddingResult, error) } type ProviderConfig struct { // @Title zh-CN 向量存储服务提供者类型 // @Description zh-CN 向量存储服务提供者类型,例如 DashVector、Milvus - typ string `json:"vectorStoreProviderType"` - // @Title zh-CN DashVector 阿里云向量搜索引擎 - // @Description zh-CN 调用阿里云的向量搜索引擎 - DashVectorServiceName string `require:"true" yaml:"DashVectorServiceName" json:"DashVectorServiceName"` - // @Title zh-CN DashVector Key - // @Description zh-CN 阿里云向量搜索引擎的 key - DashVectorKey string `require:"true" yaml:"DashVectorKey" json:"DashVectorKey"` - // @Title zh-CN DashVector AuthApiEnd - // @Description zh-CN 阿里云向量搜索引擎的 AuthApiEnd - DashVectorAuthApiEnd string `require:"true" yaml:"DashVectorEnd" json:"DashVectorEnd"` - // @Title zh-CN DashVector Collection - // @Description zh-CN 指定使用阿里云搜索引擎中的哪个向量集合 - DashVectorCollection string `require:"true" yaml:"DashVectorCollection" json:"DashVectorCollection"` - // @Title zh-CN DashVector Client - // @Description zh-CN 阿里云向量搜索引擎的 Client - DashVectorTopK int `require:"false" yaml:"DashVectorTopK" json:"DashVectorTopK"` - DashVectorTimeout uint32 `require:"false" yaml:"DashVectorTimeout" json:"DashVectorTimeout"` - DashVectorClient wrapper.HttpClient `yaml:"-" json:"-"` + typ string `json:"vectorStoreProviderType"` + serviceName string `require:"true" yaml:"serviceName" json:"serviceName"` + serviceHost string `require:"false" yaml:"serviceHost" json:"serviceHost"` + servicePort int64 `require:"false" yaml:"servicePort" json:"servicePort"` + apiKey string `require:"false" yaml:"apiKey" json:"apiKey"` + topK int `require:"false" yaml:"topK" json:"topK"` + timeout uint32 `require:"false" yaml:"timeout" json:"timeout"` + collectionID string `require:"false" yaml:"collectionID" json:"collectionID"` - // @Title zh-CN Chroma 的上游服务名称 - // @Description zh-CN Chroma 服务所对应的网关内上游服务名称 - ChromaServiceName string `require:"true" yaml:"ChromaServiceName" json:"ChromaServiceName"` - // @Title zh-CN Chroma Collection ID - // @Description zh-CN Chroma Collection 的 ID - ChromaCollectionID string `require:"false" yaml:"ChromaCollectionID" json:"ChromaCollectionID"` + // // @Title zh-CN Chroma 的上游服务名称 + // // @Description zh-CN Chroma 服务所对应的网关内上游服务名称 + // ChromaServiceName string `require:"true" yaml:"ChromaServiceName" json:"ChromaServiceName"` + // // @Title zh-CN Chroma Collection ID + // // @Description zh-CN Chroma Collection 的 ID + // ChromaCollectionID string `require:"false" yaml:"ChromaCollectionID" json:"ChromaCollectionID"` // @Title zh-CN Chroma 距离阈值 // @Description zh-CN Chroma 距离阈值,默认为 2000 ChromaDistanceThreshold float64 `require:"false" yaml:"ChromaDistanceThreshold" json:"ChromaDistanceThreshold"` - // @Title zh-CN Chroma 搜索返回结果数量 - // @Description zh-CN Chroma 搜索返回结果数量,默认为 1 - ChromaNResult int `require:"false" yaml:"ChromaNResult" json:"ChromaNResult"` - // @Title zh-CN Chroma 超时设置 - // @Description zh-CN Chroma 超时设置,默认为 10 秒 - ChromaTimeout uint32 `require:"false" yaml:"ChromaTimeout" json:"ChromaTimeout"` + // // @Title zh-CN Chroma 搜索返回结果数量 + // // @Description zh-CN Chroma 搜索返回结果数量,默认为 1 + // ChromaNResult int `require:"false" yaml:"ChromaNResult" json:"ChromaNResult"` + // // @Title zh-CN Chroma 超时设置 + // // @Description zh-CN Chroma 超时设置,默认为 10 秒 + // ChromaTimeout uint32 `require:"false" yaml:"ChromaTimeout" json:"ChromaTimeout"` + vectorClient wrapper.HttpClient `yaml:"-" json:"-"` } func (c *ProviderConfig) FromJson(json gjson.Result) { c.typ = json.Get("vectorStoreProviderType").String() // DashVector - c.DashVectorServiceName = json.Get("DashVectorServiceName").String() - c.DashVectorKey = json.Get("DashVectorKey").String() - c.DashVectorAuthApiEnd = json.Get("DashVectorEnd").String() - c.DashVectorCollection = json.Get("DashVectorCollection").String() - c.DashVectorTopK = int(json.Get("DashVectorTopK").Int()) - if c.DashVectorTopK == 0 { - c.DashVectorTopK = 1 - } - c.DashVectorTimeout = uint32(json.Get("DashVectorTimeout").Int()) - if c.DashVectorTimeout == 0 { - c.DashVectorTimeout = 10000 + c.serviceName = json.Get("serviceName").String() + c.serviceHost = json.Get("serviceHost").String() + c.servicePort = int64(json.Get("servicePort").Int()) + if c.servicePort == 0 { + c.servicePort = 443 } - // Chroma - c.ChromaCollectionID = json.Get("ChromaCollectionID").String() - c.ChromaServiceName = json.Get("ChromaServiceName").String() - c.ChromaDistanceThreshold = json.Get("ChromaDistanceThreshold").Float() - if c.ChromaDistanceThreshold == 0 { - c.ChromaDistanceThreshold = 2000 - } - c.ChromaNResult = int(json.Get("ChromaNResult").Int()) - if c.ChromaNResult == 0 { - c.ChromaNResult = 1 + c.apiKey = json.Get("apiKey").String() + c.collectionID = json.Get("collectionID").String() + c.topK = int(json.Get("topK").Int()) + if c.topK == 0 { + c.topK = 1 } - c.ChromaTimeout = uint32(json.Get("ChromaTimeout").Int()) - if c.ChromaTimeout == 0 { - c.ChromaTimeout = 10000 + c.timeout = uint32(json.Get("timeout").Int()) + if c.timeout == 0 { + c.timeout = 10000 } + // Chroma + // c.ChromaCollectionID = json.Get("ChromaCollectionID").String() + // c.ChromaServiceName = json.Get("ChromaServiceName").String() + // c.ChromaDistanceThreshold = json.Get("ChromaDistanceThreshold").Float() + // if c.ChromaDistanceThreshold == 0 { + // c.ChromaDistanceThreshold = 2000 + // } + // c.ChromaNResult = int(json.Get("ChromaNResult").Int()) + // if c.ChromaNResult == 0 { + // c.ChromaNResult = 1 + // } + // c.ChromaTimeout = uint32(json.Get("ChromaTimeout").Int()) + // if c.ChromaTimeout == 0 { + // c.ChromaTimeout = 10000 + // } } func (c *ProviderConfig) Validate() error { From 3d7e85c71d966e1fefae67bf6afb7f574f88ed17 Mon Sep 17 00:00:00 2001 From: Async Date: Sun, 25 Aug 2024 17:20:36 +0000 Subject: [PATCH 08/71] feat: add elasticsearch --- docker-compose-test/docker-compose.yml | 48 ++++- docker-compose-test/envoy.yaml | 44 +++- .../extensions/ai-cache/config/config.go | 8 +- plugins/wasm-go/extensions/ai-cache/core.go | 30 ++- .../extensions/ai-cache/vector/chroma.go | 6 +- .../ai-cache/vector/elasticsearch.go | 202 ++++++++++++++++++ .../extensions/ai-cache/vector/provider.go | 63 +++++- 7 files changed, 375 insertions(+), 26 deletions(-) create mode 100644 plugins/wasm-go/extensions/ai-cache/vector/elasticsearch.go diff --git a/docker-compose-test/docker-compose.yml b/docker-compose-test/docker-compose.yml index 3b96146349..d98d16a95c 100644 --- a/docker-compose-test/docker-compose.yml +++ b/docker-compose-test/docker-compose.yml @@ -10,6 +10,7 @@ services: - httpbin - redis - chroma + - es networks: - wasmtest ports: @@ -42,17 +43,48 @@ services: ports: - "12345:80" - lobechat: - # docker hub 如果访问不了,可以改用这个地址:registry.cn-hangzhou.aliyuncs.com/2456868764/lobe-chat:v1.1.3 - image: lobehub/lobe-chat + es: + image: elasticsearch:8.15.0 environment: - - CODE=admin - - OPENAI_API_KEY=unused - - OPENAI_PROXY_URL=http://envoy:10000/v1 + - "TZ=Asia/Shanghai" + - "discovery.type=single-node" + - "xpack.security.http.ssl.enabled=false" + - "xpack.license.self_generated.type=trial" + - "ELASTIC_PASSWORD=123456" + ports: + - "9200:9200" + - "9300:9300" networks: - wasmtest - ports: - - "3210:3210/tcp" + + # kibana: + # image: docker.elastic.co/kibana/kibana:8.15.0 + # environment: + # - "TZ=Asia/Shanghai" + # - "ELASTICSEARCH_HOSTS=http://es:9200" + # - "ELASTICSEARCH_URL=http://es:9200" + # - "ELASTICSEARCH_USERNAME=kibana_system" + # - "ELASTICSEARCH_PASSWORD=123456" + # - "xpack.security.enabled=false" + # - "xpack.license.self_generated.type=trial" + # ports: + # - "5601:5601" + # networks: + # - wasmtest + # depends_on: + # - es + + # lobechat: + # # docker hub 如果访问不了,可以改用这个地址:registry.cn-hangzhou.aliyuncs.com/2456868764/lobe-chat:v1.1.3 + # image: lobehub/lobe-chat + # environment: + # - CODE=admin + # - OPENAI_API_KEY=unused + # - OPENAI_PROXY_URL=http://envoy:10000/v1 + # networks: + # - wasmtest + # ports: + # - "3210:3210/tcp" # weaviate: # command: diff --git a/docker-compose-test/envoy.yaml b/docker-compose-test/envoy.yaml index b7d81a78ae..625a799652 100644 --- a/docker-compose-test/envoy.yaml +++ b/docker-compose-test/envoy.yaml @@ -58,13 +58,17 @@ static_resources: "embeddingProvider": { "type": "dashscope", "serviceName": "dashscope", - "apiKey": "sk-your-key", + "apiKey": "sk-key", "DashScopeServiceName": "dashscope" }, "vectorProvider": { - "vectorStoreProviderType": "chroma", - "ChromaServiceName": "chroma", - "ChromaCollectionID": "0294deb1-8ef5-4582-b21c-75f23093db2c" + "VectorStoreProviderType": "elasticsearch", + "ThresholdRelation": "gte", + "ESThreshold": 0.7, + "ESServiceName": "es", + "ESIndex": "higress", + "ESUsername": "elastic", + "ESPassword": "123456" }, "cacheKeyFrom": { "requestBody": "" @@ -83,8 +87,23 @@ static_resources: "timeout": 2000 } } + # 上面的配置中 redis 的配置名字是 redis,而不是 golang tag 中的 redisConfig + # "vectorProvider": { + # "VectorStoreProviderType": "chroma", + # "ChromaServiceName": "chroma", + # "ChromaCollectionID": "0294deb1-8ef5-4582-b21c-75f23093db2c" + # }, + # "vectorProvider": { + # "VectorStoreProviderType": "elasticsearch", + # "ThresholdRelation": "gte", + # "ESThreshold": 0.7, + # "ESServiceName": "es", + # "ESIndex": "higress", + # "ESUsername": "elastic", + # "ESPassword": "123456" + # }, # llm-proxy - name: llm-proxy typed_config: @@ -174,6 +193,23 @@ static_resources: socket_address: address: 172.17.0.1 # 本地 API 服务地址,这里是 docker0 port_value: 8001 + + # es + - name: outbound|9200||es.dns + connect_timeout: 30s + type: LOGICAL_DNS + dns_lookup_family: V4_ONLY + lb_policy: ROUND_ROBIN + load_assignment: + cluster_name: outbound|9200||es.dns + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: 172.17.0.1 # 本地 API 服务地址,这里是 docker0 + port_value: 9200 + # llm - name: llm connect_timeout: 30s diff --git a/plugins/wasm-go/extensions/ai-cache/config/config.go b/plugins/wasm-go/extensions/ai-cache/config/config.go index 7e1948306c..13e4a0f7c9 100644 --- a/plugins/wasm-go/extensions/ai-cache/config/config.go +++ b/plugins/wasm-go/extensions/ai-cache/config/config.go @@ -17,7 +17,7 @@ type KVExtractor struct { type PluginConfig struct { EmbeddingProviderConfig embedding.ProviderConfig `required:"true" yaml:"embeddingProvider" json:"embeddingProvider"` - vectorProviderConfig vector.ProviderConfig `required:"true" yaml:"vectorBaseProvider" json:"vectorBaseProvider"` + VectorProviderConfig vector.ProviderConfig `required:"true" yaml:"vectorBaseProvider" json:"vectorBaseProvider"` CacheKeyFrom KVExtractor `required:"true" yaml:"cacheKeyFrom" json:"cacheKeyFrom"` CacheValueFrom KVExtractor `required:"true" yaml:"cacheValueFrom" json:"cacheValueFrom"` CacheStreamValueFrom KVExtractor `required:"true" yaml:"cacheStreamValueFrom" json:"cacheStreamValueFrom"` @@ -46,7 +46,7 @@ type PluginConfig struct { func (c *PluginConfig) FromJson(json gjson.Result) { c.EmbeddingProviderConfig.FromJson(json.Get("embeddingProvider")) - c.vectorProviderConfig.FromJson(json.Get("vectorProvider")) + c.VectorProviderConfig.FromJson(json.Get("vectorProvider")) c.RedisConfig.FromJson(json.Get("redis")) if c.CacheKeyFrom.RequestBody == "" { c.CacheKeyFrom.RequestBody = "messages.@reverse.0.content" @@ -84,7 +84,7 @@ func (c *PluginConfig) Validate() error { if err := c.EmbeddingProviderConfig.Validate(); err != nil { return err } - if err := c.vectorProviderConfig.Validate(); err != nil { + if err := c.VectorProviderConfig.Validate(); err != nil { return err } return nil @@ -96,7 +96,7 @@ func (c *PluginConfig) Complete(log wrapper.Log) error { if err != nil { return err } - c.vectorProvider, err = vector.CreateProvider(c.vectorProviderConfig) + c.vectorProvider, err = vector.CreateProvider(c.VectorProviderConfig) if err != nil { return err } diff --git a/plugins/wasm-go/extensions/ai-cache/core.go b/plugins/wasm-go/extensions/ai-cache/core.go index 54f2c668fc..8aeff9362f 100644 --- a/plugins/wasm-go/extensions/ai-cache/core.go +++ b/plugins/wasm-go/extensions/ai-cache/core.go @@ -1,6 +1,7 @@ package main import ( + "errors" "fmt" "net/http" @@ -86,12 +87,18 @@ func QueryVectorDB(key string, text_embedding []float64, ctx wrapper.HttpContext } log.Infof("most similar key: %s", resp.MostSimilarData) - if resp.Score < activeVectorDatabaseProvider.GetThreshold() { + res, err := compare(config.VectorProviderConfig.ThresholdRelation, resp.Score, activeVectorDatabaseProvider.GetThreshold()) + if err != nil { + log.Errorf("Failed to compare score, err: %v", err) + proxywasm.ResumeHttpRequest() + return + } + if res { log.Infof("accept most similar key: %s, score: %f", resp.MostSimilarData, resp.Score) // ctx.SetContext(embedding.CacheKeyContextKey, nil) RedisSearchHandler(resp.MostSimilarData, ctx, config, log, stream, false) } else { - log.Infof("the most similar key's score is too high, key: %s, score: %f", resp.MostSimilarData, resp.Score) + log.Infof("the most similar key's score does not meet the threshold, key: %s, score: %f", resp.MostSimilarData, resp.Score) activeVectorDatabaseProvider.UploadEmbedding(text_embedding, key, ctx, log, func(ctx wrapper.HttpContext, log wrapper.Log) { proxywasm.ResumeHttpRequest() @@ -101,3 +108,22 @@ func QueryVectorDB(key string, text_embedding []float64, ctx wrapper.HttpContext }, ) } + +// 主要用于相似度/距离/点积判断 +// 相似度度量的是两个向量在方向上的相似程度。相似度越高,两个向量越接近。 +// 距离度量的是两个向量在空间上的远近程度。距离越小,两个向量越接近。 +// compare 函数根据操作符进行判断并返回结果 +func compare(operator string, value1 float64, value2 float64) (bool, error) { + switch operator { + case "gt": + return value1 > value2, nil + case "gte": + return value1 >= value2, nil + case "lt": + return value1 < value2, nil + case "lte": + return value1 <= value2, nil + default: + return false, errors.New("unsupported operator: " + operator) + } +} diff --git a/plugins/wasm-go/extensions/ai-cache/vector/chroma.go b/plugins/wasm-go/extensions/ai-cache/vector/chroma.go index e49108bbdf..da5392fbb5 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/chroma.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/chroma.go @@ -44,7 +44,7 @@ func (c *ChromaProvider) GetProviderType() string { } func (d *ChromaProvider) GetThreshold() float64 { - return d.config.ChromaDistanceThreshold + return d.config.ChromaThreshold } func (d *ChromaProvider) QueryEmbedding( @@ -52,7 +52,7 @@ func (d *ChromaProvider) QueryEmbedding( ctx wrapper.HttpContext, log wrapper.Log, callback func(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log)) { - // 最小需要填写的参数为 collection_id, embeddings 和 ids + // 最少需要填写的参数为 collection_id, embeddings 和 ids // 下面是一个例子 // { // "where": {}, 用于 metadata 过滤,可选参数 @@ -97,7 +97,7 @@ func (d *ChromaProvider) UploadEmbedding( ctx wrapper.HttpContext, log wrapper.Log, callback func(ctx wrapper.HttpContext, log wrapper.Log)) { - // 最小需要填写的参数为 collection_id, embeddings 和 ids + // 最少需要填写的参数为 collection_id, embeddings 和 ids // 下面是一个例子 // { // "embeddings": [ diff --git a/plugins/wasm-go/extensions/ai-cache/vector/elasticsearch.go b/plugins/wasm-go/extensions/ai-cache/vector/elasticsearch.go new file mode 100644 index 0000000000..4203e38688 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/vector/elasticsearch.go @@ -0,0 +1,202 @@ +package vector + +import ( + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "net/http" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" +) + +type esProviderInitializer struct{} + +const esPort = 9200 + +func (c *esProviderInitializer) ValidateConfig(config ProviderConfig) error { + if len(config.ESIndex) == 0 { + return errors.New("ESIndex is required") + } + if len(config.ESServiceName) == 0 { + return errors.New("ESServiceName is required") + } + return nil +} + +func (c *esProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + return &ESProvider{ + config: config, + client: wrapper.NewClusterClient(wrapper.DnsCluster{ + ServiceName: config.ESServiceName, + Port: esPort, + Domain: config.ESServiceName, + }), + }, nil +} + +type ESProvider struct { + config ProviderConfig + client wrapper.HttpClient +} + +func (c *ESProvider) GetProviderType() string { + return providerTypeES +} + +func (d *ESProvider) GetThreshold() float64 { + return d.config.ESThreshold +} + +func (d *ESProvider) QueryEmbedding( + emb []float64, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log)) { + // 最少需要填写的参数为 index, embeddings 和 ids + // 下面是一个例子 + // { + // "where": {}, 用于 metadata 过滤,可选参数 + // "where_document": {}, 用于 document 过滤,可选参数 + // "query_embeddings": [ + // [1.1, 2.3, 3.2] + // ], + // "n_results": 5, + // "include": [ + // "metadatas", + // "distances" + // ] + // } + requestBody, err := json.Marshal(esQueryRequest{ + Source: Source{Excludes: []string{"embedding"}}, + Knn: knn{ + Field: "embedding", + QueryVector: emb, + K: d.config.ESNResult, + }, + Size: d.config.ESNResult, + }) + + if err != nil { + log.Errorf("[es] Failed to marshal query embedding request body: %v", err) + return + } + + d.client.Post( + fmt.Sprintf("/%s/_search", d.config.ESIndex), + [][2]string{ + {"Content-Type", "application/json"}, + {"Authorization", d.getCredentials()}, + }, + requestBody, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + log.Infof("Query embedding response: %d, %s", statusCode, responseBody) + callback(responseBody, ctx, log) + }, + d.config.ESTimeout, + ) +} + +// 编码 ES 身份认证字符串 +func (d *ESProvider) getCredentials() string { + credentials := fmt.Sprintf("%s:%s", d.config.ESUsername, d.config.ESPassword) + encodedCredentials := base64.StdEncoding.EncodeToString([]byte(credentials)) + return fmt.Sprintf("Basic %s", encodedCredentials) +} + +func (d *ESProvider) UploadEmbedding( + query_emb []float64, + queryString string, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(ctx wrapper.HttpContext, log wrapper.Log)) { + // 最少需要填写的参数为 index, embeddings 和 question + // 下面是一个例子 + // POST //_doc + // { + // "embedding": [ + // [1.1, 2.3, 3.2] + // ], + // "question": [ + // "你吃了吗?" + // ] + // } + requestBody, err := json.Marshal(esInsertRequest{ + Embedding: query_emb, + Question: queryString, + }) + if err != nil { + log.Errorf("[ES] Failed to marshal upload embedding request body: %v", err) + return + } + + d.client.Post( + fmt.Sprintf("/%s/_doc", d.config.ESIndex), + [][2]string{ + {"Content-Type", "application/json"}, + {"Authorization", d.getCredentials()}, + }, + requestBody, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + log.Infof("[ES] statusCode:%d, responseBody:%s", statusCode, string(responseBody)) + callback(ctx, log) + }, + d.config.ESTimeout, + ) +} + +type esInsertRequest struct { + Embedding []float64 `json:"embedding"` + Question string `json:"question"` +} + +type knn struct { + Field string `json:"field"` + QueryVector []float64 `json:"query_vector"` + K int `json:"k"` +} + +type Source struct { + Excludes []string `json:"excludes"` +} + +type esQueryRequest struct { + Source Source `json:"_source"` + Knn knn `json:"knn"` + Size int `json:"size"` +} + +// esQueryResponse represents the search result structure. +type esQueryResponse struct { + Took int `json:"took"` + TimedOut bool `json:"timed_out"` + Hits struct { + Total struct { + Value int `json:"value"` + Relation string `json:"relation"` + } `json:"total"` + Hits []struct { + Index string `json:"_index"` + ID string `json:"_id"` + Score float64 `json:"_score"` + Source map[string]interface{} `json:"_source"` + } `json:"hits"` + } `json:"hits"` +} + +func (d *ESProvider) ParseQueryResponse(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log) (QueryEmbeddingResult, error) { + var queryResp esQueryResponse + err := json.Unmarshal(responseBody, &queryResp) + if err != nil { + return QueryEmbeddingResult{}, err + } + log.Infof("[ES] queryResp: %+v", queryResp) + log.Infof("[ES] queryResp Hits len: %d", len(queryResp.Hits.Hits)) + if len(queryResp.Hits.Hits) == 0 { + return QueryEmbeddingResult{}, nil + } + return QueryEmbeddingResult{ + MostSimilarData: queryResp.Hits.Hits[0].Source["question"].(string), + Score: queryResp.Hits.Hits[0].Score, + }, nil +} diff --git a/plugins/wasm-go/extensions/ai-cache/vector/provider.go b/plugins/wasm-go/extensions/ai-cache/vector/provider.go index 8bf1952149..c52a394cb9 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/provider.go @@ -10,6 +10,7 @@ import ( const ( providerTypeDashVector = "dashvector" providerTypeChroma = "chroma" + providerTypeES = "elasticsearch" ) type providerInitializer interface { @@ -21,6 +22,7 @@ var ( providerInitializers = map[string]providerInitializer{ providerTypeDashVector: &dashVectorProviderInitializer{}, providerTypeChroma: &chromaProviderInitializer{}, + providerTypeES: &esProviderInitializer{}, } ) @@ -51,6 +53,9 @@ type ProviderConfig struct { // @Title zh-CN 向量存储服务提供者类型 // @Description zh-CN 向量存储服务提供者类型,例如 DashVector、Milvus typ string `json:"vectorStoreProviderType"` + // @Title zh-CN ElasticSearch 需要满足的查询条件阈值关系 + // @Title zh-CN ElasticSearch 需要满足的查询条件阈值关系,默认为 lt,所有条件包括 lt (less than,小于)、lte (less than or equal to,小等于)、gt (greater than,大于)、gte (greater than or equal to,大等于) + ThresholdRelation string `require:"false" yaml:"ThresholdRelation" json:"ThresholdRelation"` // @Title zh-CN DashVector 阿里云向量搜索引擎 // @Description zh-CN 调用阿里云的向量搜索引擎 DashVectorServiceName string `require:"true" yaml:"DashVectorServiceName" json:"DashVectorServiceName"` @@ -77,17 +82,42 @@ type ProviderConfig struct { ChromaCollectionID string `require:"false" yaml:"ChromaCollectionID" json:"ChromaCollectionID"` // @Title zh-CN Chroma 距离阈值 // @Description zh-CN Chroma 距离阈值,默认为 2000 - ChromaDistanceThreshold float64 `require:"false" yaml:"ChromaDistanceThreshold" json:"ChromaDistanceThreshold"` + ChromaThreshold float64 `require:"false" yaml:"ChromaThreshold" json:"ChromaThreshold"` // @Title zh-CN Chroma 搜索返回结果数量 // @Description zh-CN Chroma 搜索返回结果数量,默认为 1 ChromaNResult int `require:"false" yaml:"ChromaNResult" json:"ChromaNResult"` // @Title zh-CN Chroma 超时设置 // @Description zh-CN Chroma 超时设置,默认为 10 秒 ChromaTimeout uint32 `require:"false" yaml:"ChromaTimeout" json:"ChromaTimeout"` + + // @Title zh-CN ElasticSearch 的上游服务名称 + // @Description zh-CN ElasticSearch 服务所对应的网关内上游服务名称 + ESServiceName string `require:"true" yaml:"ESServiceName" json:"ESServiceName"` + // @Title zh-CN ElasticSearch index + // @Description zh-CN ElasticSearch 的 index 名称 + ESIndex string `require:"false" yaml:"ESIndex" json:"ESIndex"` + // @Title zh-CN ElasticSearch 距离阈值 + // @Description zh-CN ElasticSearch 距离阈值,默认为 2000 + ESThreshold float64 `require:"false" yaml:"ESThreshold" json:"ESThreshold"` + // @Description zh-CN ElasticSearch 搜索返回结果数量,默认为 1 + ESNResult int `require:"false" yaml:"ESNResult" json:"ESNResult"` + // @Title zh-CN Chroma 超时设置 + // @Description zh-CN Chroma 超时设置,默认为 10 秒 + ESTimeout uint32 `require:"false" yaml:"ESTimeout" json:"ESTimeout"` + // @Title zh-CN ElasticSearch 用户名 + // @Description zh-CN ElasticSearch 用户名,默认为 elastic + ESUsername string `require:"false" yaml:"ESUsername" json:"ESUsername"` + // @Title zh-CN ElasticSearch 密码 + // @Description zh-CN ElasticSearch 密码,默认为 elastic + ESPassword string `require:"false" yaml:"ESPassword" json:"ESPassword"` } func (c *ProviderConfig) FromJson(json gjson.Result) { - c.typ = json.Get("vectorStoreProviderType").String() + c.typ = json.Get("VectorStoreProviderType").String() + c.ThresholdRelation = json.Get("ThresholdRelation").String() + if c.ThresholdRelation == "" { + c.ThresholdRelation = "lt" + } // DashVector c.DashVectorServiceName = json.Get("DashVectorServiceName").String() c.DashVectorKey = json.Get("DashVectorKey").String() @@ -104,9 +134,9 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { // Chroma c.ChromaCollectionID = json.Get("ChromaCollectionID").String() c.ChromaServiceName = json.Get("ChromaServiceName").String() - c.ChromaDistanceThreshold = json.Get("ChromaDistanceThreshold").Float() - if c.ChromaDistanceThreshold == 0 { - c.ChromaDistanceThreshold = 2000 + c.ChromaThreshold = json.Get("ChromaThreshold").Float() + if c.ChromaThreshold == 0 { + c.ChromaThreshold = 2000 } c.ChromaNResult = int(json.Get("ChromaNResult").Int()) if c.ChromaNResult == 0 { @@ -116,6 +146,29 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { if c.ChromaTimeout == 0 { c.ChromaTimeout = 10000 } + // ElasticSearch + c.ESServiceName = json.Get("ESServiceName").String() + c.ESIndex = json.Get("ESIndex").String() + c.ESThreshold = json.Get("ESThreshold").Float() + if c.ESThreshold == 0 { + c.ESThreshold = 2000 + } + c.ESNResult = int(json.Get("ElasticSearchNResult").Int()) + if c.ESNResult == 0 { + c.ESNResult = 1 + } + c.ESTimeout = uint32(json.Get("ElasticSearchTimeout").Int()) + if c.ESTimeout == 0 { + c.ESTimeout = 10000 + } + c.ESUsername = json.Get("ESUser").String() + if c.ESUsername == "" { + c.ESUsername = "elastic" + } + c.ESPassword = json.Get("ESPassword").String() + if c.ESPassword == "" { + c.ESPassword = "elastic" + } } func (c *ProviderConfig) Validate() error { From 85549d0ed212aded9961396b7f1ea911d0f565f1 Mon Sep 17 00:00:00 2001 From: suchun Date: Sun, 25 Aug 2024 22:50:24 +0000 Subject: [PATCH 09/71] fix bugs && refine the variable names --- .../wasm-go/extensions/ai-cache/config/config.go | 10 +++++----- plugins/wasm-go/extensions/ai-cache/core.go | 14 ++++---------- .../extensions/ai-cache/embedding/dashscope.go | 7 ++----- .../extensions/ai-cache/embedding/provider.go | 1 - 4 files changed, 11 insertions(+), 21 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-cache/config/config.go b/plugins/wasm-go/extensions/ai-cache/config/config.go index 7e1948306c..b343564fe6 100644 --- a/plugins/wasm-go/extensions/ai-cache/config/config.go +++ b/plugins/wasm-go/extensions/ai-cache/config/config.go @@ -17,7 +17,7 @@ type KVExtractor struct { type PluginConfig struct { EmbeddingProviderConfig embedding.ProviderConfig `required:"true" yaml:"embeddingProvider" json:"embeddingProvider"` - vectorProviderConfig vector.ProviderConfig `required:"true" yaml:"vectorBaseProvider" json:"vectorBaseProvider"` + VectorProviderConfig vector.ProviderConfig `required:"true" yaml:"vectorBaseProvider" json:"vectorBaseProvider"` CacheKeyFrom KVExtractor `required:"true" yaml:"cacheKeyFrom" json:"cacheKeyFrom"` CacheValueFrom KVExtractor `required:"true" yaml:"cacheValueFrom" json:"cacheValueFrom"` CacheStreamValueFrom KVExtractor `required:"true" yaml:"cacheStreamValueFrom" json:"cacheStreamValueFrom"` @@ -46,7 +46,7 @@ type PluginConfig struct { func (c *PluginConfig) FromJson(json gjson.Result) { c.EmbeddingProviderConfig.FromJson(json.Get("embeddingProvider")) - c.vectorProviderConfig.FromJson(json.Get("vectorProvider")) + c.VectorProviderConfig.FromJson(json.Get("vectorProvider")) c.RedisConfig.FromJson(json.Get("redis")) if c.CacheKeyFrom.RequestBody == "" { c.CacheKeyFrom.RequestBody = "messages.@reverse.0.content" @@ -84,7 +84,7 @@ func (c *PluginConfig) Validate() error { if err := c.EmbeddingProviderConfig.Validate(); err != nil { return err } - if err := c.vectorProviderConfig.Validate(); err != nil { + if err := c.VectorProviderConfig.Validate(); err != nil { return err } return nil @@ -96,7 +96,7 @@ func (c *PluginConfig) Complete(log wrapper.Log) error { if err != nil { return err } - c.vectorProvider, err = vector.CreateProvider(c.vectorProviderConfig) + c.vectorProvider, err = vector.CreateProvider(c.VectorProviderConfig) if err != nil { return err } @@ -111,7 +111,7 @@ func (c *PluginConfig) GetEmbeddingProvider() embedding.Provider { return c.embeddingProvider } -func (c *PluginConfig) GetvectorProvider() vector.Provider { +func (c *PluginConfig) GetVectorProvider() vector.Provider { return c.vectorProvider } diff --git a/plugins/wasm-go/extensions/ai-cache/core.go b/plugins/wasm-go/extensions/ai-cache/core.go index 0317b07e1e..2e11972522 100644 --- a/plugins/wasm-go/extensions/ai-cache/core.go +++ b/plugins/wasm-go/extensions/ai-cache/core.go @@ -2,7 +2,6 @@ package main import ( "fmt" - "net/http" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/config" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/embedding" @@ -53,20 +52,15 @@ func HandleCacheMiss(key string, err error, response resp.Value, ctx wrapper.Htt func FetchAndProcessEmbeddings(key string, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, queryString string, stream bool) { activeEmbeddingProvider := config.GetEmbeddingProvider() activeEmbeddingProvider.GetEmbedding(queryString, ctx, log, - func(emb []float64, statusCode int, responseHeaders http.Header, responseBody []byte) { - if statusCode != 200 { - log.Errorf("Failed to fetch embeddings, statusCode: %d, responseBody: %s", statusCode, string(responseBody)) - proxywasm.ResumeHttpRequest() // 当取 Embedding 失败了也继续流程 - } else { - log.Debugf("Successfully fetched embeddings for key: %s", key) - QueryVectorDB(key, emb, ctx, config, log, stream) - } + func(emb []float64) { + log.Debugf("Successfully fetched embeddings for key: %s", key) + QueryVectorDB(key, emb, ctx, config, log, stream) }) } func QueryVectorDB(key string, text_embedding []float64, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, stream bool) { log.Debugf("QueryVectorDB key: %s", key) - activeVectorProvider := config.GetvectorProvider() + activeVectorProvider := config.GetVectorProvider() log.Debugf("activeVectorProvider: %+v", activeVectorProvider) activeVectorProvider.QueryEmbedding(text_embedding, ctx, log, func(results []vector.QueryEmbeddingResult, ctx wrapper.HttpContext, log wrapper.Log) { diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go b/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go index e836a8bd97..c6e87e7e93 100644 --- a/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go +++ b/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go @@ -101,21 +101,18 @@ func (d *DSProvider) constructParameters(texts []string, log wrapper.Log) (strin }, } - // 序列化请求体并处理错误 requestBody, err := json.Marshal(data) if err != nil { log.Errorf("Failed to marshal request data: %v", err) return "", nil, nil, err } - // 检查 DashScopeKey 是否为空 if d.config.apiKey == "" { err := errors.New("DashScopeKey is empty") log.Errorf("Failed to construct headers: %v", err) return "", nil, nil, err } - // 设置请求头 headers := [][2]string{ {"Authorization", "Bearer " + d.config.apiKey}, {"Content-Type", "application/json"}, @@ -147,14 +144,14 @@ func (d *DSProvider) GetEmbedding( ctx wrapper.HttpContext, log wrapper.Log, callback func(emb []float64)) error { - Emb_url, Emb_headers, Emb_requestBody, err := d.constructParameters([]string{queryString}, log) + embUrl, embHeaders, embRequestBody, err := d.constructParameters([]string{queryString}, log) if err != nil { log.Errorf("Failed to construct parameters: %v", err) return err } var resp *Response - d.client.Post(Emb_url, Emb_headers, Emb_requestBody, + d.client.Post(embUrl, embHeaders, embRequestBody, // TODO: 函数调用返回的error要进行处理 func(statusCode int, responseHeaders http.Header, responseBody []byte) { if statusCode != http.StatusOK { log.Errorf("Failed to fetch embeddings, statusCode: %d, responseBody: %s", statusCode, string(responseBody)) diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/provider.go b/plugins/wasm-go/extensions/ai-cache/embedding/provider.go index 0c45c9cc83..e9d69a2bed 100644 --- a/plugins/wasm-go/extensions/ai-cache/embedding/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/embedding/provider.go @@ -2,7 +2,6 @@ package embedding import ( "errors" - "net/http" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/tidwall/gjson" From 8444f5e127c2d52e75698abe555cd1473569daf9 Mon Sep 17 00:00:00 2001 From: suchun Date: Sun, 25 Aug 2024 23:37:46 +0000 Subject: [PATCH 10/71] update design for cache to support extension --- .../extensions/ai-cache/cache/cache.go | 101 ------------------ .../extensions/ai-cache/cache/provider.go | 92 ++++++++++++++++ .../extensions/ai-cache/cache/redis.go | 50 +++++++++ .../extensions/ai-cache/config/config.go | 9 +- 4 files changed, 146 insertions(+), 106 deletions(-) delete mode 100644 plugins/wasm-go/extensions/ai-cache/cache/cache.go create mode 100644 plugins/wasm-go/extensions/ai-cache/cache/provider.go create mode 100644 plugins/wasm-go/extensions/ai-cache/cache/redis.go diff --git a/plugins/wasm-go/extensions/ai-cache/cache/cache.go b/plugins/wasm-go/extensions/ai-cache/cache/cache.go deleted file mode 100644 index f07c42cf64..0000000000 --- a/plugins/wasm-go/extensions/ai-cache/cache/cache.go +++ /dev/null @@ -1,101 +0,0 @@ -// TODO: 在这里写缓存的具体逻辑, 将textEmbeddingPrvider和vectorStoreProvider作为逻辑中的一个函数调用 -package cache - -import ( - "errors" - - "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" - "github.com/tidwall/gjson" -) - -type RedisConfig struct { - // @Title zh-CN redis 服务名称 - // @Description zh-CN 带服务类型的完整 FQDN 名称,例如 my-redis.dns、redis.my-ns.svc.cluster.local - RedisServiceName string `required:"true" yaml:"serviceName" json:"serviceName"` - // @Title zh-CN redis 服务端口 - // @Description zh-CN 默认值为6379 - RedisServicePort int `required:"false" yaml:"servicePort" json:"servicePort"` - // @Title zh-CN 用户名 - // @Description zh-CN 登陆 redis 的用户名,非必填 - RedisUsername string `required:"false" yaml:"username" json:"username"` - // @Title zh-CN 密码 - // @Description zh-CN 登陆 redis 的密码,非必填,可以只填密码 - RedisPassword string `required:"false" yaml:"password" json:"password"` - // @Title zh-CN 请求超时 - // @Description zh-CN 请求 redis 的超时时间,单位为毫秒。默认值是1000,即1秒 - RedisTimeout uint32 `required:"false" yaml:"timeout" json:"timeout"` - - RedisHost string `required:"false" yaml:"host" json:"host"` -} - -func CreateProvider(cf RedisConfig, log wrapper.Log) (Provider, error) { - log.Warnf("redis config: %v", cf) - rp := redisProvider{ - config: cf, - client: wrapper.NewRedisClusterClient(wrapper.FQDNCluster{ - FQDN: cf.RedisServiceName, - Host: cf.RedisHost, - Port: int64(cf.RedisServicePort)}), - } - err := rp.Init(cf.RedisUsername, cf.RedisPassword, cf.RedisTimeout) - return &rp, err -} - -func (c *RedisConfig) FromJson(json gjson.Result) { - c.RedisUsername = json.Get("username").String() - c.RedisPassword = json.Get("password").String() - c.RedisTimeout = uint32(json.Get("timeout").Int()) - c.RedisServiceName = json.Get("serviceName").String() - c.RedisServicePort = int(json.Get("servicePort").Int()) - if c.RedisServicePort == 0 { - c.RedisServicePort = 6379 - } -} - -func (c *RedisConfig) Validate() error { - if len(c.RedisServiceName) == 0 { - return errors.New("serviceName is required") - } - if c.RedisTimeout <= 0 { - return errors.New("timeout must be greater than 0") - } - if c.RedisServicePort <= 0 { - c.RedisServicePort = 6379 - } - if len(c.RedisUsername) == 0 { - // return errors.New("redis.username is required") - c.RedisUsername = "" - } - if len(c.RedisPassword) == 0 { - c.RedisPassword = "" - } - return nil -} - -type Provider interface { - GetProviderType() string - Init(username string, password string, timeout uint32) error - Get(key string, cb wrapper.RedisResponseCallback) - Set(key string, value string, cb wrapper.RedisResponseCallback) -} - -type redisProvider struct { - config RedisConfig - client wrapper.RedisClient -} - -func (rp *redisProvider) GetProviderType() string { - return "redis" -} - -func (rp *redisProvider) Init(username string, password string, timeout uint32) error { - return rp.client.Init(rp.config.RedisUsername, rp.config.RedisPassword, int64(rp.config.RedisTimeout)) -} - -func (rp *redisProvider) Get(key string, cb wrapper.RedisResponseCallback) { - rp.client.Get(key, cb) -} - -func (rp *redisProvider) Set(key string, value string, cb wrapper.RedisResponseCallback) { - rp.client.Set(key, value, cb) -} diff --git a/plugins/wasm-go/extensions/ai-cache/cache/provider.go b/plugins/wasm-go/extensions/ai-cache/cache/provider.go new file mode 100644 index 0000000000..44be18c872 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/cache/provider.go @@ -0,0 +1,92 @@ +package cache + +import ( + "errors" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" +) + +const ( + providerTypeRedis = "redis" +) + +type providerInitializer interface { + ValidateConfig(ProviderConfig) error + CreateProvider(ProviderConfig) (Provider, error) +} + +var ( + providerInitializers = map[string]providerInitializer{ + providerTypeRedis: &redisProviderInitializer{}, + } +) + +type ProviderConfig struct { + // @Title zh-CN redis 服务名称 + // @Description zh-CN 带服务类型的完整 FQDN 名称,例如 my-redis.dns、redis.my-ns.svc.cluster.local + serviceName string `required:"true" yaml:"serviceName" json:"serviceName"` + // @Title zh-CN redis 服务端口 + // @Description zh-CN 默认值为6379 + servicePort int `required:"false" yaml:"servicePort" json:"servicePort"` + // @Title zh-CN redis 服务地址 + // @Description zh-CN redis 服务地址,非必填 + serviceHost string `required:"false" yaml:"serviceHost" json:"servicehost"` + // @Title zh-CN 用户名 + // @Description zh-CN 登陆 redis 的用户名,非必填 + userName string `required:"false" yaml:"username" json:"username"` + // @Title zh-CN 密码 + // @Description zh-CN 登陆 redis 的密码,非必填,可以只填密码 + password string `required:"false" yaml:"password" json:"password"` + // @Title zh-CN 请求超时 + // @Description zh-CN 请求 redis 的超时时间,单位为毫秒。默认值是1000,即1秒 + timeout uint32 `required:"false" yaml:"timeout" json:"timeout"` +} + +func (c *ProviderConfig) FromJson(json gjson.Result) { + c.serviceName = json.Get("serviceName").String() + c.servicePort = int(json.Get("servicePort").Int()) + if c.servicePort <= 0 { + c.servicePort = 6379 + } + c.serviceHost = json.Get("serviceHost").String() + c.userName = json.Get("username").String() + if len(c.userName) == 0 { + c.userName = "" + } + c.password = json.Get("password").String() + if len(c.password) == 0 { + c.password = "" + } + c.timeout = uint32(json.Get("timeout").Int()) + if c.timeout == 0 { + c.timeout = 1000 + } + +} + +func (c *ProviderConfig) Validate() error { + if len(c.serviceName) == 0 { + return errors.New("serviceName is required") + } + if c.timeout <= 0 { + return errors.New("timeout must be greater than 0") + } + return nil +} + +func CreateProvider(pc ProviderConfig) (Provider, error) { + initializer, has := providerInitializers[providerTypeRedis] + if !has { + return nil, errors.New("unknown provider type: " + providerTypeRedis) + } + return initializer.CreateProvider(pc) +} + +type Provider interface { + GetProviderType() string + Init(username string, password string, timeout uint32) error + Get(key string, cb wrapper.RedisResponseCallback) error + Set(key string, value string, cb wrapper.RedisResponseCallback) error +} + diff --git a/plugins/wasm-go/extensions/ai-cache/cache/redis.go b/plugins/wasm-go/extensions/ai-cache/cache/redis.go new file mode 100644 index 0000000000..dab26172fd --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/cache/redis.go @@ -0,0 +1,50 @@ +package cache + +import ( + "errors" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" +) + +type redisProviderInitializer struct { +} + +func (r *redisProviderInitializer) ValidateConfig(cf ProviderConfig) error { + if len(cf.serviceName) == 0 { + return errors.New("serviceName is required") + } + return nil +} + +func (r *redisProviderInitializer) CreateProvider(cf ProviderConfig) (Provider, error) { + rp := redisProvider{ + config: cf, + client: wrapper.NewRedisClusterClient(wrapper.FQDNCluster{ + FQDN: cf.serviceName, + Host: cf.serviceHost, + Port: int64(cf.servicePort)}), + } + err := rp.Init(cf.userName, cf.password, cf.timeout) + return &rp, err +} + +type redisProvider struct { + config ProviderConfig + client wrapper.RedisClient +} + +func (rp *redisProvider) GetProviderType() string { + return "redis" +} + +func (rp *redisProvider) Init(username string, password string, timeout uint32) error { + return rp.client.Init(rp.config.userName, rp.config.password, int64(rp.config.timeout)) +} + +func (rp *redisProvider) Get(key string, cb wrapper.RedisResponseCallback) error { + return rp.client.Get(key, cb) +} + +func (rp *redisProvider) Set(key string, value string, cb wrapper.RedisResponseCallback) error { + return rp.client.Set(key, value, cb) +} diff --git a/plugins/wasm-go/extensions/ai-cache/config/config.go b/plugins/wasm-go/extensions/ai-cache/config/config.go index b343564fe6..22a0124353 100644 --- a/plugins/wasm-go/extensions/ai-cache/config/config.go +++ b/plugins/wasm-go/extensions/ai-cache/config/config.go @@ -18,6 +18,7 @@ type KVExtractor struct { type PluginConfig struct { EmbeddingProviderConfig embedding.ProviderConfig `required:"true" yaml:"embeddingProvider" json:"embeddingProvider"` VectorProviderConfig vector.ProviderConfig `required:"true" yaml:"vectorBaseProvider" json:"vectorBaseProvider"` + CacheProviderConfig cache.ProviderConfig `required:"true" yaml:"cache" json:"cache"` CacheKeyFrom KVExtractor `required:"true" yaml:"cacheKeyFrom" json:"cacheKeyFrom"` CacheValueFrom KVExtractor `required:"true" yaml:"cacheValueFrom" json:"cacheValueFrom"` CacheStreamValueFrom KVExtractor `required:"true" yaml:"cacheStreamValueFrom" json:"cacheStreamValueFrom"` @@ -37,8 +38,6 @@ type PluginConfig struct { // @Title zh-CN Redis缓存Key的前缀 // @Description zh-CN 默认值是"higress-ai-cache:" - RedisConfig cache.RedisConfig `required:"true" yaml:"redisConfig" json:"redisConfig"` - // 现在只支持RedisClient作为cacheClient redisProvider cache.Provider `yaml:"-"` embeddingProvider embedding.Provider `yaml:"-"` vectorProvider vector.Provider `yaml:"-"` @@ -47,7 +46,7 @@ type PluginConfig struct { func (c *PluginConfig) FromJson(json gjson.Result) { c.EmbeddingProviderConfig.FromJson(json.Get("embeddingProvider")) c.VectorProviderConfig.FromJson(json.Get("vectorProvider")) - c.RedisConfig.FromJson(json.Get("redis")) + c.CacheProviderConfig.FromJson(json.Get("redis")) if c.CacheKeyFrom.RequestBody == "" { c.CacheKeyFrom.RequestBody = "messages.@reverse.0.content" } @@ -78,7 +77,7 @@ func (c *PluginConfig) FromJson(json gjson.Result) { } func (c *PluginConfig) Validate() error { - if err := c.RedisConfig.Validate(); err != nil { + if err := c.CacheProviderConfig.Validate(); err != nil { return err } if err := c.EmbeddingProviderConfig.Validate(); err != nil { @@ -100,7 +99,7 @@ func (c *PluginConfig) Complete(log wrapper.Log) error { if err != nil { return err } - c.redisProvider, err = cache.CreateProvider(c.RedisConfig, log) + c.redisProvider, err = cache.CreateProvider(c.CacheProviderConfig) if err != nil { return err } From d68fa8822b3c04457fab63f3517082d3eaed4c16 Mon Sep 17 00:00:00 2001 From: suchun Date: Thu, 5 Sep 2024 15:55:26 +0000 Subject: [PATCH 11/71] Refined the code; README.md content needs to be updated. --- plugins/wasm-go/extensions/ai-cache/README.md | 2 + .../extensions/ai-cache/cache/provider.go | 87 ++++++++++----- .../extensions/ai-cache/cache/redis.go | 13 ++- .../extensions/ai-cache/config/config.go | 89 +++++++-------- plugins/wasm-go/extensions/ai-cache/core.go | 22 ++-- .../ai-cache/embedding/dashscope.go | 30 +++-- .../extensions/ai-cache/embedding/provider.go | 56 ++++++---- plugins/wasm-go/extensions/ai-cache/main.go | 105 +++++------------- plugins/wasm-go/extensions/ai-cache/util.go | 48 ++++++++ .../extensions/ai-cache/vector/dashvector.go | 4 +- .../extensions/ai-cache/vector/provider.go | 31 +++--- 11 files changed, 262 insertions(+), 225 deletions(-) create mode 100644 plugins/wasm-go/extensions/ai-cache/util.go diff --git a/plugins/wasm-go/extensions/ai-cache/README.md b/plugins/wasm-go/extensions/ai-cache/README.md index 4e70c4e050..7f4b8f6571 100644 --- a/plugins/wasm-go/extensions/ai-cache/README.md +++ b/plugins/wasm-go/extensions/ai-cache/README.md @@ -8,6 +8,8 @@ LLM 结果缓存插件,默认配置方式可以直接用于 openai 协议的结果缓存,同时支持流式和非流式响应的缓存。 +## 简介 + ## 配置说明 | Name | Type | Requirement | Default | Description | diff --git a/plugins/wasm-go/extensions/ai-cache/cache/provider.go b/plugins/wasm-go/extensions/ai-cache/cache/provider.go index 44be18c872..e8a5f4ebd1 100644 --- a/plugins/wasm-go/extensions/ai-cache/cache/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/cache/provider.go @@ -8,7 +8,7 @@ import ( ) const ( - providerTypeRedis = "redis" + PROVIDER_TYPE_REDIS = "redis" ) type providerInitializer interface { @@ -18,67 +18,94 @@ type providerInitializer interface { var ( providerInitializers = map[string]providerInitializer{ - providerTypeRedis: &redisProviderInitializer{}, + PROVIDER_TYPE_REDIS: &redisProviderInitializer{}, } ) type ProviderConfig struct { - // @Title zh-CN redis 服务名称 - // @Description zh-CN 带服务类型的完整 FQDN 名称,例如 my-redis.dns、redis.my-ns.svc.cluster.local - serviceName string `required:"true" yaml:"serviceName" json:"serviceName"` - // @Title zh-CN redis 服务端口 - // @Description zh-CN 默认值为6379 - servicePort int `required:"false" yaml:"servicePort" json:"servicePort"` - // @Title zh-CN redis 服务地址 - // @Description zh-CN redis 服务地址,非必填 - serviceHost string `required:"false" yaml:"serviceHost" json:"servicehost"` - // @Title zh-CN 用户名 - // @Description zh-CN 登陆 redis 的用户名,非必填 - userName string `required:"false" yaml:"username" json:"username"` - // @Title zh-CN 密码 - // @Description zh-CN 登陆 redis 的密码,非必填,可以只填密码 - password string `required:"false" yaml:"password" json:"password"` + // @Title zh-CN redis 缓存服务提供者类型 + // @Description zh-CN 缓存服务提供者类型,例如 redis + typ string + // @Title zh-CN redis 缓存服务名称 + // @Description zh-CN 缓存服务名称 + serviceName string + // @Title zh-CN redis 缓存服务端口 + // @Description zh-CN 缓存服务端口,默认值为6379 + servicePort int + // @Title zh-CN redis 缓存服务地址 + // @Description zh-CN Cache 缓存服务地址,非必填 + serviceHost string + // @Title zh-CN 缓存服务用户名 + // @Description zh-CN 缓存服务用户名,非必填 + userName string + // @Title zh-CN 缓存服务密码 + // @Description zh-CN 缓存服务密码,非必填 + password string // @Title zh-CN 请求超时 - // @Description zh-CN 请求 redis 的超时时间,单位为毫秒。默认值是1000,即1秒 - timeout uint32 `required:"false" yaml:"timeout" json:"timeout"` + // @Description zh-CN 请求缓存服务的超时时间,单位为毫秒。默认值是1000,即1秒 + timeout uint32 + // @Title zh-CN 缓存过期时间 + // @Description zh-CN 缓存过期时间,单位为秒。默认值是0,即永不过期 + cacheTTL uint32 + // @Title 缓存 Key 前缀 + // @Description 缓存 Key 的前缀,默认值为 "higressAiCache" + cacheKeyPrefix string } func (c *ProviderConfig) FromJson(json gjson.Result) { + c.typ = json.Get("type").String() c.serviceName = json.Get("serviceName").String() c.servicePort = int(json.Get("servicePort").Int()) - if c.servicePort <= 0 { + if !json.Get("servicePort").Exists() { c.servicePort = 6379 } c.serviceHost = json.Get("serviceHost").String() c.userName = json.Get("username").String() - if len(c.userName) == 0 { + if !json.Get("username").Exists() { c.userName = "" } c.password = json.Get("password").String() - if len(c.password) == 0 { + if !json.Get("password").Exists() { c.password = "" } c.timeout = uint32(json.Get("timeout").Int()) - if c.timeout == 0 { + if !json.Get("timeout").Exists() { c.timeout = 1000 } - + c.cacheTTL = uint32(json.Get("cacheTTL").Int()) + if !json.Get("cacheTTL").Exists() { + c.cacheTTL = 0 + } + c.cacheKeyPrefix = json.Get("cacheKeyPrefix").String() + if !json.Get("cacheKeyPrefix").Exists() { + c.cacheKeyPrefix = "higressAiCache" + } } func (c *ProviderConfig) Validate() error { - if len(c.serviceName) == 0 { - return errors.New("serviceName is required") + if c.typ == "" { + return errors.New("cache service type is required") + } + if c.serviceName == "" { + return errors.New("cache service name is required") } if c.timeout <= 0 { - return errors.New("timeout must be greater than 0") + return errors.New("cache service timeout must be greater than 0") + } + initializer, has := providerInitializers[c.typ] + if !has { + return errors.New("unknown cache service provider type: " + c.typ) + } + if err := initializer.ValidateConfig(*c); err != nil { + return err } return nil } func CreateProvider(pc ProviderConfig) (Provider, error) { - initializer, has := providerInitializers[providerTypeRedis] + initializer, has := providerInitializers[pc.typ] if !has { - return nil, errors.New("unknown provider type: " + providerTypeRedis) + return nil, errors.New("unknown provider type: " + pc.typ) } return initializer.CreateProvider(pc) } @@ -88,5 +115,5 @@ type Provider interface { Init(username string, password string, timeout uint32) error Get(key string, cb wrapper.RedisResponseCallback) error Set(key string, value string, cb wrapper.RedisResponseCallback) error + GetCacheKeyPrefix() string } - diff --git a/plugins/wasm-go/extensions/ai-cache/cache/redis.go b/plugins/wasm-go/extensions/ai-cache/cache/redis.go index dab26172fd..082146c0d7 100644 --- a/plugins/wasm-go/extensions/ai-cache/cache/redis.go +++ b/plugins/wasm-go/extensions/ai-cache/cache/redis.go @@ -6,12 +6,14 @@ import ( "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" ) +const DEFAULT_CACHE_PREFIX = "higressAiCache" + type redisProviderInitializer struct { } func (r *redisProviderInitializer) ValidateConfig(cf ProviderConfig) error { if len(cf.serviceName) == 0 { - return errors.New("serviceName is required") + return errors.New("[redis] cache service name is required") } return nil } @@ -46,5 +48,12 @@ func (rp *redisProvider) Get(key string, cb wrapper.RedisResponseCallback) error } func (rp *redisProvider) Set(key string, value string, cb wrapper.RedisResponseCallback) error { - return rp.client.Set(key, value, cb) + return rp.client.SetEx(key, value, int(rp.config.cacheTTL), cb) +} + +func (rp *redisProvider) GetCacheKeyPrefix() string { + if len(rp.config.cacheKeyPrefix) == 0 { + return DEFAULT_CACHE_PREFIX + } + return rp.config.cacheKeyPrefix } diff --git a/plugins/wasm-go/extensions/ai-cache/config/config.go b/plugins/wasm-go/extensions/ai-cache/config/config.go index 22a0124353..15824268fc 100644 --- a/plugins/wasm-go/extensions/ai-cache/config/config.go +++ b/plugins/wasm-go/extensions/ai-cache/config/config.go @@ -15,75 +15,62 @@ type KVExtractor struct { ResponseBody string `required:"false" yaml:"responseBody" json:"responseBody"` } +func (e *KVExtractor) SetRequestBodyFromJson(json gjson.Result, key string, defaultValue string) { + if json.Get(key).Exists() { + e.RequestBody = json.Get(key).String() + } else { + e.RequestBody = defaultValue + } +} + type PluginConfig struct { - EmbeddingProviderConfig embedding.ProviderConfig `required:"true" yaml:"embeddingProvider" json:"embeddingProvider"` - VectorProviderConfig vector.ProviderConfig `required:"true" yaml:"vectorBaseProvider" json:"vectorBaseProvider"` - CacheProviderConfig cache.ProviderConfig `required:"true" yaml:"cache" json:"cache"` - CacheKeyFrom KVExtractor `required:"true" yaml:"cacheKeyFrom" json:"cacheKeyFrom"` - CacheValueFrom KVExtractor `required:"true" yaml:"cacheValueFrom" json:"cacheValueFrom"` - CacheStreamValueFrom KVExtractor `required:"true" yaml:"cacheStreamValueFrom" json:"cacheStreamValueFrom"` // @Title zh-CN 返回 HTTP 响应的模版 // @Description zh-CN 用 %s 标记需要被 cache value 替换的部分 - ReturnResponseTemplate string `required:"true" yaml:"returnResponseTemplate" json:"returnResponseTemplate"` + ResponseTemplate string `required:"true" yaml:"responseTemplate" json:"responseTemplate"` // @Title zh-CN 返回流式 HTTP 响应的模版 // @Description zh-CN 用 %s 标记需要被 cache value 替换的部分 - ReturnTestResponseTemplate string `required:"true" yaml:"returnTestResponseTemplate" json:"returnTestResponseTemplate"` - - CacheKeyPrefix string `required:"false" yaml:"cacheKeyPrefix" json:"cacheKeyPrefix"` - - ReturnStreamResponseTemplate string `required:"true" yaml:"returnStreamResponseTemplate" json:"returnStreamResponseTemplate"` - // @Title zh-CN 缓存的过期时间 - // @Description zh-CN 单位是秒,默认值为0,即永不过期 - CacheTTL int `required:"false" yaml:"cacheTTL" json:"cacheTTL"` - // @Title zh-CN Redis缓存Key的前缀 - // @Description zh-CN 默认值是"higress-ai-cache:" + StreamResponseTemplate string `required:"true" yaml:"streamResponseTemplate" json:"streamResponseTemplate"` redisProvider cache.Provider `yaml:"-"` embeddingProvider embedding.Provider `yaml:"-"` vectorProvider vector.Provider `yaml:"-"` + + embeddingProviderConfig embedding.ProviderConfig + vectorProviderConfig vector.ProviderConfig + cacheProviderConfig cache.ProviderConfig + + CacheKeyFrom KVExtractor `required:"true" yaml:"cacheKeyFrom" json:"cacheKeyFrom"` + CacheValueFrom KVExtractor `required:"true" yaml:"cacheValueFrom" json:"cacheValueFrom"` + CacheStreamValueFrom KVExtractor `required:"true" yaml:"cacheStreamValueFrom" json:"cacheStreamValueFrom"` } func (c *PluginConfig) FromJson(json gjson.Result) { - c.EmbeddingProviderConfig.FromJson(json.Get("embeddingProvider")) - c.VectorProviderConfig.FromJson(json.Get("vectorProvider")) - c.CacheProviderConfig.FromJson(json.Get("redis")) - if c.CacheKeyFrom.RequestBody == "" { - c.CacheKeyFrom.RequestBody = "messages.@reverse.0.content" - } - c.CacheKeyFrom.RequestBody = json.Get("cacheKeyFrom.requestBody").String() - if c.CacheKeyFrom.RequestBody == "" { - c.CacheKeyFrom.RequestBody = "messages.@reverse.0.content" - } - c.CacheValueFrom.ResponseBody = json.Get("cacheValueFrom.responseBody").String() - if c.CacheValueFrom.ResponseBody == "" { - c.CacheValueFrom.ResponseBody = "choices.0.message.content" - } - c.CacheStreamValueFrom.ResponseBody = json.Get("cacheStreamValueFrom.responseBody").String() - if c.CacheStreamValueFrom.ResponseBody == "" { - c.CacheStreamValueFrom.ResponseBody = "choices.0.delta.content" - } - c.ReturnResponseTemplate = json.Get("returnResponseTemplate").String() - if c.ReturnResponseTemplate == "" { - c.ReturnResponseTemplate = `{"id":"from-cache","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` - } - c.ReturnStreamResponseTemplate = json.Get("returnStreamResponseTemplate").String() - if c.ReturnStreamResponseTemplate == "" { - c.ReturnStreamResponseTemplate = `data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + "\n\ndata:[DONE]\n\n" + c.embeddingProviderConfig.FromJson(json.Get("embedding")) + c.vectorProviderConfig.FromJson(json.Get("vector")) + c.cacheProviderConfig.FromJson(json.Get("cache")) + + c.CacheKeyFrom.SetRequestBodyFromJson(json, "cacheKeyFrom.requestBody", "messages.@reverse.0.content") + c.CacheValueFrom.SetRequestBodyFromJson(json, "cacheValueFrom.requestBody", "choices.0.message.content") + c.CacheStreamValueFrom.SetRequestBodyFromJson(json, "cacheStreamValueFrom.requestBody", "choices.0.delta.content") + + c.StreamResponseTemplate = json.Get("streamResponseTemplate").String() + if c.StreamResponseTemplate == "" { + c.StreamResponseTemplate = `data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + "\n\ndata:[DONE]\n\n" } - c.ReturnTestResponseTemplate = json.Get("returnTestResponseTemplate").String() - if c.ReturnTestResponseTemplate == "" { - c.ReturnTestResponseTemplate = `{"id":"random-generate","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + c.ResponseTemplate = json.Get("responseTemplate").String() + if c.ResponseTemplate == "" { + c.ResponseTemplate = `{"id":"from-cache","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` } } func (c *PluginConfig) Validate() error { - if err := c.CacheProviderConfig.Validate(); err != nil { + if err := c.cacheProviderConfig.Validate(); err != nil { return err } - if err := c.EmbeddingProviderConfig.Validate(); err != nil { + if err := c.embeddingProviderConfig.Validate(); err != nil { return err } - if err := c.VectorProviderConfig.Validate(); err != nil { + if err := c.vectorProviderConfig.Validate(); err != nil { return err } return nil @@ -91,15 +78,15 @@ func (c *PluginConfig) Validate() error { func (c *PluginConfig) Complete(log wrapper.Log) error { var err error - c.embeddingProvider, err = embedding.CreateProvider(c.EmbeddingProviderConfig) + c.embeddingProvider, err = embedding.CreateProvider(c.embeddingProviderConfig) if err != nil { return err } - c.vectorProvider, err = vector.CreateProvider(c.VectorProviderConfig) + c.vectorProvider, err = vector.CreateProvider(c.vectorProviderConfig) if err != nil { return err } - c.redisProvider, err = cache.CreateProvider(c.CacheProviderConfig) + c.redisProvider, err = cache.CreateProvider(c.cacheProviderConfig) if err != nil { return err } diff --git a/plugins/wasm-go/extensions/ai-cache/core.go b/plugins/wasm-go/extensions/ai-cache/core.go index 2e11972522..23ecd67902 100644 --- a/plugins/wasm-go/extensions/ai-cache/core.go +++ b/plugins/wasm-go/extensions/ai-cache/core.go @@ -4,7 +4,6 @@ import ( "fmt" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/config" - "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/embedding" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/vector" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" @@ -13,13 +12,14 @@ import ( func RedisSearchHandler(key string, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, stream bool, ifUseEmbedding bool) { activeCacheProvider := config.GetCacheProvider() + key = activeCacheProvider.GetCacheKeyPrefix() + ":" + key log.Debugf("activeCacheProvider:%v", activeCacheProvider) - activeCacheProvider.Get(embedding.CacheKeyPrefix+key, func(response resp.Value) { + activeCacheProvider.Get(key, func(response resp.Value) { if err := response.Error(); err == nil && !response.IsNull() { - log.Warnf("cache hit, key:%s", key) + log.Debugf("cache hit, key:%s", key) HandleCacheHit(key, response, stream, ctx, config, log) } else { - log.Warnf("cache miss, key:%s", key) + log.Debugf("cache miss, key:%s", key) if ifUseEmbedding { HandleCacheMiss(key, err, response, ctx, config, log, key, stream) } else { @@ -31,11 +31,11 @@ func RedisSearchHandler(key string, ctx wrapper.HttpContext, config config.Plugi } func HandleCacheHit(key string, response resp.Value, stream bool, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log) { - ctx.SetContext(embedding.CacheKeyContextKey, nil) + ctx.SetContext(CACHE_KEY_CONTEXT_KEY, nil) if !stream { - proxywasm.SendHttpResponse(200, [][2]string{{"content-type", "application/json; charset=utf-8"}}, []byte(fmt.Sprintf(config.ReturnResponseTemplate, "[Test, this is cache]"+response.String())), -1) + proxywasm.SendHttpResponse(200, [][2]string{{"content-type", "application/json; charset=utf-8"}}, []byte(fmt.Sprintf(config.ResponseTemplate, response.String())), -1) } else { - proxywasm.SendHttpResponse(200, [][2]string{{"content-type", "text/event-stream; charset=utf-8"}}, []byte(fmt.Sprintf(config.ReturnStreamResponseTemplate, "[Test, this is cache]"+response.String())), -1) + proxywasm.SendHttpResponse(200, [][2]string{{"content-type", "text/event-stream; charset=utf-8"}}, []byte(fmt.Sprintf(config.StreamResponseTemplate, response.String())), -1) } } @@ -58,16 +58,16 @@ func FetchAndProcessEmbeddings(key string, ctx wrapper.HttpContext, config confi }) } -func QueryVectorDB(key string, text_embedding []float64, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, stream bool) { +func QueryVectorDB(key string, textEmbedding []float64, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, stream bool) { log.Debugf("QueryVectorDB key: %s", key) activeVectorProvider := config.GetVectorProvider() log.Debugf("activeVectorProvider: %+v", activeVectorProvider) - activeVectorProvider.QueryEmbedding(text_embedding, ctx, log, + activeVectorProvider.QueryEmbedding(textEmbedding, ctx, log, func(results []vector.QueryEmbeddingResult, ctx wrapper.HttpContext, log wrapper.Log) { // The baisc logic is to compare the similarity of the embedding with the most similar key in the database if len(results) == 0 { log.Warnf("Failed to query vector database, no similar key found") - activeVectorProvider.UploadEmbedding(text_embedding, key, ctx, log, + activeVectorProvider.UploadEmbedding(textEmbedding, key, ctx, log, func(ctx wrapper.HttpContext, log wrapper.Log) { proxywasm.ResumeHttpRequest() }) @@ -82,7 +82,7 @@ func QueryVectorDB(key string, text_embedding []float64, ctx wrapper.HttpContext RedisSearchHandler(mostSimilarData.Text, ctx, config, log, stream, false) } else { log.Infof("the most similar key's score is too high, key: %s, score: %f", mostSimilarData.Text, mostSimilarData.Score) - activeVectorProvider.UploadEmbedding(text_embedding, key, ctx, log, + activeVectorProvider.UploadEmbedding(textEmbedding, key, ctx, log, func(ctx wrapper.HttpContext, log wrapper.Log) { proxywasm.ResumeHttpRequest() }) diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go b/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go index c6e87e7e93..0aaa6e2e68 100644 --- a/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go +++ b/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go @@ -9,10 +9,10 @@ import ( ) const ( - domain = "dashscope.aliyuncs.com" - port = 443 - modelName = "text-embedding-v1" - endpoint = "/api/v1/services/embeddings/text-embedding/text-embedding" + DOMAIN = "dashscope.aliyuncs.com" + PORT = 443 + MODEL_NAME = "text-embedding-v1" + END_POINT = "/api/v1/services/embeddings/text-embedding/text-embedding" ) type dashScopeProviderInitializer struct { @@ -20,30 +20,30 @@ type dashScopeProviderInitializer struct { func (d *dashScopeProviderInitializer) ValidateConfig(config ProviderConfig) error { if config.apiKey == "" { - return errors.New("DashScopeKey is required") + return errors.New("[DashScope] apiKey is required") } return nil } func (d *dashScopeProviderInitializer) CreateProvider(c ProviderConfig) (Provider, error) { if c.servicePort == 0 { - c.servicePort = port + c.servicePort = PORT } - if c.serviceHost == "" { - c.serviceHost = domain + if c.serviceDomain == "" { + c.serviceDomain = DOMAIN } return &DSProvider{ config: c, client: wrapper.NewClusterClient(wrapper.DnsCluster{ ServiceName: c.serviceName, Port: c.servicePort, - Domain: c.serviceHost, + Domain: c.serviceDomain, }), }, nil } func (d *DSProvider) GetProviderType() string { - return providerTypeDashScope + return PROVIDER_TYPE_DASHSCOPE } type Embedding struct { @@ -92,7 +92,7 @@ type DSProvider struct { func (d *DSProvider) constructParameters(texts []string, log wrapper.Log) (string, [][2]string, []byte, error) { data := EmbeddingRequest{ - Model: modelName, + Model: MODEL_NAME, Input: Input{ Texts: texts, }, @@ -118,18 +118,16 @@ func (d *DSProvider) constructParameters(texts []string, log wrapper.Log) (strin {"Content-Type", "application/json"}, } - return endpoint, headers, requestBody, err + return END_POINT, headers, requestBody, err } -// Result 定义查询结果的结构 type Result struct { ID string `json:"id"` - Vector []float64 `json:"vector,omitempty"` // omitempty 使得如果 vector 是空,它将不会被序列化 + Vector []float64 `json:"vector,omitempty"` Fields map[string]interface{} `json:"fields"` Score float64 `json:"score"` } -// 返回指针防止拷贝 Embedding func (d *DSProvider) parseTextEmbedding(responseBody []byte) (*Response, error) { var resp Response err := json.Unmarshal(responseBody, &resp) @@ -151,7 +149,7 @@ func (d *DSProvider) GetEmbedding( } var resp *Response - d.client.Post(embUrl, embHeaders, embRequestBody, // TODO: 函数调用返回的error要进行处理 + d.client.Post(embUrl, embHeaders, embRequestBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) { if statusCode != http.StatusOK { log.Errorf("Failed to fetch embeddings, statusCode: %d, responseBody: %s", statusCode, string(responseBody)) diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/provider.go b/plugins/wasm-go/extensions/ai-cache/embedding/provider.go index e9d69a2bed..f7066c3761 100644 --- a/plugins/wasm-go/extensions/ai-cache/embedding/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/embedding/provider.go @@ -8,15 +8,7 @@ import ( ) const ( - providerTypeDashScope = "dashscope" - CacheKeyContextKey = "cacheKey" - CacheContentContextKey = "cacheContent" - PartialMessageContextKey = "partialMessage" - ToolCallsContextKey = "toolCalls" - StreamContextKey = "stream" - CacheKeyPrefix = "higressAiCache" - DefaultCacheKeyPrefix = "higressAiCache" - queryEmbeddingKey = "queryEmbedding" + PROVIDER_TYPE_DASHSCOPE = "dashscope" ) type providerInitializer interface { @@ -26,28 +18,35 @@ type providerInitializer interface { var ( providerInitializers = map[string]providerInitializer{ - providerTypeDashScope: &dashScopeProviderInitializer{}, + PROVIDER_TYPE_DASHSCOPE: &dashScopeProviderInitializer{}, } ) type ProviderConfig struct { // @Title zh-CN 文本特征提取服务提供者类型 // @Description zh-CN 文本特征提取服务提供者类型,例如 DashScope - typ string `json:"type"` - // @Title zh-CN DashScope 阿里云大模型服务名 - // @Description zh-CN 调用阿里云的大模型服务 - serviceName string `require:"true" yaml:"serviceName" json:"serviceName"` - serviceHost string `require:"false" yaml:"serviceHost" json:"serviceHost"` - servicePort int64 `require:"false" yaml:"servicePort" json:"servicePort"` - apiKey string `require:"false" yaml:"apiKey" json:"apiKey"` - timeout uint32 `require:"false" yaml:"timeout" json:"timeout"` - client wrapper.HttpClient `yaml:"-"` + typ string + // @Title zh-CN DashScope 文本特征提取服务名称 + // @Description zh-CN 文本特征提取服务名称 + serviceName string + // @Title zh-CN 文本特征提取服务域名 + // @Description zh-CN 文本特征提取服务域名 + serviceDomain string + // @Title zh-CN 文本特征提取服务端口 + // @Description zh-CN 文本特征提取服务端口 + servicePort int64 + // @Title zh-CN 文本特征提取服务 API Key + // @Description zh-CN 文本特征提取服务 API Key + apiKey string + // @Title zh-CN 文本特征提取服务超时时间 + // @Description zh-CN 文本特征提取服务超时时间 + timeout uint32 } func (c *ProviderConfig) FromJson(json gjson.Result) { c.typ = json.Get("type").String() c.serviceName = json.Get("serviceName").String() - c.serviceHost = json.Get("serviceHost").String() + c.serviceDomain = json.Get("serviceDomain").String() c.servicePort = json.Get("servicePort").Int() c.apiKey = json.Get("apiKey").String() c.timeout = uint32(json.Get("timeout").Int()) @@ -57,8 +56,21 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { } func (c *ProviderConfig) Validate() error { - if len(c.serviceName) == 0 { - return errors.New("serviceName is required") + if c.serviceName == "" { + return errors.New("embedding service name is required") + } + if c.apiKey == "" { + return errors.New("embedding service API key is required") + } + if c.typ == "" { + return errors.New("embedding service type is required") + } + initializer, has := providerInitializers[c.typ] + if !has { + return errors.New("unknown embedding service provider type: " + c.typ) + } + if err := initializer.ValidateConfig(*c); err != nil { + return err } return nil } diff --git a/plugins/wasm-go/extensions/ai-cache/main.go b/plugins/wasm-go/extensions/ai-cache/main.go index d58350c604..2fc63d8381 100644 --- a/plugins/wasm-go/extensions/ai-cache/main.go +++ b/plugins/wasm-go/extensions/ai-cache/main.go @@ -6,7 +6,6 @@ import ( "strings" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/config" - "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/embedding" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" @@ -14,22 +13,20 @@ import ( ) const ( - pluginName = "ai-cache" - CacheKeyContextKey = "cacheKey" - CacheContentContextKey = "cacheContent" - PartialMessageContextKey = "partialMessage" - ToolCallsContextKey = "toolCalls" - StreamContextKey = "stream" - CacheKeyPrefix = "higressAiCache" - DefaultCacheKeyPrefix = "higressAiCache" - QueryEmbeddingKey = "queryEmbedding" + PLUGIN_NAME = "ai-cache" + CACHE_KEY_CONTEXT_KEY = "cacheKey" + CACHE_CONTENT_CONTEXT_KEY = "cacheContent" + PARTIAL_MESSAGE_CONTEXT_KEY = "partialMessage" + TOOL_CALLS_CONTEXT_KEY = "toolCalls" + STREAM_CONTEXT_KEY = "stream" + QUERY_EMBEDDING_KEY = "queryEmbedding" ) func main() { // CreateClient() wrapper.SetCtx( - pluginName, + PLUGIN_NAME, wrapper.ParseConfigBy(parseConfig), wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders), wrapper.ProcessRequestBodyBy(onHttpRequestBody), @@ -54,10 +51,6 @@ func parseConfig(json gjson.Result, config *config.PluginConfig, log wrapper.Log return nil } -func TrimQuote(source string) string { - return strings.Trim(source, `"`) -} - func onHttpRequestHeaders(ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log) types.Action { // 这段代码是为了测试,在 parseConfig 阶段初始化 client 会出错,比如 docker compose 中的 redis 就无法使用 // 但是在 onHttpRequestHeaders 中可以连接到 redis、 @@ -97,95 +90,57 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config config.PluginConfig, body stream := false if bodyJson.Get("stream").Bool() { stream = true - ctx.SetContext(StreamContextKey, struct{}{}) - } else if ctx.GetContext(StreamContextKey) != nil { + ctx.SetContext(STREAM_CONTEXT_KEY, struct{}{}) + } else if ctx.GetContext(STREAM_CONTEXT_KEY) != nil { stream = true } // key := TrimQuote(bodyJson.Get(config.CacheKeyFrom.RequestBody).Raw) key := bodyJson.Get(config.CacheKeyFrom.RequestBody).String() - ctx.SetContext(CacheKeyContextKey, key) + ctx.SetContext(CACHE_KEY_CONTEXT_KEY, key) log.Debugf("[onHttpRequestBody] key:%s", key) if key == "" { log.Debug("[onHttpRquestBody] parse key from request body failed") return types.ActionContinue } - queryString := config.CacheKeyPrefix + key - - RedisSearchHandler(queryString, ctx, config, log, stream, true) + RedisSearchHandler(key, ctx, config, log, stream, true) - // 需要等待异步回调完成,返回 Pause 状态,可以被 ResumeHttpRequest 恢复 return types.ActionPause } -func processSSEMessage(ctx wrapper.HttpContext, config config.PluginConfig, sseMessage string, log wrapper.Log) string { - subMessages := strings.Split(sseMessage, "\n") - var message string - for _, msg := range subMessages { - if strings.HasPrefix(msg, "data:") { - message = msg - break - } - } - if len(message) < 6 { - log.Warnf("invalid message:%s", message) - return "" - } - // skip the prefix "data:" - bodyJson := message[5:] - if gjson.Get(bodyJson, config.CacheStreamValueFrom.ResponseBody).Exists() { - tempContentI := ctx.GetContext(CacheContentContextKey) - if tempContentI == nil { - content := TrimQuote(gjson.Get(bodyJson, config.CacheStreamValueFrom.ResponseBody).Raw) - ctx.SetContext(CacheContentContextKey, content) - return content - } - append := TrimQuote(gjson.Get(bodyJson, config.CacheStreamValueFrom.ResponseBody).Raw) - content := tempContentI.(string) + append - ctx.SetContext(CacheContentContextKey, content) - return content - } else if gjson.Get(bodyJson, "choices.0.delta.content.tool_calls").Exists() { - // TODO: compatible with other providers - ctx.SetContext(ToolCallsContextKey, struct{}{}) - return "" - } - log.Warnf("unknown message:%s", bodyJson) - return "" -} - func onHttpResponseHeaders(ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log) types.Action { contentType, _ := proxywasm.GetHttpResponseHeader("content-type") if strings.Contains(contentType, "text/event-stream") { - ctx.SetContext(StreamContextKey, struct{}{}) + ctx.SetContext(STREAM_CONTEXT_KEY, struct{}{}) } return types.ActionContinue } func onHttpResponseBody(ctx wrapper.HttpContext, config config.PluginConfig, chunk []byte, isLastChunk bool, log wrapper.Log) []byte { - log.Infof("[onHttpResponseBody] chunk:%s", string(chunk)) - log.Infof("[onHttpResponseBody] isLastChunk:%v", isLastChunk) - if ctx.GetContext(ToolCallsContextKey) != nil { + log.Debugf("[onHttpResponseBody] chunk:%s", string(chunk)) + log.Debugf("[onHttpResponseBody] isLastChunk:%v", isLastChunk) + if ctx.GetContext(TOOL_CALLS_CONTEXT_KEY) != nil { // we should not cache tool call result return chunk } - keyI := ctx.GetContext(CacheKeyContextKey) + keyI := ctx.GetContext(CACHE_KEY_CONTEXT_KEY) if keyI == nil { return chunk } if !isLastChunk { - stream := ctx.GetContext(StreamContextKey) + stream := ctx.GetContext(STREAM_CONTEXT_KEY) if stream == nil { - tempContentI := ctx.GetContext(CacheContentContextKey) + tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) if tempContentI == nil { - ctx.SetContext(CacheContentContextKey, chunk) + ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, chunk) return chunk } tempContent := tempContentI.([]byte) tempContent = append(tempContent, chunk...) - ctx.SetContext(CacheContentContextKey, tempContent) + ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, tempContent) } else { var partialMessage []byte - partialMessageI := ctx.GetContext(PartialMessageContextKey) + partialMessageI := ctx.GetContext(PARTIAL_MESSAGE_CONTEXT_KEY) if partialMessageI != nil { partialMessage = append(partialMessageI.([]byte), chunk...) } else { @@ -199,20 +154,20 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config config.PluginConfig, chu } } if !strings.HasSuffix(string(partialMessage), "\n\n") { - ctx.SetContext(PartialMessageContextKey, []byte(messages[len(messages)-1])) + ctx.SetContext(PARTIAL_MESSAGE_CONTEXT_KEY, []byte(messages[len(messages)-1])) } else { - ctx.SetContext(PartialMessageContextKey, nil) + ctx.SetContext(PARTIAL_MESSAGE_CONTEXT_KEY, nil) } } return chunk } // last chunk key := keyI.(string) - stream := ctx.GetContext(StreamContextKey) + stream := ctx.GetContext(STREAM_CONTEXT_KEY) var value string if stream == nil { var body []byte - tempContentI := ctx.GetContext(CacheContentContextKey) + tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) if tempContentI != nil { body = append(tempContentI.([]byte), chunk...) } else { @@ -229,7 +184,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config config.PluginConfig, chu log.Infof("[onHttpResponseBody] stream mode") if len(chunk) > 0 { var lastMessage []byte - partialMessageI := ctx.GetContext(PartialMessageContextKey) + partialMessageI := ctx.GetContext(PARTIAL_MESSAGE_CONTEXT_KEY) if partialMessageI != nil { lastMessage = append(partialMessageI.([]byte), chunk...) } else { @@ -243,7 +198,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config config.PluginConfig, chu lastMessage = lastMessage[:len(lastMessage)-2] value = processSSEMessage(ctx, config, string(lastMessage), log) } else { - tempContentI := ctx.GetContext(CacheContentContextKey) + tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) if tempContentI == nil { log.Warnf("[onHttpResponseBody] no content in tempContentI") return chunk @@ -252,7 +207,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config config.PluginConfig, chu } } log.Infof("[onHttpResponseBody] Setting cache to redis, key:%s, value:%s", key, value) - config.GetCacheProvider().Set(embedding.CacheKeyPrefix+key, value, nil) - // TODO: 要不要加个Expire方法 + activeCacheProvider := config.GetCacheProvider() + config.GetCacheProvider().Set(activeCacheProvider.GetCacheKeyPrefix()+":"+key, value, nil) return chunk } diff --git a/plugins/wasm-go/extensions/ai-cache/util.go b/plugins/wasm-go/extensions/ai-cache/util.go new file mode 100644 index 0000000000..a1f613ce69 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/util.go @@ -0,0 +1,48 @@ +package main + +import ( + "strings" + + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/config" + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" +) + +func TrimQuote(source string) string { + return strings.Trim(source, `"`) +} + +func processSSEMessage(ctx wrapper.HttpContext, config config.PluginConfig, sseMessage string, log wrapper.Log) string { + subMessages := strings.Split(sseMessage, "\n") + var message string + for _, msg := range subMessages { + if strings.HasPrefix(msg, "data:") { + message = msg + break + } + } + if len(message) < 6 { + log.Warnf("invalid message:%s", message) + return "" + } + // skip the prefix "data:" + bodyJson := message[5:] + if gjson.Get(bodyJson, config.CacheStreamValueFrom.ResponseBody).Exists() { + tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) + if tempContentI == nil { + content := TrimQuote(gjson.Get(bodyJson, config.CacheStreamValueFrom.ResponseBody).Raw) + ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, content) + return content + } + append := TrimQuote(gjson.Get(bodyJson, config.CacheStreamValueFrom.ResponseBody).Raw) + content := tempContentI.(string) + append + ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, content) + return content + } else if gjson.Get(bodyJson, "choices.0.delta.content.tool_calls").Exists() { + // TODO: compatible with other providers + ctx.SetContext(TOOL_CALLS_CONTEXT_KEY, struct{}{}) + return "" + } + log.Warnf("unknown message:%s", bodyJson) + return "" +} diff --git a/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go b/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go index 250a653919..b0801eb82d 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go @@ -26,7 +26,7 @@ func (d *dashVectorProviderInitializer) ValidateConfig(config ProviderConfig) er if len(config.serviceName) == 0 { return errors.New("[DashVector] serviceName is required") } - if len(config.serviceHost) == 0 { + if len(config.serviceDomain) == 0 { return errors.New("[DashVector] endPoint is required") } return nil @@ -38,7 +38,7 @@ func (d *dashVectorProviderInitializer) CreateProvider(config ProviderConfig) (P client: wrapper.NewClusterClient(wrapper.DnsCluster{ ServiceName: config.serviceName, Port: config.servicePort, - Domain: config.serviceHost, + Domain: config.serviceDomain, }), }, nil } diff --git a/plugins/wasm-go/extensions/ai-cache/vector/provider.go b/plugins/wasm-go/extensions/ai-cache/vector/provider.go index a37eac06fa..7de9e64404 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/provider.go @@ -26,9 +26,9 @@ var ( // 定义通用的查询结果的结构体 type QueryEmbeddingResult struct { - Text string // 相似的文本 - Embedding []float64 // 相似文本的向量 - Score float64 // 文本的向量相似度或距离等度量 + Text string // 相似的文本 + Embedding []float64 // 相似文本的向量 + Score float64 // 文本的向量相似度或距离等度量 } type Provider interface { @@ -53,14 +53,14 @@ type Provider interface { type ProviderConfig struct { // @Title zh-CN 向量存储服务提供者类型 // @Description zh-CN 向量存储服务提供者类型,例如 DashVector、Milvus - typ string `json:"vectorStoreProviderType"` - serviceName string `require:"true" yaml:"serviceName" json:"serviceName"` - serviceHost string `require:"false" yaml:"serviceHost" json:"serviceHost"` - servicePort int64 `require:"false" yaml:"servicePort" json:"servicePort"` - apiKey string `require:"false" yaml:"apiKey" json:"apiKey"` - topK int `require:"false" yaml:"topK" json:"topK"` - timeout uint32 `require:"false" yaml:"timeout" json:"timeout"` - collectionID string `require:"false" yaml:"collectionID" json:"collectionID"` + typ string + serviceName string + serviceDomain string + servicePort int64 + apiKey string + topK int + timeout uint32 + collectionID string // // @Title zh-CN Chroma 的上游服务名称 // // @Description zh-CN Chroma 服务所对应的网关内上游服务名称 @@ -77,14 +77,13 @@ type ProviderConfig struct { // // @Title zh-CN Chroma 超时设置 // // @Description zh-CN Chroma 超时设置,默认为 10 秒 // ChromaTimeout uint32 `require:"false" yaml:"ChromaTimeout" json:"ChromaTimeout"` - vectorClient wrapper.HttpClient `yaml:"-" json:"-"` } func (c *ProviderConfig) FromJson(json gjson.Result) { - c.typ = json.Get("vectorStoreProviderType").String() + c.typ = json.Get("type").String() // DashVector c.serviceName = json.Get("serviceName").String() - c.serviceHost = json.Get("serviceHost").String() + c.serviceDomain = json.Get("serviceDomain").String() c.servicePort = int64(json.Get("servicePort").Int()) if c.servicePort == 0 { c.servicePort = 443 @@ -118,11 +117,11 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { func (c *ProviderConfig) Validate() error { if c.typ == "" { - return errors.New("[ai-cache] missing type in provider config") + return errors.New("vector database service is required") } initializer, has := providerInitializers[c.typ] if !has { - return errors.New("unknown provider type: " + c.typ) + return errors.New("unknown vector database service provider type: " + c.typ) } if err := initializer.ValidateConfig(*c); err != nil { return err From d6c643fbd95d472c1077b28f1c7ed20a69f6a70b Mon Sep 17 00:00:00 2001 From: Async Date: Fri, 6 Sep 2024 11:30:36 +0800 Subject: [PATCH 12/71] add: makefile for weaviate --- docker-compose-test/Makefile | 111 +++++++++++++++++++++++++ docker-compose-test/docker-compose.yml | 50 +++++------ docker-compose-test/envoy.yaml | 15 ++++ 3 files changed, 151 insertions(+), 25 deletions(-) create mode 100644 docker-compose-test/Makefile diff --git a/docker-compose-test/Makefile b/docker-compose-test/Makefile new file mode 100644 index 0000000000..c4955a000e --- /dev/null +++ b/docker-compose-test/Makefile @@ -0,0 +1,111 @@ +HEADER := Content-Type: application/json + + +.PHONY: docker docker-down proxy cache weaviate-collection weaviate-obj + +all: docker proxy cache weaviate-collection weaviate-obj + +docker: + docker compose up + +docker-down: + docker compose down + +proxy: + cd ../plugins/wasm-go/extensions/ai-proxy && \ + tinygo build -o ai-proxy.wasm -scheduler=none -target=wasi -gc=custom -tags='custommalloc nottinygc_finalizer proxy_wasm_version_0_2_100' . && \ + mv ai-proxy.wasm ../../../../docker-compose-test/ + + +cache: + cd ../plugins/wasm-go/extensions/ai-cache && \ + tinygo build -o ai-cache.wasm -scheduler=none -target=wasi -gc=custom -tags='custommalloc nottinygc_finalizer proxy_wasm_version_0_2_100' . && \ + mv ai-cache.wasm ../../../../docker-compose-test/ + +weaviate-obj: + curl --request POST --url http://localhost:8081/v1/objects -H "$(HEADER)" --data '{ "class": "", "vectorWeights": {}, "properties": {}, "id": "", "creationTimeUnix": 1, "lastUpdateTimeUnix": 1, "vector": [ 1 ], "vectors": { "ANY_ADDITIONAL_PROPERTY": [1] }, "tenant": "", "additional": { "ANY_ADDITIONAL_PROPERTY": {} }}' + +weaviate-collection: + curl -X POST http://localhost:8081/v1/schema -H "$(HEADER)" -d '{ + "class": "", + "vectorConfig": { + "": { + "vectorizer": "none", + "vectorIndexType": "hnsw", + "vectorIndexConfig": {} + } + }, + "vectorIndexType": "hnsw", + "vectorIndexConfig": {}, + "shardingConfig": { + "desiredCount": 1, + "virtualPerPhysical": 128 + }, + "replicationConfig": { + "factor": 1, + "asyncEnabled": true + }, + "invertedIndexConfig": { + "cleanupIntervalSeconds": 60, + "bm25": { + "k1": 1.2, + "b": 0.75 + }, + "stopwords": { + "preset": "en", + "additions": [], + "removals": [] + }, + "indexTimestamps": false, + "indexNullState": false, + "indexPropertyLength": false + }, + "multiTenancyConfig": { + "enabled": false, + "autoTenantCreation": true, + "autoTenantActivation": true + }, + "vectorizer": "none", + "moduleConfig": { + "": { + "vectorizeClassName": true + } + }, + "description": "", + "properties": [ + { + "dataType": [ + "string" + ], + "description": "", + "moduleConfig": { + "": { + "skip": false, + "vectorizePropertyName": true + } + }, + "name": "", + "indexInverted": true, + "indexFilterable": true, + "indexSearchable": true, + "indexRangeFilters": true, + "tokenization": "word", + "nestedProperties": [ + { + "name": "", + "dataType": [ + null + ], + "description": "", + "indexFilterable": true, + "indexSearchable": true, + "indexRangeFilters": true, + "tokenization": "word", + "nestedProperties": [ + null + ] + } + ] + } + ] + }' \ No newline at end of file diff --git a/docker-compose-test/docker-compose.yml b/docker-compose-test/docker-compose.yml index d98d16a95c..a8457ddd24 100644 --- a/docker-compose-test/docker-compose.yml +++ b/docker-compose-test/docker-compose.yml @@ -86,31 +86,31 @@ services: # ports: # - "3210:3210/tcp" - # weaviate: - # command: - # - --host - # - 0.0.0.0 - # - --port - # - '8080' - # - --scheme - # - http - # image: cr.weaviate.io/semitechnologies/weaviate:1.26.1 - # ports: - # - 8081:8080 - # - 50051:50051 - # volumes: - # - weaviate_data:/var/lib/weaviate - # restart: on-failure:0 - # networks: - # - wasmtest - # environment: - # QUERY_DEFAULTS_LIMIT: 25 - # AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: 'true' - # PERSISTENCE_DATA_PATH: '/var/lib/weaviate' - # DEFAULT_VECTORIZER_MODULE: 'none' - # ENABLE_API_BASED_MODULES: 'true' - # CLUSTER_HOSTNAME: 'node1' - # TRANSFORMERS_INFERENCE_API: http://t2v-transformers:8080 # Set the inference API endpoint + weaviate: + command: + - --host + - 0.0.0.0 + - --port + - '8080' + - --scheme + - http + image: cr.weaviate.io/semitechnologies/weaviate:1.26.3 + ports: + - 8081:8080 + - 50051:50051 + volumes: + - weaviate_data:/var/lib/weaviate + restart: on-failure:0 + networks: + - wasmtest + environment: + QUERY_DEFAULTS_LIMIT: 25 + AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: 'true' + PERSISTENCE_DATA_PATH: '/var/lib/weaviate' + DEFAULT_VECTORIZER_MODULE: 'none' + ENABLE_API_BASED_MODULES: 'true' + CLUSTER_HOSTNAME: 'node1' + TRANSFORMERS_INFERENCE_API: http://t2v-transformers:8080 # Set the inference API endpoint # t2v-transformers: # Set the name of the inference container # image: cr.weaviate.io/semitechnologies/transformers-inference:sentence-transformers-multi-qa-MiniLM-L6-cos-v1 diff --git a/docker-compose-test/envoy.yaml b/docker-compose-test/envoy.yaml index 625a799652..e67a0ce0d1 100644 --- a/docker-compose-test/envoy.yaml +++ b/docker-compose-test/envoy.yaml @@ -209,6 +209,21 @@ static_resources: socket_address: address: 172.17.0.1 # 本地 API 服务地址,这里是 docker0 port_value: 9200 + # weaviate + - name: outbound|8081||weaviate.dns + connect_timeout: 30s + type: LOGICAL_DNS + dns_lookup_family: V4_ONLY + lb_policy: ROUND_ROBIN + load_assignment: + cluster_name: outbound|8081||weaviate.dns + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: 172.17.0.1 # 本地 API 服务地址,这里是 docker0 + port_value: 8081 # llm - name: llm From 3f3a1bcd380f9e225853fae15ddafa7863124557 Mon Sep 17 00:00:00 2001 From: Async Date: Fri, 6 Sep 2024 15:32:09 +0800 Subject: [PATCH 13/71] feat: add weaviate --- docker-compose-test/Makefile | 162 ++++----- docker-compose-test/docker-compose.yml | 47 +-- docker-compose-test/envoy.yaml | 11 +- .../extensions/ai-cache/vector/chroma.go | 11 +- .../extensions/ai-cache/vector/dashvector.go | 8 +- .../ai-cache/vector/elasticsearch.go | 4 +- .../extensions/ai-cache/vector/provider.go | 33 ++ .../extensions/ai-cache/vector/weaviate.go | 336 +++++++++--------- 8 files changed, 307 insertions(+), 305 deletions(-) diff --git a/docker-compose-test/Makefile b/docker-compose-test/Makefile index c4955a000e..b66c628609 100644 --- a/docker-compose-test/Makefile +++ b/docker-compose-test/Makefile @@ -1,111 +1,91 @@ HEADER := Content-Type: application/json +WEAVIATE_PORT = 8081 -.PHONY: docker docker-down proxy cache weaviate-collection weaviate-obj +.PHONY: proxy cache docker weaviate-post-collection weaviate-post-obj weaviate-get-obj -all: docker proxy cache weaviate-collection weaviate-obj +all: proxy cache docker weaviate-post-collection weaviate-post-obj weaviate-get-obj docker: - docker compose up + docker compose up -d docker-down: docker compose down +# 编译 proxy 插件 proxy: cd ../plugins/wasm-go/extensions/ai-proxy && \ tinygo build -o ai-proxy.wasm -scheduler=none -target=wasi -gc=custom -tags='custommalloc nottinygc_finalizer proxy_wasm_version_0_2_100' . && \ mv ai-proxy.wasm ../../../../docker-compose-test/ - +# 编译 cache 插件 cache: cd ../plugins/wasm-go/extensions/ai-cache && \ tinygo build -o ai-cache.wasm -scheduler=none -target=wasi -gc=custom -tags='custommalloc nottinygc_finalizer proxy_wasm_version_0_2_100' . && \ mv ai-cache.wasm ../../../../docker-compose-test/ -weaviate-obj: - curl --request POST --url http://localhost:8081/v1/objects -H "$(HEADER)" --data '{ "class": "", "vectorWeights": {}, "properties": {}, "id": "", "creationTimeUnix": 1, "lastUpdateTimeUnix": 1, "vector": [ 1 ], "vectors": { "ANY_ADDITIONAL_PROPERTY": [1] }, "tenant": "", "additional": { "ANY_ADDITIONAL_PROPERTY": {} }}' - -weaviate-collection: - curl -X POST http://localhost:8081/v1/schema -H "$(HEADER)" -d '{ - "class": "", - "vectorConfig": { - "": { - "vectorizer": "none", - "vectorIndexType": "hnsw", - "vectorIndexConfig": {} - } - }, - "vectorIndexType": "hnsw", - "vectorIndexConfig": {}, - "shardingConfig": { - "desiredCount": 1, - "virtualPerPhysical": 128 - }, - "replicationConfig": { - "factor": 1, - "asyncEnabled": true - }, - "invertedIndexConfig": { - "cleanupIntervalSeconds": 60, - "bm25": { - "k1": 1.2, - "b": 0.75 - }, - "stopwords": { - "preset": "en", - "additions": [], - "removals": [] - }, - "indexTimestamps": false, - "indexNullState": false, - "indexPropertyLength": false - }, - "multiTenancyConfig": { - "enabled": false, - "autoTenantCreation": true, - "autoTenantActivation": true - }, - "vectorizer": "none", - "moduleConfig": { - "": { - "vectorizeClassName": true - } - }, - "description": "", - "properties": [ - { - "dataType": [ - "string" - ], - "description": "", - "moduleConfig": { - "": { - "skip": false, - "vectorizePropertyName": true - } - }, - "name": "", - "indexInverted": true, - "indexFilterable": true, - "indexSearchable": true, - "indexRangeFilters": true, - "tokenization": "word", - "nestedProperties": [ - { - "name": "", - "dataType": [ - null - ], - "description": "", - "indexFilterable": true, - "indexSearchable": true, - "indexRangeFilters": true, - "tokenization": "word", - "nestedProperties": [ - null - ] - } - ] - } - ] - }' \ No newline at end of file +# 创建 object +weaviate-post-obj: + curl --request POST --url http://localhost:$(WEAVIATE_PORT)/v1/objects -H "$(HEADER)" --data '{"class": "Higress", "vector": [0.1, 0.2, 0.3], "properties": {"question": "这里是问题3"}}' + +# 获取 schema +weaviate-get-schema: + curl -X GET "http://localhost:$(WEAVIATE_PORT)/v1/schema" -H "$(HEADER)" + +# 创建 collection +weaviate-post-collection: + curl -X POST "http://localhost:$(WEAVIATE_PORT)/v1/schema" -H "$(HEADER)" -d '{"class": "Higress"}' + +# 获取 objs +weaviate-get-obj: + curl -X GET "http://localhost:$(WEAVIATE_PORT)/v1/objects" + +# 获取具体 obj +weaviate-get-obj-id: + curl -X GET "http://localhost:$(WEAVIATE_PORT)/v1/objects/Higress/8e7df58e-3415-4264-9bcb-afbb3c51318b" + +# 删除 obj +weaviate-delete-obj: + curl -X DELETE "http://localhost:$(WEAVIATE_PORT)/v1/objects/Higress/8e7df58e-3415-4264-9bcb-afbb3c51318b" + +# 删除 collection,这里 classname 会自动大写 +weaviate-delete-collection: + curl -X DELETE "http://localhost:$(WEAVIATE_PORT)/v1/schema/Higress" + +QUERY = "{ \ + Get { \ + Higress ( \ + limit: 5 \ + nearVector: { \ + vector: [0.1, 0.2, 0.3] \ + } \ + ) { \ + question \ + _additional { \ + distance \ + } \ + } \ + } \ +}" +# 搜索,默认按照 distance 升序 +# https://weaviate.io/developers/weaviate/config-refs/distances +weaviate-search: + curl -X POST "http://localhost:$(WEAVIATE_PORT)/v1/graphql" -H "$(HEADER)" -d '{"query": ${QUERY}}' + + +# redis client +redis-cli: + docker run -it --network docker-compose-test_wasmtest --rm redis redis-cli -h docker-compose-test-redis-1 + +# llm request +llm: + curl -X POST http://localhost:10000/v1/chat/completions -H "Content-Type: application/json" -d '{"model": "gpt-4o-mini", "messages": [{"role": "user", "content": "你好"}]}' + +llm1: + curl -X POST http://localhost:10000/v1/chat/completions -H "Content-Type: application/json" -d '{"model": "gpt-4o-mini", "messages": [{"role": "user", "content": "今天晚上吃什么"}]}' + +llm2: + curl -X POST http://localhost:10000/v1/chat/completions -H "Content-Type: application/json" -d '{"model": "gpt-4o-mini", "messages": [{"role": "user", "content": "今天晚上吃什么?"}]}' + +llm3: + curl -X POST http://localhost:10000/v1/chat/completions -H "Content-Type: application/json" -d '{"model": "gpt-4o-mini", "messages": [{"role": "user", "content": "今天晚上吃什么呢?有无推荐?"}]}' \ No newline at end of file diff --git a/docker-compose-test/docker-compose.yml b/docker-compose-test/docker-compose.yml index a8457ddd24..9d76fbf987 100644 --- a/docker-compose-test/docker-compose.yml +++ b/docker-compose-test/docker-compose.yml @@ -9,8 +9,9 @@ services: depends_on: - httpbin - redis - - chroma - - es + - weaviate + # - chroma + # - es networks: - wasmtest ports: @@ -22,12 +23,12 @@ services: - ./ai-cache.wasm:/etc/envoy/ai-cache.wasm - ./ai-proxy.wasm:/etc/envoy/ai-proxy.wasm - chroma: - image: chromadb/chroma - ports: - - "8001:8000" - volumes: - - chroma-data:/chroma/chroma + # chroma: + # image: chromadb/chroma + # ports: + # - "8001:8000" + # volumes: + # - chroma-data:/chroma/chroma redis: image: redis:latest @@ -43,19 +44,19 @@ services: ports: - "12345:80" - es: - image: elasticsearch:8.15.0 - environment: - - "TZ=Asia/Shanghai" - - "discovery.type=single-node" - - "xpack.security.http.ssl.enabled=false" - - "xpack.license.self_generated.type=trial" - - "ELASTIC_PASSWORD=123456" - ports: - - "9200:9200" - - "9300:9300" - networks: - - wasmtest + # es: + # image: elasticsearch:8.15.0 + # environment: + # - "TZ=Asia/Shanghai" + # - "discovery.type=single-node" + # - "xpack.security.http.ssl.enabled=false" + # - "xpack.license.self_generated.type=trial" + # - "ELASTIC_PASSWORD=123456" + # ports: + # - "9200:9200" + # - "9300:9300" + # networks: + # - wasmtest # kibana: # image: docker.elastic.co/kibana/kibana:8.15.0 @@ -94,7 +95,8 @@ services: - '8080' - --scheme - http - image: cr.weaviate.io/semitechnologies/weaviate:1.26.3 + # 高于 1.24.x 的版本,单节点部署有问题 + image: cr.weaviate.io/semitechnologies/weaviate:1.24.1 ports: - 8081:8080 - 50051:50051 @@ -110,7 +112,6 @@ services: DEFAULT_VECTORIZER_MODULE: 'none' ENABLE_API_BASED_MODULES: 'true' CLUSTER_HOSTNAME: 'node1' - TRANSFORMERS_INFERENCE_API: http://t2v-transformers:8080 # Set the inference API endpoint # t2v-transformers: # Set the name of the inference container # image: cr.weaviate.io/semitechnologies/transformers-inference:sentence-transformers-multi-qa-MiniLM-L6-cos-v1 diff --git a/docker-compose-test/envoy.yaml b/docker-compose-test/envoy.yaml index e67a0ce0d1..48fafa99b0 100644 --- a/docker-compose-test/envoy.yaml +++ b/docker-compose-test/envoy.yaml @@ -62,13 +62,10 @@ static_resources: "DashScopeServiceName": "dashscope" }, "vectorProvider": { - "VectorStoreProviderType": "elasticsearch", - "ThresholdRelation": "gte", - "ESThreshold": 0.7, - "ESServiceName": "es", - "ESIndex": "higress", - "ESUsername": "elastic", - "ESPassword": "123456" + "VectorStoreProviderType": "weaviate", + "WeaviateServiceName": "weaviate", + "WeaviateCollection": "Higress", + "WeaviateThreshold": "0.3" }, "cacheKeyFrom": { "requestBody": "" diff --git a/plugins/wasm-go/extensions/ai-cache/vector/chroma.go b/plugins/wasm-go/extensions/ai-cache/vector/chroma.go index da5392fbb5..655fe1c3fd 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/chroma.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/chroma.go @@ -92,7 +92,7 @@ func (d *ChromaProvider) QueryEmbedding( } func (d *ChromaProvider) UploadEmbedding( - query_emb []float64, + queryEmb []float64, queryString string, ctx wrapper.HttpContext, log wrapper.Log, @@ -108,7 +108,7 @@ func (d *ChromaProvider) UploadEmbedding( // ] // } requestBody, err := json.Marshal(ChromaInsertRequest{ - Embeddings: []ChromaEmbedding{query_emb}, + Embeddings: []ChromaEmbedding{queryEmb}, IDs: []string{queryString}, // queryString 指的是用户查询的问题 }) @@ -131,13 +131,8 @@ func (d *ChromaProvider) UploadEmbedding( ) } -// ChromaEmbedding represents the embedding vector for a data point. type ChromaEmbedding []float64 - -// ChromaMetadataMap is a map from key to value for metadata. type ChromaMetadataMap map[string]string - -// Dataset represents the entire dataset containing multiple data points. type ChromaInsertRequest struct { Embeddings []ChromaEmbedding `json:"embeddings"` Metadatas []ChromaMetadataMap `json:"metadatas,omitempty"` // Optional metadata map array @@ -145,7 +140,6 @@ type ChromaInsertRequest struct { IDs []string `json:"ids"` } -// ChromaQueryRequest represents the query request structure. type ChromaQueryRequest struct { Where map[string]string `json:"where,omitempty"` // Optional where filter WhereDocument map[string]string `json:"where_document,omitempty"` // Optional where_document filter @@ -154,7 +148,6 @@ type ChromaQueryRequest struct { Include []string `json:"include"` } -// ChromaQueryResponse represents the search result structure. type ChromaQueryResponse struct { Ids [][]string `json:"ids"` // 每一个 embedding 相似的 key 可能会有多个,然后会有多个 embedding,所以是一个二维数组 Distances [][]float64 `json:"distances"` // 与 Ids 一一对应 diff --git a/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go b/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go index 58ded82db8..7e076d8db4 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go @@ -173,13 +173,13 @@ type insertRequest struct { Docs []document `json:"docs"` } -func (d *DvProvider) constructEmbeddingUploadParameters(emb []float64, query_string string) (string, []byte, [][2]string, error) { +func (d *DvProvider) constructEmbeddingUploadParameters(emb []float64, queryString string) (string, []byte, [][2]string, error) { url := "/v1/collections/" + d.config.DashVectorCollection + "/docs" doc := document{ Vector: emb, Fields: map[string]string{ - "query": query_string, + "query": queryString, }, } @@ -196,8 +196,8 @@ func (d *DvProvider) constructEmbeddingUploadParameters(emb []float64, query_str return url, requestBody, header, err } -func (d *DvProvider) UploadEmbedding(query_emb []float64, queryString string, ctx wrapper.HttpContext, log wrapper.Log, callback func(ctx wrapper.HttpContext, log wrapper.Log)) { - url, body, headers, _ := d.constructEmbeddingUploadParameters(query_emb, queryString) +func (d *DvProvider) UploadEmbedding(queryEmb []float64, queryString string, ctx wrapper.HttpContext, log wrapper.Log, callback func(ctx wrapper.HttpContext, log wrapper.Log)) { + url, body, headers, _ := d.constructEmbeddingUploadParameters(queryEmb, queryString) d.client.Post( url, headers, diff --git a/plugins/wasm-go/extensions/ai-cache/vector/elasticsearch.go b/plugins/wasm-go/extensions/ai-cache/vector/elasticsearch.go index 4203e38688..c84f259270 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/elasticsearch.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/elasticsearch.go @@ -105,7 +105,7 @@ func (d *ESProvider) getCredentials() string { } func (d *ESProvider) UploadEmbedding( - query_emb []float64, + queryEmb []float64, queryString string, ctx wrapper.HttpContext, log wrapper.Log, @@ -122,7 +122,7 @@ func (d *ESProvider) UploadEmbedding( // ] // } requestBody, err := json.Marshal(esInsertRequest{ - Embedding: query_emb, + Embedding: queryEmb, Question: queryString, }) if err != nil { diff --git a/plugins/wasm-go/extensions/ai-cache/vector/provider.go b/plugins/wasm-go/extensions/ai-cache/vector/provider.go index c52a394cb9..aa7c69538e 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/provider.go @@ -11,6 +11,7 @@ const ( providerTypeDashVector = "dashvector" providerTypeChroma = "chroma" providerTypeES = "elasticsearch" + providerTypeWeaviate = "weaviate" ) type providerInitializer interface { @@ -23,6 +24,7 @@ var ( providerTypeDashVector: &dashVectorProviderInitializer{}, providerTypeChroma: &chromaProviderInitializer{}, providerTypeES: &esProviderInitializer{}, + providerTypeWeaviate: &weaviateProviderInitializer{}, } ) @@ -110,6 +112,22 @@ type ProviderConfig struct { // @Title zh-CN ElasticSearch 密码 // @Description zh-CN ElasticSearch 密码,默认为 elastic ESPassword string `require:"false" yaml:"ESPassword" json:"ESPassword"` + + // @Title zh-CN Weaviate 的上游服务名称 + // @Description zh-CN Weaviate 服务所对应的网关内上游服务名称 + WeaviateServiceName string `require:"true" yaml:"WeaviateServiceName" json:"WeaviateServiceName"` + // @Title zh-CN Weaviate 的 Collection 名称 + // @Description zh-CN Weaviate Collection 的名称(class name),注意这里 weaviate 会自动把首字母进行大写 + WeaviateCollection string `require:"true" yaml:"WeaviateCollection" json:"WeaviateCollection"` + // @Title zh-CN Weaviate 的距离阈值 + // @Description zh-CN Weaviate 距离阈值,默认为 0.5,具体见 https://weaviate.io/developers/weaviate/config-refs/distances + WeaviateThreshold float64 `require:"false" yaml:"WeaviateThreshold" json:"WeaviateThreshold"` + // @Title zh-CN 搜索返回结果数量 + // @Description zh-CN 搜索返回结果数量,默认为 1 + WeaviateNResult int `require:"false" yaml:"WeaviateNResult" json:"WeaviateNResult"` + // @Title zh-CN Chroma 超时设置 + // @Description zh-CN Chroma 超时设置,默认为 10 秒 + WeaviateTimeout uint32 `require:"false" yaml:"WeaviateTimeout" json:"WeaviateTimeout"` } func (c *ProviderConfig) FromJson(json gjson.Result) { @@ -169,6 +187,21 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { if c.ESPassword == "" { c.ESPassword = "elastic" } + // Weaviate + c.WeaviateServiceName = json.Get("WeaviateServiceName").String() + c.WeaviateCollection = json.Get("WeaviateCollection").String() + c.WeaviateThreshold = json.Get("WeaviateThreshold").Float() + if c.WeaviateThreshold == 0 { + c.WeaviateThreshold = 0.5 + } + c.WeaviateNResult = int(json.Get("WeaviateNResult").Int()) + if c.WeaviateNResult == 0 { + c.WeaviateNResult = 1 + } + c.WeaviateTimeout = uint32(json.Get("WeaviateTimeout").Int()) + if c.WeaviateTimeout == 0 { + c.WeaviateTimeout = 10000 + } } func (c *ProviderConfig) Validate() error { diff --git a/plugins/wasm-go/extensions/ai-cache/vector/weaviate.go b/plugins/wasm-go/extensions/ai-cache/vector/weaviate.go index 0b361a6598..935525e1b3 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/weaviate.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/weaviate.go @@ -1,171 +1,169 @@ package vector -// import ( -// "encoding/json" -// "errors" -// "fmt" -// "net/http" - -// "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" -// ) - -// const ( -// dashVectorPort = 443 -// ) - -// type dashVectorProviderInitializer struct { -// } - -// func (d *dashVectorProviderInitializer) ValidateConfig(config ProviderConfig) error { -// if len(config.DashVectorKey) == 0 { -// return errors.New("DashVectorKey is required") -// } -// if len(config.DashVectorAuthApiEnd) == 0 { -// return errors.New("DashVectorEnd is required") -// } -// if len(config.DashVectorCollection) == 0 { -// return errors.New("DashVectorCollection is required") -// } -// if len(config.DashVectorServiceName) == 0 { -// return errors.New("DashVectorServiceName is required") -// } -// return nil -// } - -// func (d *dashVectorProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { -// return &DvProvider{ -// config: config, -// client: wrapper.NewClusterClient(wrapper.DnsCluster{ -// ServiceName: config.DashVectorServiceName, -// Port: dashVectorPort, -// Domain: config.DashVectorAuthApiEnd, -// }), -// }, nil -// } - -// type DvProvider struct { -// config ProviderConfig -// client wrapper.HttpClient -// } - -// func (d *DvProvider) GetProviderType() string { -// return providerTypeDashVector -// } - -// type EmbeddingRequest struct { -// Model string `json:"model"` -// Input Input `json:"input"` -// Parameters Params `json:"parameters"` -// } - -// type Params struct { -// TextType string `json:"text_type"` -// } - -// type Input struct { -// Texts []string `json:"texts"` -// } - -// func (d *DvProvider) ConstructEmbeddingQueryParameters(vector []float64) (string, []byte, [][2]string, error) { -// url := fmt.Sprintf("/v1/collections/%s/query", d.config.DashVectorCollection) - -// requestData := QueryRequest{ -// Vector: vector, -// TopK: d.config.DashVectorTopK, -// IncludeVector: false, -// } - -// requestBody, err := json.Marshal(requestData) -// if err != nil { -// return "", nil, nil, err -// } - -// header := [][2]string{ -// {"Content-Type", "application/json"}, -// {"dashvector-auth-token", d.config.DashVectorKey}, -// } - -// return url, requestBody, header, nil -// } - -// func (d *DvProvider) ParseQueryResponse(responseBody []byte) (QueryResponse, error) { -// var queryResp QueryResponse -// err := json.Unmarshal(responseBody, &queryResp) -// if err != nil { -// return QueryResponse{}, err -// } -// return queryResp, nil -// } - -// func (d *DvProvider) QueryEmbedding( -// queryEmb []float64, -// ctx wrapper.HttpContext, -// log wrapper.Log, -// callback func(query_resp QueryResponse, ctx wrapper.HttpContext, log wrapper.Log)) { - -// // 构造请求参数 -// url, body, headers, err := d.ConstructEmbeddingQueryParameters(queryEmb) -// if err != nil { -// log.Infof("Failed to construct embedding query parameters: %v", err) -// } - -// err = d.client.Post(url, headers, body, -// func(statusCode int, responseHeaders http.Header, responseBody []byte) { -// log.Infof("Query embedding response: %d, %s", statusCode, responseBody) -// query_resp, err_query := d.ParseQueryResponse(responseBody) -// if err_query != nil { -// log.Infof("Failed to parse response: %v", err_query) -// } -// callback(query_resp, ctx, log) -// }, -// d.config.DashVectorTimeout) -// if err != nil { -// log.Infof("Failed to query embedding: %v", err) -// } - -// } - -// type Document struct { -// Vector []float64 `json:"vector"` -// Fields map[string]string `json:"fields"` -// } - -// type InsertRequest struct { -// Docs []Document `json:"docs"` -// } - -// func (d *DvProvider) ConstructEmbeddingUploadParameters(emb []float64, query_string string) (string, []byte, [][2]string, error) { -// url := "/v1/collections/" + d.config.DashVectorCollection + "/docs" - -// doc := Document{ -// Vector: emb, -// Fields: map[string]string{ -// "query": query_string, -// }, -// } - -// requestBody, err := json.Marshal(InsertRequest{Docs: []Document{doc}}) -// if err != nil { -// return "", nil, nil, err -// } - -// header := [][2]string{ -// {"Content-Type", "application/json"}, -// {"dashvector-auth-token", d.config.DashVectorKey}, -// } - -// return url, requestBody, header, err -// } - -// func (d *DvProvider) UploadEmbedding(query_emb []float64, queryString string, ctx wrapper.HttpContext, log wrapper.Log, callback func(ctx wrapper.HttpContext, log wrapper.Log)) { -// url, body, headers, _ := d.ConstructEmbeddingUploadParameters(query_emb, queryString) -// d.client.Post( -// url, -// headers, -// body, -// func(statusCode int, responseHeaders http.Header, responseBody []byte) { -// log.Infof("statusCode:%d, responseBody:%s", statusCode, string(responseBody)) -// callback(ctx, log) -// }, -// d.config.DashVectorTimeout) -// } +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" +) + +type weaviateProviderInitializer struct{} + +const weaviatePort = 8081 + +func (c *weaviateProviderInitializer) ValidateConfig(config ProviderConfig) error { + if len(config.WeaviateCollection) == 0 { + return errors.New("WeaviateCollection is required") + } + if len(config.WeaviateServiceName) == 0 { + return errors.New("WeaviateServiceName is required") + } + return nil +} + +func (c *weaviateProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + return &WeaviateProvider{ + config: config, + client: wrapper.NewClusterClient(wrapper.DnsCluster{ + ServiceName: config.WeaviateServiceName, + Port: weaviatePort, + Domain: config.WeaviateServiceName, + }), + }, nil +} + +type WeaviateProvider struct { + config ProviderConfig + client wrapper.HttpClient +} + +func (c *WeaviateProvider) GetProviderType() string { + return providerTypeWeaviate +} + +func (d *WeaviateProvider) GetThreshold() float64 { + return d.config.WeaviateThreshold +} + +func (d *WeaviateProvider) QueryEmbedding( + emb []float64, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log)) { + // 最少需要填写的参数为 class, vector 和 question + // 下面是一个例子 + // {"query": "{ Get { Higress ( limit: 2 nearVector: { vector: [0.1, 0.2, 0.3] } ) { question _additional { distance } } } }"} + embString, err := json.Marshal(emb) + if err != nil { + log.Errorf("[Weaviate] Failed to marshal query embedding: %v", err) + return + } + // 这里默认按照 distance 进行升序,所以不用再次排序 + graphql := fmt.Sprintf(` + { + Get { + %s ( + limit: %d + nearVector: { + vector: %s + } + ) { + question + _additional { + distance + } + } + } + } + `, d.config.WeaviateCollection, d.config.WeaviateNResult, embString) + + requestBody, err := json.Marshal(WeaviateQueryRequest{ + Query: graphql, + }) + + if err != nil { + log.Errorf("[Weaviate] Failed to marshal query embedding request body: %v", err) + return + } + + d.client.Post( + "/v1/graphql", + [][2]string{ + {"Content-Type", "application/json"}, + }, + requestBody, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + log.Infof("Query embedding response: %d, %s", statusCode, responseBody) + callback(responseBody, ctx, log) + }, + d.config.WeaviateTimeout, + ) +} + +func (d *WeaviateProvider) UploadEmbedding( + queryEmb []float64, + queryString string, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(ctx wrapper.HttpContext, log wrapper.Log)) { + // 最少需要填写的参数为 class, vector 和 question + // 下面是一个例子 + // {"class": "Higress", "vector": [0.1, 0.2, 0.3], "properties": {"question": "这里是问题"}} + requestBody, err := json.Marshal(WeaviateInsertRequest{ + Class: d.config.WeaviateCollection, + Vector: queryEmb, + Properties: WeaviateProperties{Question: queryString}, // queryString 指的是用户查询的问题 + }) + + if err != nil { + log.Errorf("[Weaviate] Failed to marshal upload embedding request body: %v", err) + return + } + + d.client.Post( + "/v1/objects", + [][2]string{ + {"Content-Type", "application/json"}, + }, + requestBody, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + log.Infof("[Weaviate] statusCode:%d, responseBody:%s", statusCode, string(responseBody)) + callback(ctx, log) + }, + d.config.WeaviateTimeout, + ) +} + +type WeaviateProperties struct { + Question string `json:"question"` +} + +type WeaviateInsertRequest struct { + Class string `json:"class"` + Vector []float64 `json:"vector"` + Properties WeaviateProperties `json:"properties"` +} + +type WeaviateQueryRequest struct { + Query string `json:"query"` +} + +func (d *WeaviateProvider) ParseQueryResponse(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log) (QueryEmbeddingResult, error) { + if !gjson.GetBytes(responseBody, fmt.Sprintf("data.Get.%s.0._additional.distance", d.config.WeaviateCollection)).Exists() { + log.Errorf("[Weaviate] No distance found in response body: %s", responseBody) + return QueryEmbeddingResult{}, nil + } + + if !gjson.GetBytes(responseBody, fmt.Sprintf("data.Get.%s.0.question", d.config.WeaviateCollection)).Exists() { + log.Errorf("[Weaviate] No question found in response body: %s", responseBody) + return QueryEmbeddingResult{}, nil + } + + return QueryEmbeddingResult{ + MostSimilarData: gjson.GetBytes(responseBody, fmt.Sprintf("data.Get.%s.0.question", d.config.WeaviateCollection)).String(), + Score: gjson.GetBytes(responseBody, fmt.Sprintf("data.Get.%s.0._additional.distance", d.config.WeaviateCollection)).Float(), + }, nil +} From 71cc25b6daeacd9db48394f6164e330db955f90e Mon Sep 17 00:00:00 2001 From: Async Date: Fri, 6 Sep 2024 17:54:35 +0800 Subject: [PATCH 14/71] feat: add pinecone fix: remove key --- docker-compose-test/Makefile | 4 + docker-compose-test/envoy.yaml | 41 +++- .../extensions/ai-cache/vector/pinecone.go | 180 ++++++++++++++++++ .../extensions/ai-cache/vector/provider.go | 44 +++++ .../extensions/ai-cache/vector/weaviate.go | 2 +- 5 files changed, 264 insertions(+), 7 deletions(-) create mode 100644 plugins/wasm-go/extensions/ai-cache/vector/pinecone.go diff --git a/docker-compose-test/Makefile b/docker-compose-test/Makefile index b66c628609..da7c64ba43 100644 --- a/docker-compose-test/Makefile +++ b/docker-compose-test/Makefile @@ -77,6 +77,10 @@ weaviate-search: redis-cli: docker run -it --network docker-compose-test_wasmtest --rm redis redis-cli -h docker-compose-test-redis-1 +# redis flushall +redis-flushall: + docker run -it --network docker-compose-test_wasmtest --rm redis redis-cli -h docker-compose-test-redis-1 flushall + # llm request llm: curl -X POST http://localhost:10000/v1/chat/completions -H "Content-Type: application/json" -d '{"model": "gpt-4o-mini", "messages": [{"role": "user", "content": "你好"}]}' diff --git a/docker-compose-test/envoy.yaml b/docker-compose-test/envoy.yaml index 48fafa99b0..90211dd5e1 100644 --- a/docker-compose-test/envoy.yaml +++ b/docker-compose-test/envoy.yaml @@ -58,14 +58,16 @@ static_resources: "embeddingProvider": { "type": "dashscope", "serviceName": "dashscope", - "apiKey": "sk-key", + "apiKey": "sk-", "DashScopeServiceName": "dashscope" }, "vectorProvider": { - "VectorStoreProviderType": "weaviate", - "WeaviateServiceName": "weaviate", - "WeaviateCollection": "Higress", - "WeaviateThreshold": "0.3" + "VectorStoreProviderType": "pinecone", + "PineconeServiceName": "pinecone", + "PineconeApiEndpoint": "higress-2bdfipe.svc.aped-4627-b74a.pinecone.io", + "PineconeThreshold": "0.7", + "ThresholdRelation": "gte", + "PineconeApiKey": "key" }, "cacheKeyFrom": { "requestBody": "" @@ -101,6 +103,13 @@ static_resources: # "ESUsername": "elastic", # "ESPassword": "123456" # }, + + # "vectorProvider": { + # "VectorStoreProviderType": "weaviate", + # "WeaviateServiceName": "weaviate", + # "WeaviateCollection": "Higress", + # "WeaviateThreshold": "0.3" + # }, # llm-proxy - name: llm-proxy typed_config: @@ -276,4 +285,24 @@ static_resources: name: envoy.transport_sockets.tls typed_config: "@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext - "sni": "dashscope.aliyuncs.com" \ No newline at end of file + "sni": "dashscope.aliyuncs.com" + # pinecone + - name: outbound|443||pinecone.dns + connect_timeout: 30s + type: LOGICAL_DNS + dns_lookup_family: V4_ONLY + lb_policy: ROUND_ROBIN + load_assignment: + cluster_name: outbound|443||pinecone.dns + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: higress-2bdfipe.svc.aped-4627-b74a.pinecone.io + port_value: 443 + transport_socket: + name: envoy.transport_sockets.tls + typed_config: + "@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext + "sni": "higress-2bdfipe.svc.aped-4627-b74a.pinecone.io" \ No newline at end of file diff --git a/plugins/wasm-go/extensions/ai-cache/vector/pinecone.go b/plugins/wasm-go/extensions/ai-cache/vector/pinecone.go new file mode 100644 index 0000000000..815a3a3637 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/vector/pinecone.go @@ -0,0 +1,180 @@ +package vector + +import ( + "encoding/json" + "errors" + "net/http" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/google/uuid" + "github.com/tidwall/gjson" +) + +type pineconeProviderInitializer struct{} + +const pineconePort = 443 + +func (c *pineconeProviderInitializer) ValidateConfig(config ProviderConfig) error { + if len(config.PineconeApiEndpoint) == 0 { + return errors.New("PineconeApiEndpoint is required") + } + if len(config.PineconeServiceName) == 0 { + return errors.New("PineconeServiceName is required") + } + if len(config.PineconeApiKey) == 0 { + return errors.New("PineconeApiKey is required") + } + return nil +} + +func (c *pineconeProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + return &pineconeProvider{ + config: config, + client: wrapper.NewClusterClient(wrapper.DnsCluster{ + ServiceName: config.PineconeServiceName, + Port: pineconePort, + Domain: config.PineconeApiEndpoint, + }), + }, nil +} + +type pineconeProvider struct { + config ProviderConfig + client wrapper.HttpClient +} + +func (c *pineconeProvider) GetProviderType() string { + return providerTypePinecone +} + +func (d *pineconeProvider) GetThreshold() float64 { + return d.config.PineconeThreshold +} + +type pineconeMetadata struct { + Question string `json:"question"` +} + +type pineconeVector struct { + ID string `json:"id"` + Values []float64 `json:"values"` + Properties pineconeMetadata `json:"metadata"` +} + +type pineconeInsertRequest struct { + Vectors []pineconeVector `json:"vectors"` + Namespace string `json:"namespace"` +} + +func (d *pineconeProvider) UploadEmbedding( + queryEmb []float64, + queryString string, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(ctx wrapper.HttpContext, log wrapper.Log)) { + // 最少需要填写的参数为 class, vector 和 question + // 下面是一个例子 + // { + // "vectors": [ + // { + // "id": "A", + // "values": [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], + // "metadata": {"question": "你好"} + // } + // ] + // } + requestBody, err := json.Marshal(pineconeInsertRequest{ + Vectors: []pineconeVector{ + { + ID: uuid.New().String(), + Values: queryEmb, + Properties: pineconeMetadata{Question: queryString}, + }, + }, + Namespace: d.config.PineconeNamespace, + }) + + if err != nil { + log.Errorf("[Pinecone] Failed to marshal upload embedding request body: %v", err) + return + } + + d.client.Post( + "/vectors/upsert", + [][2]string{ + {"Content-Type", "application/json"}, + {"Api-Key", d.config.PineconeApiKey}, + }, + requestBody, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + log.Infof("[Pinecone] statusCode:%d, responseBody:%s", statusCode, string(responseBody)) + callback(ctx, log) + }, + d.config.PineconeTimeout, + ) +} + +type pineconeQueryRequest struct { + Namespace string `json:"namespace"` + Vector []float64 `json:"vector"` + TopK int `json:"topK"` + IncludeMetadata bool `json:"includeMetadata"` + IncludeValues bool `json:"includeValues"` +} + +func (d *pineconeProvider) QueryEmbedding( + emb []float64, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log)) { + // 最少需要填写的参数为 vector + // 下面是一个例子 + // { + // "namespace": "higress", + // "vector": [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], + // "topK": 1, + // "includeMetadata": false + // } + requestBody, err := json.Marshal(pineconeQueryRequest{ + Namespace: d.config.PineconeNamespace, + Vector: emb, + TopK: d.config.PineconeTopK, + IncludeMetadata: true, + IncludeValues: false, + }) + if err != nil { + log.Errorf("[Pinecone] Failed to marshal query embedding: %v", err) + return + } + + d.client.Post( + "/query", + [][2]string{ + {"Content-Type", "application/json"}, + {"Api-Key", d.config.PineconeApiKey}, + }, + requestBody, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + log.Infof("Query embedding response: %d, %s", statusCode, responseBody) + callback(responseBody, ctx, log) + }, + d.config.PineconeTimeout, + ) +} + +func (d *pineconeProvider) ParseQueryResponse(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log) (QueryEmbeddingResult, error) { + if !gjson.GetBytes(responseBody, "matches.0.score").Exists() { + log.Errorf("[Pinecone] No distance found in response body: %s", responseBody) + return QueryEmbeddingResult{}, nil + } + + if !gjson.GetBytes(responseBody, "matches.0.metadata.question").Exists() { + log.Errorf("[Pinecone] No question found in response body: %s", responseBody) + return QueryEmbeddingResult{}, nil + } + + return QueryEmbeddingResult{ + MostSimilarData: gjson.GetBytes(responseBody, "matches.0.metadata.question").String(), + Score: gjson.GetBytes(responseBody, "matches.0.score").Float(), + }, nil +} diff --git a/plugins/wasm-go/extensions/ai-cache/vector/provider.go b/plugins/wasm-go/extensions/ai-cache/vector/provider.go index aa7c69538e..8ad4907a17 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/provider.go @@ -12,6 +12,7 @@ const ( providerTypeChroma = "chroma" providerTypeES = "elasticsearch" providerTypeWeaviate = "weaviate" + providerTypePinecone = "pinecone" ) type providerInitializer interface { @@ -25,6 +26,7 @@ var ( providerTypeChroma: &chromaProviderInitializer{}, providerTypeES: &esProviderInitializer{}, providerTypeWeaviate: &weaviateProviderInitializer{}, + providerTypePinecone: &pineconeProviderInitializer{}, } ) @@ -128,6 +130,28 @@ type ProviderConfig struct { // @Title zh-CN Chroma 超时设置 // @Description zh-CN Chroma 超时设置,默认为 10 秒 WeaviateTimeout uint32 `require:"false" yaml:"WeaviateTimeout" json:"WeaviateTimeout"` + + // @Title zh-CN Pinecone 的 upstream service name + // @Description zh-CN Pinecone 服务所对应的网关内上游服务名称 + PineconeServiceName string `require:"true" yaml:"PineconeServiceName" json:"PineconeServiceName"` + // @Title zh-CN Pinecone 的 api endpoint + // @Description zh-CN Pinecone 的 index host endpoint,例如 https://us-west4-gcp-free.pinecone.io + PineconeApiEndpoint string `require:"true" yaml:"PineconeApiEndpoint" json:"PineconeApiEndpoint"` + // @Title zh-CN Pinecone 的 api key + // @Description zh-CN Pinecone 的 api key + PineconeApiKey string `require:"true" yaml:"PineconeApiKey" json:"PineconeApiKey"` + // @Title zh-CN Pinecone 的 namespace + // @Description zh-CN Pinecone 的 namespace,默认为 higress + PineconeNamespace string `require:"false" yaml:"PineconeNamespace" json:"PineconeNamespace"` + // @Title zh-CN Pinecone 的超时设置 + // @Description zh-CN Pinecone 的超时设置,默认为 10 秒 + PineconeTimeout uint32 `require:"false" yaml:"PineconeTimeout" json:"PineconeTimeout"` + // @Title zh-CN Pinecone 的 TopK + // @Description zh-CN Pinecone 的 TopK,默认为 1 + PineconeTopK int `require:"false" yaml:"PineconeTopK" json:"PineconeTopK"` + // @Title zh-CN Pinecone 的距离阈值 + // @Description zh-CN Pinecone 的距离阈值,默认为 0.5,具体见 https://docs.pinecone.io/guides/indexes/understanding-indexes#distance-metrics + PineconeThreshold float64 `require:"false" yaml:"PineconeThreshold" json:"PineconeThreshold"` } func (c *ProviderConfig) FromJson(json gjson.Result) { @@ -202,6 +226,26 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { if c.WeaviateTimeout == 0 { c.WeaviateTimeout = 10000 } + // Pinecone + c.PineconeServiceName = json.Get("PineconeServiceName").String() + c.PineconeApiEndpoint = json.Get("PineconeApiEndpoint").String() + c.PineconeApiKey = json.Get("PineconeApiKey").String() + c.PineconeNamespace = json.Get("PineconeNamespace").String() + if c.PineconeNamespace == "" { + c.PineconeNamespace = "higress" + } + c.PineconeTimeout = uint32(json.Get("PineconeTimeout").Int()) + if c.PineconeTimeout == 0 { + c.PineconeTimeout = 10000 + } + c.PineconeTopK = int(json.Get("PineconeTopK").Int()) + if c.PineconeTopK == 0 { + c.PineconeTopK = 1 + } + c.PineconeThreshold = json.Get("PineconeThreshold").Float() + if c.PineconeThreshold == 0 { + c.PineconeThreshold = 0.5 + } } func (c *ProviderConfig) Validate() error { diff --git a/plugins/wasm-go/extensions/ai-cache/vector/weaviate.go b/plugins/wasm-go/extensions/ai-cache/vector/weaviate.go index 935525e1b3..6a01776349 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/weaviate.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/weaviate.go @@ -53,7 +53,7 @@ func (d *WeaviateProvider) QueryEmbedding( ctx wrapper.HttpContext, log wrapper.Log, callback func(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log)) { - // 最少需要填写的参数为 class, vector 和 question + // 最少需要填写的参数为 class, vector // 下面是一个例子 // {"query": "{ Get { Higress ( limit: 2 nearVector: { vector: [0.1, 0.2, 0.3] } ) { question _additional { distance } } } }"} embString, err := json.Marshal(emb) From 5179392b3c03169a321e7e2005d6b52a17919894 Mon Sep 17 00:00:00 2001 From: suchun Date: Fri, 6 Sep 2024 10:40:58 +0000 Subject: [PATCH 15/71] fix bugs, README.md to be updated --- plugins/wasm-go/extensions/ai-cache/README.md | 4 ++++ .../extensions/ai-cache/cache/provider.go | 2 +- .../wasm-go/extensions/ai-cache/cache/redis.go | 6 +++--- .../extensions/ai-cache/config/config.go | 18 +++++++++++++----- plugins/wasm-go/extensions/ai-cache/core.go | 2 -- plugins/wasm-go/extensions/ai-cache/main.go | 2 +- 6 files changed, 22 insertions(+), 12 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-cache/README.md b/plugins/wasm-go/extensions/ai-cache/README.md index 7f4b8f6571..a72177661d 100644 --- a/plugins/wasm-go/extensions/ai-cache/README.md +++ b/plugins/wasm-go/extensions/ai-cache/README.md @@ -10,6 +10,10 @@ LLM 结果缓存插件,默认配置方式可以直接用于 openai 协议的 ## 简介 +本插件的逻辑是 1. 通过`文本向量化接口`将请求内容向量化,结果作为key,原请求作为value,存入`向量数据库`。2. 同时,将请求内容作为key,LLM响应作为value,存入`缓存数据库`。3. 当有新请求时,通过向量化结果查询最相似的已有请求,若相似度高于设定阈值,则直接返回缓存的响应,否则将新请求和响应存入数据库,以提升处理效率。 +> TODO: 是否需要将`文本向量化接口`和`缓存数据库`作为可选项?因为部分向量数据库内置了向量化接口,其次直接使用向量数据库存储响应出错几率可能并不大,且配置项更少。 + + ## 配置说明 | Name | Type | Requirement | Default | Description | diff --git a/plugins/wasm-go/extensions/ai-cache/cache/provider.go b/plugins/wasm-go/extensions/ai-cache/cache/provider.go index e8a5f4ebd1..f1d4ebf744 100644 --- a/plugins/wasm-go/extensions/ai-cache/cache/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/cache/provider.go @@ -74,7 +74,7 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { } c.cacheTTL = uint32(json.Get("cacheTTL").Int()) if !json.Get("cacheTTL").Exists() { - c.cacheTTL = 0 + c.cacheTTL = 3600000 } c.cacheKeyPrefix = json.Get("cacheKeyPrefix").String() if !json.Get("cacheKeyPrefix").Exists() { diff --git a/plugins/wasm-go/extensions/ai-cache/cache/redis.go b/plugins/wasm-go/extensions/ai-cache/cache/redis.go index 082146c0d7..47fab4bfbb 100644 --- a/plugins/wasm-go/extensions/ai-cache/cache/redis.go +++ b/plugins/wasm-go/extensions/ai-cache/cache/redis.go @@ -6,7 +6,7 @@ import ( "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" ) -const DEFAULT_CACHE_PREFIX = "higressAiCache" +const DEFAULT_CACHE_PREFIX = "higressAiCache:" type redisProviderInitializer struct { } @@ -44,11 +44,11 @@ func (rp *redisProvider) Init(username string, password string, timeout uint32) } func (rp *redisProvider) Get(key string, cb wrapper.RedisResponseCallback) error { - return rp.client.Get(key, cb) + return rp.client.Get(DEFAULT_CACHE_PREFIX+key, cb) } func (rp *redisProvider) Set(key string, value string, cb wrapper.RedisResponseCallback) error { - return rp.client.SetEx(key, value, int(rp.config.cacheTTL), cb) + return rp.client.SetEx(DEFAULT_CACHE_PREFIX+key, value, int(rp.config.cacheTTL), cb) } func (rp *redisProvider) GetCacheKeyPrefix() string { diff --git a/plugins/wasm-go/extensions/ai-cache/config/config.go b/plugins/wasm-go/extensions/ai-cache/config/config.go index 15824268fc..b1c268a1e1 100644 --- a/plugins/wasm-go/extensions/ai-cache/config/config.go +++ b/plugins/wasm-go/extensions/ai-cache/config/config.go @@ -23,6 +23,14 @@ func (e *KVExtractor) SetRequestBodyFromJson(json gjson.Result, key string, defa } } +func (e *KVExtractor) SetResponseBodyFromJson(json gjson.Result, key string, defaultValue string) { + if json.Get(key).Exists() { + e.ResponseBody = json.Get(key).String() + } else { + e.ResponseBody = defaultValue + } +} + type PluginConfig struct { // @Title zh-CN 返回 HTTP 响应的模版 // @Description zh-CN 用 %s 标记需要被 cache value 替换的部分 @@ -31,7 +39,7 @@ type PluginConfig struct { // @Description zh-CN 用 %s 标记需要被 cache value 替换的部分 StreamResponseTemplate string `required:"true" yaml:"streamResponseTemplate" json:"streamResponseTemplate"` - redisProvider cache.Provider `yaml:"-"` + cacheProvider cache.Provider `yaml:"-"` embeddingProvider embedding.Provider `yaml:"-"` vectorProvider vector.Provider `yaml:"-"` @@ -50,8 +58,8 @@ func (c *PluginConfig) FromJson(json gjson.Result) { c.cacheProviderConfig.FromJson(json.Get("cache")) c.CacheKeyFrom.SetRequestBodyFromJson(json, "cacheKeyFrom.requestBody", "messages.@reverse.0.content") - c.CacheValueFrom.SetRequestBodyFromJson(json, "cacheValueFrom.requestBody", "choices.0.message.content") - c.CacheStreamValueFrom.SetRequestBodyFromJson(json, "cacheStreamValueFrom.requestBody", "choices.0.delta.content") + c.CacheValueFrom.SetResponseBodyFromJson(json, "cacheValueFrom.responseBody", "choices.0.message.content") + c.CacheStreamValueFrom.SetResponseBodyFromJson(json, "cacheStreamValueFrom.responseBody", "choices.0.delta.content") c.StreamResponseTemplate = json.Get("streamResponseTemplate").String() if c.StreamResponseTemplate == "" { @@ -86,7 +94,7 @@ func (c *PluginConfig) Complete(log wrapper.Log) error { if err != nil { return err } - c.redisProvider, err = cache.CreateProvider(c.cacheProviderConfig) + c.cacheProvider, err = cache.CreateProvider(c.cacheProviderConfig) if err != nil { return err } @@ -102,5 +110,5 @@ func (c *PluginConfig) GetVectorProvider() vector.Provider { } func (c *PluginConfig) GetCacheProvider() cache.Provider { - return c.redisProvider + return c.cacheProvider } diff --git a/plugins/wasm-go/extensions/ai-cache/core.go b/plugins/wasm-go/extensions/ai-cache/core.go index 23ecd67902..6d3233880f 100644 --- a/plugins/wasm-go/extensions/ai-cache/core.go +++ b/plugins/wasm-go/extensions/ai-cache/core.go @@ -12,8 +12,6 @@ import ( func RedisSearchHandler(key string, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, stream bool, ifUseEmbedding bool) { activeCacheProvider := config.GetCacheProvider() - key = activeCacheProvider.GetCacheKeyPrefix() + ":" + key - log.Debugf("activeCacheProvider:%v", activeCacheProvider) activeCacheProvider.Get(key, func(response resp.Value) { if err := response.Error(); err == nil && !response.IsNull() { log.Debugf("cache hit, key:%s", key) diff --git a/plugins/wasm-go/extensions/ai-cache/main.go b/plugins/wasm-go/extensions/ai-cache/main.go index 2fc63d8381..6d8dc6950e 100644 --- a/plugins/wasm-go/extensions/ai-cache/main.go +++ b/plugins/wasm-go/extensions/ai-cache/main.go @@ -208,6 +208,6 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config config.PluginConfig, chu } log.Infof("[onHttpResponseBody] Setting cache to redis, key:%s, value:%s", key, value) activeCacheProvider := config.GetCacheProvider() - config.GetCacheProvider().Set(activeCacheProvider.GetCacheKeyPrefix()+":"+key, value, nil) + activeCacheProvider.Set(key, value, nil) return chunk } From ece7e2fc6cbae41b736a8edade1dd6d71f1aba79 Mon Sep 17 00:00:00 2001 From: suchun Date: Fri, 6 Sep 2024 11:05:34 +0000 Subject: [PATCH 16/71] fix bugs, refine variable name, update README.md --- plugins/wasm-go/extensions/ai-cache/README.md | 93 ++++++++++++++----- .../extensions/ai-cache/cache/provider.go | 13 +-- .../extensions/ai-cache/cache/redis.go | 5 +- .../extensions/ai-cache/vector/provider.go | 30 ++++-- 4 files changed, 101 insertions(+), 40 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-cache/README.md b/plugins/wasm-go/extensions/ai-cache/README.md index a72177661d..d9c361684a 100644 --- a/plugins/wasm-go/extensions/ai-cache/README.md +++ b/plugins/wasm-go/extensions/ai-cache/README.md @@ -1,46 +1,94 @@ ## 简介 - **Note** > 需要数据面的proxy wasm版本大于等于0.2.100 +> > 编译时,需要带上版本的tag,例如:`tinygo build -o main.wasm -scheduler=none -target=wasi -gc=custom -tags="custommalloc nottinygc_finalizer proxy_wasm_version_0_2_100" ./` +> LLM 结果缓存插件,默认配置方式可以直接用于 openai 协议的结果缓存,同时支持流式和非流式响应的缓存。 ## 简介 +本插件的逻辑是 1. 通过`文本向量化接口`将请求内容向量化,结果作为 key,原请求作为 value,存入`向量数据库`。2. 同时,将请求内容作为key,LLM响应作为value,存入`缓存数据库`。3. 当有新请求时,通过向量化结果查询最相似的已有请求,若相似度高于设定阈值,则直接返回缓存的响应,否则将新请求和响应存入数据库,以提升处理效率。 -本插件的逻辑是 1. 通过`文本向量化接口`将请求内容向量化,结果作为key,原请求作为value,存入`向量数据库`。2. 同时,将请求内容作为key,LLM响应作为value,存入`缓存数据库`。3. 当有新请求时,通过向量化结果查询最相似的已有请求,若相似度高于设定阈值,则直接返回缓存的响应,否则将新请求和响应存入数据库,以提升处理效率。 > TODO: 是否需要将`文本向量化接口`和`缓存数据库`作为可选项?因为部分向量数据库内置了向量化接口,其次直接使用向量数据库存储响应出错几率可能并不大,且配置项更少。 - +> ## 配置说明 +配置分为 3 个部分:向量数据库(vector);文本向量化接口(embedding);缓存数据库(cache),同时也提供了细粒度的 LLM 请求/响应提取参数配置等。 + +## 向量数据库服务(vector) +| Name | Type | Requirement | Default | Description | +| --- | --- | --- | --- | --- | +| vector.type | string | required | "" | 向量存储服务提供者类型,例如 DashVector | +| vector.serviceName | string | required | "" | 向量存储服务名称 | +| vector.serviceDomain | string | required | "" | 向量存储服务域名 | +| vector.servicePort | int64 | optional | 443 | 向量存储服务端口 | +| vector.apiKey | string | optional | "" | 向量存储服务 API Key | +| vector.topK | int | optional | 1 | 返回TopK结果,默认为 1 | +| vector.timeout | uint32 | optional | 10000 | 请求向量存储服务的超时时间,单位为毫秒。默认值是10000,即10秒 | +| vector.collectionID | string | optional | "" | DashVector 向量存储服务 Collection ID | + + +## 文本向量化服务(embedding) +| Name | Type | Requirement | Default | Description | +| --- | --- | --- | --- | --- | +| embedding.type | string | required | "" | 请求文本向量化服务类型,例如 DashScope | +| embedding.serviceName | string | required | "" | 请求文本向量化服务名称 | +| embedding.serviceDomain | string | required | "" | 请求文本向量化服务域名 | +| embedding.servicePort | int64 | optional | 443 | 请求文本向量化服务端口 | +| embedding.apiKey | string | optional | "" | 请求文本向量化服务的 API Key | +| embedding.timeout | uint32 | optional | 10000 | 请求文本向量化服务的超时时间,单位为毫秒。默认值是10000,即10秒 | + + +## 缓存服务(cache) +| cache.type | string | required | "" | 缓存服务类型,例如 redis | +| --- | --- | --- | --- | --- | +| cache.serviceName | string | required | "" | 缓存服务名称 | +| cache.serviceDomain | string | required | "" | 缓存服务域名 | +| cache.servicePort | int64 | optional | 6379 | 缓存服务端口 | +| cache.userName | string | optional | "" | 缓存服务用户名 | +| cache.password | string | optional | "" | 缓存服务密码 | +| cache.timeout | uint32 | optional | 10000 | 缓存服务的超时时间,单位为毫秒。默认值是10000,即10秒 | +| cache.cacheTTL | uint32 | optional | 3600000 | 缓存过期时间,单位为秒。默认值是 3600000,即 1 小时 | +| cacheKeyPrefix | string | optional | "higressAiCache:" | 缓存 Key 的前缀,默认值为 "higressAiCache:" | + + +## 其他配置 +| Name | Type | Requirement | Default | Description | +| --- | --- | --- | --- | --- | +| cacheKeyFrom.requestBody | string | optional | "messages.@reverse.0.content" | 从请求 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 | +| cacheValueFrom.responseBody | string | optional | "choices.0.message.content" | 从响应 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 | +| cacheStreamValueFrom.responseBody | string | optional | "choices.0.delta.content" | 从流式响应 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 | +| responseTemplate | string | optional | `{"id":"from-cache","choices":[%s],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` | 返回 HTTP 响应的模版,用 %s 标记需要被 cache value 替换的部分 | +| streamResponseTemplate | string | optional | `data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}\n\ndata:[DONE]\n\n` | 返回流式 HTTP 响应的模版,用 %s 标记需要被 cache value 替换的部分 | -| Name | Type | Requirement | Default | Description | -| -------- | -------- | -------- | -------- | -------- | -| cacheKeyFrom.requestBody | string | optional | "messages.@reverse.0.content" | 从请求 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 | -| cacheValueFrom.responseBody | string | optional | "choices.0.message.content" | 从响应 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 | -| cacheStreamValueFrom.responseBody | string | optional | "choices.0.delta.content" | 从流式响应 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 | -| cacheKeyPrefix | string | optional | "higress-ai-cache:" | Redis缓存Key的前缀 | -| cacheTTL | integer | optional | 0 | 缓存的过期时间,单位是秒,默认值为0,即永不过期 | -| redis.serviceName | string | requried | - | redis 服务名称,带服务类型的完整 FQDN 名称,例如 my-redis.dns、redis.my-ns.svc.cluster.local | -| redis.servicePort | integer | optional | 6379 | redis 服务端口 | -| redis.timeout | integer | optional | 1000 | 请求 redis 的超时时间,单位为毫秒 | -| redis.username | string | optional | - | 登陆 redis 的用户名 | -| redis.password | string | optional | - | 登陆 redis 的密码 | -| returnResponseTemplate | string | optional | `{"id":"from-cache","choices":[%s],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` | 返回 HTTP 响应的模版,用 %s 标记需要被 cache value 替换的部分 | -| returnStreamResponseTemplate | string | optional | `data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}\n\ndata:[DONE]\n\n` | 返回流式 HTTP 响应的模版,用 %s 标记需要被 cache value 替换的部分 | ## 配置示例 - +### 基础配置 ```yaml -redis: - serviceName: my-redis.dns - timeout: 2000 +embedding: + type: dashscope + serviceName: [Your Service Name] + apiKey: [Your Key] + +vector: + type: dashvector + serviceName: [Your Service Name] + collectionID: [Your Collection ID] + serviceDomain: [Your domain] + apiKey: [Your key] + +cache: + type: redis + serviceName: [Your Service Name] + servicePort: 6379 + timeout: 100 + ``` ## 进阶用法 - 当前默认的缓存 key 是基于 GJSON PATH 的表达式:`messages.@reverse.0.content` 提取,含义是把 messages 数组反转后取第一项的 content; GJSON PATH 支持条件判断语法,例如希望取最后一个 role 为 user 的 content 作为 key,可以写成: `messages.@reverse.#(role=="user").content`; @@ -50,3 +98,4 @@ GJSON PATH 支持条件判断语法,例如希望取最后一个 role 为 user 还可以支持管道语法,例如希望取到数第二个 role 为 user 的 content 作为 key,可以写成:`messages.@reverse.#(role=="user")#.content|1`。 更多用法可以参考[官方文档](https://github.com/tidwall/gjson/blob/master/SYNTAX.md),可以使用 [GJSON Playground](https://gjson.dev/) 进行语法测试。 + diff --git a/plugins/wasm-go/extensions/ai-cache/cache/provider.go b/plugins/wasm-go/extensions/ai-cache/cache/provider.go index f1d4ebf744..c16ce6484a 100644 --- a/plugins/wasm-go/extensions/ai-cache/cache/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/cache/provider.go @@ -8,7 +8,8 @@ import ( ) const ( - PROVIDER_TYPE_REDIS = "redis" + PROVIDER_TYPE_REDIS = "redis" + DEFAULT_CACHE_PREFIX = "higressAiCache:" ) type providerInitializer interface { @@ -42,13 +43,13 @@ type ProviderConfig struct { // @Description zh-CN 缓存服务密码,非必填 password string // @Title zh-CN 请求超时 - // @Description zh-CN 请求缓存服务的超时时间,单位为毫秒。默认值是1000,即1秒 + // @Description zh-CN 请求缓存服务的超时时间,单位为毫秒。默认值是10000,即10秒 timeout uint32 // @Title zh-CN 缓存过期时间 - // @Description zh-CN 缓存过期时间,单位为秒。默认值是0,即永不过期 + // @Description zh-CN 缓存过期时间,单位为秒。默认值是3600000,即1小时 cacheTTL uint32 // @Title 缓存 Key 前缀 - // @Description 缓存 Key 的前缀,默认值为 "higressAiCache" + // @Description 缓存 Key 的前缀,默认值为 "higressAiCache:" cacheKeyPrefix string } @@ -70,7 +71,7 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { } c.timeout = uint32(json.Get("timeout").Int()) if !json.Get("timeout").Exists() { - c.timeout = 1000 + c.timeout = 10000 } c.cacheTTL = uint32(json.Get("cacheTTL").Int()) if !json.Get("cacheTTL").Exists() { @@ -78,7 +79,7 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { } c.cacheKeyPrefix = json.Get("cacheKeyPrefix").String() if !json.Get("cacheKeyPrefix").Exists() { - c.cacheKeyPrefix = "higressAiCache" + c.cacheKeyPrefix = DEFAULT_CACHE_PREFIX } } diff --git a/plugins/wasm-go/extensions/ai-cache/cache/redis.go b/plugins/wasm-go/extensions/ai-cache/cache/redis.go index 47fab4bfbb..0fc17e73fc 100644 --- a/plugins/wasm-go/extensions/ai-cache/cache/redis.go +++ b/plugins/wasm-go/extensions/ai-cache/cache/redis.go @@ -6,7 +6,6 @@ import ( "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" ) -const DEFAULT_CACHE_PREFIX = "higressAiCache:" type redisProviderInitializer struct { } @@ -44,11 +43,11 @@ func (rp *redisProvider) Init(username string, password string, timeout uint32) } func (rp *redisProvider) Get(key string, cb wrapper.RedisResponseCallback) error { - return rp.client.Get(DEFAULT_CACHE_PREFIX+key, cb) + return rp.client.Get(rp.GetCacheKeyPrefix()+key, cb) } func (rp *redisProvider) Set(key string, value string, cb wrapper.RedisResponseCallback) error { - return rp.client.SetEx(DEFAULT_CACHE_PREFIX+key, value, int(rp.config.cacheTTL), cb) + return rp.client.SetEx(rp.GetCacheKeyPrefix()+key, value, int(rp.config.cacheTTL), cb) } func (rp *redisProvider) GetCacheKeyPrefix() string { diff --git a/plugins/wasm-go/extensions/ai-cache/vector/provider.go b/plugins/wasm-go/extensions/ai-cache/vector/provider.go index 7de9e64404..0545608be6 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/provider.go @@ -33,13 +33,11 @@ type QueryEmbeddingResult struct { type Provider interface { GetProviderType() string - // TODO: 考虑失败的场景 QueryEmbedding( emb []float64, ctx wrapper.HttpContext, log wrapper.Log, callback func(results []QueryEmbeddingResult, ctx wrapper.HttpContext, log wrapper.Log)) - // TODO: 考虑失败的场景 UploadEmbedding( queryEmb []float64, queryString string, @@ -53,14 +51,28 @@ type Provider interface { type ProviderConfig struct { // @Title zh-CN 向量存储服务提供者类型 // @Description zh-CN 向量存储服务提供者类型,例如 DashVector、Milvus - typ string - serviceName string + typ string + // @Title zh-CN 向量存储服务名称 + // @Description zh-CN 向量存储服务名称 + serviceName string + // @Title zh-CN 向量存储服务域名 + // @Description zh-CN 向量存储服务域名 serviceDomain string - servicePort int64 - apiKey string - topK int - timeout uint32 - collectionID string + // @Title zh-CN 向量存储服务端口 + // @Description zh-CN 向量存储服务端口 + servicePort int64 + // @Title zh-CN 向量存储服务 API Key + // @Description zh-CN 向量存储服务 API Key + apiKey string + // @Title zh-CN 返回TopK结果 + // @Description zh-CN 返回TopK结果,默认为 1 + topK int + // @Title zh-CN 请求超时 + // @Description zh-CN 请求向量存储服务的超时时间,单位为毫秒。默认值是10000,即10秒 + timeout uint32 + // @Title zh-CN DashVector 向量存储服务 Collection ID + // @Description zh-CN DashVector 向量存储服务 Collection ID + collectionID string // // @Title zh-CN Chroma 的上游服务名称 // // @Description zh-CN Chroma 服务所对应的网关内上游服务名称 From 138a526e962c87f0bdc430e9cc4f55a0e2217c6a Mon Sep 17 00:00:00 2001 From: suchun Date: Fri, 6 Sep 2024 11:07:30 +0000 Subject: [PATCH 17/71] delete folder --- docker-compose-test/docker-compose.yml | 93 ---------- docker-compose-test/envoy.yaml | 231 ------------------------- 2 files changed, 324 deletions(-) delete mode 100644 docker-compose-test/docker-compose.yml delete mode 100644 docker-compose-test/envoy.yaml diff --git a/docker-compose-test/docker-compose.yml b/docker-compose-test/docker-compose.yml deleted file mode 100644 index 3b96146349..0000000000 --- a/docker-compose-test/docker-compose.yml +++ /dev/null @@ -1,93 +0,0 @@ -version: '3.7' -services: - envoy: - # image: higress-registry.cn-hangzhou.cr.aliyuncs.com/higress/gateway:v1.4.0-rc.1 - image: higress-registry.cn-hangzhou.cr.aliyuncs.com/higress/gateway:1.4.2 - entrypoint: /usr/local/bin/envoy - # 注意这里对wasm开启了debug级别日志,正式部署时则默认info级别 - command: -c /etc/envoy/envoy.yaml --component-log-level wasm:debug - depends_on: - - httpbin - - redis - - chroma - networks: - - wasmtest - ports: - - "10000:10000" - - "9901:9901" - volumes: - - ./envoy.yaml:/etc/envoy/envoy.yaml - # 注意默认没有这两个 wasm 的时候,docker 会创建文件夹,这样会出错,需要有 wasm 文件之后 down 然后重新 up - - ./ai-cache.wasm:/etc/envoy/ai-cache.wasm - - ./ai-proxy.wasm:/etc/envoy/ai-proxy.wasm - - chroma: - image: chromadb/chroma - ports: - - "8001:8000" - volumes: - - chroma-data:/chroma/chroma - - redis: - image: redis:latest - networks: - - wasmtest - ports: - - "6379:6379" - - httpbin: - image: kennethreitz/httpbin:latest - networks: - - wasmtest - ports: - - "12345:80" - - lobechat: - # docker hub 如果访问不了,可以改用这个地址:registry.cn-hangzhou.aliyuncs.com/2456868764/lobe-chat:v1.1.3 - image: lobehub/lobe-chat - environment: - - CODE=admin - - OPENAI_API_KEY=unused - - OPENAI_PROXY_URL=http://envoy:10000/v1 - networks: - - wasmtest - ports: - - "3210:3210/tcp" - - # weaviate: - # command: - # - --host - # - 0.0.0.0 - # - --port - # - '8080' - # - --scheme - # - http - # image: cr.weaviate.io/semitechnologies/weaviate:1.26.1 - # ports: - # - 8081:8080 - # - 50051:50051 - # volumes: - # - weaviate_data:/var/lib/weaviate - # restart: on-failure:0 - # networks: - # - wasmtest - # environment: - # QUERY_DEFAULTS_LIMIT: 25 - # AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: 'true' - # PERSISTENCE_DATA_PATH: '/var/lib/weaviate' - # DEFAULT_VECTORIZER_MODULE: 'none' - # ENABLE_API_BASED_MODULES: 'true' - # CLUSTER_HOSTNAME: 'node1' - # TRANSFORMERS_INFERENCE_API: http://t2v-transformers:8080 # Set the inference API endpoint - - # t2v-transformers: # Set the name of the inference container - # image: cr.weaviate.io/semitechnologies/transformers-inference:sentence-transformers-multi-qa-MiniLM-L6-cos-v1 - # environment: - # ENABLE_CUDA: 0 # Set to 1 to enable -volumes: - weaviate_data: {} - chroma-data: - driver: local - -networks: - wasmtest: {} \ No newline at end of file diff --git a/docker-compose-test/envoy.yaml b/docker-compose-test/envoy.yaml deleted file mode 100644 index b7d81a78ae..0000000000 --- a/docker-compose-test/envoy.yaml +++ /dev/null @@ -1,231 +0,0 @@ -admin: - address: - socket_address: - protocol: TCP - address: 0.0.0.0 - port_value: 9901 -static_resources: - listeners: - - name: listener_0 - address: - socket_address: - protocol: TCP - address: 0.0.0.0 - port_value: 10000 - filter_chains: - - filters: - # httpbin - - name: envoy.filters.network.http_connection_manager - typed_config: - "@type": type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager - scheme_header_transformation: - scheme_to_overwrite: https - stat_prefix: ingress_http - route_config: - name: local_route - virtual_hosts: - - name: local_service - domains: ["*"] - routes: - # - match: - # prefix: "/" - # route: - # cluster: httpbin - - match: - prefix: "/" - route: - cluster: llm - timeout: 300s - - http_filters: - # ai-cache - - name: ai-cache - typed_config: - "@type": type.googleapis.com/udpa.type.v1.TypedStruct - type_url: type.googleapis.com/envoy.extensions.filters.http.wasm.v3.Wasm - value: - config: - name: ai-cache - vm_config: - runtime: envoy.wasm.runtime.v8 - code: - local: - filename: /etc/envoy/ai-cache.wasm - configuration: - "@type": "type.googleapis.com/google.protobuf.StringValue" - value: | - { - "embeddingProvider": { - "type": "dashscope", - "serviceName": "dashscope", - "apiKey": "sk-your-key", - "DashScopeServiceName": "dashscope" - }, - "vectorProvider": { - "vectorStoreProviderType": "chroma", - "ChromaServiceName": "chroma", - "ChromaCollectionID": "0294deb1-8ef5-4582-b21c-75f23093db2c" - }, - "cacheKeyFrom": { - "requestBody": "" - }, - "cacheValueFrom": { - "responseBody": "" - }, - "cacheStreamValueFrom": { - "responseBody": "" - }, - "returnResponseTemplate": "", - "returnTestResponseTemplate": "", - "ReturnStreamResponseTemplate": "", - "redis": { - "serviceName": "redis_cluster", - "timeout": 2000 - } - } - # 上面的配置中 redis 的配置名字是 redis,而不是 golang tag 中的 redisConfig - - # llm-proxy - - name: llm-proxy - typed_config: - "@type": type.googleapis.com/udpa.type.v1.TypedStruct - type_url: type.googleapis.com/envoy.extensions.filters.http.wasm.v3.Wasm - value: - config: - name: llm - vm_config: - runtime: envoy.wasm.runtime.v8 - code: - local: - filename: /etc/envoy/ai-proxy.wasm - configuration: - "@type": "type.googleapis.com/google.protobuf.StringValue" - value: | # 插件配置 - { - "provider": { - "type": "openai", - "apiTokens": [ - "YOUR_API_TOKEN" - ], - "openaiCustomUrl": "172.17.0.1:8000/v1/chat/completions" - } - } - - - - name: envoy.filters.http.router - - clusters: - - name: httpbin - connect_timeout: 30s - type: LOGICAL_DNS - # Comment out the following line to test on v6 networks - dns_lookup_family: V4_ONLY - lb_policy: ROUND_ROBIN - load_assignment: - cluster_name: httpbin - endpoints: - - lb_endpoints: - - endpoint: - address: - socket_address: - address: httpbin - port_value: 80 - # - name: redis_cluster - # connect_timeout: 30s - # type: STRICT_DNS - # lb_policy: ROUND_ROBIN - # load_assignment: - # cluster_name: redis - # endpoints: - # - lb_endpoints: - # - endpoint: - # address: - # socket_address: - # address: 172.17.0.1 - # port_value: 6379 - - name: outbound|6379||redis_cluster - connect_timeout: 1s - type: strict_dns - lb_policy: ROUND_ROBIN - load_assignment: - cluster_name: outbound|6379||redis_cluster - endpoints: - - lb_endpoints: - - endpoint: - address: - socket_address: - address: 172.17.0.1 - port_value: 6379 - typed_extension_protocol_options: - envoy.filters.network.redis_proxy: - "@type": type.googleapis.com/envoy.extensions.filters.network.redis_proxy.v3.RedisProtocolOptions - # chroma - - name: outbound|8001||chroma.dns - connect_timeout: 30s - type: LOGICAL_DNS - dns_lookup_family: V4_ONLY - lb_policy: ROUND_ROBIN - load_assignment: - cluster_name: outbound|8001||chroma.dns - endpoints: - - lb_endpoints: - - endpoint: - address: - socket_address: - address: 172.17.0.1 # 本地 API 服务地址,这里是 docker0 - port_value: 8001 - # llm - - name: llm - connect_timeout: 30s - type: LOGICAL_DNS - dns_lookup_family: V4_ONLY - lb_policy: ROUND_ROBIN - load_assignment: - cluster_name: llm - endpoints: - - lb_endpoints: - - endpoint: - address: - socket_address: - address: 172.17.0.1 # 本地 API 服务地址,这里是 docker0 - port_value: 8000 - # dashvector - - name: outbound|443||dashvector.dns - connect_timeout: 30s - type: LOGICAL_DNS - dns_lookup_family: V4_ONLY - lb_policy: ROUND_ROBIN - load_assignment: - cluster_name: outbound|443||dashvector.dns - endpoints: - - lb_endpoints: - - endpoint: - address: - socket_address: - address: vrs-cn-0dw3vnaqs0002z.dashvector.cn-hangzhou.aliyuncs.com - port_value: 443 - transport_socket: - name: envoy.transport_sockets.tls - typed_config: - "@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext - "sni": "vrs-cn-0dw3vnaqs0002z.dashvector.cn-hangzhou.aliyuncs.com" - # dashscope - - name: outbound|443||dashscope.dns - connect_timeout: 30s - type: LOGICAL_DNS - dns_lookup_family: V4_ONLY - lb_policy: ROUND_ROBIN - load_assignment: - cluster_name: outbound|443||dashscope.dns - endpoints: - - lb_endpoints: - - endpoint: - address: - socket_address: - address: dashscope.aliyuncs.com - port_value: 443 - transport_socket: - name: envoy.transport_sockets.tls - typed_config: - "@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext - "sni": "dashscope.aliyuncs.com" \ No newline at end of file From bfaed4c5b10a3e7827d4549c3d58dc2c8166ccfd Mon Sep 17 00:00:00 2001 From: Async Date: Fri, 6 Sep 2024 19:24:30 +0800 Subject: [PATCH 18/71] fix: format --- plugins/wasm-go/extensions/request-block/Dockerfile | 2 -- plugins/wasm-go/extensions/request-block/Makefile | 4 ---- plugins/wasm-go/extensions/request-block/main.go | 2 -- 3 files changed, 8 deletions(-) delete mode 100644 plugins/wasm-go/extensions/request-block/Dockerfile delete mode 100644 plugins/wasm-go/extensions/request-block/Makefile diff --git a/plugins/wasm-go/extensions/request-block/Dockerfile b/plugins/wasm-go/extensions/request-block/Dockerfile deleted file mode 100644 index 9b084e0596..0000000000 --- a/plugins/wasm-go/extensions/request-block/Dockerfile +++ /dev/null @@ -1,2 +0,0 @@ -FROM scratch -COPY main.wasm plugin.wasm \ No newline at end of file diff --git a/plugins/wasm-go/extensions/request-block/Makefile b/plugins/wasm-go/extensions/request-block/Makefile deleted file mode 100644 index 1210d6ec34..0000000000 --- a/plugins/wasm-go/extensions/request-block/Makefile +++ /dev/null @@ -1,4 +0,0 @@ -.DEFAULT: -build: - tinygo build -o main.wasm -scheduler=none -target=wasi -gc=custom -tags='custommalloc nottinygc_finalizer' ./main.go - mv main.wasm ../../../../docker-compose-test/ \ No newline at end of file diff --git a/plugins/wasm-go/extensions/request-block/main.go b/plugins/wasm-go/extensions/request-block/main.go index 2a43b4df72..224d4b26d6 100644 --- a/plugins/wasm-go/extensions/request-block/main.go +++ b/plugins/wasm-go/extensions/request-block/main.go @@ -177,9 +177,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config RequestBlockConfig, lo } func onHttpRequestBody(ctx wrapper.HttpContext, config RequestBlockConfig, body []byte, log wrapper.Log) types.Action { - log.Infof("My request-block body: %s\n", string(body)) bodyStr := string(body) - if !config.caseSensitive { bodyStr = strings.ToLower(bodyStr) } From e8ad550f3af57c03b8dba2ffd6b4a23ab6f37aa5 Mon Sep 17 00:00:00 2001 From: Yang Beining <35399433+Suchun-sv@users.noreply.github.com> Date: Fri, 6 Sep 2024 12:59:06 +0100 Subject: [PATCH 19/71] fix typos --- plugins/wasm-go/extensions/ai-cache/cache/provider.go | 3 --- plugins/wasm-go/extensions/ai-cache/core.go | 2 +- plugins/wasm-go/extensions/ai-cache/main.go | 1 - plugins/wasm-go/extensions/ai-cache/vector/dashvector.go | 4 ++-- 4 files changed, 3 insertions(+), 7 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-cache/cache/provider.go b/plugins/wasm-go/extensions/ai-cache/cache/provider.go index c16ce6484a..e22427c6a5 100644 --- a/plugins/wasm-go/extensions/ai-cache/cache/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/cache/provider.go @@ -90,9 +90,6 @@ func (c *ProviderConfig) Validate() error { if c.serviceName == "" { return errors.New("cache service name is required") } - if c.timeout <= 0 { - return errors.New("cache service timeout must be greater than 0") - } initializer, has := providerInitializers[c.typ] if !has { return errors.New("unknown cache service provider type: " + c.typ) diff --git a/plugins/wasm-go/extensions/ai-cache/core.go b/plugins/wasm-go/extensions/ai-cache/core.go index 6d3233880f..129969704d 100644 --- a/plugins/wasm-go/extensions/ai-cache/core.go +++ b/plugins/wasm-go/extensions/ai-cache/core.go @@ -62,7 +62,7 @@ func QueryVectorDB(key string, textEmbedding []float64, ctx wrapper.HttpContext, log.Debugf("activeVectorProvider: %+v", activeVectorProvider) activeVectorProvider.QueryEmbedding(textEmbedding, ctx, log, func(results []vector.QueryEmbeddingResult, ctx wrapper.HttpContext, log wrapper.Log) { - // The baisc logic is to compare the similarity of the embedding with the most similar key in the database + // The basic logic is to compare the similarity of the embedding with the most similar key in the database if len(results) == 0 { log.Warnf("Failed to query vector database, no similar key found") activeVectorProvider.UploadEmbedding(textEmbedding, key, ctx, log, diff --git a/plugins/wasm-go/extensions/ai-cache/main.go b/plugins/wasm-go/extensions/ai-cache/main.go index 6d8dc6950e..fdd0fc810c 100644 --- a/plugins/wasm-go/extensions/ai-cache/main.go +++ b/plugins/wasm-go/extensions/ai-cache/main.go @@ -19,7 +19,6 @@ const ( PARTIAL_MESSAGE_CONTEXT_KEY = "partialMessage" TOOL_CALLS_CONTEXT_KEY = "toolCalls" STREAM_CONTEXT_KEY = "stream" - QUERY_EMBEDDING_KEY = "queryEmbedding" ) func main() { diff --git a/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go b/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go index b0801eb82d..52fb26a391 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go @@ -187,13 +187,13 @@ type insertRequest struct { Docs []document `json:"docs"` } -func (d *DvProvider) constructEmbeddingUploadParameters(emb []float64, query_string string) (string, []byte, [][2]string, error) { +func (d *DvProvider) constructEmbeddingUploadParameters(emb []float64, queryString string) (string, []byte, [][2]string, error) { url := "/v1/collections/" + d.config.collectionID + "/docs" doc := document{ Vector: emb, Fields: map[string]string{ - "query": query_string, + "query": queryString, }, } From a40f5e9067e625f106e4b97f2e36fd23d4c42f9e Mon Sep 17 00:00:00 2001 From: Async Date: Fri, 6 Sep 2024 20:02:11 +0800 Subject: [PATCH 20/71] update --- plugins/wasm-go/extensions/ai-cache/vector/pinecone.go | 2 +- plugins/wasm-go/extensions/transformer/go.mod | 2 +- plugins/wasm-go/extensions/transformer/go.sum | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-cache/vector/pinecone.go b/plugins/wasm-go/extensions/ai-cache/vector/pinecone.go index 815a3a3637..76e33c4b61 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/pinecone.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/pinecone.go @@ -72,7 +72,7 @@ func (d *pineconeProvider) UploadEmbedding( ctx wrapper.HttpContext, log wrapper.Log, callback func(ctx wrapper.HttpContext, log wrapper.Log)) { - // 最少需要填写的参数为 class, vector 和 question + // 最少需要填写的参数为 vector 和 question // 下面是一个例子 // { // "vectors": [ diff --git a/plugins/wasm-go/extensions/transformer/go.mod b/plugins/wasm-go/extensions/transformer/go.mod index e70583a937..464974140a 100644 --- a/plugins/wasm-go/extensions/transformer/go.mod +++ b/plugins/wasm-go/extensions/transformer/go.mod @@ -9,7 +9,7 @@ require ( github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f github.com/pkg/errors v0.9.1 github.com/stretchr/testify v1.8.4 - github.com/tidwall/gjson v1.17.0 + github.com/tidwall/gjson v1.17.3 github.com/tidwall/pretty v1.2.1 github.com/tidwall/sjson v1.2.5 github.com/wasilibs/go-re2 v1.6.0 diff --git a/plugins/wasm-go/extensions/transformer/go.sum b/plugins/wasm-go/extensions/transformer/go.sum index 76246bba99..897140b6e4 100644 --- a/plugins/wasm-go/extensions/transformer/go.sum +++ b/plugins/wasm-go/extensions/transformer/go.sum @@ -19,6 +19,7 @@ github.com/tetratelabs/wazero v1.7.2/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXA github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.17.0 h1:/Jocvlh98kcTfpN2+JzGQWQcqrPQwDrVEMApx/M5ZwM= github.com/tidwall/gjson v1.17.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= From c83f5c4f5a6df837b42a4d2c5119a7675b474020 Mon Sep 17 00:00:00 2001 From: suchun <2594405419@qq.com> Date: Fri, 6 Sep 2024 13:15:53 +0100 Subject: [PATCH 21/71] fix typos --- plugins/wasm-go/extensions/ai-cache/core.go | 3 ++- plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go | 2 +- plugins/wasm-go/extensions/ai-cache/main.go | 6 +++--- plugins/wasm-go/extensions/ai-cache/vector/dashvector.go | 5 ++--- plugins/wasm-go/extensions/ai-cache/vector/provider.go | 2 +- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-cache/core.go b/plugins/wasm-go/extensions/ai-cache/core.go index 129969704d..63da11d444 100644 --- a/plugins/wasm-go/extensions/ai-cache/core.go +++ b/plugins/wasm-go/extensions/ai-cache/core.go @@ -49,7 +49,8 @@ func HandleCacheMiss(key string, err error, response resp.Value, ctx wrapper.Htt func FetchAndProcessEmbeddings(key string, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, queryString string, stream bool) { activeEmbeddingProvider := config.GetEmbeddingProvider() - activeEmbeddingProvider.GetEmbedding(queryString, ctx, log, + // The error will be handled within the callback function + _ = activeEmbeddingProvider.GetEmbedding(queryString, ctx, log, func(emb []float64) { log.Debugf("Successfully fetched embeddings for key: %s", key) QueryVectorDB(key, emb, ctx, config, log, stream) diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go b/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go index 0aaa6e2e68..fd880f2d04 100644 --- a/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go +++ b/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go @@ -149,7 +149,7 @@ func (d *DSProvider) GetEmbedding( } var resp *Response - d.client.Post(embUrl, embHeaders, embRequestBody, + err = d.client.Post(embUrl, embHeaders, embRequestBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) { if statusCode != http.StatusOK { log.Errorf("Failed to fetch embeddings, statusCode: %d, responseBody: %s", statusCode, string(responseBody)) diff --git a/plugins/wasm-go/extensions/ai-cache/main.go b/plugins/wasm-go/extensions/ai-cache/main.go index fdd0fc810c..d18b10c61c 100644 --- a/plugins/wasm-go/extensions/ai-cache/main.go +++ b/plugins/wasm-go/extensions/ai-cache/main.go @@ -76,7 +76,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config config.PluginConfig, l ctx.DontReadRequestBody() return types.ActionContinue } - proxywasm.RemoveHttpRequestHeader("Accept-Encoding") + _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") // The request has a body and requires delaying the header transmission until a cache miss occurs, // at which point the header should be sent. return types.HeaderStopIteration @@ -98,7 +98,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config config.PluginConfig, body ctx.SetContext(CACHE_KEY_CONTEXT_KEY, key) log.Debugf("[onHttpRequestBody] key:%s", key) if key == "" { - log.Debug("[onHttpRquestBody] parse key from request body failed") + log.Debug("[onHttpRequestBody] parse key from request body failed") return types.ActionContinue } @@ -207,6 +207,6 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config config.PluginConfig, chu } log.Infof("[onHttpResponseBody] Setting cache to redis, key:%s, value:%s", key, value) activeCacheProvider := config.GetCacheProvider() - activeCacheProvider.Set(key, value, nil) + _ = activeCacheProvider.Set(key, value, nil) return chunk } diff --git a/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go b/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go index 52fb26a391..2b4db49d83 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go @@ -134,7 +134,7 @@ func (d *DvProvider) QueryEmbedding( log.Infof("Failed to construct embedding query parameters: %v", err) } - d.client.Post(url, headers, body, + err = d.client.Post(url, headers, body, func(statusCode int, responseHeaders http.Header, responseBody []byte) { if statusCode != http.StatusOK { log.Infof("Failed to query embedding: %d", statusCode) @@ -142,7 +142,6 @@ func (d *DvProvider) QueryEmbedding( } log.Debugf("Query embedding response: %d, %s", statusCode, responseBody) results, err := d.ParseQueryResponse(responseBody, ctx, log) - // TODO: 如果解析失败,应该如何处理? if err != nil { log.Infof("Failed to parse query response: %v", err) return @@ -212,7 +211,7 @@ func (d *DvProvider) constructEmbeddingUploadParameters(emb []float64, queryStri func (d *DvProvider) UploadEmbedding(queryEmb []float64, queryString string, ctx wrapper.HttpContext, log wrapper.Log, callback func(ctx wrapper.HttpContext, log wrapper.Log)) { url, body, headers, _ := d.constructEmbeddingUploadParameters(queryEmb, queryString) - d.client.Post( + _ = d.client.Post( url, headers, body, diff --git a/plugins/wasm-go/extensions/ai-cache/vector/provider.go b/plugins/wasm-go/extensions/ai-cache/vector/provider.go index 0545608be6..e7b6c98e5f 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/provider.go @@ -24,7 +24,7 @@ var ( } ) -// 定义通用的查询结果的结构体 +// QueryEmbeddingResult 定义通用的查询结果的结构体 type QueryEmbeddingResult struct { Text string // 相似的文本 Embedding []float64 // 相似文本的向量 From f3d3292c09c8651dc31a7ffb3954e26f9c81bb9b Mon Sep 17 00:00:00 2001 From: suchun <2594405419@qq.com> Date: Fri, 6 Sep 2024 13:20:26 +0100 Subject: [PATCH 22/71] change append to appendMsg --- plugins/wasm-go/extensions/ai-cache/util.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-cache/util.go b/plugins/wasm-go/extensions/ai-cache/util.go index a1f613ce69..34ecaf1b1b 100644 --- a/plugins/wasm-go/extensions/ai-cache/util.go +++ b/plugins/wasm-go/extensions/ai-cache/util.go @@ -34,8 +34,8 @@ func processSSEMessage(ctx wrapper.HttpContext, config config.PluginConfig, sseM ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, content) return content } - append := TrimQuote(gjson.Get(bodyJson, config.CacheStreamValueFrom.ResponseBody).Raw) - content := tempContentI.(string) + append + appendMsg := TrimQuote(gjson.Get(bodyJson, config.CacheStreamValueFrom.ResponseBody).Raw) + content := tempContentI.(string) + appendMsg ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, content) return content } else if gjson.Get(bodyJson, "choices.0.delta.content.tool_calls").Exists() { From b0cf29dff2d2543b016dee9da42900f8d0b059dd Mon Sep 17 00:00:00 2001 From: suchun Date: Wed, 11 Sep 2024 00:52:51 +0000 Subject: [PATCH 23/71] fix bugs and refine code --- plugins/wasm-go/extensions/ai-cache/README.md | 5 +- .../extensions/ai-cache/cache/provider.go | 18 +-- .../extensions/ai-cache/cache/redis.go | 18 +-- .../extensions/ai-cache/config/config.go | 36 ++--- plugins/wasm-go/extensions/ai-cache/core.go | 129 +++++++++++------- .../ai-cache/embedding/dashscope.go | 47 ++++--- .../extensions/ai-cache/embedding/provider.go | 6 +- plugins/wasm-go/extensions/ai-cache/main.go | 20 ++- plugins/wasm-go/extensions/ai-cache/util.go | 40 ++++-- .../extensions/ai-cache/vector/dashvector.go | 39 +++--- .../extensions/ai-cache/vector/provider.go | 7 +- 11 files changed, 206 insertions(+), 159 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-cache/README.md b/plugins/wasm-go/extensions/ai-cache/README.md index d9c361684a..6b7fc7c493 100644 --- a/plugins/wasm-go/extensions/ai-cache/README.md +++ b/plugins/wasm-go/extensions/ai-cache/README.md @@ -40,6 +40,7 @@ LLM 结果缓存插件,默认配置方式可以直接用于 openai 协议的 | embedding.servicePort | int64 | optional | 443 | 请求文本向量化服务端口 | | embedding.apiKey | string | optional | "" | 请求文本向量化服务的 API Key | | embedding.timeout | uint32 | optional | 10000 | 请求文本向量化服务的超时时间,单位为毫秒。默认值是10000,即10秒 | +| embedding.model | string | optional | "" | 请求文本向量化服务的模型名称 | ## 缓存服务(cache) @@ -48,10 +49,10 @@ LLM 结果缓存插件,默认配置方式可以直接用于 openai 协议的 | cache.serviceName | string | required | "" | 缓存服务名称 | | cache.serviceDomain | string | required | "" | 缓存服务域名 | | cache.servicePort | int64 | optional | 6379 | 缓存服务端口 | -| cache.userName | string | optional | "" | 缓存服务用户名 | +| cache.username | string | optional | "" | 缓存服务用户名 | | cache.password | string | optional | "" | 缓存服务密码 | | cache.timeout | uint32 | optional | 10000 | 缓存服务的超时时间,单位为毫秒。默认值是10000,即10秒 | -| cache.cacheTTL | uint32 | optional | 3600000 | 缓存过期时间,单位为秒。默认值是 3600000,即 1 小时 | +| cache.cacheTTL | uint32 | optional | 0 | 缓存过期时间,单位为秒。默认值是 0,即 永不过期| | cacheKeyPrefix | string | optional | "higressAiCache:" | 缓存 Key 的前缀,默认值为 "higressAiCache:" | diff --git a/plugins/wasm-go/extensions/ai-cache/cache/provider.go b/plugins/wasm-go/extensions/ai-cache/cache/provider.go index e22427c6a5..53f3d7eb1f 100644 --- a/plugins/wasm-go/extensions/ai-cache/cache/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/cache/provider.go @@ -38,7 +38,7 @@ type ProviderConfig struct { serviceHost string // @Title zh-CN 缓存服务用户名 // @Description zh-CN 缓存服务用户名,非必填 - userName string + username string // @Title zh-CN 缓存服务密码 // @Description zh-CN 缓存服务密码,非必填 password string @@ -46,7 +46,7 @@ type ProviderConfig struct { // @Description zh-CN 请求缓存服务的超时时间,单位为毫秒。默认值是10000,即10秒 timeout uint32 // @Title zh-CN 缓存过期时间 - // @Description zh-CN 缓存过期时间,单位为秒。默认值是3600000,即1小时 + // @Description zh-CN 缓存过期时间,单位为秒。默认值是0,即永不过期 cacheTTL uint32 // @Title 缓存 Key 前缀 // @Description 缓存 Key 的前缀,默认值为 "higressAiCache:" @@ -61,9 +61,9 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { c.servicePort = 6379 } c.serviceHost = json.Get("serviceHost").String() - c.userName = json.Get("username").String() + c.username = json.Get("username").String() if !json.Get("username").Exists() { - c.userName = "" + c.username = "" } c.password = json.Get("password").String() if !json.Get("password").Exists() { @@ -75,12 +75,14 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { } c.cacheTTL = uint32(json.Get("cacheTTL").Int()) if !json.Get("cacheTTL").Exists() { - c.cacheTTL = 3600000 + c.cacheTTL = 0 } - c.cacheKeyPrefix = json.Get("cacheKeyPrefix").String() - if !json.Get("cacheKeyPrefix").Exists() { + if json.Get("cacheKeyPrefix").Exists() { + c.cacheKeyPrefix = json.Get("cacheKeyPrefix").String() + } else { c.cacheKeyPrefix = DEFAULT_CACHE_PREFIX } + } func (c *ProviderConfig) Validate() error { @@ -113,5 +115,5 @@ type Provider interface { Init(username string, password string, timeout uint32) error Get(key string, cb wrapper.RedisResponseCallback) error Set(key string, value string, cb wrapper.RedisResponseCallback) error - GetCacheKeyPrefix() string + getCacheKeyPrefix() string } diff --git a/plugins/wasm-go/extensions/ai-cache/cache/redis.go b/plugins/wasm-go/extensions/ai-cache/cache/redis.go index 0fc17e73fc..2c3d49047a 100644 --- a/plugins/wasm-go/extensions/ai-cache/cache/redis.go +++ b/plugins/wasm-go/extensions/ai-cache/cache/redis.go @@ -6,13 +6,12 @@ import ( "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" ) - type redisProviderInitializer struct { } func (r *redisProviderInitializer) ValidateConfig(cf ProviderConfig) error { if len(cf.serviceName) == 0 { - return errors.New("[redis] cache service name is required") + return errors.New("cache service name is required") } return nil } @@ -25,7 +24,7 @@ func (r *redisProviderInitializer) CreateProvider(cf ProviderConfig) (Provider, Host: cf.serviceHost, Port: int64(cf.servicePort)}), } - err := rp.Init(cf.userName, cf.password, cf.timeout) + err := rp.Init(cf.username, cf.password, cf.timeout) return &rp, err } @@ -35,24 +34,21 @@ type redisProvider struct { } func (rp *redisProvider) GetProviderType() string { - return "redis" + return PROVIDER_TYPE_REDIS } func (rp *redisProvider) Init(username string, password string, timeout uint32) error { - return rp.client.Init(rp.config.userName, rp.config.password, int64(rp.config.timeout)) + return rp.client.Init(rp.config.username, rp.config.password, int64(rp.config.timeout)) } func (rp *redisProvider) Get(key string, cb wrapper.RedisResponseCallback) error { - return rp.client.Get(rp.GetCacheKeyPrefix()+key, cb) + return rp.client.Get(rp.getCacheKeyPrefix()+key, cb) } func (rp *redisProvider) Set(key string, value string, cb wrapper.RedisResponseCallback) error { - return rp.client.SetEx(rp.GetCacheKeyPrefix()+key, value, int(rp.config.cacheTTL), cb) + return rp.client.SetEx(rp.getCacheKeyPrefix()+key, value, int(rp.config.cacheTTL), cb) } -func (rp *redisProvider) GetCacheKeyPrefix() string { - if len(rp.config.cacheKeyPrefix) == 0 { - return DEFAULT_CACHE_PREFIX - } +func (rp *redisProvider) getCacheKeyPrefix() string { return rp.config.cacheKeyPrefix } diff --git a/plugins/wasm-go/extensions/ai-cache/config/config.go b/plugins/wasm-go/extensions/ai-cache/config/config.go index b1c268a1e1..74a4f2e2d7 100644 --- a/plugins/wasm-go/extensions/ai-cache/config/config.go +++ b/plugins/wasm-go/extensions/ai-cache/config/config.go @@ -8,26 +8,28 @@ import ( "github.com/tidwall/gjson" ) -type KVExtractor struct { +type BodyPathMapper struct { // @Title zh-CN 从请求 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 - RequestBody string `required:"false" yaml:"requestBody" json:"requestBody"` + RequestPath string `required:"false" yaml:"requestBody" json:"requestBody"` // @Title zh-CN 从响应 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 - ResponseBody string `required:"false" yaml:"responseBody" json:"responseBody"` + ResponsePath string `required:"false" yaml:"responseBody" json:"responseBody"` } -func (e *KVExtractor) SetRequestBodyFromJson(json gjson.Result, key string, defaultValue string) { - if json.Get(key).Exists() { - e.RequestBody = json.Get(key).String() +func (e *BodyPathMapper) SetRequestPathFromJson(json gjson.Result, key string, defaultValue string) { + value := json.Get(key) + if value.Exists() { + e.RequestPath = value.String() } else { - e.RequestBody = defaultValue + e.RequestPath = defaultValue } } -func (e *KVExtractor) SetResponseBodyFromJson(json gjson.Result, key string, defaultValue string) { - if json.Get(key).Exists() { - e.ResponseBody = json.Get(key).String() +func (e *BodyPathMapper) SetResponsePathFromJson(json gjson.Result, key string, defaultValue string) { + value := json.Get(key) + if value.Exists() { + e.ResponsePath = value.String() } else { - e.ResponseBody = defaultValue + e.ResponsePath = defaultValue } } @@ -47,9 +49,9 @@ type PluginConfig struct { vectorProviderConfig vector.ProviderConfig cacheProviderConfig cache.ProviderConfig - CacheKeyFrom KVExtractor `required:"true" yaml:"cacheKeyFrom" json:"cacheKeyFrom"` - CacheValueFrom KVExtractor `required:"true" yaml:"cacheValueFrom" json:"cacheValueFrom"` - CacheStreamValueFrom KVExtractor `required:"true" yaml:"cacheStreamValueFrom" json:"cacheStreamValueFrom"` + CacheKeyFrom BodyPathMapper `required:"true" yaml:"cacheKeyFrom" json:"cacheKeyFrom"` + CacheValueFrom BodyPathMapper `required:"true" yaml:"cacheValueFrom" json:"cacheValueFrom"` + CacheStreamValueFrom BodyPathMapper `required:"true" yaml:"cacheStreamValueFrom" json:"cacheStreamValueFrom"` } func (c *PluginConfig) FromJson(json gjson.Result) { @@ -57,9 +59,9 @@ func (c *PluginConfig) FromJson(json gjson.Result) { c.vectorProviderConfig.FromJson(json.Get("vector")) c.cacheProviderConfig.FromJson(json.Get("cache")) - c.CacheKeyFrom.SetRequestBodyFromJson(json, "cacheKeyFrom.requestBody", "messages.@reverse.0.content") - c.CacheValueFrom.SetResponseBodyFromJson(json, "cacheValueFrom.responseBody", "choices.0.message.content") - c.CacheStreamValueFrom.SetResponseBodyFromJson(json, "cacheStreamValueFrom.responseBody", "choices.0.delta.content") + c.CacheKeyFrom.SetRequestPathFromJson(json, "cacheKeyFrom.requestBody", "messages.@reverse.0.content") + c.CacheValueFrom.SetResponsePathFromJson(json, "cacheValueFrom.responseBody", "choices.0.message.content") + c.CacheStreamValueFrom.SetResponsePathFromJson(json, "cacheStreamValueFrom.responseBody", "choices.0.delta.content") c.StreamResponseTemplate = json.Get("streamResponseTemplate").String() if c.StreamResponseTemplate == "" { diff --git a/plugins/wasm-go/extensions/ai-cache/core.go b/plugins/wasm-go/extensions/ai-cache/core.go index 63da11d444..ca23af4357 100644 --- a/plugins/wasm-go/extensions/ai-cache/core.go +++ b/plugins/wasm-go/extensions/ai-cache/core.go @@ -1,6 +1,7 @@ package main import ( + "encoding/json" "fmt" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/config" @@ -10,83 +11,105 @@ import ( "github.com/tidwall/resp" ) -func RedisSearchHandler(key string, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, stream bool, ifUseEmbedding bool) { +func CheckCacheForKey(key string, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, stream bool, useSimilaritySearch bool) { activeCacheProvider := config.GetCacheProvider() - activeCacheProvider.Get(key, func(response resp.Value) { + err := activeCacheProvider.Get(key, func(response resp.Value) { if err := response.Error(); err == nil && !response.IsNull() { - log.Debugf("cache hit, key:%s", key) - HandleCacheHit(key, response, stream, ctx, config, log) + log.Infof("cache hit, key: %s", key) + ProcessCacheHit(key, response, stream, ctx, config, log) } else { - log.Debugf("cache miss, key:%s", key) - if ifUseEmbedding { - HandleCacheMiss(key, err, response, ctx, config, log, key, stream) + log.Infof("cache miss, key: %s, error: %s", key, err.Error()) + if useSimilaritySearch { + err = performSimilaritySearch(key, ctx, config, log, key, stream) + if err != nil { + log.Errorf("failed to perform similarity search for key: %s, error: %v", key, err) + proxywasm.ResumeHttpRequest() + } } else { proxywasm.ResumeHttpRequest() return } } }) -} -func HandleCacheHit(key string, response resp.Value, stream bool, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log) { - ctx.SetContext(CACHE_KEY_CONTEXT_KEY, nil) - if !stream { - proxywasm.SendHttpResponse(200, [][2]string{{"content-type", "application/json; charset=utf-8"}}, []byte(fmt.Sprintf(config.ResponseTemplate, response.String())), -1) - } else { - proxywasm.SendHttpResponse(200, [][2]string{{"content-type", "text/event-stream; charset=utf-8"}}, []byte(fmt.Sprintf(config.StreamResponseTemplate, response.String())), -1) + if err != nil { + log.Errorf("Failed to retrieve key: %s from cache, error: %v", key, err) + proxywasm.ResumeHttpRequest() } } -func HandleCacheMiss(key string, err error, response resp.Value, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, queryString string, stream bool) { +func ProcessCacheHit(key string, response resp.Value, stream bool, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log) { + escapedResponse, err := json.Marshal(response.String()) if err != nil { - log.Warnf("redis get key:%s failed, err:%v", key, err) + proxywasm.SendHttpResponse(500, [][2]string{{"content-type", "text/plain"}}, []byte("Internal Server Error"), -1) + return } - if response.IsNull() { - log.Warnf("cache miss, key:%s", key) + ctx.SetContext(CACHE_KEY_CONTEXT_KEY, nil) + if !stream { + proxywasm.SendHttpResponse(200, [][2]string{{"content-type", "application/json; charset=utf-8"}}, []byte(fmt.Sprintf(config.ResponseTemplate, escapedResponse)), -1) + } else { + proxywasm.SendHttpResponse(200, [][2]string{{"content-type", "text/event-stream; charset=utf-8"}}, []byte(fmt.Sprintf(config.StreamResponseTemplate, escapedResponse)), -1) } - FetchAndProcessEmbeddings(key, ctx, config, log, queryString, stream) } -func FetchAndProcessEmbeddings(key string, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, queryString string, stream bool) { +func performSimilaritySearch(key string, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, queryString string, stream bool) error { activeEmbeddingProvider := config.GetEmbeddingProvider() - // The error will be handled within the callback function - _ = activeEmbeddingProvider.GetEmbedding(queryString, ctx, log, - func(emb []float64) { - log.Debugf("Successfully fetched embeddings for key: %s", key) + err := activeEmbeddingProvider.GetEmbedding(queryString, ctx, log, + func(emb []float64, err error) { + if err != nil { + log.Errorf("failed to fetch embeddings for key: %s, err: %v", key, err) + proxywasm.ResumeHttpRequest() + return + } + log.Debugf("successfully fetched embeddings for key: %s", key) QueryVectorDB(key, emb, ctx, config, log, stream) }) + return err } func QueryVectorDB(key string, textEmbedding []float64, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, stream bool) { - log.Debugf("QueryVectorDB key: %s", key) + log.Debugf("starting query for key: %s", key) activeVectorProvider := config.GetVectorProvider() - log.Debugf("activeVectorProvider: %+v", activeVectorProvider) - activeVectorProvider.QueryEmbedding(textEmbedding, ctx, log, - func(results []vector.QueryEmbeddingResult, ctx wrapper.HttpContext, log wrapper.Log) { - // The basic logic is to compare the similarity of the embedding with the most similar key in the database - if len(results) == 0 { - log.Warnf("Failed to query vector database, no similar key found") - activeVectorProvider.UploadEmbedding(textEmbedding, key, ctx, log, - func(ctx wrapper.HttpContext, log wrapper.Log) { - proxywasm.ResumeHttpRequest() - }) - return - } + log.Debugf("active vector provider configuration: %+v", activeVectorProvider) - mostSimilarData := results[0] - log.Infof("most similar key: %s", mostSimilarData.Text) - if mostSimilarData.Score < activeVectorProvider.GetThreshold() { - log.Infof("accept most similar key: %s, score: %f", mostSimilarData.Text, mostSimilarData.Score) - // ctx.SetContext(embedding.CacheKeyContextKey, nil) - RedisSearchHandler(mostSimilarData.Text, ctx, config, log, stream, false) - } else { - log.Infof("the most similar key's score is too high, key: %s, score: %f", mostSimilarData.Text, mostSimilarData.Score) - activeVectorProvider.UploadEmbedding(textEmbedding, key, ctx, log, - func(ctx wrapper.HttpContext, log wrapper.Log) { - proxywasm.ResumeHttpRequest() - }) - return - } - }, - ) + err := activeVectorProvider.QueryEmbedding(textEmbedding, ctx, log, func(results []vector.QueryEmbeddingResult, ctx wrapper.HttpContext, log wrapper.Log, err error) { + if err != nil { + log.Errorf("error querying vector database: %v", err) + proxywasm.ResumeHttpRequest() + return + } + + if len(results) == 0 { + log.Warnf("no similar keys found in vector database for key: %s", key) + uploadEmbedding(textEmbedding, key, ctx, log, activeVectorProvider) + return + } + + mostSimilarData := results[0] + log.Debugf("most similar key found: %s with score: %f", mostSimilarData.Text, mostSimilarData.Score) + + if mostSimilarData.Score < activeVectorProvider.GetSimThreshold() { + log.Infof("key accepted: %s with score: %f below threshold", mostSimilarData.Text, mostSimilarData.Score) + CheckCacheForKey(mostSimilarData.Text, ctx, config, log, stream, false) + } else { + log.Infof("score too high for key: %s with score: %f above threshold", mostSimilarData.Text, mostSimilarData.Score) + uploadEmbedding(textEmbedding, key, ctx, log, activeVectorProvider) + } + }) + + if err != nil { + log.Errorf("error querying vector database: %v", err) + proxywasm.ResumeHttpRequest() + } +} + +func uploadEmbedding(textEmbedding []float64, key string, ctx wrapper.HttpContext, log wrapper.Log, provider vector.Provider) { + provider.UploadEmbedding(textEmbedding, key, ctx, log, func(ctx wrapper.HttpContext, log wrapper.Log, err error) { + if err != nil { + log.Errorf("failed to upload embedding for key: %s, error: %v", key, err) + } else { + log.Debugf("successfully uploaded embedding for key: %s", key) + } + proxywasm.ResumeHttpRequest() + }) } diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go b/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go index fd880f2d04..e89cf3f6b5 100644 --- a/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go +++ b/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go @@ -4,15 +4,16 @@ import ( "encoding/json" "errors" "net/http" + "strconv" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" ) const ( - DOMAIN = "dashscope.aliyuncs.com" - PORT = 443 - MODEL_NAME = "text-embedding-v1" - END_POINT = "/api/v1/services/embeddings/text-embedding/text-embedding" + DOMAIN = "dashscope.aliyuncs.com" + PORT = 443 + DEFAULT_MODEL_NAME = "text-embedding-v1" + ENDPOINT = "/api/v1/services/embeddings/text-embedding/text-embedding" ) type dashScopeProviderInitializer struct { @@ -91,8 +92,13 @@ type DSProvider struct { func (d *DSProvider) constructParameters(texts []string, log wrapper.Log) (string, [][2]string, []byte, error) { + model := d.config.model + + if model == "" { + model = DEFAULT_MODEL_NAME + } data := EmbeddingRequest{ - Model: MODEL_NAME, + Model: model, Input: Input{ Texts: texts, }, @@ -103,13 +109,13 @@ func (d *DSProvider) constructParameters(texts []string, log wrapper.Log) (strin requestBody, err := json.Marshal(data) if err != nil { - log.Errorf("Failed to marshal request data: %v", err) + log.Errorf("failed to marshal request data: %v", err) return "", nil, nil, err } if d.config.apiKey == "" { - err := errors.New("DashScopeKey is empty") - log.Errorf("Failed to construct headers: %v", err) + err := errors.New("dashScopeKey is empty") + log.Errorf("failed to construct headers: %v", err) return "", nil, nil, err } @@ -118,7 +124,7 @@ func (d *DSProvider) constructParameters(texts []string, log wrapper.Log) (strin {"Content-Type", "application/json"}, } - return END_POINT, headers, requestBody, err + return ENDPOINT, headers, requestBody, err } type Result struct { @@ -141,41 +147,40 @@ func (d *DSProvider) GetEmbedding( queryString string, ctx wrapper.HttpContext, log wrapper.Log, - callback func(emb []float64)) error { + callback func(emb []float64, err error)) error { embUrl, embHeaders, embRequestBody, err := d.constructParameters([]string{queryString}, log) if err != nil { - log.Errorf("Failed to construct parameters: %v", err) + log.Errorf("failed to construct parameters: %v", err) return err } var resp *Response err = d.client.Post(embUrl, embHeaders, embRequestBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) { + if statusCode != http.StatusOK { - log.Errorf("Failed to fetch embeddings, statusCode: %d, responseBody: %s", statusCode, string(responseBody)) - err = errors.New("failed to get embedding") - callback(nil) + err = errors.New("failed to get embedding due to status code: " + strconv.Itoa(statusCode)) + callback(nil, err) return } - log.Infof("Get embedding response: %d, %s", statusCode, responseBody) + log.Debugf("get embedding response: %d, %s", statusCode, responseBody) resp, err = d.parseTextEmbedding(responseBody) if err != nil { - log.Errorf("Failed to parse response: %v", err) - callback(nil) + err = errors.New("failed to parse response") + callback(nil, err) return } if len(resp.Output.Embeddings) == 0 { - log.Errorf("No embedding found in response") err = errors.New("no embedding found in response") - callback(nil) + callback(nil, err) return } - callback(resp.Output.Embeddings[0].Embedding) + callback(resp.Output.Embeddings[0].Embedding, nil) }, d.config.timeout) - return nil + return err } diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/provider.go b/plugins/wasm-go/extensions/ai-cache/embedding/provider.go index f7066c3761..adbbf84888 100644 --- a/plugins/wasm-go/extensions/ai-cache/embedding/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/embedding/provider.go @@ -41,6 +41,9 @@ type ProviderConfig struct { // @Title zh-CN 文本特征提取服务超时时间 // @Description zh-CN 文本特征提取服务超时时间 timeout uint32 + // @Title zh-CN 文本特征提取服务使用的模型 + // @Description zh-CN 用于文本特征提取的模型名称, 在 DashScope 中默认为 "text-embedding-v1" + model string } func (c *ProviderConfig) FromJson(json gjson.Result) { @@ -50,6 +53,7 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { c.servicePort = json.Get("servicePort").Int() c.apiKey = json.Get("apiKey").String() c.timeout = uint32(json.Get("timeout").Int()) + c.model = json.Get("model").String() if c.timeout == 0 { c.timeout = 1000 } @@ -93,5 +97,5 @@ type Provider interface { queryString string, ctx wrapper.HttpContext, log wrapper.Log, - callback func(emb []float64)) error + callback func(emb []float64, err error)) error } diff --git a/plugins/wasm-go/extensions/ai-cache/main.go b/plugins/wasm-go/extensions/ai-cache/main.go index d18b10c61c..7a3f99dae0 100644 --- a/plugins/wasm-go/extensions/ai-cache/main.go +++ b/plugins/wasm-go/extensions/ai-cache/main.go @@ -44,7 +44,7 @@ func parseConfig(json gjson.Result, config *config.PluginConfig, log wrapper.Log } // 注意,在 parseConfig 阶段初始化 client 会出错,比如 docker compose 中的 redis 就无法使用 if err := config.Complete(log); err != nil { - log.Errorf("complete config failed:%v", err) + log.Errorf("complete config failed: %v", err) return err } return nil @@ -72,7 +72,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config config.PluginConfig, l return types.ActionContinue } if !strings.Contains(contentType, "application/json") { - log.Warnf("content is not json, can't process:%s", contentType) + log.Warnf("content is not json, can't process: %s", contentType) ctx.DontReadRequestBody() return types.ActionContinue } @@ -90,19 +90,17 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config config.PluginConfig, body if bodyJson.Get("stream").Bool() { stream = true ctx.SetContext(STREAM_CONTEXT_KEY, struct{}{}) - } else if ctx.GetContext(STREAM_CONTEXT_KEY) != nil { - stream = true - } - // key := TrimQuote(bodyJson.Get(config.CacheKeyFrom.RequestBody).Raw) - key := bodyJson.Get(config.CacheKeyFrom.RequestBody).String() + } + + key := bodyJson.Get(config.CacheKeyFrom.RequestPath).String() ctx.SetContext(CACHE_KEY_CONTEXT_KEY, key) - log.Debugf("[onHttpRequestBody] key:%s", key) + log.Debugf("[onHttpRequestBody] key: %s", key) if key == "" { log.Debug("[onHttpRequestBody] parse key from request body failed") return types.ActionContinue } - RedisSearchHandler(key, ctx, config, log, stream, true) + CheckCacheForKey(key, ctx, config, log, stream, true) return types.ActionPause } @@ -174,7 +172,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config config.PluginConfig, chu } bodyJson := gjson.ParseBytes(body) - value = TrimQuote(bodyJson.Get(config.CacheValueFrom.ResponseBody).Raw) + value = bodyJson.Get(config.CacheValueFrom.RequestPath).String() if value == "" { log.Warnf("parse value from response body failded, body:%s", body) return chunk @@ -205,7 +203,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config config.PluginConfig, chu value = tempContentI.(string) } } - log.Infof("[onHttpResponseBody] Setting cache to redis, key:%s, value:%s", key, value) + log.Infof("[onHttpResponseBody] setting cache to redis, key:%s, value:%s", key, value) activeCacheProvider := config.GetCacheProvider() _ = activeCacheProvider.Set(key, value, nil) return chunk diff --git a/plugins/wasm-go/extensions/ai-cache/util.go b/plugins/wasm-go/extensions/ai-cache/util.go index 34ecaf1b1b..2472824b32 100644 --- a/plugins/wasm-go/extensions/ai-cache/util.go +++ b/plugins/wasm-go/extensions/ai-cache/util.go @@ -8,41 +8,51 @@ import ( "github.com/tidwall/gjson" ) -func TrimQuote(source string) string { - return strings.Trim(source, `"`) -} - func processSSEMessage(ctx wrapper.HttpContext, config config.PluginConfig, sseMessage string, log wrapper.Log) string { subMessages := strings.Split(sseMessage, "\n") var message string for _, msg := range subMessages { - if strings.HasPrefix(msg, "data:") { + if strings.HasPrefix(msg, "data: ") { message = msg break } } if len(message) < 6 { - log.Warnf("invalid message:%s", message) + log.Warnf("invalid message: %s", message) return "" } + // skip the prefix "data:" bodyJson := message[5:] - if gjson.Get(bodyJson, config.CacheStreamValueFrom.ResponseBody).Exists() { + // Extract values from JSON fields + responseBody := gjson.Get(bodyJson, config.CacheStreamValueFrom.ResponsePath) + toolCalls := gjson.Get(bodyJson, "choices.0.delta.content.tool_calls") + + if toolCalls.Exists() { + // TODO: Temporarily store the tool_calls value in the context for processing + ctx.SetContext(TOOL_CALLS_CONTEXT_KEY, toolCalls.String()) + } + + // Check if the ResponseBody field exists + if !responseBody.Exists() { + // Return an empty string if we cannot extract the content + log.Warnf("cannot extract content from message: %s", message) + return "" + } else { tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) + + // If there is no content in the cache, initialize and set the content if tempContentI == nil { - content := TrimQuote(gjson.Get(bodyJson, config.CacheStreamValueFrom.ResponseBody).Raw) + content := responseBody.String() ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, content) return content } - appendMsg := TrimQuote(gjson.Get(bodyJson, config.CacheStreamValueFrom.ResponseBody).Raw) + + // Update the content in the cache + appendMsg := responseBody.String() content := tempContentI.(string) + appendMsg ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, content) return content - } else if gjson.Get(bodyJson, "choices.0.delta.content.tool_calls").Exists() { - // TODO: compatible with other providers - ctx.SetContext(TOOL_CALLS_CONTEXT_KEY, struct{}{}) - return "" } - log.Warnf("unknown message:%s", bodyJson) - return "" + } diff --git a/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go b/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go index 2b4db49d83..70a6f4c873 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go @@ -120,38 +120,38 @@ func (d *DvProvider) parseQueryResponse(responseBody []byte) (queryResponse, err return queryResp, nil } -func (d *DvProvider) GetThreshold() float64 { +func (d *DvProvider) GetSimThreshold() float64 { return threshold } func (d *DvProvider) QueryEmbedding( emb []float64, ctx wrapper.HttpContext, log wrapper.Log, - callback func(results []QueryEmbeddingResult, ctx wrapper.HttpContext, log wrapper.Log)) { - // 构造请求参数 + callback func(results []QueryEmbeddingResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error { url, body, headers, err := d.constructEmbeddingQueryParameters(emb) if err != nil { - log.Infof("Failed to construct embedding query parameters: %v", err) + err = fmt.Errorf("failed to construct embedding query parameters: %v", err) + return err } err = d.client.Post(url, headers, body, func(statusCode int, responseHeaders http.Header, responseBody []byte) { + err = nil if statusCode != http.StatusOK { - log.Infof("Failed to query embedding: %d", statusCode) - return + err = fmt.Errorf("failed to query embedding: %d", statusCode) } - log.Debugf("Query embedding response: %d, %s", statusCode, responseBody) + log.Debugf("query embedding response: %d, %s", statusCode, responseBody) results, err := d.ParseQueryResponse(responseBody, ctx, log) if err != nil { - log.Infof("Failed to parse query response: %v", err) - return + err = fmt.Errorf("failed to parse query response: %v", err) } - callback(results, ctx, log) + callback(results, ctx, log, err) }, d.config.timeout) if err != nil { - log.Infof("Failed to query embedding: %v", err) + err = fmt.Errorf("failed to query embedding: %v", err) } + return err } func (d *DvProvider) ParseQueryResponse(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]QueryEmbeddingResult, error) { @@ -209,15 +209,22 @@ func (d *DvProvider) constructEmbeddingUploadParameters(emb []float64, queryStri return url, requestBody, header, err } -func (d *DvProvider) UploadEmbedding(queryEmb []float64, queryString string, ctx wrapper.HttpContext, log wrapper.Log, callback func(ctx wrapper.HttpContext, log wrapper.Log)) { - url, body, headers, _ := d.constructEmbeddingUploadParameters(queryEmb, queryString) - _ = d.client.Post( +func (d *DvProvider) UploadEmbedding(queryEmb []float64, queryString string, ctx wrapper.HttpContext, log wrapper.Log, callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error { + url, body, headers, err := d.constructEmbeddingUploadParameters(queryEmb, queryString) + if err != nil { + return err + } + err = d.client.Post( url, headers, body, func(statusCode int, responseHeaders http.Header, responseBody []byte) { - log.Infof("statusCode:%d, responseBody:%s", statusCode, string(responseBody)) - callback(ctx, log) + log.Debugf("statusCode:%d, responseBody:%s", statusCode, string(responseBody)) + if statusCode != http.StatusOK { + err = fmt.Errorf("failed to upload embedding: %d", statusCode) + } + callback(ctx, log, err) }, d.config.timeout) + return err } diff --git a/plugins/wasm-go/extensions/ai-cache/vector/provider.go b/plugins/wasm-go/extensions/ai-cache/vector/provider.go index e7b6c98e5f..6e8b2a86b2 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/provider.go @@ -37,15 +37,14 @@ type Provider interface { emb []float64, ctx wrapper.HttpContext, log wrapper.Log, - callback func(results []QueryEmbeddingResult, ctx wrapper.HttpContext, log wrapper.Log)) + callback func(results []QueryEmbeddingResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error UploadEmbedding( queryEmb []float64, queryString string, ctx wrapper.HttpContext, log wrapper.Log, - callback func(ctx wrapper.HttpContext, log wrapper.Log)) - GetThreshold() float64 - // ParseQueryResponse(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log) (QueryEmbeddingResult, error) + callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error + GetSimThreshold() float64 } type ProviderConfig struct { From 21c9a792f0f324e7393f00dcbcff36d272a215cb Mon Sep 17 00:00:00 2001 From: suchun Date: Thu, 12 Sep 2024 04:51:14 +0000 Subject: [PATCH 24/71] fix bugs and update the SetEx function --- plugins/wasm-go/extensions/ai-cache/README.md | 13 ++-- .../extensions/ai-cache/cache/provider.go | 10 ++- .../extensions/ai-cache/cache/redis.go | 10 ++- .../extensions/ai-cache/config/config.go | 65 ++++++++----------- plugins/wasm-go/extensions/ai-cache/core.go | 28 +++++--- .../ai-cache/embedding/dashscope.go | 2 +- .../extensions/ai-cache/embedding/provider.go | 2 +- plugins/wasm-go/extensions/ai-cache/main.go | 20 +++--- plugins/wasm-go/extensions/ai-cache/util.go | 4 +- .../extensions/ai-cache/vector/dashvector.go | 8 ++- plugins/wasm-go/pkg/wrapper/redis_wrapper.go | 3 +- 11 files changed, 90 insertions(+), 75 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-cache/README.md b/plugins/wasm-go/extensions/ai-cache/README.md index 6b7fc7c493..f89e6ee7b1 100644 --- a/plugins/wasm-go/extensions/ai-cache/README.md +++ b/plugins/wasm-go/extensions/ai-cache/README.md @@ -52,18 +52,19 @@ LLM 结果缓存插件,默认配置方式可以直接用于 openai 协议的 | cache.username | string | optional | "" | 缓存服务用户名 | | cache.password | string | optional | "" | 缓存服务密码 | | cache.timeout | uint32 | optional | 10000 | 缓存服务的超时时间,单位为毫秒。默认值是10000,即10秒 | -| cache.cacheTTL | uint32 | optional | 0 | 缓存过期时间,单位为秒。默认值是 0,即 永不过期| +| cache.cacheTTL | int | optional | 0 | 缓存过期时间,单位为秒。默认值是 0,即 永不过期| | cacheKeyPrefix | string | optional | "higressAiCache:" | 缓存 Key 的前缀,默认值为 "higressAiCache:" | ## 其他配置 | Name | Type | Requirement | Default | Description | | --- | --- | --- | --- | --- | -| cacheKeyFrom.requestBody | string | optional | "messages.@reverse.0.content" | 从请求 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 | -| cacheValueFrom.responseBody | string | optional | "choices.0.message.content" | 从响应 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 | -| cacheStreamValueFrom.responseBody | string | optional | "choices.0.delta.content" | 从流式响应 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 | -| responseTemplate | string | optional | `{"id":"from-cache","choices":[%s],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` | 返回 HTTP 响应的模版,用 %s 标记需要被 cache value 替换的部分 | -| streamResponseTemplate | string | optional | `data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}\n\ndata:[DONE]\n\n` | 返回流式 HTTP 响应的模版,用 %s 标记需要被 cache value 替换的部分 | +| cacheKeyFrom | string | optional | "messages.@reverse.0.content" | 从请求 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 | +| cacheValueFrom | string | optional | "choices.0.message.content" | 从响应 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 | +| cacheStreamValueFrom | string | optional | "choices.0.delta.content" | 从流式响应 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 | +| cacheToolCallsFrom | string | optional | "choices.0.delta.content.tool_calls" | 从请求 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 | +| responseTemplate | string | optional | `{"id":"ai-cache.hit","choices":[{"index":0,"message":{"role":"assistant","content":%s},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` | 返回 HTTP 响应的模版,用 %s 标记需要被 cache value 替换的部分 | +| streamResponseTemplate | string | optional | `data:{"id":"ai-cache.hit","choices":[{"index":0,"delta":{"role":"assistant","content":%s},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}\n\ndata:[DONE]\n\n` | 返回流式 HTTP 响应的模版,用 %s 标记需要被 cache value 替换的部分 | ## 配置示例 diff --git a/plugins/wasm-go/extensions/ai-cache/cache/provider.go b/plugins/wasm-go/extensions/ai-cache/cache/provider.go index 53f3d7eb1f..0390bb2b11 100644 --- a/plugins/wasm-go/extensions/ai-cache/cache/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/cache/provider.go @@ -47,7 +47,7 @@ type ProviderConfig struct { timeout uint32 // @Title zh-CN 缓存过期时间 // @Description zh-CN 缓存过期时间,单位为秒。默认值是0,即永不过期 - cacheTTL uint32 + cacheTTL int // @Title 缓存 Key 前缀 // @Description 缓存 Key 的前缀,默认值为 "higressAiCache:" cacheKeyPrefix string @@ -73,9 +73,10 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { if !json.Get("timeout").Exists() { c.timeout = 10000 } - c.cacheTTL = uint32(json.Get("cacheTTL").Int()) + c.cacheTTL = int(json.Get("cacheTTL").Int()) if !json.Get("cacheTTL").Exists() { c.cacheTTL = 0 + // c.cacheTTL = 3600000 } if json.Get("cacheKeyPrefix").Exists() { c.cacheKeyPrefix = json.Get("cacheKeyPrefix").String() @@ -92,6 +93,9 @@ func (c *ProviderConfig) Validate() error { if c.serviceName == "" { return errors.New("cache service name is required") } + if c.cacheTTL < 0 { + return errors.New("cache TTL must be greater than or equal to 0") + } initializer, has := providerInitializers[c.typ] if !has { return errors.New("unknown cache service provider type: " + c.typ) @@ -115,5 +119,5 @@ type Provider interface { Init(username string, password string, timeout uint32) error Get(key string, cb wrapper.RedisResponseCallback) error Set(key string, value string, cb wrapper.RedisResponseCallback) error - getCacheKeyPrefix() string + GetCacheKeyPrefix() string } diff --git a/plugins/wasm-go/extensions/ai-cache/cache/redis.go b/plugins/wasm-go/extensions/ai-cache/cache/redis.go index 2c3d49047a..4cb69744e1 100644 --- a/plugins/wasm-go/extensions/ai-cache/cache/redis.go +++ b/plugins/wasm-go/extensions/ai-cache/cache/redis.go @@ -42,13 +42,17 @@ func (rp *redisProvider) Init(username string, password string, timeout uint32) } func (rp *redisProvider) Get(key string, cb wrapper.RedisResponseCallback) error { - return rp.client.Get(rp.getCacheKeyPrefix()+key, cb) + return rp.client.Get(key, cb) } func (rp *redisProvider) Set(key string, value string, cb wrapper.RedisResponseCallback) error { - return rp.client.SetEx(rp.getCacheKeyPrefix()+key, value, int(rp.config.cacheTTL), cb) + if rp.config.cacheTTL == 0 { + return rp.client.Set(key, value, cb) + } else { + return rp.client.SetEx(key, value, rp.config.cacheTTL, cb) + } } -func (rp *redisProvider) getCacheKeyPrefix() string { +func (rp *redisProvider) GetCacheKeyPrefix() string { return rp.config.cacheKeyPrefix } diff --git a/plugins/wasm-go/extensions/ai-cache/config/config.go b/plugins/wasm-go/extensions/ai-cache/config/config.go index 74a4f2e2d7..fad64d70ee 100644 --- a/plugins/wasm-go/extensions/ai-cache/config/config.go +++ b/plugins/wasm-go/extensions/ai-cache/config/config.go @@ -8,50 +8,26 @@ import ( "github.com/tidwall/gjson" ) -type BodyPathMapper struct { - // @Title zh-CN 从请求 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 - RequestPath string `required:"false" yaml:"requestBody" json:"requestBody"` - // @Title zh-CN 从响应 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 - ResponsePath string `required:"false" yaml:"responseBody" json:"responseBody"` -} - -func (e *BodyPathMapper) SetRequestPathFromJson(json gjson.Result, key string, defaultValue string) { - value := json.Get(key) - if value.Exists() { - e.RequestPath = value.String() - } else { - e.RequestPath = defaultValue - } -} - -func (e *BodyPathMapper) SetResponsePathFromJson(json gjson.Result, key string, defaultValue string) { - value := json.Get(key) - if value.Exists() { - e.ResponsePath = value.String() - } else { - e.ResponsePath = defaultValue - } -} - type PluginConfig struct { // @Title zh-CN 返回 HTTP 响应的模版 // @Description zh-CN 用 %s 标记需要被 cache value 替换的部分 - ResponseTemplate string `required:"true" yaml:"responseTemplate" json:"responseTemplate"` + ResponseTemplate string // @Title zh-CN 返回流式 HTTP 响应的模版 // @Description zh-CN 用 %s 标记需要被 cache value 替换的部分 - StreamResponseTemplate string `required:"true" yaml:"streamResponseTemplate" json:"streamResponseTemplate"` + StreamResponseTemplate string - cacheProvider cache.Provider `yaml:"-"` - embeddingProvider embedding.Provider `yaml:"-"` - vectorProvider vector.Provider `yaml:"-"` + cacheProvider cache.Provider + embeddingProvider embedding.Provider + vectorProvider vector.Provider embeddingProviderConfig embedding.ProviderConfig vectorProviderConfig vector.ProviderConfig cacheProviderConfig cache.ProviderConfig - CacheKeyFrom BodyPathMapper `required:"true" yaml:"cacheKeyFrom" json:"cacheKeyFrom"` - CacheValueFrom BodyPathMapper `required:"true" yaml:"cacheValueFrom" json:"cacheValueFrom"` - CacheStreamValueFrom BodyPathMapper `required:"true" yaml:"cacheStreamValueFrom" json:"cacheStreamValueFrom"` + CacheKeyFrom string + CacheValueFrom string + CacheStreamValueFrom string + CacheToolCallsFrom string } func (c *PluginConfig) FromJson(json gjson.Result) { @@ -59,17 +35,30 @@ func (c *PluginConfig) FromJson(json gjson.Result) { c.vectorProviderConfig.FromJson(json.Get("vector")) c.cacheProviderConfig.FromJson(json.Get("cache")) - c.CacheKeyFrom.SetRequestPathFromJson(json, "cacheKeyFrom.requestBody", "messages.@reverse.0.content") - c.CacheValueFrom.SetResponsePathFromJson(json, "cacheValueFrom.responseBody", "choices.0.message.content") - c.CacheStreamValueFrom.SetResponsePathFromJson(json, "cacheStreamValueFrom.responseBody", "choices.0.delta.content") + c.CacheKeyFrom = json.Get("cacheKeyFrom").String() + if c.CacheKeyFrom == "" { + c.CacheKeyFrom = "messages.@reverse.0.content" + } + c.CacheValueFrom = json.Get("cacheValueFrom").String() + if c.CacheValueFrom == "" { + c.CacheValueFrom = "choices.0.message.content" + } + c.CacheStreamValueFrom = json.Get("cacheStreamValueFrom").String() + if c.CacheStreamValueFrom == "" { + c.CacheStreamValueFrom = "choices.0.delta.content" + } + c.CacheToolCallsFrom = json.Get("cacheToolCallsFrom").String() + if c.CacheToolCallsFrom == "" { + c.CacheToolCallsFrom = "choices.0.delta.content.tool_calls" + } c.StreamResponseTemplate = json.Get("streamResponseTemplate").String() if c.StreamResponseTemplate == "" { - c.StreamResponseTemplate = `data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + "\n\ndata:[DONE]\n\n" + c.StreamResponseTemplate = `data:{"id":"ai-cache.hit","choices":[{"index":0,"delta":{"role":"assistant","content":%s},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + "\n\ndata:[DONE]\n\n" } c.ResponseTemplate = json.Get("responseTemplate").String() if c.ResponseTemplate == "" { - c.ResponseTemplate = `{"id":"from-cache","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + c.ResponseTemplate = `{"id":"ai-cache.hit","choices":[{"index":0,"message":{"role":"assistant","content":%s},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` } } diff --git a/plugins/wasm-go/extensions/ai-cache/core.go b/plugins/wasm-go/extensions/ai-cache/core.go index ca23af4357..7845c44496 100644 --- a/plugins/wasm-go/extensions/ai-cache/core.go +++ b/plugins/wasm-go/extensions/ai-cache/core.go @@ -13,22 +13,30 @@ import ( func CheckCacheForKey(key string, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, stream bool, useSimilaritySearch bool) { activeCacheProvider := config.GetCacheProvider() - err := activeCacheProvider.Get(key, func(response resp.Value) { + queryKey := activeCacheProvider.GetCacheKeyPrefix() + key + // queryKey := key + log.Debugf("query key: %s", queryKey) + err := activeCacheProvider.Get(queryKey, func(response resp.Value) { if err := response.Error(); err == nil && !response.IsNull() { log.Infof("cache hit, key: %s", key) - ProcessCacheHit(key, response, stream, ctx, config, log) + processCacheHit(key, response, stream, ctx, config, log) } else { - log.Infof("cache miss, key: %s, error: %s", key, err.Error()) + if err != nil { + log.Errorf("error retrieving key: %s from cache, error: %v", key, err) + } + if response.IsNull() { + log.Infof("cache miss, key: %s", key) + } if useSimilaritySearch { err = performSimilaritySearch(key, ctx, config, log, key, stream) if err != nil { log.Errorf("failed to perform similarity search for key: %s, error: %v", key, err) proxywasm.ResumeHttpRequest() + return } - } else { - proxywasm.ResumeHttpRequest() - return } + proxywasm.ResumeHttpRequest() + return } }) @@ -38,9 +46,11 @@ func CheckCacheForKey(key string, ctx wrapper.HttpContext, config config.PluginC } } -func ProcessCacheHit(key string, response resp.Value, stream bool, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log) { +func processCacheHit(key string, response resp.Value, stream bool, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log) { escapedResponse, err := json.Marshal(response.String()) + log.Debugf("cached response: %s", escapedResponse) if err != nil { + log.Errorf("failed to marshal cached response: %v", err) proxywasm.SendHttpResponse(500, [][2]string{{"content-type", "text/plain"}}, []byte("Internal Server Error"), -1) return } @@ -62,12 +72,12 @@ func performSimilaritySearch(key string, ctx wrapper.HttpContext, config config. return } log.Debugf("successfully fetched embeddings for key: %s", key) - QueryVectorDB(key, emb, ctx, config, log, stream) + queryVectorDB(key, emb, ctx, config, log, stream) }) return err } -func QueryVectorDB(key string, textEmbedding []float64, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, stream bool) { +func queryVectorDB(key string, textEmbedding []float64, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, stream bool) { log.Debugf("starting query for key: %s", key) activeVectorProvider := config.GetVectorProvider() log.Debugf("active vector provider configuration: %+v", activeVectorProvider) diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go b/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go index e89cf3f6b5..fe06752398 100644 --- a/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go +++ b/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go @@ -168,7 +168,7 @@ func (d *DSProvider) GetEmbedding( resp, err = d.parseTextEmbedding(responseBody) if err != nil { - err = errors.New("failed to parse response") + err = errors.New("failed to parse response: " + err.Error()) callback(nil, err) return } diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/provider.go b/plugins/wasm-go/extensions/ai-cache/embedding/provider.go index adbbf84888..b7748f2cc4 100644 --- a/plugins/wasm-go/extensions/ai-cache/embedding/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/embedding/provider.go @@ -55,7 +55,7 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { c.timeout = uint32(json.Get("timeout").Int()) c.model = json.Get("model").String() if c.timeout == 0 { - c.timeout = 1000 + c.timeout = 10000 } } diff --git a/plugins/wasm-go/extensions/ai-cache/main.go b/plugins/wasm-go/extensions/ai-cache/main.go index 7a3f99dae0..4bcb3b21ef 100644 --- a/plugins/wasm-go/extensions/ai-cache/main.go +++ b/plugins/wasm-go/extensions/ai-cache/main.go @@ -90,13 +90,14 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config config.PluginConfig, body if bodyJson.Get("stream").Bool() { stream = true ctx.SetContext(STREAM_CONTEXT_KEY, struct{}{}) - } + } - key := bodyJson.Get(config.CacheKeyFrom.RequestPath).String() + key := bodyJson.Get(config.CacheKeyFrom).String() ctx.SetContext(CACHE_KEY_CONTEXT_KEY, key) log.Debugf("[onHttpRequestBody] key: %s", key) if key == "" { log.Debug("[onHttpRequestBody] parse key from request body failed") + ctx.DontReadResponseBody() return types.ActionContinue } @@ -114,8 +115,8 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config config.PluginConfig, } func onHttpResponseBody(ctx wrapper.HttpContext, config config.PluginConfig, chunk []byte, isLastChunk bool, log wrapper.Log) []byte { - log.Debugf("[onHttpResponseBody] chunk:%s", string(chunk)) - log.Debugf("[onHttpResponseBody] isLastChunk:%v", isLastChunk) + log.Debugf("[onHttpResponseBody] chunk: %s", string(chunk)) + log.Debugf("[onHttpResponseBody] isLastChunk: %v", isLastChunk) if ctx.GetContext(TOOL_CALLS_CONTEXT_KEY) != nil { // we should not cache tool call result return chunk @@ -159,7 +160,6 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config config.PluginConfig, chu return chunk } // last chunk - key := keyI.(string) stream := ctx.GetContext(STREAM_CONTEXT_KEY) var value string if stream == nil { @@ -172,9 +172,9 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config config.PluginConfig, chu } bodyJson := gjson.ParseBytes(body) - value = bodyJson.Get(config.CacheValueFrom.RequestPath).String() + value = bodyJson.Get(config.CacheValueFrom).String() if value == "" { - log.Warnf("parse value from response body failded, body:%s", body) + log.Warnf("parse value from response body failded, body: %s", body) return chunk } } else { @@ -203,8 +203,10 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config config.PluginConfig, chu value = tempContentI.(string) } } - log.Infof("[onHttpResponseBody] setting cache to redis, key:%s, value:%s", key, value) activeCacheProvider := config.GetCacheProvider() - _ = activeCacheProvider.Set(key, value, nil) + queryKey := activeCacheProvider.GetCacheKeyPrefix() + ctx.GetContext(CACHE_KEY_CONTEXT_KEY).(string) + // queryKey := keyI.(string) + log.Infof("[onHttpResponseBody] setting cache to redis, key: %s, value: %s", queryKey, value) + _ = activeCacheProvider.Set(queryKey, value, nil) return chunk } diff --git a/plugins/wasm-go/extensions/ai-cache/util.go b/plugins/wasm-go/extensions/ai-cache/util.go index 2472824b32..32642ec64a 100644 --- a/plugins/wasm-go/extensions/ai-cache/util.go +++ b/plugins/wasm-go/extensions/ai-cache/util.go @@ -25,8 +25,8 @@ func processSSEMessage(ctx wrapper.HttpContext, config config.PluginConfig, sseM // skip the prefix "data:" bodyJson := message[5:] // Extract values from JSON fields - responseBody := gjson.Get(bodyJson, config.CacheStreamValueFrom.ResponsePath) - toolCalls := gjson.Get(bodyJson, "choices.0.delta.content.tool_calls") + responseBody := gjson.Get(bodyJson, config.CacheStreamValueFrom) + toolCalls := gjson.Get(bodyJson, config.CacheToolCallsFrom) if toolCalls.Exists() { // TODO: Temporarily store the tool_calls value in the context for processing diff --git a/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go b/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go index 70a6f4c873..5ae320d815 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go @@ -10,7 +10,7 @@ import ( ) const ( - threshold = 2000 + threshold = 10000 ) type dashVectorProviderInitializer struct { @@ -129,6 +129,7 @@ func (d *DvProvider) QueryEmbedding( log wrapper.Log, callback func(results []QueryEmbeddingResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error { url, body, headers, err := d.constructEmbeddingQueryParameters(emb) + log.Debugf("url:%s, body:%s, headers:%v", url, string(body), headers) if err != nil { err = fmt.Errorf("failed to construct embedding query parameters: %v", err) return err @@ -139,6 +140,8 @@ func (d *DvProvider) QueryEmbedding( err = nil if statusCode != http.StatusOK { err = fmt.Errorf("failed to query embedding: %d", statusCode) + callback(nil, ctx, log, err) + return } log.Debugf("query embedding response: %d, %s", statusCode, responseBody) results, err := d.ParseQueryResponse(responseBody, ctx, log) @@ -159,8 +162,9 @@ func (d *DvProvider) ParseQueryResponse(responseBody []byte, ctx wrapper.HttpCon if err != nil { return nil, err } + if len(resp.Output) == 0 { - return nil, nil + return nil, errors.New("no query results found in response") } results := make([]QueryEmbeddingResult, 0, len(resp.Output)) diff --git a/plugins/wasm-go/pkg/wrapper/redis_wrapper.go b/plugins/wasm-go/pkg/wrapper/redis_wrapper.go index 10aa9020bd..13ba7c0d98 100644 --- a/plugins/wasm-go/pkg/wrapper/redis_wrapper.go +++ b/plugins/wasm-go/pkg/wrapper/redis_wrapper.go @@ -235,8 +235,9 @@ func (c RedisClusterClient[C]) Set(key string, value interface{}, callback Redis func (c RedisClusterClient[C]) SetEx(key string, value interface{}, ttl int, callback RedisResponseCallback) error { args := make([]interface{}, 0) - args = append(args, "setex") + args = append(args, "set") args = append(args, key) + args = append(args, "ex") args = append(args, ttl) args = append(args, value) return RedisCall(c.cluster, respString(args), callback) From 71b9530b740f637ca682cbecfe0c9b892679686a Mon Sep 17 00:00:00 2001 From: suchun Date: Tue, 17 Sep 2024 21:25:09 +0000 Subject: [PATCH 25/71] Optimize query flow logic (not fully tested) --- plugins/wasm-go/extensions/ai-cache/core.go | 262 +++++++++++++----- plugins/wasm-go/extensions/ai-cache/main.go | 105 ++----- plugins/wasm-go/extensions/ai-cache/util.go | 94 +++++++ .../extensions/ai-cache/vector/dashvector.go | 37 ++- .../extensions/ai-cache/vector/provider.go | 34 ++- 5 files changed, 366 insertions(+), 166 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-cache/core.go b/plugins/wasm-go/extensions/ai-cache/core.go index 7845c44496..1ce9e4b15c 100644 --- a/plugins/wasm-go/extensions/ai-cache/core.go +++ b/plugins/wasm-go/extensions/ai-cache/core.go @@ -7,119 +7,235 @@ import ( "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/config" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/vector" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/go-errors/errors" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/tidwall/resp" ) -func CheckCacheForKey(key string, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, stream bool, useSimilaritySearch bool) { +// CheckCacheForKey checks if the key is in the cache, or triggers similarity search if not found. +func CheckCacheForKey(key string, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, stream bool, useSimilaritySearch bool) error { activeCacheProvider := config.GetCacheProvider() + if activeCacheProvider == nil { + log.Debug("No cache provider configured, performing similarity search") + return performSimilaritySearch(key, ctx, config, log, key, stream) + } + queryKey := activeCacheProvider.GetCacheKeyPrefix() + key - // queryKey := key - log.Debugf("query key: %s", queryKey) + log.Debugf("Querying cache with key: %s", queryKey) + err := activeCacheProvider.Get(queryKey, func(response resp.Value) { - if err := response.Error(); err == nil && !response.IsNull() { - log.Infof("cache hit, key: %s", key) - processCacheHit(key, response, stream, ctx, config, log) - } else { - if err != nil { - log.Errorf("error retrieving key: %s from cache, error: %v", key, err) - } - if response.IsNull() { - log.Infof("cache miss, key: %s", key) - } - if useSimilaritySearch { - err = performSimilaritySearch(key, ctx, config, log, key, stream) - if err != nil { - log.Errorf("failed to perform similarity search for key: %s, error: %v", key, err) - proxywasm.ResumeHttpRequest() - return - } - } - proxywasm.ResumeHttpRequest() - return - } + handleCacheResponse(key, response, ctx, log, stream, config, useSimilaritySearch) }) if err != nil { log.Errorf("Failed to retrieve key: %s from cache, error: %v", key, err) - proxywasm.ResumeHttpRequest() + return err + } + + return nil +} + +// handleCacheResponse processes cache response and handles cache hits and misses. +func handleCacheResponse(key string, response resp.Value, ctx wrapper.HttpContext, log wrapper.Log, stream bool, config config.PluginConfig, useSimilaritySearch bool) { + if err := response.Error(); err == nil && !response.IsNull() { + log.Infof("Cache hit for key: %s", key) + processCacheHit(key, response.String(), stream, ctx, config, log) + return + } + + log.Infof("Cache miss for key: %s", key) + if err := response.Error(); err != nil { + log.Errorf("Error retrieving key: %s from cache, error: %v", key, err) } + if useSimilaritySearch { + if err := performSimilaritySearch(key, ctx, config, log, key, stream); err != nil { + log.Errorf("Failed to perform similarity search for key: %s, error: %v", key, err) + } + } + proxywasm.ResumeHttpRequest() } -func processCacheHit(key string, response resp.Value, stream bool, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log) { - escapedResponse, err := json.Marshal(response.String()) - log.Debugf("cached response: %s", escapedResponse) +// processCacheHit handles a successful cache hit. +func processCacheHit(key string, response string, stream bool, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log) { + escapedResponse, err := json.Marshal(response) + log.Debugf("Cached response for key %s: %s", key, escapedResponse) + if err != nil { - log.Errorf("failed to marshal cached response: %v", err) - proxywasm.SendHttpResponse(500, [][2]string{{"content-type", "text/plain"}}, []byte("Internal Server Error"), -1) + handleInternalError(err, "Failed to marshal cached response", log) return } + ctx.SetContext(CACHE_KEY_CONTEXT_KEY, nil) - if !stream { - proxywasm.SendHttpResponse(200, [][2]string{{"content-type", "application/json; charset=utf-8"}}, []byte(fmt.Sprintf(config.ResponseTemplate, escapedResponse)), -1) - } else { - proxywasm.SendHttpResponse(200, [][2]string{{"content-type", "text/event-stream; charset=utf-8"}}, []byte(fmt.Sprintf(config.StreamResponseTemplate, escapedResponse)), -1) + + contentType := "application/json; charset=utf-8" + if stream { + contentType = "text/event-stream; charset=utf-8" } + + proxywasm.SendHttpResponse(200, [][2]string{{"content-type", contentType}}, []byte(fmt.Sprintf(config.ResponseTemplate, escapedResponse)), -1) } +// performSimilaritySearch determines the appropriate similarity search method to use. func performSimilaritySearch(key string, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, queryString string, stream bool) error { - activeEmbeddingProvider := config.GetEmbeddingProvider() - err := activeEmbeddingProvider.GetEmbedding(queryString, ctx, log, - func(emb []float64, err error) { - if err != nil { - log.Errorf("failed to fetch embeddings for key: %s, err: %v", key, err) - proxywasm.ResumeHttpRequest() - return - } - log.Debugf("successfully fetched embeddings for key: %s", key) - queryVectorDB(key, emb, ctx, config, log, stream) - }) - return err + activeVectorProvider := config.GetVectorProvider() + if activeVectorProvider == nil { + return errors.New("no vector provider configured for similarity search") + } + + // Check if the active vector provider implements the StringQuerier interface. + if _, ok := activeVectorProvider.(vector.StringQuerier); ok { + return performStringQuery(key, queryString, ctx, config, log, stream) + } + + // Check if the active vector provider implements the EmbeddingQuerier interface. + if _, ok := activeVectorProvider.(vector.EmbeddingQuerier); ok { + return performEmbeddingQuery(key, ctx, config, log, stream) + } + + return errors.New("no suitable querier or embedding provider available for similarity search") } -func queryVectorDB(key string, textEmbedding []float64, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, stream bool) { - log.Debugf("starting query for key: %s", key) - activeVectorProvider := config.GetVectorProvider() - log.Debugf("active vector provider configuration: %+v", activeVectorProvider) +// performStringQuery executes the string-based similarity search. +func performStringQuery(key string, queryString string, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, stream bool) error { + stringQuerier, ok := config.GetVectorProvider().(vector.StringQuerier) + if !ok { + return logAndReturnError(log, "active vector provider does not implement StringQuerier interface") + } + + return stringQuerier.QueryString(queryString, ctx, log, func(results []vector.QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error) { + handleQueryResults(key, results, ctx, log, stream, config, err) + }) +} + +// performEmbeddingQuery executes the embedding-based similarity search. +func performEmbeddingQuery(key string, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, stream bool) error { + embeddingQuerier, ok := config.GetVectorProvider().(vector.EmbeddingQuerier) + if !ok { + return logAndReturnError(log, "active vector provider does not implement EmbeddingQuerier interface") + } - err := activeVectorProvider.QueryEmbedding(textEmbedding, ctx, log, func(results []vector.QueryEmbeddingResult, ctx wrapper.HttpContext, log wrapper.Log, err error) { + activeEmbeddingProvider := config.GetEmbeddingProvider() + if activeEmbeddingProvider == nil { + return logAndReturnError(log, "no embedding provider configured for similarity search") + } + + return activeEmbeddingProvider.GetEmbedding(key, ctx, log, func(textEmbedding []float64, err error) { if err != nil { - log.Errorf("error querying vector database: %v", err) - proxywasm.ResumeHttpRequest() + handleInternalError(err, fmt.Sprintf("Error getting embedding for key: %s", key), log) return } + ctx.SetContext(CACHE_KEY_EMBEDDING_KEY, textEmbedding) - if len(results) == 0 { - log.Warnf("no similar keys found in vector database for key: %s", key) - uploadEmbedding(textEmbedding, key, ctx, log, activeVectorProvider) - return + err = embeddingQuerier.QueryEmbedding(textEmbedding, ctx, log, func(results []vector.QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error) { + handleQueryResults(key, results, ctx, log, stream, config, err) + }) + if err != nil { + handleInternalError(err, fmt.Sprintf("Error querying vector database for key: %s", key), log) } + }) +} - mostSimilarData := results[0] - log.Debugf("most similar key found: %s with score: %f", mostSimilarData.Text, mostSimilarData.Score) +// handleQueryResults processes the results of similarity search and determines next actions. +func handleQueryResults(key string, results []vector.QueryResult, ctx wrapper.HttpContext, log wrapper.Log, stream bool, config config.PluginConfig, err error) { + if err != nil { + handleInternalError(err, fmt.Sprintf("Error querying vector database for key: %s", key), log) + return + } - if mostSimilarData.Score < activeVectorProvider.GetSimThreshold() { - log.Infof("key accepted: %s with score: %f below threshold", mostSimilarData.Text, mostSimilarData.Score) - CheckCacheForKey(mostSimilarData.Text, ctx, config, log, stream, false) + if len(results) == 0 { + log.Warnf("No similar keys found for key: %s", key) + proxywasm.ResumeHttpRequest() + return + } + + mostSimilarData := results[0] + log.Debugf("Most similar key found: %s with score: %f", mostSimilarData.Text, mostSimilarData.Score) + + simThresholdProvider, ok := config.GetVectorProvider().(vector.SimilarityThresholdProvider) + if !ok { + handleInternalError(nil, "Active vector provider does not implement SimilarityThresholdProvider interface", log) + return + } + + simThreshold := simThresholdProvider.GetSimilarityThreshold() + if mostSimilarData.Score < simThreshold { + log.Infof("Key accepted: %s with score: %f below threshold", mostSimilarData.Text, mostSimilarData.Score) + if mostSimilarData.Answer != "" { + // direct return the answer if available + processCacheHit(key, mostSimilarData.Answer, stream, ctx, config, log) } else { - log.Infof("score too high for key: %s with score: %f above threshold", mostSimilarData.Text, mostSimilarData.Score) - uploadEmbedding(textEmbedding, key, ctx, log, activeVectorProvider) + // otherwise, continue to check cache for the most similar key + CheckCacheForKey(mostSimilarData.Text, ctx, config, log, stream, false) } - }) + } else { + log.Infof("Score too high for key: %s with score: %f above threshold", mostSimilarData.Text, mostSimilarData.Score) + proxywasm.ResumeHttpRequest() + } +} +// logAndReturnError logs an error and returns it. +func logAndReturnError(log wrapper.Log, message string) error { + log.Errorf(message) + return errors.New(message) +} + +// handleInternalError logs an error and resumes the HTTP request. +func handleInternalError(err error, message string, log wrapper.Log) { if err != nil { - log.Errorf("error querying vector database: %v", err) - proxywasm.ResumeHttpRequest() + log.Errorf("%s: %v", message, err) + } else { + log.Errorf(message) + } + // proxywasm.SendHttpResponse(500, [][2]string{{"content-type", "text/plain"}}, []byte("Internal Server Error"), -1) + proxywasm.ResumeHttpRequest() +} + +// Caches the response value +func cacheResponse(ctx wrapper.HttpContext, config config.PluginConfig, key string, value string, log wrapper.Log) { + activeCacheProvider := config.GetCacheProvider() + if activeCacheProvider != nil { + queryKey := activeCacheProvider.GetCacheKeyPrefix() + key + log.Infof("[onHttpResponseBody] setting cache to redis, key: %s, value: %s", queryKey, value) + _ = activeCacheProvider.Set(queryKey, value, nil) } } -func uploadEmbedding(textEmbedding []float64, key string, ctx wrapper.HttpContext, log wrapper.Log, provider vector.Provider) { - provider.UploadEmbedding(textEmbedding, key, ctx, log, func(ctx wrapper.HttpContext, log wrapper.Log, err error) { +// Handles embedding upload if available +func uploadEmbeddingAndAnswer(ctx wrapper.HttpContext, config config.PluginConfig, key string, value string, log wrapper.Log) { + embedding := ctx.GetContext(CACHE_KEY_EMBEDDING_KEY) + if embedding == nil { + return + } + + emb, ok := embedding.([]float64) + if !ok { + log.Errorf("[onHttpResponseBody] embedding is not of expected type []float64") + return + } + + activeVectorProvider := config.GetVectorProvider() + if activeVectorProvider == nil { + log.Debug("[onHttpResponseBody] no vector provider configured for uploading embedding") + return + } + + // Attempt to upload answer embedding first + if ansEmbUploader, ok := activeVectorProvider.(vector.AnswerEmbeddingUploader); ok { + log.Infof("[onHttpResponseBody] uploading answer embedding for key: %s", key) + err := ansEmbUploader.UploadAnswerEmbedding(key, emb, value, ctx, log, nil) if err != nil { - log.Errorf("failed to upload embedding for key: %s, error: %v", key, err) + log.Warnf("[onHttpResponseBody] failed to upload answer embedding for key: %s, error: %v", key, err) } else { - log.Debugf("successfully uploaded embedding for key: %s", key) + return // If successful, return early } - proxywasm.ResumeHttpRequest() - }) + } + + // If answer embedding upload fails, attempt normal embedding upload + if embUploader, ok := activeVectorProvider.(vector.EmbeddingUploader); ok { + log.Infof("[onHttpResponseBody] uploading embedding for key: %s", key) + err := embUploader.UploadEmbedding(key, emb, ctx, log, nil) + if err != nil { + log.Warnf("[onHttpResponseBody] failed to upload embedding for key: %s, error: %v", key, err) + } + } } diff --git a/plugins/wasm-go/extensions/ai-cache/main.go b/plugins/wasm-go/extensions/ai-cache/main.go index 4bcb3b21ef..80558290a4 100644 --- a/plugins/wasm-go/extensions/ai-cache/main.go +++ b/plugins/wasm-go/extensions/ai-cache/main.go @@ -15,6 +15,7 @@ import ( const ( PLUGIN_NAME = "ai-cache" CACHE_KEY_CONTEXT_KEY = "cacheKey" + CACHE_KEY_EMBEDDING_KEY = "cacheKeyEmbedding" CACHE_CONTENT_CONTEXT_KEY = "cacheContent" PARTIAL_MESSAGE_CONTEXT_KEY = "partialMessage" TOOL_CALLS_CONTEXT_KEY = "toolCalls" @@ -23,7 +24,6 @@ const ( func main() { // CreateClient() - wrapper.SetCtx( PLUGIN_NAME, wrapper.ParseConfigBy(parseConfig), @@ -101,7 +101,10 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config config.PluginConfig, body return types.ActionContinue } - CheckCacheForKey(key, ctx, config, log, stream, true) + if err := CheckCacheForKey(key, ctx, config, log, stream, true); err != nil { + log.Errorf("check cache for key: %s failed, error: %v", key, err) + return types.ActionContinue + } return types.ActionPause } @@ -117,96 +120,34 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config config.PluginConfig, func onHttpResponseBody(ctx wrapper.HttpContext, config config.PluginConfig, chunk []byte, isLastChunk bool, log wrapper.Log) []byte { log.Debugf("[onHttpResponseBody] chunk: %s", string(chunk)) log.Debugf("[onHttpResponseBody] isLastChunk: %v", isLastChunk) + + // If the context contains TOOL_CALLS_CONTEXT_KEY, bypass caching if ctx.GetContext(TOOL_CALLS_CONTEXT_KEY) != nil { - // we should not cache tool call result return chunk } + keyI := ctx.GetContext(CACHE_KEY_CONTEXT_KEY) if keyI == nil { return chunk } + if !isLastChunk { - stream := ctx.GetContext(STREAM_CONTEXT_KEY) - if stream == nil { - tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) - if tempContentI == nil { - ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, chunk) - return chunk - } - tempContent := tempContentI.([]byte) - tempContent = append(tempContent, chunk...) - ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, tempContent) - } else { - var partialMessage []byte - partialMessageI := ctx.GetContext(PARTIAL_MESSAGE_CONTEXT_KEY) - if partialMessageI != nil { - partialMessage = append(partialMessageI.([]byte), chunk...) - } else { - partialMessage = chunk - } - messages := strings.Split(string(partialMessage), "\n\n") - for i, msg := range messages { - if i < len(messages)-1 { - // process complete message - processSSEMessage(ctx, config, msg, log) - } - } - if !strings.HasSuffix(string(partialMessage), "\n\n") { - ctx.SetContext(PARTIAL_MESSAGE_CONTEXT_KEY, []byte(messages[len(messages)-1])) - } else { - ctx.SetContext(PARTIAL_MESSAGE_CONTEXT_KEY, nil) - } - } + handlePartialChunk(ctx, config, chunk, log) return chunk } - // last chunk - stream := ctx.GetContext(STREAM_CONTEXT_KEY) - var value string - if stream == nil { - var body []byte - tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) - if tempContentI != nil { - body = append(tempContentI.([]byte), chunk...) - } else { - body = chunk - } - bodyJson := gjson.ParseBytes(body) - - value = bodyJson.Get(config.CacheValueFrom).String() - if value == "" { - log.Warnf("parse value from response body failded, body: %s", body) - return chunk - } - } else { - log.Infof("[onHttpResponseBody] stream mode") - if len(chunk) > 0 { - var lastMessage []byte - partialMessageI := ctx.GetContext(PARTIAL_MESSAGE_CONTEXT_KEY) - if partialMessageI != nil { - lastMessage = append(partialMessageI.([]byte), chunk...) - } else { - lastMessage = chunk - } - if !strings.HasSuffix(string(lastMessage), "\n\n") { - log.Warnf("[onHttpResponseBody] invalid lastMessage:%s", lastMessage) - return chunk - } - // remove the last \n\n - lastMessage = lastMessage[:len(lastMessage)-2] - value = processSSEMessage(ctx, config, string(lastMessage), log) - } else { - tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) - if tempContentI == nil { - log.Warnf("[onHttpResponseBody] no content in tempContentI") - return chunk - } - value = tempContentI.(string) - } + + // Handle last chunk + value, err := processFinalChunk(ctx, config, chunk, log) + if err != nil { + log.Warnf("[onHttpResponseBody] failed to process final chunk: %v", err) + return chunk } - activeCacheProvider := config.GetCacheProvider() - queryKey := activeCacheProvider.GetCacheKeyPrefix() + ctx.GetContext(CACHE_KEY_CONTEXT_KEY).(string) - // queryKey := keyI.(string) - log.Infof("[onHttpResponseBody] setting cache to redis, key: %s, value: %s", queryKey, value) - _ = activeCacheProvider.Set(queryKey, value, nil) + + // Cache the final value + cacheResponse(ctx, config, keyI.(string), value, log) + + // Handle embedding upload if available + uploadEmbeddingAndAnswer(ctx, config, keyI.(string), value, log) + return chunk } diff --git a/plugins/wasm-go/extensions/ai-cache/util.go b/plugins/wasm-go/extensions/ai-cache/util.go index 32642ec64a..c13161c129 100644 --- a/plugins/wasm-go/extensions/ai-cache/util.go +++ b/plugins/wasm-go/extensions/ai-cache/util.go @@ -1,6 +1,7 @@ package main import ( + "fmt" "strings" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/config" @@ -56,3 +57,96 @@ func processSSEMessage(ctx wrapper.HttpContext, config config.PluginConfig, sseM } } + +// Handles partial chunks of data when the full response is not received yet. +func handlePartialChunk(ctx wrapper.HttpContext, config config.PluginConfig, chunk []byte, log wrapper.Log) { + stream := ctx.GetContext(STREAM_CONTEXT_KEY) + + if stream == nil { + tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) + if tempContentI == nil { + ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, chunk) + } else { + tempContent := append(tempContentI.([]byte), chunk...) + ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, tempContent) + } + } else { + partialMessage := appendPartialMessage(ctx, chunk) + messages := strings.Split(string(partialMessage), "\n\n") + for _, msg := range messages[:len(messages)-1] { + processSSEMessage(ctx, config, msg, log) + } + savePartialMessage(ctx, partialMessage, messages) + } +} + +// Appends the partial message chunks +func appendPartialMessage(ctx wrapper.HttpContext, chunk []byte) []byte { + partialMessageI := ctx.GetContext(PARTIAL_MESSAGE_CONTEXT_KEY) + if partialMessageI != nil { + return append(partialMessageI.([]byte), chunk...) + } + return chunk +} + +// Saves the remaining partial message chunk +func savePartialMessage(ctx wrapper.HttpContext, partialMessage []byte, messages []string) { + if !strings.HasSuffix(string(partialMessage), "\n\n") { + ctx.SetContext(PARTIAL_MESSAGE_CONTEXT_KEY, []byte(messages[len(messages)-1])) + } else { + ctx.SetContext(PARTIAL_MESSAGE_CONTEXT_KEY, nil) + } +} + +// Processes the final chunk and returns the parsed value or an error +func processFinalChunk(ctx wrapper.HttpContext, config config.PluginConfig, chunk []byte, log wrapper.Log) (string, error) { + stream := ctx.GetContext(STREAM_CONTEXT_KEY) + var value string + + if stream == nil { + body := appendFinalBody(ctx, chunk) + bodyJson := gjson.ParseBytes(body) + value = bodyJson.Get(config.CacheValueFrom).String() + + if value == "" { + return "", fmt.Errorf("failed to parse value from response body: %s", body) + } + } else { + value, err := processFinalStreamMessage(ctx, config, log, chunk) + if err != nil { + return "", err + } + return value, nil + } + + return value, nil +} + +// Appends the final body chunk to the existing body content +func appendFinalBody(ctx wrapper.HttpContext, chunk []byte) []byte { + tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) + if tempContentI != nil { + return append(tempContentI.([]byte), chunk...) + } + return chunk +} + +// Processes the final SSE message chunk +func processFinalStreamMessage(ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, chunk []byte) (string, error) { + var lastMessage []byte + partialMessageI := ctx.GetContext(PARTIAL_MESSAGE_CONTEXT_KEY) + + if partialMessageI != nil { + lastMessage = append(partialMessageI.([]byte), chunk...) + } else { + lastMessage = chunk + } + + if !strings.HasSuffix(string(lastMessage), "\n\n") { + log.Warnf("[onHttpResponseBody] invalid lastMessage: %s", lastMessage) + return "", fmt.Errorf("invalid lastMessage format") + } + + lastMessage = lastMessage[:len(lastMessage)-2] // Remove the last \n\n + return processSSEMessage(ctx, config, string(lastMessage), log), nil +} diff --git a/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go b/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go index 5ae320d815..a5ca7e0909 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go @@ -127,7 +127,7 @@ func (d *DvProvider) QueryEmbedding( emb []float64, ctx wrapper.HttpContext, log wrapper.Log, - callback func(results []QueryEmbeddingResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error { + callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error { url, body, headers, err := d.constructEmbeddingQueryParameters(emb) log.Debugf("url:%s, body:%s, headers:%v", url, string(body), headers) if err != nil { @@ -157,7 +157,7 @@ func (d *DvProvider) QueryEmbedding( return err } -func (d *DvProvider) ParseQueryResponse(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]QueryEmbeddingResult, error) { +func (d *DvProvider) ParseQueryResponse(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]QueryResult, error) { resp, err := d.parseQueryResponse(responseBody) if err != nil { return nil, err @@ -167,10 +167,10 @@ func (d *DvProvider) ParseQueryResponse(responseBody []byte, ctx wrapper.HttpCon return nil, errors.New("no query results found in response") } - results := make([]QueryEmbeddingResult, 0, len(resp.Output)) + results := make([]QueryResult, 0, len(resp.Output)) for _, output := range resp.Output { - result := QueryEmbeddingResult{ + result := QueryResult{ Text: output.Fields["query"].(string), Embedding: output.Vector, Score: output.Score, @@ -190,13 +190,14 @@ type insertRequest struct { Docs []document `json:"docs"` } -func (d *DvProvider) constructEmbeddingUploadParameters(emb []float64, queryString string) (string, []byte, [][2]string, error) { +func (d *DvProvider) constructUploadParameters(emb []float64, queryString string, answer string) (string, []byte, [][2]string, error) { url := "/v1/collections/" + d.config.collectionID + "/docs" doc := document{ Vector: emb, Fields: map[string]string{ - "query": queryString, + "query": queryString, + "answer": answer, }, } @@ -213,8 +214,28 @@ func (d *DvProvider) constructEmbeddingUploadParameters(emb []float64, queryStri return url, requestBody, header, err } -func (d *DvProvider) UploadEmbedding(queryEmb []float64, queryString string, ctx wrapper.HttpContext, log wrapper.Log, callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error { - url, body, headers, err := d.constructEmbeddingUploadParameters(queryEmb, queryString) +func (d *DvProvider) UploadEmbedding(queryString string, queryEmb []float64, ctx wrapper.HttpContext, log wrapper.Log, callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error { + url, body, headers, err := d.constructUploadParameters(queryEmb, queryString, "") + if err != nil { + return err + } + err = d.client.Post( + url, + headers, + body, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + log.Debugf("statusCode:%d, responseBody:%s", statusCode, string(responseBody)) + if statusCode != http.StatusOK { + err = fmt.Errorf("failed to upload embedding: %d", statusCode) + } + callback(ctx, log, err) + }, + d.config.timeout) + return err +} + +func (d *DvProvider) UploadAnswerEmbedding(queryString string, queryEmb []float64, queryAnswer string, ctx wrapper.HttpContext, log wrapper.Log, callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error { + url, body, headers, err := d.constructUploadParameters(queryEmb, queryString, queryAnswer) if err != nil { return err } diff --git a/plugins/wasm-go/extensions/ai-cache/vector/provider.go b/plugins/wasm-go/extensions/ai-cache/vector/provider.go index 6e8b2a86b2..35789e14ca 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/provider.go @@ -25,26 +25,54 @@ var ( ) // QueryEmbeddingResult 定义通用的查询结果的结构体 -type QueryEmbeddingResult struct { +type QueryResult struct { Text string // 相似的文本 Embedding []float64 // 相似文本的向量 Score float64 // 文本的向量相似度或距离等度量 + Answer string // 相似文本对应的LLM生成的回答 } type Provider interface { GetProviderType() string +} + +type EmbeddingQuerier interface { QueryEmbedding( emb []float64, ctx wrapper.HttpContext, log wrapper.Log, - callback func(results []QueryEmbeddingResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error + callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error +} + +type EmbeddingUploader interface { UploadEmbedding( + queryString string, queryEmb []float64, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error +} + +type AnswerEmbeddingUploader interface { + UploadAnswerEmbedding( queryString string, + queryEmb []float64, + answer string, ctx wrapper.HttpContext, log wrapper.Log, callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error - GetSimThreshold() float64 +} + +type StringQuerier interface { + QueryString( + queryString string, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error +} + +type SimilarityThresholdProvider interface { + GetSimilarityThreshold() float64 } type ProviderConfig struct { From 51b9cccf78f233166abd7caa2454966a89827250 Mon Sep 17 00:00:00 2001 From: suchun Date: Sat, 21 Sep 2024 00:53:30 +0000 Subject: [PATCH 26/71] Fix bugs and verify removal of cache setting --- .../extensions/ai-cache/cache/provider.go | 4 ++ .../extensions/ai-cache/config/config.go | 43 +++++++++++++------ plugins/wasm-go/extensions/ai-cache/core.go | 30 ++++++------- plugins/wasm-go/extensions/ai-cache/main.go | 28 ++++++++---- plugins/wasm-go/extensions/ai-cache/util.go | 11 ++++- .../extensions/ai-cache/vector/dashvector.go | 14 +++++- .../extensions/ai-cache/vector/provider.go | 4 ++ 7 files changed, 95 insertions(+), 39 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-cache/cache/provider.go b/plugins/wasm-go/extensions/ai-cache/cache/provider.go index 0390bb2b11..9cdeaac262 100644 --- a/plugins/wasm-go/extensions/ai-cache/cache/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/cache/provider.go @@ -53,6 +53,10 @@ type ProviderConfig struct { cacheKeyPrefix string } +func (c *ProviderConfig) GetProviderType() string { + return c.typ +} + func (c *ProviderConfig) FromJson(json gjson.Result) { c.typ = json.Get("type").String() c.serviceName = json.Get("serviceName").String() diff --git a/plugins/wasm-go/extensions/ai-cache/config/config.go b/plugins/wasm-go/extensions/ai-cache/config/config.go index fad64d70ee..10aa354e76 100644 --- a/plugins/wasm-go/extensions/ai-cache/config/config.go +++ b/plugins/wasm-go/extensions/ai-cache/config/config.go @@ -31,8 +31,8 @@ type PluginConfig struct { } func (c *PluginConfig) FromJson(json gjson.Result) { - c.embeddingProviderConfig.FromJson(json.Get("embedding")) c.vectorProviderConfig.FromJson(json.Get("vector")) + c.embeddingProviderConfig.FromJson(json.Get("embedding")) c.cacheProviderConfig.FromJson(json.Get("cache")) c.CacheKeyFrom = json.Get("cacheKeyFrom").String() @@ -54,20 +54,25 @@ func (c *PluginConfig) FromJson(json gjson.Result) { c.StreamResponseTemplate = json.Get("streamResponseTemplate").String() if c.StreamResponseTemplate == "" { - c.StreamResponseTemplate = `data:{"id":"ai-cache.hit","choices":[{"index":0,"delta":{"role":"assistant","content":%s},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + "\n\ndata:[DONE]\n\n" + c.StreamResponseTemplate = `data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + "\n\ndata:[DONE]\n\n" } c.ResponseTemplate = json.Get("responseTemplate").String() if c.ResponseTemplate == "" { - c.ResponseTemplate = `{"id":"ai-cache.hit","choices":[{"index":0,"message":{"role":"assistant","content":%s},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + c.ResponseTemplate = `{"id":"from-cache","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` } } func (c *PluginConfig) Validate() error { - if err := c.cacheProviderConfig.Validate(); err != nil { - return err + // if cache provider is configured, validate it + if c.cacheProviderConfig.GetProviderType() != "" { + if err := c.cacheProviderConfig.Validate(); err != nil { + return err + } } - if err := c.embeddingProviderConfig.Validate(); err != nil { - return err + if c.embeddingProviderConfig.GetProviderType() != "" { + if err := c.embeddingProviderConfig.Validate(); err != nil { + return err + } } if err := c.vectorProviderConfig.Validate(); err != nil { return err @@ -77,15 +82,25 @@ func (c *PluginConfig) Validate() error { func (c *PluginConfig) Complete(log wrapper.Log) error { var err error - c.embeddingProvider, err = embedding.CreateProvider(c.embeddingProviderConfig) - if err != nil { - return err + if c.embeddingProviderConfig.GetProviderType() != "" { + c.embeddingProvider, err = embedding.CreateProvider(c.embeddingProviderConfig) + if err != nil { + return err + } + } else { + log.Info("embedding provider is not configured") + c.embeddingProvider = nil } - c.vectorProvider, err = vector.CreateProvider(c.vectorProviderConfig) - if err != nil { - return err + if c.cacheProviderConfig.GetProviderType() != "" { + c.cacheProvider, err = cache.CreateProvider(c.cacheProviderConfig) + if err != nil { + return err + } + } else { + log.Info("cache provider is not configured") + c.cacheProvider = nil } - c.cacheProvider, err = cache.CreateProvider(c.cacheProviderConfig) + c.vectorProvider, err = vector.CreateProvider(c.vectorProviderConfig) if err != nil { return err } diff --git a/plugins/wasm-go/extensions/ai-cache/core.go b/plugins/wasm-go/extensions/ai-cache/core.go index 1ce9e4b15c..1e5d433ca1 100644 --- a/plugins/wasm-go/extensions/ai-cache/core.go +++ b/plugins/wasm-go/extensions/ai-cache/core.go @@ -1,13 +1,12 @@ package main import ( - "encoding/json" + "errors" "fmt" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/config" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/vector" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" - "github.com/go-errors/errors" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/tidwall/resp" ) @@ -57,22 +56,23 @@ func handleCacheResponse(key string, response resp.Value, ctx wrapper.HttpContex // processCacheHit handles a successful cache hit. func processCacheHit(key string, response string, stream bool, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log) { - escapedResponse, err := json.Marshal(response) - log.Debugf("Cached response for key %s: %s", key, escapedResponse) - - if err != nil { - handleInternalError(err, "Failed to marshal cached response", log) - return + if stream { + log.Debug("streaming response is not supported for cache hit yet") + stream = false } + // escapedResponse, err := json.Marshal(response) + // log.Debugf("Cached response for key %s: %s", key, escapedResponse) - ctx.SetContext(CACHE_KEY_CONTEXT_KEY, nil) + // if err != nil { + // handleInternalError(err, "Failed to marshal cached response", log) + // return + // } + log.Debugf("Cached response for key %s: %s", key, response) - contentType := "application/json; charset=utf-8" - if stream { - contentType = "text/event-stream; charset=utf-8" - } + ctx.SetContext(CACHE_KEY_CONTEXT_KEY, nil) - proxywasm.SendHttpResponse(200, [][2]string{{"content-type", contentType}}, []byte(fmt.Sprintf(config.ResponseTemplate, escapedResponse)), -1) + // proxywasm.SendHttpResponseWithDetail(200, "ai-cache.hit", [][2]string{{"content-type", "application/json; charset=utf-8"}}, []byte(fmt.Sprintf(config.ResponseTemplate, escapedResponse)), -1) + proxywasm.SendHttpResponseWithDetail(200, "ai-cache.hit", [][2]string{{"content-type", "application/json; charset=utf-8"}}, []byte(fmt.Sprintf(config.ResponseTemplate, response)), -1) } // performSimilaritySearch determines the appropriate similarity search method to use. @@ -149,7 +149,7 @@ func handleQueryResults(key string, results []vector.QueryResult, ctx wrapper.Ht } mostSimilarData := results[0] - log.Debugf("Most similar key found: %s with score: %f", mostSimilarData.Text, mostSimilarData.Score) + log.Debugf("For key: %s, the most similar key found: %s with score: %f", key, mostSimilarData.Text, mostSimilarData.Score) simThresholdProvider, ok := config.GetVectorProvider().(vector.SimilarityThresholdProvider) if !ok { diff --git a/plugins/wasm-go/extensions/ai-cache/main.go b/plugins/wasm-go/extensions/ai-cache/main.go index 80558290a4..4aa04c0f35 100644 --- a/plugins/wasm-go/extensions/ai-cache/main.go +++ b/plugins/wasm-go/extensions/ai-cache/main.go @@ -118,16 +118,21 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config config.PluginConfig, } func onHttpResponseBody(ctx wrapper.HttpContext, config config.PluginConfig, chunk []byte, isLastChunk bool, log wrapper.Log) []byte { - log.Debugf("[onHttpResponseBody] chunk: %s", string(chunk)) + log.Debugf("[onHttpResponseBody] escaped chunk: %q", string(chunk)) log.Debugf("[onHttpResponseBody] isLastChunk: %v", isLastChunk) + // if strings.HasSuffix(string(chunk), "[DONE] \n\n") { + // isLastChunk = true + // } + // If the context contains TOOL_CALLS_CONTEXT_KEY, bypass caching if ctx.GetContext(TOOL_CALLS_CONTEXT_KEY) != nil { return chunk } - keyI := ctx.GetContext(CACHE_KEY_CONTEXT_KEY) - if keyI == nil { + key := ctx.GetContext(CACHE_KEY_CONTEXT_KEY) + if key == nil { + log.Debug("[onHttpResponseBody] key is nil, bypass caching") return chunk } @@ -137,17 +142,24 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config config.PluginConfig, chu } // Handle last chunk - value, err := processFinalChunk(ctx, config, chunk, log) + var value string + var err error + + if len(chunk) > 0 { + value, err = processNonEmptyChunk(ctx, config, chunk, log) + } else { + value, err = processEmptyChunk(ctx, config, chunk, log) + } + if err != nil { - log.Warnf("[onHttpResponseBody] failed to process final chunk: %v", err) + log.Warnf("[onHttpResponseBody] failed to process chunk: %v", err) return chunk } - // Cache the final value - cacheResponse(ctx, config, keyI.(string), value, log) + cacheResponse(ctx, config, key.(string), value, log) // Handle embedding upload if available - uploadEmbeddingAndAnswer(ctx, config, keyI.(string), value, log) + uploadEmbeddingAndAnswer(ctx, config, key.(string), value, log) return chunk } diff --git a/plugins/wasm-go/extensions/ai-cache/util.go b/plugins/wasm-go/extensions/ai-cache/util.go index c13161c129..4459ce427b 100644 --- a/plugins/wasm-go/extensions/ai-cache/util.go +++ b/plugins/wasm-go/extensions/ai-cache/util.go @@ -99,7 +99,7 @@ func savePartialMessage(ctx wrapper.HttpContext, partialMessage []byte, messages } // Processes the final chunk and returns the parsed value or an error -func processFinalChunk(ctx wrapper.HttpContext, config config.PluginConfig, chunk []byte, log wrapper.Log) (string, error) { +func processNonEmptyChunk(ctx wrapper.HttpContext, config config.PluginConfig, chunk []byte, log wrapper.Log) (string, error) { stream := ctx.GetContext(STREAM_CONTEXT_KEY) var value string @@ -122,6 +122,15 @@ func processFinalChunk(ctx wrapper.HttpContext, config config.PluginConfig, chun return value, nil } +func processEmptyChunk(ctx wrapper.HttpContext, config config.PluginConfig, chunk []byte, log wrapper.Log) (string, error) { + tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) + if tempContentI == nil { + return string(chunk), nil + } + value := tempContentI.(string) + return value, nil +} + // Appends the final body chunk to the existing body content func appendFinalBody(ctx wrapper.HttpContext, chunk []byte) []byte { tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) diff --git a/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go b/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go index a5ca7e0909..a39e29f551 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go @@ -157,6 +157,13 @@ func (d *DvProvider) QueryEmbedding( return err } +func checkField(fields map[string]interface{}, key string) string { + if val, ok := fields[key]; ok { + return val.(string) + } + return "" +} + func (d *DvProvider) ParseQueryResponse(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]QueryResult, error) { resp, err := d.parseQueryResponse(responseBody) if err != nil { @@ -171,9 +178,10 @@ func (d *DvProvider) ParseQueryResponse(responseBody []byte, ctx wrapper.HttpCon for _, output := range resp.Output { result := QueryResult{ - Text: output.Fields["query"].(string), + Text: checkField(output.Fields, "query"), Embedding: output.Vector, Score: output.Score, + Answer: checkField(output.Fields, "answer"), } results = append(results, result) } @@ -234,6 +242,10 @@ func (d *DvProvider) UploadEmbedding(queryString string, queryEmb []float64, ctx return err } +func (d *DvProvider) GetSimilarityThreshold() float64 { + return threshold +} + func (d *DvProvider) UploadAnswerEmbedding(queryString string, queryEmb []float64, queryAnswer string, ctx wrapper.HttpContext, log wrapper.Log, callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error { url, body, headers, err := d.constructUploadParameters(queryEmb, queryString, queryAnswer) if err != nil { diff --git a/plugins/wasm-go/extensions/ai-cache/vector/provider.go b/plugins/wasm-go/extensions/ai-cache/vector/provider.go index 35789e14ca..56e0e0c71d 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/provider.go @@ -118,6 +118,10 @@ type ProviderConfig struct { // ChromaTimeout uint32 `require:"false" yaml:"ChromaTimeout" json:"ChromaTimeout"` } +func (c *ProviderConfig) GetProviderType() string { + return c.typ +} + func (c *ProviderConfig) FromJson(json gjson.Result) { c.typ = json.Get("type").String() // DashVector From 3583bc9215d83054427c2b735213a4e458269fc4 Mon Sep 17 00:00:00 2001 From: suchun Date: Sat, 21 Sep 2024 01:17:42 +0000 Subject: [PATCH 27/71] fix bugs and update logic as requested --- plugins/wasm-go/extensions/ai-cache/core.go | 14 ++++++++++---- .../ai-cache/embedding/dashscope.go | 19 ++++++++++--------- .../extensions/ai-cache/vector/dashvector.go | 2 +- .../extensions/ai-cache/vector/provider.go | 4 ++-- 4 files changed, 23 insertions(+), 16 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-cache/core.go b/plugins/wasm-go/extensions/ai-cache/core.go index 1e5d433ca1..028e85cfc5 100644 --- a/plugins/wasm-go/extensions/ai-cache/core.go +++ b/plugins/wasm-go/extensions/ai-cache/core.go @@ -49,9 +49,11 @@ func handleCacheResponse(key string, response resp.Value, ctx wrapper.HttpContex if useSimilaritySearch { if err := performSimilaritySearch(key, ctx, config, log, key, stream); err != nil { log.Errorf("Failed to perform similarity search for key: %s, error: %v", key, err) + proxywasm.ResumeHttpRequest() } + } else { + proxywasm.ResumeHttpRequest() } - proxywasm.ResumeHttpRequest() } // processCacheHit handles a successful cache hit. @@ -165,7 +167,11 @@ func handleQueryResults(key string, results []vector.QueryResult, ctx wrapper.Ht processCacheHit(key, mostSimilarData.Answer, stream, ctx, config, log) } else { // otherwise, continue to check cache for the most similar key - CheckCacheForKey(mostSimilarData.Text, ctx, config, log, stream, false) + err = CheckCacheForKey(mostSimilarData.Text, ctx, config, log, stream, false) + if err != nil { + log.Errorf("check cache for key: %s failed, error: %v", mostSimilarData.Text, err) + proxywasm.ResumeHttpRequest() + } } } else { log.Infof("Score too high for key: %s with score: %f above threshold", mostSimilarData.Text, mostSimilarData.Score) @@ -220,9 +226,9 @@ func uploadEmbeddingAndAnswer(ctx wrapper.HttpContext, config config.PluginConfi } // Attempt to upload answer embedding first - if ansEmbUploader, ok := activeVectorProvider.(vector.AnswerEmbeddingUploader); ok { + if ansEmbUploader, ok := activeVectorProvider.(vector.AnswerAndEmbeddingUploader); ok { log.Infof("[onHttpResponseBody] uploading answer embedding for key: %s", key) - err := ansEmbUploader.UploadAnswerEmbedding(key, emb, value, ctx, log, nil) + err := ansEmbUploader.UploadAnswerAndEmbedding(key, emb, value, ctx, log, nil) if err != nil { log.Warnf("[onHttpResponseBody] failed to upload answer embedding for key: %s, error: %v", key, err) } else { diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go b/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go index fe06752398..eef81ea64c 100644 --- a/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go +++ b/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go @@ -3,6 +3,7 @@ package embedding import ( "encoding/json" "errors" + "fmt" "net/http" "strconv" @@ -10,10 +11,10 @@ import ( ) const ( - DOMAIN = "dashscope.aliyuncs.com" - PORT = 443 - DEFAULT_MODEL_NAME = "text-embedding-v1" - ENDPOINT = "/api/v1/services/embeddings/text-embedding/text-embedding" + DASHSCOPE_DOMAIN = "dashscope.aliyuncs.com" + DASHSCOPE_PORT = 443 + DASHSCOPE_DEFAULT_MODEL_NAME = "text-embedding-v1" + DASHSCOPE_ENDPOINT = "/api/v1/services/embeddings/text-embedding/text-embedding" ) type dashScopeProviderInitializer struct { @@ -28,10 +29,10 @@ func (d *dashScopeProviderInitializer) ValidateConfig(config ProviderConfig) err func (d *dashScopeProviderInitializer) CreateProvider(c ProviderConfig) (Provider, error) { if c.servicePort == 0 { - c.servicePort = PORT + c.servicePort = DASHSCOPE_PORT } if c.serviceDomain == "" { - c.serviceDomain = DOMAIN + c.serviceDomain = DASHSCOPE_DOMAIN } return &DSProvider{ config: c, @@ -95,7 +96,7 @@ func (d *DSProvider) constructParameters(texts []string, log wrapper.Log) (strin model := d.config.model if model == "" { - model = DEFAULT_MODEL_NAME + model = DASHSCOPE_DEFAULT_MODEL_NAME } data := EmbeddingRequest{ Model: model, @@ -124,7 +125,7 @@ func (d *DSProvider) constructParameters(texts []string, log wrapper.Log) (strin {"Content-Type", "application/json"}, } - return ENDPOINT, headers, requestBody, err + return DASHSCOPE_ENDPOINT, headers, requestBody, err } type Result struct { @@ -168,7 +169,7 @@ func (d *DSProvider) GetEmbedding( resp, err = d.parseTextEmbedding(responseBody) if err != nil { - err = errors.New("failed to parse response: " + err.Error()) + err = fmt.Errorf("failed to parse response: %v", err) callback(nil, err) return } diff --git a/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go b/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go index a39e29f551..a5fcfcf957 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go @@ -246,7 +246,7 @@ func (d *DvProvider) GetSimilarityThreshold() float64 { return threshold } -func (d *DvProvider) UploadAnswerEmbedding(queryString string, queryEmb []float64, queryAnswer string, ctx wrapper.HttpContext, log wrapper.Log, callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error { +func (d *DvProvider) UploadAnswerAndEmbedding(queryString string, queryEmb []float64, queryAnswer string, ctx wrapper.HttpContext, log wrapper.Log, callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error { url, body, headers, err := d.constructUploadParameters(queryEmb, queryString, queryAnswer) if err != nil { return err diff --git a/plugins/wasm-go/extensions/ai-cache/vector/provider.go b/plugins/wasm-go/extensions/ai-cache/vector/provider.go index 56e0e0c71d..32031bc467 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/provider.go @@ -53,8 +53,8 @@ type EmbeddingUploader interface { callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error } -type AnswerEmbeddingUploader interface { - UploadAnswerEmbedding( +type AnswerAndEmbeddingUploader interface { + UploadAnswerAndEmbedding( queryString string, queryEmb []float64, answer string, From c2615834332ce70bc0f29a99019e43bb45d12519 Mon Sep 17 00:00:00 2001 From: suchun Date: Mon, 14 Oct 2024 11:28:17 +0000 Subject: [PATCH 28/71] add cacheKeyStrategy and enableSemanticCache --- plugins/wasm-go/extensions/ai-cache/README.md | 12 ++++++ .../extensions/ai-cache/config/config.go | 39 +++++++++++++++++-- plugins/wasm-go/extensions/ai-cache/core.go | 19 +++++---- plugins/wasm-go/extensions/ai-cache/main.go | 29 +++++++++++--- .../extensions/ai-cache/vector/dashvector.go | 6 +-- 5 files changed, 85 insertions(+), 20 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-cache/README.md b/plugins/wasm-go/extensions/ai-cache/README.md index 6537842013..88bd265341 100644 --- a/plugins/wasm-go/extensions/ai-cache/README.md +++ b/plugins/wasm-go/extensions/ai-cache/README.md @@ -23,6 +23,18 @@ LLM 结果缓存插件,默认配置方式可以直接用于 openai 协议的 ## 配置说明 配置分为 3 个部分:向量数据库(vector);文本向量化接口(embedding);缓存数据库(cache),同时也提供了细粒度的 LLM 请求/响应提取参数配置等。 +## 配置说明 + +首先本插件必须配置向量数据库服务(vector),然后根据向量数据库服务类型来决定是否配置文本向量化接口(embedding)来将问题转换为向量,最后根据缓存服务类型来决定是否配置缓存服务(cache)来缓存LLM的回答。 + +| Name | Type | Requirement | Default | Description | +| --- | --- | --- | --- | --- | +| vector.type | string | required | "" | 向量存储服务提供者类型,例如 DashVector | +| embedding.type | string | optional | "" | 请求文本向量化服务类型,例如 DashScope | +| cache.type | string | optional | "" | 缓存服务类型,例如 redis | +| cacheKeyStrategy | string | optional | "lastQuestion" | 决定如何根据历史问题生成缓存键的策略。可选值: "lastQuestion" (使用最后一个问题), "allQuestions" (拼接所有问题) 或 "disable" (禁用缓存) | +| enableSemanticCache | bool | optional | true | 是否启用语义化缓存, 若不启用,则使用逐字匹配的方式来查找缓存,此时需要配置cache服务 | + ## 向量数据库服务(vector) | Name | Type | Requirement | Default | Description | | --- | --- | --- | --- | --- | diff --git a/plugins/wasm-go/extensions/ai-cache/config/config.go b/plugins/wasm-go/extensions/ai-cache/config/config.go index 10aa354e76..249c768527 100644 --- a/plugins/wasm-go/extensions/ai-cache/config/config.go +++ b/plugins/wasm-go/extensions/ai-cache/config/config.go @@ -1,6 +1,8 @@ package config import ( + "fmt" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/cache" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/embedding" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/vector" @@ -24,10 +26,18 @@ type PluginConfig struct { vectorProviderConfig vector.ProviderConfig cacheProviderConfig cache.ProviderConfig - CacheKeyFrom string + // CacheKeyFrom string CacheValueFrom string CacheStreamValueFrom string CacheToolCallsFrom string + + // @Title zh-CN 启用语义化缓存 + // @Description zh-CN 控制是否启用语义化缓存功能。true 表示启用,false 表示禁用。 + EnableSemanticCache bool + + // @Title zh-CN 缓存键策略 + // @Description zh-CN 决定如何生成缓存键的策略。可选值: "lastQuestion" (使用最后一个问题), "allQuestions" (拼接所有问题) 或 "disable" (禁用缓存) + CacheKeyStrategy string } func (c *PluginConfig) FromJson(json gjson.Result) { @@ -35,10 +45,14 @@ func (c *PluginConfig) FromJson(json gjson.Result) { c.embeddingProviderConfig.FromJson(json.Get("embedding")) c.cacheProviderConfig.FromJson(json.Get("cache")) - c.CacheKeyFrom = json.Get("cacheKeyFrom").String() - if c.CacheKeyFrom == "" { - c.CacheKeyFrom = "messages.@reverse.0.content" + c.CacheKeyStrategy = json.Get("cacheKeyStrategy").String() + if c.CacheKeyStrategy == "" { + c.CacheKeyStrategy = "lastQuestion" // 设置默认值 } + // c.CacheKeyFrom = json.Get("cacheKeyFrom").String() + // if c.CacheKeyFrom == "" { + // c.CacheKeyFrom = "messages.@reverse.0.content" + // } c.CacheValueFrom = json.Get("cacheValueFrom").String() if c.CacheValueFrom == "" { c.CacheValueFrom = "choices.0.message.content" @@ -60,6 +74,13 @@ func (c *PluginConfig) FromJson(json gjson.Result) { if c.ResponseTemplate == "" { c.ResponseTemplate = `{"id":"from-cache","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` } + + // 默认值为 true + if json.Get("enableSemanticCache").Exists() { + c.EnableSemanticCache = json.Get("enableSemanticCache").Bool() + } else { + c.EnableSemanticCache = true // 设置默认值为 true + } } func (c *PluginConfig) Validate() error { @@ -77,6 +98,16 @@ func (c *PluginConfig) Validate() error { if err := c.vectorProviderConfig.Validate(); err != nil { return err } + // 验证 CacheKeyStrategy 的值 + if c.CacheKeyStrategy != "lastQuestion" && c.CacheKeyStrategy != "allQuestions" && c.CacheKeyStrategy != "disable" { + return fmt.Errorf("invalid CacheKeyStrategy: %s", c.CacheKeyStrategy) + } + // 如果启用了语义化缓存,确保必要的组件已配置 + if c.EnableSemanticCache { + if c.embeddingProviderConfig.GetProviderType() == "" { + return fmt.Errorf("semantic cache is enabled but embedding provider is not configured") + } + } return nil } diff --git a/plugins/wasm-go/extensions/ai-cache/core.go b/plugins/wasm-go/extensions/ai-cache/core.go index 028e85cfc5..9b679a6947 100644 --- a/plugins/wasm-go/extensions/ai-cache/core.go +++ b/plugins/wasm-go/extensions/ai-cache/core.go @@ -46,7 +46,8 @@ func handleCacheResponse(key string, response resp.Value, ctx wrapper.HttpContex if err := response.Error(); err != nil { log.Errorf("Error retrieving key: %s from cache, error: %v", key, err) } - if useSimilaritySearch { + + if useSimilaritySearch && config.EnableSemanticCache { if err := performSimilaritySearch(key, ctx, config, log, key, stream); err != nil { log.Errorf("Failed to perform similarity search for key: %s, error: %v", key, err) proxywasm.ResumeHttpRequest() @@ -166,12 +167,16 @@ func handleQueryResults(key string, results []vector.QueryResult, ctx wrapper.Ht // direct return the answer if available processCacheHit(key, mostSimilarData.Answer, stream, ctx, config, log) } else { - // otherwise, continue to check cache for the most similar key - err = CheckCacheForKey(mostSimilarData.Text, ctx, config, log, stream, false) - if err != nil { - log.Errorf("check cache for key: %s failed, error: %v", mostSimilarData.Text, err) - proxywasm.ResumeHttpRequest() - } + // // otherwise, continue to check cache for the most similar key + // err = CheckCacheForKey(mostSimilarData.Text, ctx, config, log, stream, false) + // if err != nil { + // log.Errorf("check cache for key: %s failed, error: %v", mostSimilarData.Text, err) + // proxywasm.ResumeHttpRequest() + // } + + // Otherwise, do not check the cache, directly return + log.Warnf("No cache hit for key: %s, however, no answer found in vector database", mostSimilarData.Text) + proxywasm.ResumeHttpRequest() } } else { log.Infof("Score too high for key: %s with score: %f above threshold", mostSimilarData.Text, mostSimilarData.Score) diff --git a/plugins/wasm-go/extensions/ai-cache/main.go b/plugins/wasm-go/extensions/ai-cache/main.go index 4aa04c0f35..5025ab0e70 100644 --- a/plugins/wasm-go/extensions/ai-cache/main.go +++ b/plugins/wasm-go/extensions/ai-cache/main.go @@ -92,7 +92,29 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config config.PluginConfig, body ctx.SetContext(STREAM_CONTEXT_KEY, struct{}{}) } - key := bodyJson.Get(config.CacheKeyFrom).String() + var key string + if config.CacheKeyStrategy == "lastQuestion" { + key = bodyJson.Get("messages.@reverse.0.content").String() + } else if config.CacheKeyStrategy == "allQuestions" { + // Retrieve all user messages and concatenate them + messages := bodyJson.Get("messages").Array() + var userMessages []string + for _, msg := range messages { + if msg.Get("role").String() == "user" { + userMessages = append(userMessages, msg.Get("content").String()) + } + } + key = strings.Join(userMessages, " ") + } else if config.CacheKeyStrategy == "disable" { + log.Debugf("[onHttpRequestBody] cache key strategy is disabled") + ctx.DontReadRequestBody() + return types.ActionContinue + } else { + log.Warnf("[onHttpRequestBody] unknown cache key strategy: %s", config.CacheKeyStrategy) + ctx.DontReadRequestBody() + return types.ActionContinue + } + ctx.SetContext(CACHE_KEY_CONTEXT_KEY, key) log.Debugf("[onHttpRequestBody] key: %s", key) if key == "" { @@ -121,11 +143,6 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config config.PluginConfig, chu log.Debugf("[onHttpResponseBody] escaped chunk: %q", string(chunk)) log.Debugf("[onHttpResponseBody] isLastChunk: %v", isLastChunk) - // if strings.HasSuffix(string(chunk), "[DONE] \n\n") { - // isLastChunk = true - // } - - // If the context contains TOOL_CALLS_CONTEXT_KEY, bypass caching if ctx.GetContext(TOOL_CALLS_CONTEXT_KEY) != nil { return chunk } diff --git a/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go b/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go index a5fcfcf957..3e7114212b 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go @@ -157,7 +157,7 @@ func (d *DvProvider) QueryEmbedding( return err } -func checkField(fields map[string]interface{}, key string) string { +func getStringValue(fields map[string]interface{}, key string) string { if val, ok := fields[key]; ok { return val.(string) } @@ -178,10 +178,10 @@ func (d *DvProvider) ParseQueryResponse(responseBody []byte, ctx wrapper.HttpCon for _, output := range resp.Output { result := QueryResult{ - Text: checkField(output.Fields, "query"), + Text: getStringValue(output.Fields, "query"), Embedding: output.Vector, Score: output.Score, - Answer: checkField(output.Fields, "answer"), + Answer: getStringValue(output.Fields, "answer"), } results = append(results, result) } From fa22d63a37b88df098b71b33969923f64e1fb4e1 Mon Sep 17 00:00:00 2001 From: suchun Date: Mon, 14 Oct 2024 11:28:54 +0000 Subject: [PATCH 29/71] add cacheKeyStrategy and enableSemanticCache --- plugins/wasm-go/extensions/ai-cache/README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/plugins/wasm-go/extensions/ai-cache/README.md b/plugins/wasm-go/extensions/ai-cache/README.md index 88bd265341..c614c3bf77 100644 --- a/plugins/wasm-go/extensions/ai-cache/README.md +++ b/plugins/wasm-go/extensions/ai-cache/README.md @@ -25,7 +25,7 @@ LLM 结果缓存插件,默认配置方式可以直接用于 openai 协议的 ## 配置说明 -首先本插件必须配置向量数据库服务(vector),然后根据向量数据库服务类型来决定是否配置文本向量化接口(embedding)来将问题转换为向量,最后根据缓存服务类型来决定是否配置缓存服务(cache)来缓存LLM的回答。 +本插件需要配置向量数据库服务(vector),根据所选的向量数据库服务类型,您可以决定是否配置文本向量化接口(embedding)以将问题转换为向量。最后,根据所选的缓存服务类型,您可以决定是否配置缓存服务(cache)以存储LLM的响应结果。 | Name | Type | Requirement | Default | Description | | --- | --- | --- | --- | --- | @@ -35,6 +35,8 @@ LLM 结果缓存插件,默认配置方式可以直接用于 openai 协议的 | cacheKeyStrategy | string | optional | "lastQuestion" | 决定如何根据历史问题生成缓存键的策略。可选值: "lastQuestion" (使用最后一个问题), "allQuestions" (拼接所有问题) 或 "disable" (禁用缓存) | | enableSemanticCache | bool | optional | true | 是否启用语义化缓存, 若不启用,则使用逐字匹配的方式来查找缓存,此时需要配置cache服务 | +以下是vector、embedding、cache的具体配置说明,注意若不配置embedding或cache服务,则可忽略以下相应配置中的 `required` 字段。 + ## 向量数据库服务(vector) | Name | Type | Requirement | Default | Description | | --- | --- | --- | --- | --- | From 9145132057e61155c14a30d8060e2c938a047e93 Mon Sep 17 00:00:00 2001 From: suchun Date: Mon, 14 Oct 2024 11:48:16 +0000 Subject: [PATCH 30/71] Vector or cache database must be configured --- plugins/wasm-go/extensions/ai-cache/README.md | 8 +++++--- plugins/wasm-go/extensions/ai-cache/config/config.go | 12 ++++++++++-- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-cache/README.md b/plugins/wasm-go/extensions/ai-cache/README.md index c614c3bf77..17b91d8a68 100644 --- a/plugins/wasm-go/extensions/ai-cache/README.md +++ b/plugins/wasm-go/extensions/ai-cache/README.md @@ -25,15 +25,17 @@ LLM 结果缓存插件,默认配置方式可以直接用于 openai 协议的 ## 配置说明 -本插件需要配置向量数据库服务(vector),根据所选的向量数据库服务类型,您可以决定是否配置文本向量化接口(embedding)以将问题转换为向量。最后,根据所选的缓存服务类型,您可以决定是否配置缓存服务(cache)以存储LLM的响应结果。 +本插件同时支持基于向量数据库的语义化缓存和基于字符串匹配的缓存方法,如果同时配置了向量数据库和缓存数据库,优先使用向量数据库。 + +*Note*: 向量数据库(vector) 和 缓存数据库(cache) 不能同时为空,否则本插件无法提供缓存服务。 | Name | Type | Requirement | Default | Description | | --- | --- | --- | --- | --- | -| vector.type | string | required | "" | 向量存储服务提供者类型,例如 DashVector | +| vector.type | string | optional | "" | 向量存储服务提供者类型,例如 DashVector | | embedding.type | string | optional | "" | 请求文本向量化服务类型,例如 DashScope | | cache.type | string | optional | "" | 缓存服务类型,例如 redis | | cacheKeyStrategy | string | optional | "lastQuestion" | 决定如何根据历史问题生成缓存键的策略。可选值: "lastQuestion" (使用最后一个问题), "allQuestions" (拼接所有问题) 或 "disable" (禁用缓存) | -| enableSemanticCache | bool | optional | true | 是否启用语义化缓存, 若不启用,则使用逐字匹配的方式来查找缓存,此时需要配置cache服务 | +| enableSemanticCache | bool | optional | true | 是否启用语义化缓存, 若不启用,则使用字符串匹配的方式来查找缓存,此时需要配置cache服务 | 以下是vector、embedding、cache的具体配置说明,注意若不配置embedding或cache服务,则可忽略以下相应配置中的 `required` 字段。 diff --git a/plugins/wasm-go/extensions/ai-cache/config/config.go b/plugins/wasm-go/extensions/ai-cache/config/config.go index 249c768527..ac9a46bdc5 100644 --- a/plugins/wasm-go/extensions/ai-cache/config/config.go +++ b/plugins/wasm-go/extensions/ai-cache/config/config.go @@ -95,9 +95,17 @@ func (c *PluginConfig) Validate() error { return err } } - if err := c.vectorProviderConfig.Validate(); err != nil { - return err + if c.vectorProviderConfig.GetProviderType() != "" { + if err := c.vectorProviderConfig.Validate(); err != nil { + return err + } } + + // vector 和 embedding 不能同时为空 + if c.vectorProviderConfig.GetProviderType() == "" && c.embeddingProviderConfig.GetProviderType() == "" { + return fmt.Errorf("vector and embedding provider cannot be both empty") + } + // 验证 CacheKeyStrategy 的值 if c.CacheKeyStrategy != "lastQuestion" && c.CacheKeyStrategy != "allQuestions" && c.CacheKeyStrategy != "disable" { return fmt.Errorf("invalid CacheKeyStrategy: %s", c.CacheKeyStrategy) From ef443bf3492efb5cb234c71c27a500eeafb65c6b Mon Sep 17 00:00:00 2001 From: async Date: Fri, 18 Oct 2024 17:36:49 +0800 Subject: [PATCH 31/71] new version envoy --- docker-compose-test/envoy.yaml | 84 +++++++++++++++++++++------------- 1 file changed, 53 insertions(+), 31 deletions(-) diff --git a/docker-compose-test/envoy.yaml b/docker-compose-test/envoy.yaml index 90211dd5e1..411f3b04fb 100644 --- a/docker-compose-test/envoy.yaml +++ b/docker-compose-test/envoy.yaml @@ -34,7 +34,7 @@ static_resources: - match: prefix: "/" route: - cluster: llm + cluster: outbound|443||bigmodel.dns timeout: 300s http_filters: @@ -55,39 +55,33 @@ static_resources: "@type": "type.googleapis.com/google.protobuf.StringValue" value: | { - "embeddingProvider": { + "embedding": { "type": "dashscope", "serviceName": "dashscope", - "apiKey": "sk-", - "DashScopeServiceName": "dashscope" + "apiKey": "sk-346fadc4ed8448e487ea84b57788d816" }, - "vectorProvider": { - "VectorStoreProviderType": "pinecone", - "PineconeServiceName": "pinecone", - "PineconeApiEndpoint": "higress-2bdfipe.svc.aped-4627-b74a.pinecone.io", - "PineconeThreshold": "0.7", - "ThresholdRelation": "gte", - "PineconeApiKey": "key" + "vector": { + "type": "chroma", + "serviceName": "chroma", + "collectionID": "8f8accc8-b68c-4fb1-84af-0292f9236e8a", + "servicePort": 8001 }, - "cacheKeyFrom": { - "requestBody": "" - }, - "cacheValueFrom": { - "responseBody": "" - }, - "cacheStreamValueFrom": { - "responseBody": "" - }, - "returnResponseTemplate": "", - "returnTestResponseTemplate": "", - "ReturnStreamResponseTemplate": "", - "redis": { + "cache": { + "type": "redis", "serviceName": "redis_cluster", - "timeout": 2000 + "timeout": 1000 } } # 上面的配置中 redis 的配置名字是 redis,而不是 golang tag 中的 redisConfig + # "vector": { + # "type": "dashvector", + # "serviceName": "dashvector", + # "collectionID": "higress_euclidean", + # "serviceDomain": "vrs-cn-g6z3yq2wy0001z.dashvector.cn-hangzhou.aliyuncs.com", + # "apiKey": "sk-nAn4GfZFrbLNhGffVIIq6tdgWNjV7D8A0F7CC5E1011EF9A1EB61E393DC850" + # }, + # "vectorProvider": { # "VectorStoreProviderType": "chroma", # "ChromaServiceName": "chroma", @@ -110,6 +104,15 @@ static_resources: # "WeaviateCollection": "Higress", # "WeaviateThreshold": "0.3" # }, + + # "vector": { + # "type": "pinecone", + # "PineconeServiceName": "pinecone", + # "PineconeApiEndpoint": "higress-2bdfipe.svc.aped-4627-b74a.pinecone.io", + # "PineconeThreshold": "0.7", + # "ThresholdRelation": "gte", + # "PineconeApiKey": "key" + # }, # llm-proxy - name: llm-proxy typed_config: @@ -128,11 +131,10 @@ static_resources: value: | # 插件配置 { "provider": { - "type": "openai", + "type": "zhipuai", "apiTokens": [ - "YOUR_API_TOKEN" - ], - "openaiCustomUrl": "172.17.0.1:8000/v1/chat/completions" + "67e93d524df46fca3640df67a7461c04.qOksqKAoWHcv03aV" + ] } } @@ -259,13 +261,13 @@ static_resources: - endpoint: address: socket_address: - address: vrs-cn-0dw3vnaqs0002z.dashvector.cn-hangzhou.aliyuncs.com + address: vrs-cn-g6z3yq2wy0001z.dashvector.cn-hangzhou.aliyuncs.com port_value: 443 transport_socket: name: envoy.transport_sockets.tls typed_config: "@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext - "sni": "vrs-cn-0dw3vnaqs0002z.dashvector.cn-hangzhou.aliyuncs.com" + "sni": "vrs-cn-g6z3yq2wy0001z.dashvector.cn-hangzhou.aliyuncs.com" # dashscope - name: outbound|443||dashscope.dns connect_timeout: 30s @@ -286,6 +288,26 @@ static_resources: typed_config: "@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext "sni": "dashscope.aliyuncs.com" + # bigmodel + - name: outbound|443||bigmodel.dns + connect_timeout: 30s + type: LOGICAL_DNS + dns_lookup_family: V4_ONLY + lb_policy: ROUND_ROBIN + load_assignment: + cluster_name: outbound|443||bigmodel.dns + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: open.bigmodel.cn + port_value: 443 + transport_socket: + name: envoy.transport_sockets.tls + typed_config: + "@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext + "sni": "open.bigmodel.cn" # pinecone - name: outbound|443||pinecone.dns connect_timeout: 30s From 14a2a3d4ac3826178d815c2f87de795ab475a995 Mon Sep 17 00:00:00 2001 From: async Date: Fri, 18 Oct 2024 18:31:15 +0800 Subject: [PATCH 32/71] fix: GetContext type --- plugins/wasm-go/extensions/ai-cache/README.md | 2 +- plugins/wasm-go/extensions/ai-cache/util.go | 7 +++++-- plugins/wasm-go/extensions/ai-cache/vector/dashvector.go | 2 +- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-cache/README.md b/plugins/wasm-go/extensions/ai-cache/README.md index 17b91d8a68..fcd0faab05 100644 --- a/plugins/wasm-go/extensions/ai-cache/README.md +++ b/plugins/wasm-go/extensions/ai-cache/README.md @@ -57,7 +57,7 @@ LLM 结果缓存插件,默认配置方式可以直接用于 openai 协议的 | --- | --- | --- | --- | --- | | embedding.type | string | required | "" | 请求文本向量化服务类型,例如 DashScope | | embedding.serviceName | string | required | "" | 请求文本向量化服务名称 | -| embedding.serviceDomain | string | required | "" | 请求文本向量化服务域名 | +| embedding.serviceDomain | string | optional | "" | 请求文本向量化服务域名 | | embedding.servicePort | int64 | optional | 443 | 请求文本向量化服务端口 | | embedding.apiKey | string | optional | "" | 请求文本向量化服务的 API Key | | embedding.timeout | uint32 | optional | 10000 | 请求文本向量化服务的超时时间,单位为毫秒。默认值是10000,即10秒 | diff --git a/plugins/wasm-go/extensions/ai-cache/util.go b/plugins/wasm-go/extensions/ai-cache/util.go index 4459ce427b..64f3d6dce9 100644 --- a/plugins/wasm-go/extensions/ai-cache/util.go +++ b/plugins/wasm-go/extensions/ai-cache/util.go @@ -127,8 +127,11 @@ func processEmptyChunk(ctx wrapper.HttpContext, config config.PluginConfig, chun if tempContentI == nil { return string(chunk), nil } - value := tempContentI.(string) - return value, nil + value, ok := tempContentI.([]byte) + if !ok { + return "", fmt.Errorf("invalid type for tempContentI") + } + return string(value), nil } // Appends the final body chunk to the existing body content diff --git a/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go b/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go index 3e7114212b..8b5848a119 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go @@ -27,7 +27,7 @@ func (d *dashVectorProviderInitializer) ValidateConfig(config ProviderConfig) er return errors.New("[DashVector] serviceName is required") } if len(config.serviceDomain) == 0 { - return errors.New("[DashVector] endPoint is required") + return errors.New("[DashVector] serviceDomain is required") } return nil } From b862ef948b575d6b5a33a289e9d32b8d9b63d951 Mon Sep 17 00:00:00 2001 From: async Date: Fri, 18 Oct 2024 18:37:21 +0800 Subject: [PATCH 33/71] feat: chroma --- .../extensions/ai-cache/vector/chroma.go | 389 ++++++++++-------- .../extensions/ai-cache/vector/provider.go | 35 +- 2 files changed, 209 insertions(+), 215 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-cache/vector/chroma.go b/plugins/wasm-go/extensions/ai-cache/vector/chroma.go index b3c3fc42d6..99148ad63b 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/chroma.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/chroma.go @@ -1,184 +1,209 @@ package vector -// import ( -// "encoding/json" -// "errors" -// "fmt" -// "net/http" - -// "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" -// ) - -// type chromaProviderInitializer struct{} - -// const chromaPort = 8001 - -// func (c *chromaProviderInitializer) ValidateConfig(config ProviderConfig) error { -// if len(config.ChromaCollectionID) == 0 { -// return errors.New("ChromaCollectionID is required") -// } -// if len(config.ChromaServiceName) == 0 { -// return errors.New("ChromaServiceName is required") -// } -// return nil -// } - -// func (c *chromaProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { -// return &ChromaProvider{ -// config: config, -// client: wrapper.NewClusterClient(wrapper.DnsCluster{ -// ServiceName: config.ChromaServiceName, -// Port: chromaPort, -// Domain: config.ChromaServiceName, -// }), -// }, nil -// } - -// type ChromaProvider struct { -// config ProviderConfig -// client wrapper.HttpClient -// } - -// func (c *ChromaProvider) GetProviderType() string { -// return providerTypeChroma -// } - -// func (d *ChromaProvider) GetThreshold() float64 { -// return d.config.ChromaDistanceThreshold -// } - -// func (d *ChromaProvider) QueryEmbedding( -// emb []float64, -// ctx wrapper.HttpContext, -// log wrapper.Log, -// callback func(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log)) { -// // 最小需要填写的参数为 collection_id, embeddings 和 ids -// // 下面是一个例子 -// // { -// // "where": {}, 用于 metadata 过滤,可选参数 -// // "where_document": {}, 用于 document 过滤,可选参数 -// // "query_embeddings": [ -// // [1.1, 2.3, 3.2] -// // ], -// // "n_results": 5, -// // "include": [ -// // "metadatas", -// // "distances" -// // ] -// // } -// requestBody, err := json.Marshal(ChromaQueryRequest{ -// QueryEmbeddings: []ChromaEmbedding{emb}, -// NResults: d.config.ChromaNResult, -// Include: []string{"distances"}, -// }) - -// if err != nil { -// log.Errorf("[Chroma] Failed to marshal query embedding request body: %v", err) -// return -// } - -// d.client.Post( -// fmt.Sprintf("/api/v1/collections/%s/query", d.config.ChromaCollectionID), -// [][2]string{ -// {"Content-Type", "application/json"}, -// }, -// requestBody, -// func(statusCode int, responseHeaders http.Header, responseBody []byte) { -// log.Infof("Query embedding response: %d, %s", statusCode, responseBody) -// callback(responseBody, ctx, log) -// }, -// d.config.ChromaTimeout, -// ) -// } - -// func (d *ChromaProvider) UploadEmbedding( -// query_emb []float64, -// queryString string, -// ctx wrapper.HttpContext, -// log wrapper.Log, -// callback func(ctx wrapper.HttpContext, log wrapper.Log)) { -// // 最小需要填写的参数为 collection_id, embeddings 和 ids -// // 下面是一个例子 -// // { -// // "embeddings": [ -// // [1.1, 2.3, 3.2] -// // ], -// // "ids": [ -// // "你吃了吗?" -// // ] -// // } -// requestBody, err := json.Marshal(ChromaInsertRequest{ -// Embeddings: []ChromaEmbedding{query_emb}, -// IDs: []string{queryString}, // queryString 指的是用户查询的问题 -// }) - -// if err != nil { -// log.Errorf("[Chroma] Failed to marshal upload embedding request body: %v", err) -// return -// } - -// d.client.Post( -// fmt.Sprintf("/api/v1/collections/%s/add", d.config.ChromaCollectionID), -// [][2]string{ -// {"Content-Type", "application/json"}, -// }, -// requestBody, -// func(statusCode int, responseHeaders http.Header, responseBody []byte) { -// log.Infof("[Chroma] statusCode:%d, responseBody:%s", statusCode, string(responseBody)) -// callback(ctx, log) -// }, -// d.config.ChromaTimeout, -// ) -// } - -// // ChromaEmbedding represents the embedding vector for a data point. -// type ChromaEmbedding []float64 - -// // ChromaMetadataMap is a map from key to value for metadata. -// type ChromaMetadataMap map[string]string - -// // Dataset represents the entire dataset containing multiple data points. -// type ChromaInsertRequest struct { -// Embeddings []ChromaEmbedding `json:"embeddings"` -// Metadatas []ChromaMetadataMap `json:"metadatas,omitempty"` // Optional metadata map array -// Documents []string `json:"documents,omitempty"` // Optional document array -// IDs []string `json:"ids"` -// } - -// // ChromaQueryRequest represents the query request structure. -// type ChromaQueryRequest struct { -// Where map[string]string `json:"where,omitempty"` // Optional where filter -// WhereDocument map[string]string `json:"where_document,omitempty"` // Optional where_document filter -// QueryEmbeddings []ChromaEmbedding `json:"query_embeddings"` -// NResults int `json:"n_results"` -// Include []string `json:"include"` -// } - -// // ChromaQueryResponse represents the search result structure. -// type ChromaQueryResponse struct { -// Ids [][]string `json:"ids"` // 每一个 embedding 相似的 key 可能会有多个,然后会有多个 embedding,所以是一个二维数组 -// Distances [][]float64 `json:"distances"` // 与 Ids 一一对应 -// Metadatas []ChromaMetadataMap `json:"metadatas,omitempty"` // Optional, can be null -// Embeddings []ChromaEmbedding `json:"embeddings,omitempty"` // Optional, can be null -// Documents []string `json:"documents,omitempty"` // Optional, can be null -// Uris []string `json:"uris,omitempty"` // Optional, can be null -// Data []interface{} `json:"data,omitempty"` // Optional, can be null -// Included []string `json:"included"` -// } - -// func (d *ChromaProvider) ParseQueryResponse(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log) (QueryEmbeddingResult, error) { -// var queryResp ChromaQueryResponse -// err := json.Unmarshal(responseBody, &queryResp) -// if err != nil { -// return QueryEmbeddingResult{}, err -// } -// log.Infof("[Chroma] queryResp: %+v", queryResp) -// log.Infof("[Chroma] queryResp Ids len: %d", len(queryResp.Ids)) -// if len(queryResp.Ids) == 1 && len(queryResp.Ids[0]) == 0 { -// return QueryEmbeddingResult{}, nil -// } -// return QueryEmbeddingResult{ -// MostSimilarData: queryResp.Ids[0][0], -// Score: queryResp.Distances[0][0], -// }, nil -// } +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" +) + +type chromaProviderInitializer struct{} + +const chromaThreshold = 1000 + +func (c *chromaProviderInitializer) ValidateConfig(config ProviderConfig) error { + if len(config.collectionID) == 0 { + return errors.New("[Chroma] collectionID is required") + } + if len(config.serviceName) == 0 { + return errors.New("[Chroma] serviceName is required") + } + return nil +} + +func (c *chromaProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + return &ChromaProvider{ + config: config, + client: wrapper.NewClusterClient(wrapper.DnsCluster{ + ServiceName: config.serviceName, + Port: config.servicePort, + Domain: config.serviceDomain, + }), + }, nil +} + +type ChromaProvider struct { + config ProviderConfig + client wrapper.HttpClient +} + +func (c *ChromaProvider) GetProviderType() string { + return providerTypeChroma +} + +func (d *ChromaProvider) GetSimilarityThreshold() float64 { + return chromaThreshold +} + +func (d *ChromaProvider) QueryEmbedding( + emb []float64, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error { + // 最少需要填写的参数为 collection_id, embeddings 和 ids + // 下面是一个例子 + // { + // "where": {}, // 用于 metadata 过滤,可选参数 + // "where_document": {}, // 用于 document 过滤,可选参数 + // "query_embeddings": [ + // [1.1, 2.3, 3.2] + // ], + // "limit": 5, + // "include": [ + // "metadatas", // 可选 + // "documents", // 如果需要答案则需要 + // "distances" + // ] + // } + + requestBody, err := json.Marshal(chromaQueryRequest{ + QueryEmbeddings: []chromaEmbedding{emb}, + Limit: d.config.topK, + Include: []string{"distances", "documents"}, + }) + + if err != nil { + log.Errorf("[Chroma] Failed to marshal query embedding request body: %v", err) + return err + } + + d.client.Post( + fmt.Sprintf("/api/v1/collections/%s/query", d.config.collectionID), + [][2]string{ + {"Content-Type", "application/json"}, + }, + requestBody, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + log.Infof("[Chroma] Query embedding response: %d, %s", statusCode, responseBody) + results, err := d.parseQueryResponse(responseBody, ctx, log) + if err != nil { + err = fmt.Errorf("[Chroma] Failed to parse query response: %v", err) + } + callback(results, ctx, log, err) + }, + d.config.timeout, + ) + return errors.New("[Chroma] QueryEmbedding Not implemented") +} + +func (d *ChromaProvider) UploadAnswerAndEmbedding( + queryString string, + queryEmb []float64, + queryAnswer string, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error { + // 最少需要填写的参数为 collection_id, embeddings 和 ids + // 下面是一个例子 + // { + // "embeddings": [ + // [1.1, 2.3, 3.2] + // ], + // "ids": [ + // "你吃了吗?" + // ], + // "documents": [ + // "我吃了。" + // ] + // } + // 如果要添加 answer,则按照以下例子 + // { + // "embeddings": [ + // [1.1, 2.3, 3.2] + // ], + // "documents": [ + // "answer1" + // ], + // "ids": [ + // "id1" + // ] + // } + requestBody, err := json.Marshal(chromaInsertRequest{ + Embeddings: []chromaEmbedding{queryEmb}, + IDs: []string{queryString}, // queryString 指的是用户查询的问题 + Documents: []string{queryAnswer}, // queryAnswer 指的是用户查询的问题的答案 + }) + + if err != nil { + log.Errorf("[Chroma] Failed to marshal upload embedding request body: %v", err) + return err + } + + err = d.client.Post( + fmt.Sprintf("/api/v1/collections/%s/add", d.config.collectionID), + [][2]string{ + {"Content-Type", "application/json"}, + }, + requestBody, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + log.Infof("[Chroma] statusCode:%d, responseBody:%s", statusCode, string(responseBody)) + callback(ctx, log, err) + }, + d.config.timeout, + ) + return err +} + +type chromaEmbedding []float64 +type chromaMetadataMap map[string]string +type chromaInsertRequest struct { + Embeddings []chromaEmbedding `json:"embeddings"` + Metadatas []chromaMetadataMap `json:"metadatas,omitempty"` // Optional metadata map array + Documents []string `json:"documents,omitempty"` // Optional document array + IDs []string `json:"ids"` +} + +type chromaQueryRequest struct { + Where map[string]string `json:"where,omitempty"` // Optional where filter + WhereDocument map[string]string `json:"where_document,omitempty"` // Optional where_document filter + QueryEmbeddings []chromaEmbedding `json:"query_embeddings"` + Limit int `json:"limit"` + Include []string `json:"include"` +} + +type chromaQueryResponse struct { + Ids [][]string `json:"ids"` // 第一维是 batch query,第二维是查询到的多个 ids + Distances [][]float64 `json:"distances,omitempty"` // 与 Ids 一一对应 + Metadatas []chromaMetadataMap `json:"metadatas,omitempty"` // Optional, can be null + Embeddings []chromaEmbedding `json:"embeddings,omitempty"` // Optional, can be null + Documents [][]string `json:"documents,omitempty"` // 与 Ids 一一对应 + Uris []string `json:"uris,omitempty"` // Optional, can be null + Data []interface{} `json:"data,omitempty"` // Optional, can be null + Included []string `json:"included"` +} + +func (d *ChromaProvider) parseQueryResponse(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]QueryResult, error) { + log.Infof("[Chroma] queryResp: %s", string(responseBody)) + var queryResp chromaQueryResponse + err := json.Unmarshal(responseBody, &queryResp) + if err != nil { + return nil, err + } + + log.Infof("[Chroma] queryResp Ids len: %d", len(queryResp.Ids)) + if len(queryResp.Ids) == 1 && len(queryResp.Ids[0]) == 0 { + return nil, errors.New("no query results found in response") + } + results := make([]QueryResult, 0, len(queryResp.Ids[0])) + for i := range queryResp.Ids[0] { + result := QueryResult{ + Text: queryResp.Documents[0][i], + Score: queryResp.Distances[0][i], + Answer: queryResp.Documents[0][i], + } + results = append(results, result) + } + return results, nil +} diff --git a/plugins/wasm-go/extensions/ai-cache/vector/provider.go b/plugins/wasm-go/extensions/ai-cache/vector/provider.go index 32031bc467..d52ade72f8 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/provider.go @@ -20,11 +20,11 @@ type providerInitializer interface { var ( providerInitializers = map[string]providerInitializer{ providerTypeDashVector: &dashVectorProviderInitializer{}, - // providerTypeChroma: &chromaProviderInitializer{}, + providerTypeChroma: &chromaProviderInitializer{}, } ) -// QueryEmbeddingResult 定义通用的查询结果的结构体 +// QueryResult 定义通用的查询结果的结构体 type QueryResult struct { Text string // 相似的文本 Embedding []float64 // 相似文本的向量 @@ -100,22 +100,6 @@ type ProviderConfig struct { // @Title zh-CN DashVector 向量存储服务 Collection ID // @Description zh-CN DashVector 向量存储服务 Collection ID collectionID string - - // // @Title zh-CN Chroma 的上游服务名称 - // // @Description zh-CN Chroma 服务所对应的网关内上游服务名称 - // ChromaServiceName string `require:"true" yaml:"ChromaServiceName" json:"ChromaServiceName"` - // // @Title zh-CN Chroma Collection ID - // // @Description zh-CN Chroma Collection 的 ID - // ChromaCollectionID string `require:"false" yaml:"ChromaCollectionID" json:"ChromaCollectionID"` - // @Title zh-CN Chroma 距离阈值 - // @Description zh-CN Chroma 距离阈值,默认为 2000 - ChromaDistanceThreshold float64 `require:"false" yaml:"ChromaDistanceThreshold" json:"ChromaDistanceThreshold"` - // // @Title zh-CN Chroma 搜索返回结果数量 - // // @Description zh-CN Chroma 搜索返回结果数量,默认为 1 - // ChromaNResult int `require:"false" yaml:"ChromaNResult" json:"ChromaNResult"` - // // @Title zh-CN Chroma 超时设置 - // // @Description zh-CN Chroma 超时设置,默认为 10 秒 - // ChromaTimeout uint32 `require:"false" yaml:"ChromaTimeout" json:"ChromaTimeout"` } func (c *ProviderConfig) GetProviderType() string { @@ -141,21 +125,6 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { if c.timeout == 0 { c.timeout = 10000 } - // Chroma - // c.ChromaCollectionID = json.Get("ChromaCollectionID").String() - // c.ChromaServiceName = json.Get("ChromaServiceName").String() - // c.ChromaDistanceThreshold = json.Get("ChromaDistanceThreshold").Float() - // if c.ChromaDistanceThreshold == 0 { - // c.ChromaDistanceThreshold = 2000 - // } - // c.ChromaNResult = int(json.Get("ChromaNResult").Int()) - // if c.ChromaNResult == 0 { - // c.ChromaNResult = 1 - // } - // c.ChromaTimeout = uint32(json.Get("ChromaTimeout").Int()) - // if c.ChromaTimeout == 0 { - // c.ChromaTimeout = 10000 - // } } func (c *ProviderConfig) Validate() error { From 303f6edeb909d71f157e092352f2e9e58dcdcd4e Mon Sep 17 00:00:00 2001 From: async Date: Fri, 18 Oct 2024 22:28:35 +0800 Subject: [PATCH 34/71] feat: weaviate --- .../extensions/ai-cache/vector/chroma.go | 5 +- .../extensions/ai-cache/vector/provider.go | 4 + .../extensions/ai-cache/vector/weaviate.go | 359 +++++++++--------- 3 files changed, 196 insertions(+), 172 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-cache/vector/chroma.go b/plugins/wasm-go/extensions/ai-cache/vector/chroma.go index 99148ad63b..5978bced72 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/chroma.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/chroma.go @@ -87,7 +87,7 @@ func (d *ChromaProvider) QueryEmbedding( requestBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) { log.Infof("[Chroma] Query embedding response: %d, %s", statusCode, responseBody) - results, err := d.parseQueryResponse(responseBody, ctx, log) + results, err := d.parseQueryResponse(responseBody, log) if err != nil { err = fmt.Errorf("[Chroma] Failed to parse query response: %v", err) } @@ -184,8 +184,7 @@ type chromaQueryResponse struct { Included []string `json:"included"` } -func (d *ChromaProvider) parseQueryResponse(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]QueryResult, error) { - log.Infof("[Chroma] queryResp: %s", string(responseBody)) +func (d *ChromaProvider) parseQueryResponse(responseBody []byte, log wrapper.Log) ([]QueryResult, error) { var queryResp chromaQueryResponse err := json.Unmarshal(responseBody, &queryResp) if err != nil { diff --git a/plugins/wasm-go/extensions/ai-cache/vector/provider.go b/plugins/wasm-go/extensions/ai-cache/vector/provider.go index d52ade72f8..6aa6f68f14 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/provider.go @@ -10,6 +10,9 @@ import ( const ( providerTypeDashVector = "dashvector" providerTypeChroma = "chroma" + providerTypeES = "elasticsearch" + providerTypeWeaviate = "weaviate" + providerTypePinecone = "pinecone" ) type providerInitializer interface { @@ -21,6 +24,7 @@ var ( providerInitializers = map[string]providerInitializer{ providerTypeDashVector: &dashVectorProviderInitializer{}, providerTypeChroma: &chromaProviderInitializer{}, + providerTypeWeaviate: &weaviateProviderInitializer{}, } ) diff --git a/plugins/wasm-go/extensions/ai-cache/vector/weaviate.go b/plugins/wasm-go/extensions/ai-cache/vector/weaviate.go index 0b361a6598..25b6908e9a 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/weaviate.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/weaviate.go @@ -1,171 +1,192 @@ package vector -// import ( -// "encoding/json" -// "errors" -// "fmt" -// "net/http" - -// "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" -// ) - -// const ( -// dashVectorPort = 443 -// ) - -// type dashVectorProviderInitializer struct { -// } - -// func (d *dashVectorProviderInitializer) ValidateConfig(config ProviderConfig) error { -// if len(config.DashVectorKey) == 0 { -// return errors.New("DashVectorKey is required") -// } -// if len(config.DashVectorAuthApiEnd) == 0 { -// return errors.New("DashVectorEnd is required") -// } -// if len(config.DashVectorCollection) == 0 { -// return errors.New("DashVectorCollection is required") -// } -// if len(config.DashVectorServiceName) == 0 { -// return errors.New("DashVectorServiceName is required") -// } -// return nil -// } - -// func (d *dashVectorProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { -// return &DvProvider{ -// config: config, -// client: wrapper.NewClusterClient(wrapper.DnsCluster{ -// ServiceName: config.DashVectorServiceName, -// Port: dashVectorPort, -// Domain: config.DashVectorAuthApiEnd, -// }), -// }, nil -// } - -// type DvProvider struct { -// config ProviderConfig -// client wrapper.HttpClient -// } - -// func (d *DvProvider) GetProviderType() string { -// return providerTypeDashVector -// } - -// type EmbeddingRequest struct { -// Model string `json:"model"` -// Input Input `json:"input"` -// Parameters Params `json:"parameters"` -// } - -// type Params struct { -// TextType string `json:"text_type"` -// } - -// type Input struct { -// Texts []string `json:"texts"` -// } - -// func (d *DvProvider) ConstructEmbeddingQueryParameters(vector []float64) (string, []byte, [][2]string, error) { -// url := fmt.Sprintf("/v1/collections/%s/query", d.config.DashVectorCollection) - -// requestData := QueryRequest{ -// Vector: vector, -// TopK: d.config.DashVectorTopK, -// IncludeVector: false, -// } - -// requestBody, err := json.Marshal(requestData) -// if err != nil { -// return "", nil, nil, err -// } - -// header := [][2]string{ -// {"Content-Type", "application/json"}, -// {"dashvector-auth-token", d.config.DashVectorKey}, -// } - -// return url, requestBody, header, nil -// } - -// func (d *DvProvider) ParseQueryResponse(responseBody []byte) (QueryResponse, error) { -// var queryResp QueryResponse -// err := json.Unmarshal(responseBody, &queryResp) -// if err != nil { -// return QueryResponse{}, err -// } -// return queryResp, nil -// } - -// func (d *DvProvider) QueryEmbedding( -// queryEmb []float64, -// ctx wrapper.HttpContext, -// log wrapper.Log, -// callback func(query_resp QueryResponse, ctx wrapper.HttpContext, log wrapper.Log)) { - -// // 构造请求参数 -// url, body, headers, err := d.ConstructEmbeddingQueryParameters(queryEmb) -// if err != nil { -// log.Infof("Failed to construct embedding query parameters: %v", err) -// } - -// err = d.client.Post(url, headers, body, -// func(statusCode int, responseHeaders http.Header, responseBody []byte) { -// log.Infof("Query embedding response: %d, %s", statusCode, responseBody) -// query_resp, err_query := d.ParseQueryResponse(responseBody) -// if err_query != nil { -// log.Infof("Failed to parse response: %v", err_query) -// } -// callback(query_resp, ctx, log) -// }, -// d.config.DashVectorTimeout) -// if err != nil { -// log.Infof("Failed to query embedding: %v", err) -// } - -// } - -// type Document struct { -// Vector []float64 `json:"vector"` -// Fields map[string]string `json:"fields"` -// } - -// type InsertRequest struct { -// Docs []Document `json:"docs"` -// } - -// func (d *DvProvider) ConstructEmbeddingUploadParameters(emb []float64, query_string string) (string, []byte, [][2]string, error) { -// url := "/v1/collections/" + d.config.DashVectorCollection + "/docs" - -// doc := Document{ -// Vector: emb, -// Fields: map[string]string{ -// "query": query_string, -// }, -// } - -// requestBody, err := json.Marshal(InsertRequest{Docs: []Document{doc}}) -// if err != nil { -// return "", nil, nil, err -// } - -// header := [][2]string{ -// {"Content-Type", "application/json"}, -// {"dashvector-auth-token", d.config.DashVectorKey}, -// } - -// return url, requestBody, header, err -// } - -// func (d *DvProvider) UploadEmbedding(query_emb []float64, queryString string, ctx wrapper.HttpContext, log wrapper.Log, callback func(ctx wrapper.HttpContext, log wrapper.Log)) { -// url, body, headers, _ := d.ConstructEmbeddingUploadParameters(query_emb, queryString) -// d.client.Post( -// url, -// headers, -// body, -// func(statusCode int, responseHeaders http.Header, responseBody []byte) { -// log.Infof("statusCode:%d, responseBody:%s", statusCode, string(responseBody)) -// callback(ctx, log) -// }, -// d.config.DashVectorTimeout) -// } +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" +) + +type weaviateProviderInitializer struct{} + +const weaviateThreshold = 0.5 + +func (c *weaviateProviderInitializer) ValidateConfig(config ProviderConfig) error { + if len(config.collectionID) == 0 { + return errors.New("[Weaviate] collectionID is required") + } + if len(config.serviceName) == 0 { + return errors.New("[Weaviate] serviceName is required") + } + return nil +} + +func (c *weaviateProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + return &WeaviateProvider{ + config: config, + client: wrapper.NewClusterClient(wrapper.DnsCluster{ + ServiceName: config.serviceName, + Port: config.servicePort, + Domain: config.serviceDomain, + }), + }, nil +} + +type WeaviateProvider struct { + config ProviderConfig + client wrapper.HttpClient +} + +func (c *WeaviateProvider) GetProviderType() string { + return providerTypeWeaviate +} + +func (d *WeaviateProvider) GetSimilarityThreshold() float64 { + return weaviateThreshold +} + +func (d *WeaviateProvider) QueryEmbedding( + emb []float64, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error { + // 最少需要填写的参数为 class, vector + // 下面是一个例子 + // {"query": "{ Get { Higress ( limit: 2 nearVector: { vector: [0.1, 0.2, 0.3] } ) { question _additional { distance } } } }"} + embString, err := json.Marshal(emb) + if err != nil { + log.Errorf("[Weaviate] Failed to marshal query embedding: %v", err) + return err + } + // 这里默认按照 distance 进行升序,所以不用再次排序 + graphql := fmt.Sprintf(` + { + Get { + %s ( + limit: %d + nearVector: { + vector: %s + } + ) { + question + answer + _additional { + distance + } + } + } + } + `, d.config.collectionID, d.config.topK, embString) + + requestBody, err := json.Marshal(weaviateQueryRequest{ + Query: graphql, + }) + + if err != nil { + log.Errorf("[Weaviate] Failed to marshal query embedding request body: %v", err) + return err + } + + err = d.client.Post( + "/v1/graphql", + [][2]string{ + {"Content-Type", "application/json"}, + }, + requestBody, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + log.Infof("[Weaviate] Query embedding response: %d, %s", statusCode, responseBody) + results, err := d.parseQueryResponse(responseBody, log) + if err != nil { + err = fmt.Errorf("[Weaviate] Failed to parse query response: %v", err) + } + callback(results, ctx, log, err) + }, + d.config.timeout, + ) + return err +} + +func (d *WeaviateProvider) UploadAnswerAndEmbedding( + queryString string, + queryEmb []float64, + queryAnswer string, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error { + // 最少需要填写的参数为 class, vector 和 question 和 answer + // 下面是一个例子 + // {"class": "Higress", "vector": [0.1, 0.2, 0.3], "properties": {"question": "这里是问题", "answer": "这里是答案"}} + requestBody, err := json.Marshal(weaviateInsertRequest{ + Class: d.config.collectionID, + Vector: queryEmb, + Properties: weaviateProperties{Question: queryString, Answer: queryAnswer}, // queryString 指的是用户查询的问题 + }) + + if err != nil { + log.Errorf("[Weaviate] Failed to marshal upload embedding request body: %v", err) + return err + } + + return d.client.Post( + "/v1/objects", + [][2]string{ + {"Content-Type", "application/json"}, + }, + requestBody, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + log.Infof("[Weaviate] statusCode: %d, responseBody: %s", statusCode, string(responseBody)) + callback(ctx, log, err) + }, + d.config.timeout, + ) +} + +type weaviateProperties struct { + Question string `json:"question"` + Answer string `json:"answer"` +} + +type weaviateInsertRequest struct { + Class string `json:"class"` + Vector []float64 `json:"vector"` + Properties weaviateProperties `json:"properties"` +} + +type weaviateQueryRequest struct { + Query string `json:"query"` +} + +func (d *WeaviateProvider) parseQueryResponse(responseBody []byte, log wrapper.Log) ([]QueryResult, error) { + log.Infof("[Weaviate] queryResp: %s", string(responseBody)) + + if !gjson.GetBytes(responseBody, fmt.Sprintf("data.Get.%s.0._additional.distance", d.config.collectionID)).Exists() { + log.Errorf("[Weaviate] No distance found in response body: %s", responseBody) + return nil, errors.New("[Weaviate] No distance found in response body") + } + + if !gjson.GetBytes(responseBody, fmt.Sprintf("data.Get.%s.0.question", d.config.collectionID)).Exists() { + log.Errorf("[Weaviate] No question found in response body: %s", responseBody) + return nil, errors.New("[Weaviate] No question found in response body") + } + + if !gjson.GetBytes(responseBody, fmt.Sprintf("data.Get.%s.0.answer", d.config.collectionID)).Exists() { + log.Errorf("[Weaviate] No answer found in response body: %s", responseBody) + return nil, errors.New("[Weaviate] No answer found in response body") + } + + resultNum := gjson.GetBytes(responseBody, fmt.Sprintf("data.Get.%s.#", d.config.collectionID)).Int() + results := make([]QueryResult, 0, resultNum) + for i := 0; i < int(resultNum); i++ { + result := QueryResult{ + Text: gjson.GetBytes(responseBody, fmt.Sprintf("data.Get.%s.%d.question", d.config.collectionID, i)).String(), + Score: gjson.GetBytes(responseBody, fmt.Sprintf("data.Get.%s.%d._additional.distance", d.config.collectionID, i)).Float(), + Answer: gjson.GetBytes(responseBody, fmt.Sprintf("data.Get.%s.%d.answer", d.config.collectionID, i)).String(), + } + results = append(results, result) + } + + return results, nil +} From fb2c26c358932425fb25b5e99f98fb912823d58a Mon Sep 17 00:00:00 2001 From: async Date: Fri, 18 Oct 2024 22:31:54 +0800 Subject: [PATCH 35/71] fix: clean useless code --- .../extensions/ai-cache/vector/chroma.go | 184 ------------------ .../extensions/ai-cache/vector/provider.go | 33 +--- .../extensions/ai-cache/vector/weaviate.go | 171 ---------------- 3 files changed, 1 insertion(+), 387 deletions(-) delete mode 100644 plugins/wasm-go/extensions/ai-cache/vector/chroma.go delete mode 100644 plugins/wasm-go/extensions/ai-cache/vector/weaviate.go diff --git a/plugins/wasm-go/extensions/ai-cache/vector/chroma.go b/plugins/wasm-go/extensions/ai-cache/vector/chroma.go deleted file mode 100644 index b3c3fc42d6..0000000000 --- a/plugins/wasm-go/extensions/ai-cache/vector/chroma.go +++ /dev/null @@ -1,184 +0,0 @@ -package vector - -// import ( -// "encoding/json" -// "errors" -// "fmt" -// "net/http" - -// "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" -// ) - -// type chromaProviderInitializer struct{} - -// const chromaPort = 8001 - -// func (c *chromaProviderInitializer) ValidateConfig(config ProviderConfig) error { -// if len(config.ChromaCollectionID) == 0 { -// return errors.New("ChromaCollectionID is required") -// } -// if len(config.ChromaServiceName) == 0 { -// return errors.New("ChromaServiceName is required") -// } -// return nil -// } - -// func (c *chromaProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { -// return &ChromaProvider{ -// config: config, -// client: wrapper.NewClusterClient(wrapper.DnsCluster{ -// ServiceName: config.ChromaServiceName, -// Port: chromaPort, -// Domain: config.ChromaServiceName, -// }), -// }, nil -// } - -// type ChromaProvider struct { -// config ProviderConfig -// client wrapper.HttpClient -// } - -// func (c *ChromaProvider) GetProviderType() string { -// return providerTypeChroma -// } - -// func (d *ChromaProvider) GetThreshold() float64 { -// return d.config.ChromaDistanceThreshold -// } - -// func (d *ChromaProvider) QueryEmbedding( -// emb []float64, -// ctx wrapper.HttpContext, -// log wrapper.Log, -// callback func(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log)) { -// // 最小需要填写的参数为 collection_id, embeddings 和 ids -// // 下面是一个例子 -// // { -// // "where": {}, 用于 metadata 过滤,可选参数 -// // "where_document": {}, 用于 document 过滤,可选参数 -// // "query_embeddings": [ -// // [1.1, 2.3, 3.2] -// // ], -// // "n_results": 5, -// // "include": [ -// // "metadatas", -// // "distances" -// // ] -// // } -// requestBody, err := json.Marshal(ChromaQueryRequest{ -// QueryEmbeddings: []ChromaEmbedding{emb}, -// NResults: d.config.ChromaNResult, -// Include: []string{"distances"}, -// }) - -// if err != nil { -// log.Errorf("[Chroma] Failed to marshal query embedding request body: %v", err) -// return -// } - -// d.client.Post( -// fmt.Sprintf("/api/v1/collections/%s/query", d.config.ChromaCollectionID), -// [][2]string{ -// {"Content-Type", "application/json"}, -// }, -// requestBody, -// func(statusCode int, responseHeaders http.Header, responseBody []byte) { -// log.Infof("Query embedding response: %d, %s", statusCode, responseBody) -// callback(responseBody, ctx, log) -// }, -// d.config.ChromaTimeout, -// ) -// } - -// func (d *ChromaProvider) UploadEmbedding( -// query_emb []float64, -// queryString string, -// ctx wrapper.HttpContext, -// log wrapper.Log, -// callback func(ctx wrapper.HttpContext, log wrapper.Log)) { -// // 最小需要填写的参数为 collection_id, embeddings 和 ids -// // 下面是一个例子 -// // { -// // "embeddings": [ -// // [1.1, 2.3, 3.2] -// // ], -// // "ids": [ -// // "你吃了吗?" -// // ] -// // } -// requestBody, err := json.Marshal(ChromaInsertRequest{ -// Embeddings: []ChromaEmbedding{query_emb}, -// IDs: []string{queryString}, // queryString 指的是用户查询的问题 -// }) - -// if err != nil { -// log.Errorf("[Chroma] Failed to marshal upload embedding request body: %v", err) -// return -// } - -// d.client.Post( -// fmt.Sprintf("/api/v1/collections/%s/add", d.config.ChromaCollectionID), -// [][2]string{ -// {"Content-Type", "application/json"}, -// }, -// requestBody, -// func(statusCode int, responseHeaders http.Header, responseBody []byte) { -// log.Infof("[Chroma] statusCode:%d, responseBody:%s", statusCode, string(responseBody)) -// callback(ctx, log) -// }, -// d.config.ChromaTimeout, -// ) -// } - -// // ChromaEmbedding represents the embedding vector for a data point. -// type ChromaEmbedding []float64 - -// // ChromaMetadataMap is a map from key to value for metadata. -// type ChromaMetadataMap map[string]string - -// // Dataset represents the entire dataset containing multiple data points. -// type ChromaInsertRequest struct { -// Embeddings []ChromaEmbedding `json:"embeddings"` -// Metadatas []ChromaMetadataMap `json:"metadatas,omitempty"` // Optional metadata map array -// Documents []string `json:"documents,omitempty"` // Optional document array -// IDs []string `json:"ids"` -// } - -// // ChromaQueryRequest represents the query request structure. -// type ChromaQueryRequest struct { -// Where map[string]string `json:"where,omitempty"` // Optional where filter -// WhereDocument map[string]string `json:"where_document,omitempty"` // Optional where_document filter -// QueryEmbeddings []ChromaEmbedding `json:"query_embeddings"` -// NResults int `json:"n_results"` -// Include []string `json:"include"` -// } - -// // ChromaQueryResponse represents the search result structure. -// type ChromaQueryResponse struct { -// Ids [][]string `json:"ids"` // 每一个 embedding 相似的 key 可能会有多个,然后会有多个 embedding,所以是一个二维数组 -// Distances [][]float64 `json:"distances"` // 与 Ids 一一对应 -// Metadatas []ChromaMetadataMap `json:"metadatas,omitempty"` // Optional, can be null -// Embeddings []ChromaEmbedding `json:"embeddings,omitempty"` // Optional, can be null -// Documents []string `json:"documents,omitempty"` // Optional, can be null -// Uris []string `json:"uris,omitempty"` // Optional, can be null -// Data []interface{} `json:"data,omitempty"` // Optional, can be null -// Included []string `json:"included"` -// } - -// func (d *ChromaProvider) ParseQueryResponse(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log) (QueryEmbeddingResult, error) { -// var queryResp ChromaQueryResponse -// err := json.Unmarshal(responseBody, &queryResp) -// if err != nil { -// return QueryEmbeddingResult{}, err -// } -// log.Infof("[Chroma] queryResp: %+v", queryResp) -// log.Infof("[Chroma] queryResp Ids len: %d", len(queryResp.Ids)) -// if len(queryResp.Ids) == 1 && len(queryResp.Ids[0]) == 0 { -// return QueryEmbeddingResult{}, nil -// } -// return QueryEmbeddingResult{ -// MostSimilarData: queryResp.Ids[0][0], -// Score: queryResp.Distances[0][0], -// }, nil -// } diff --git a/plugins/wasm-go/extensions/ai-cache/vector/provider.go b/plugins/wasm-go/extensions/ai-cache/vector/provider.go index 32031bc467..cd99ebaf3f 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/provider.go @@ -24,7 +24,7 @@ var ( } ) -// QueryEmbeddingResult 定义通用的查询结果的结构体 +// QueryResult 定义通用的查询结果的结构体 type QueryResult struct { Text string // 相似的文本 Embedding []float64 // 相似文本的向量 @@ -100,22 +100,6 @@ type ProviderConfig struct { // @Title zh-CN DashVector 向量存储服务 Collection ID // @Description zh-CN DashVector 向量存储服务 Collection ID collectionID string - - // // @Title zh-CN Chroma 的上游服务名称 - // // @Description zh-CN Chroma 服务所对应的网关内上游服务名称 - // ChromaServiceName string `require:"true" yaml:"ChromaServiceName" json:"ChromaServiceName"` - // // @Title zh-CN Chroma Collection ID - // // @Description zh-CN Chroma Collection 的 ID - // ChromaCollectionID string `require:"false" yaml:"ChromaCollectionID" json:"ChromaCollectionID"` - // @Title zh-CN Chroma 距离阈值 - // @Description zh-CN Chroma 距离阈值,默认为 2000 - ChromaDistanceThreshold float64 `require:"false" yaml:"ChromaDistanceThreshold" json:"ChromaDistanceThreshold"` - // // @Title zh-CN Chroma 搜索返回结果数量 - // // @Description zh-CN Chroma 搜索返回结果数量,默认为 1 - // ChromaNResult int `require:"false" yaml:"ChromaNResult" json:"ChromaNResult"` - // // @Title zh-CN Chroma 超时设置 - // // @Description zh-CN Chroma 超时设置,默认为 10 秒 - // ChromaTimeout uint32 `require:"false" yaml:"ChromaTimeout" json:"ChromaTimeout"` } func (c *ProviderConfig) GetProviderType() string { @@ -141,21 +125,6 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { if c.timeout == 0 { c.timeout = 10000 } - // Chroma - // c.ChromaCollectionID = json.Get("ChromaCollectionID").String() - // c.ChromaServiceName = json.Get("ChromaServiceName").String() - // c.ChromaDistanceThreshold = json.Get("ChromaDistanceThreshold").Float() - // if c.ChromaDistanceThreshold == 0 { - // c.ChromaDistanceThreshold = 2000 - // } - // c.ChromaNResult = int(json.Get("ChromaNResult").Int()) - // if c.ChromaNResult == 0 { - // c.ChromaNResult = 1 - // } - // c.ChromaTimeout = uint32(json.Get("ChromaTimeout").Int()) - // if c.ChromaTimeout == 0 { - // c.ChromaTimeout = 10000 - // } } func (c *ProviderConfig) Validate() error { diff --git a/plugins/wasm-go/extensions/ai-cache/vector/weaviate.go b/plugins/wasm-go/extensions/ai-cache/vector/weaviate.go deleted file mode 100644 index 0b361a6598..0000000000 --- a/plugins/wasm-go/extensions/ai-cache/vector/weaviate.go +++ /dev/null @@ -1,171 +0,0 @@ -package vector - -// import ( -// "encoding/json" -// "errors" -// "fmt" -// "net/http" - -// "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" -// ) - -// const ( -// dashVectorPort = 443 -// ) - -// type dashVectorProviderInitializer struct { -// } - -// func (d *dashVectorProviderInitializer) ValidateConfig(config ProviderConfig) error { -// if len(config.DashVectorKey) == 0 { -// return errors.New("DashVectorKey is required") -// } -// if len(config.DashVectorAuthApiEnd) == 0 { -// return errors.New("DashVectorEnd is required") -// } -// if len(config.DashVectorCollection) == 0 { -// return errors.New("DashVectorCollection is required") -// } -// if len(config.DashVectorServiceName) == 0 { -// return errors.New("DashVectorServiceName is required") -// } -// return nil -// } - -// func (d *dashVectorProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { -// return &DvProvider{ -// config: config, -// client: wrapper.NewClusterClient(wrapper.DnsCluster{ -// ServiceName: config.DashVectorServiceName, -// Port: dashVectorPort, -// Domain: config.DashVectorAuthApiEnd, -// }), -// }, nil -// } - -// type DvProvider struct { -// config ProviderConfig -// client wrapper.HttpClient -// } - -// func (d *DvProvider) GetProviderType() string { -// return providerTypeDashVector -// } - -// type EmbeddingRequest struct { -// Model string `json:"model"` -// Input Input `json:"input"` -// Parameters Params `json:"parameters"` -// } - -// type Params struct { -// TextType string `json:"text_type"` -// } - -// type Input struct { -// Texts []string `json:"texts"` -// } - -// func (d *DvProvider) ConstructEmbeddingQueryParameters(vector []float64) (string, []byte, [][2]string, error) { -// url := fmt.Sprintf("/v1/collections/%s/query", d.config.DashVectorCollection) - -// requestData := QueryRequest{ -// Vector: vector, -// TopK: d.config.DashVectorTopK, -// IncludeVector: false, -// } - -// requestBody, err := json.Marshal(requestData) -// if err != nil { -// return "", nil, nil, err -// } - -// header := [][2]string{ -// {"Content-Type", "application/json"}, -// {"dashvector-auth-token", d.config.DashVectorKey}, -// } - -// return url, requestBody, header, nil -// } - -// func (d *DvProvider) ParseQueryResponse(responseBody []byte) (QueryResponse, error) { -// var queryResp QueryResponse -// err := json.Unmarshal(responseBody, &queryResp) -// if err != nil { -// return QueryResponse{}, err -// } -// return queryResp, nil -// } - -// func (d *DvProvider) QueryEmbedding( -// queryEmb []float64, -// ctx wrapper.HttpContext, -// log wrapper.Log, -// callback func(query_resp QueryResponse, ctx wrapper.HttpContext, log wrapper.Log)) { - -// // 构造请求参数 -// url, body, headers, err := d.ConstructEmbeddingQueryParameters(queryEmb) -// if err != nil { -// log.Infof("Failed to construct embedding query parameters: %v", err) -// } - -// err = d.client.Post(url, headers, body, -// func(statusCode int, responseHeaders http.Header, responseBody []byte) { -// log.Infof("Query embedding response: %d, %s", statusCode, responseBody) -// query_resp, err_query := d.ParseQueryResponse(responseBody) -// if err_query != nil { -// log.Infof("Failed to parse response: %v", err_query) -// } -// callback(query_resp, ctx, log) -// }, -// d.config.DashVectorTimeout) -// if err != nil { -// log.Infof("Failed to query embedding: %v", err) -// } - -// } - -// type Document struct { -// Vector []float64 `json:"vector"` -// Fields map[string]string `json:"fields"` -// } - -// type InsertRequest struct { -// Docs []Document `json:"docs"` -// } - -// func (d *DvProvider) ConstructEmbeddingUploadParameters(emb []float64, query_string string) (string, []byte, [][2]string, error) { -// url := "/v1/collections/" + d.config.DashVectorCollection + "/docs" - -// doc := Document{ -// Vector: emb, -// Fields: map[string]string{ -// "query": query_string, -// }, -// } - -// requestBody, err := json.Marshal(InsertRequest{Docs: []Document{doc}}) -// if err != nil { -// return "", nil, nil, err -// } - -// header := [][2]string{ -// {"Content-Type", "application/json"}, -// {"dashvector-auth-token", d.config.DashVectorKey}, -// } - -// return url, requestBody, header, err -// } - -// func (d *DvProvider) UploadEmbedding(query_emb []float64, queryString string, ctx wrapper.HttpContext, log wrapper.Log, callback func(ctx wrapper.HttpContext, log wrapper.Log)) { -// url, body, headers, _ := d.ConstructEmbeddingUploadParameters(query_emb, queryString) -// d.client.Post( -// url, -// headers, -// body, -// func(statusCode int, responseHeaders http.Header, responseBody []byte) { -// log.Infof("statusCode:%d, responseBody:%s", statusCode, string(responseBody)) -// callback(ctx, log) -// }, -// d.config.DashVectorTimeout) -// } From e9a14d8d439bd8bd500c20efde0e18b69be859d3 Mon Sep 17 00:00:00 2001 From: async Date: Fri, 18 Oct 2024 23:52:25 +0800 Subject: [PATCH 36/71] feat: es --- .../extensions/ai-cache/embedding/weaviate.go | 27 --- .../extensions/ai-cache/vector/chroma.go | 2 +- .../ai-cache/vector/elasticsearch.go | 202 ++++++++++++++++++ .../extensions/ai-cache/vector/provider.go | 14 +- 4 files changed, 216 insertions(+), 29 deletions(-) delete mode 100644 plugins/wasm-go/extensions/ai-cache/embedding/weaviate.go create mode 100644 plugins/wasm-go/extensions/ai-cache/vector/elasticsearch.go diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/weaviate.go b/plugins/wasm-go/extensions/ai-cache/embedding/weaviate.go deleted file mode 100644 index b26d9cea8d..0000000000 --- a/plugins/wasm-go/extensions/ai-cache/embedding/weaviate.go +++ /dev/null @@ -1,27 +0,0 @@ -package embedding - -// import ( -// "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" -// ) - -// const ( -// weaviateURL = "172.17.0.1:8081" -// ) - -// type weaviateProviderInitializer struct { -// } - -// func (d *weaviateProviderInitializer) ValidateConfig(config ProviderConfig) error { -// return nil -// } - -// func (d *weaviateProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { -// return &DSProvider{ -// config: config, -// client: wrapper.NewClusterClient(wrapper.DnsCluster{ -// ServiceName: config.ServiceName, -// Port: dashScopePort, -// Domain: dashScopeDomain, -// }), -// }, nil -// } diff --git a/plugins/wasm-go/extensions/ai-cache/vector/chroma.go b/plugins/wasm-go/extensions/ai-cache/vector/chroma.go index 5978bced72..7004702c13 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/chroma.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/chroma.go @@ -198,7 +198,7 @@ func (d *ChromaProvider) parseQueryResponse(responseBody []byte, log wrapper.Log results := make([]QueryResult, 0, len(queryResp.Ids[0])) for i := range queryResp.Ids[0] { result := QueryResult{ - Text: queryResp.Documents[0][i], + Text: queryResp.Ids[0][i], Score: queryResp.Distances[0][i], Answer: queryResp.Documents[0][i], } diff --git a/plugins/wasm-go/extensions/ai-cache/vector/elasticsearch.go b/plugins/wasm-go/extensions/ai-cache/vector/elasticsearch.go new file mode 100644 index 0000000000..27884aa7ce --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/vector/elasticsearch.go @@ -0,0 +1,202 @@ +package vector + +import ( + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "net/http" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" +) + +type esProviderInitializer struct{} + +const esThreshold = 1000 + +func (c *esProviderInitializer) ValidateConfig(config ProviderConfig) error { + if len(config.collectionID) == 0 { + return errors.New("[ES] collectionID is required") + } + if len(config.serviceName) == 0 { + return errors.New("[ES] serviceName is required") + } + return nil +} + +func (c *esProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + return &ESProvider{ + config: config, + client: wrapper.NewClusterClient(wrapper.DnsCluster{ + ServiceName: config.serviceName, + Port: config.servicePort, + Domain: config.serviceDomain, + }), + }, nil +} + +type ESProvider struct { + config ProviderConfig + client wrapper.HttpClient +} + +func (c *ESProvider) GetProviderType() string { + return providerTypeES +} + +func (d *ESProvider) GetSimilarityThreshold() float64 { + return esThreshold +} + +func (d *ESProvider) QueryEmbedding( + emb []float64, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error { + + requestBody, err := json.Marshal(esQueryRequest{ + Source: Source{Excludes: []string{"embedding"}}, + Knn: knn{ + Field: "embedding", + QueryVector: emb, + K: d.config.topK, + }, + Size: d.config.topK, + }) + + if err != nil { + log.Errorf("[ES] Failed to marshal query embedding request body: %v", err) + return err + } + + return d.client.Post( + fmt.Sprintf("/%s/_search", d.config.collectionID), + [][2]string{ + {"Content-Type", "application/json"}, + {"Authorization", d.getCredentials()}, + }, + requestBody, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + log.Infof("[ES] Query embedding response: %d, %s", statusCode, responseBody) + results, err := d.parseQueryResponse(responseBody, log) + if err != nil { + err = fmt.Errorf("[ES] Failed to parse query response: %v", err) + } + callback(results, ctx, log, err) + }, + d.config.timeout, + ) +} + +// 编码 ES 身份认证字符串 +func (d *ESProvider) getCredentials() string { + credentials := fmt.Sprintf("%s:%s", d.config.esUsername, d.config.esPassword) + encodedCredentials := base64.StdEncoding.EncodeToString([]byte(credentials)) + return fmt.Sprintf("Basic %s", encodedCredentials) +} + +func (d *ESProvider) UploadAnswerAndEmbedding( + queryString string, + queryEmb []float64, + queryAnswer string, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error { + // 最少需要填写的参数为 index, embeddings 和 question + // 下面是一个例子 + // POST //_doc + // { + // "embedding": [ + // [1.1, 2.3, 3.2] + // ], + // "question": [ + // "你吃了吗?" + // ] + // } + requestBody, err := json.Marshal(esInsertRequest{ + Embedding: queryEmb, + Question: queryString, + Answer: queryAnswer, + }) + if err != nil { + log.Errorf("[ES] Failed to marshal upload embedding request body: %v", err) + return err + } + + return d.client.Post( + fmt.Sprintf("/%s/_doc", d.config.collectionID), + [][2]string{ + {"Content-Type", "application/json"}, + {"Authorization", d.getCredentials()}, + }, + requestBody, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + log.Infof("[ES] statusCode:%d, responseBody:%s", statusCode, string(responseBody)) + callback(ctx, log, err) + }, + d.config.timeout, + ) +} + +type esInsertRequest struct { + Embedding []float64 `json:"embedding"` + Question string `json:"question"` + Answer string `json:"answer"` +} + +type knn struct { + Field string `json:"field"` + QueryVector []float64 `json:"query_vector"` + K int `json:"k"` +} + +type Source struct { + Excludes []string `json:"excludes"` +} + +type esQueryRequest struct { + Source Source `json:"_source"` + Knn knn `json:"knn"` + Size int `json:"size"` +} + +// esQueryResponse represents the search result structure. +type esQueryResponse struct { + Took int `json:"took"` + TimedOut bool `json:"timed_out"` + Hits struct { + Total struct { + Value int `json:"value"` + Relation string `json:"relation"` + } `json:"total"` + Hits []struct { + Index string `json:"_index"` + ID string `json:"_id"` + Score float64 `json:"_score"` + Source map[string]interface{} `json:"_source"` + } `json:"hits"` + } `json:"hits"` +} + +func (d *ESProvider) parseQueryResponse(responseBody []byte, log wrapper.Log) ([]QueryResult, error) { + log.Infof("[ES] responseBody: %s", string(responseBody)) + var queryResp esQueryResponse + err := json.Unmarshal(responseBody, &queryResp) + if err != nil { + return []QueryResult{}, err + } + log.Infof("[ES] queryResp Hits len: %d", len(queryResp.Hits.Hits)) + if len(queryResp.Hits.Hits) == 0 { + return nil, errors.New("no query results found in response") + } + results := make([]QueryResult, 0, queryResp.Hits.Total.Value) + for _, hit := range queryResp.Hits.Hits { + result := QueryResult{ + Text: hit.Source["question"].(string), + Score: hit.Score, + Answer: hit.Source["answer"].(string), + } + results = append(results, result) + } + return results, nil +} diff --git a/plugins/wasm-go/extensions/ai-cache/vector/provider.go b/plugins/wasm-go/extensions/ai-cache/vector/provider.go index 6aa6f68f14..dff9ac962e 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/provider.go @@ -25,6 +25,7 @@ var ( providerTypeDashVector: &dashVectorProviderInitializer{}, providerTypeChroma: &chromaProviderInitializer{}, providerTypeWeaviate: &weaviateProviderInitializer{}, + providerTypeES: &esProviderInitializer{}, } ) @@ -104,6 +105,14 @@ type ProviderConfig struct { // @Title zh-CN DashVector 向量存储服务 Collection ID // @Description zh-CN DashVector 向量存储服务 Collection ID collectionID string + + // ES 配置 + // @Title zh-CN ES 用户名 + // @Description zh-CN ES 用户名 + esUsername string + // @Title zh-CN ES 密码 + // @Description zh-CN ES 密码 + esPassword string } func (c *ProviderConfig) GetProviderType() string { @@ -112,7 +121,6 @@ func (c *ProviderConfig) GetProviderType() string { func (c *ProviderConfig) FromJson(json gjson.Result) { c.typ = json.Get("type").String() - // DashVector c.serviceName = json.Get("serviceName").String() c.serviceDomain = json.Get("serviceDomain").String() c.servicePort = int64(json.Get("servicePort").Int()) @@ -129,6 +137,10 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { if c.timeout == 0 { c.timeout = 10000 } + + // ES + c.esUsername = json.Get("esUsername").String() + c.esPassword = json.Get("esPassword").String() } func (c *ProviderConfig) Validate() error { From 32eccd7171d5ce25bd706874c70c78fdfe6440e9 Mon Sep 17 00:00:00 2001 From: async Date: Sat, 19 Oct 2024 00:41:23 +0800 Subject: [PATCH 37/71] feat: pinecone --- .../extensions/ai-cache/vector/pinecone.go | 203 ++++++++++++++++++ .../extensions/ai-cache/vector/provider.go | 1 + 2 files changed, 204 insertions(+) create mode 100644 plugins/wasm-go/extensions/ai-cache/vector/pinecone.go diff --git a/plugins/wasm-go/extensions/ai-cache/vector/pinecone.go b/plugins/wasm-go/extensions/ai-cache/vector/pinecone.go new file mode 100644 index 0000000000..f040b5b4a9 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/vector/pinecone.go @@ -0,0 +1,203 @@ +package vector + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/google/uuid" + "github.com/tidwall/gjson" +) + +type pineconeProviderInitializer struct{} + +const pineconeThreshold = 1000 + +func (c *pineconeProviderInitializer) ValidateConfig(config ProviderConfig) error { + if len(config.serviceDomain) == 0 { + return errors.New("[Pinecone] serviceDomain is required") + } + if len(config.serviceName) == 0 { + return errors.New("[Pinecone] serviceName is required") + } + if len(config.apiKey) == 0 { + return errors.New("[Pinecone] apiKey is required") + } + if len(config.collectionID) == 0 { + return errors.New("[Pinecone] collectionID is required") + } + return nil +} + +func (c *pineconeProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + return &pineconeProvider{ + config: config, + client: wrapper.NewClusterClient(wrapper.DnsCluster{ + ServiceName: config.serviceName, + Port: config.servicePort, + Domain: config.serviceDomain, + }), + }, nil +} + +type pineconeProvider struct { + config ProviderConfig + client wrapper.HttpClient +} + +func (c *pineconeProvider) GetProviderType() string { + return providerTypePinecone +} + +func (d *pineconeProvider) GetSimilarityThreshold() float64 { + return pineconeThreshold +} + +type pineconeMetadata struct { + Question string `json:"question"` + Answer string `json:"answer"` +} + +type pineconeVector struct { + ID string `json:"id"` + Values []float64 `json:"values"` + Properties pineconeMetadata `json:"metadata"` +} + +type pineconeInsertRequest struct { + Vectors []pineconeVector `json:"vectors"` + Namespace string `json:"namespace"` +} + +func (d *pineconeProvider) UploadAnswerAndEmbedding( + queryString string, + queryEmb []float64, + queryAnswer string, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error { + // 最少需要填写的参数为 vector 和 question + // 下面是一个例子 + // { + // "vectors": [ + // { + // "id": "A", + // "values": [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], + // "metadata": {"question": "你好", "answer": "你也好"} + // } + // ] + // } + requestBody, err := json.Marshal(pineconeInsertRequest{ + Vectors: []pineconeVector{ + { + ID: uuid.New().String(), + Values: queryEmb, + Properties: pineconeMetadata{Question: queryString, Answer: queryAnswer}, + }, + }, + Namespace: d.config.collectionID, + }) + + if err != nil { + log.Errorf("[Pinecone] Failed to marshal upload embedding request body: %v", err) + return err + } + + return d.client.Post( + "/vectors/upsert", + [][2]string{ + {"Content-Type", "application/json"}, + {"Api-Key", d.config.apiKey}, + }, + requestBody, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + log.Infof("[Pinecone] statusCode:%d, responseBody:%s", statusCode, string(responseBody)) + callback(ctx, log, err) + }, + d.config.timeout, + ) +} + +type pineconeQueryRequest struct { + Namespace string `json:"namespace"` + Vector []float64 `json:"vector"` + TopK int `json:"topK"` + IncludeMetadata bool `json:"includeMetadata"` + IncludeValues bool `json:"includeValues"` +} + +func (d *pineconeProvider) QueryEmbedding( + emb []float64, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error { + // 最少需要填写的参数为 vector + // 下面是一个例子 + // { + // "namespace": "higress", + // "vector": [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], + // "topK": 1, + // "includeMetadata": false + // } + requestBody, err := json.Marshal(pineconeQueryRequest{ + Namespace: d.config.collectionID, + Vector: emb, + TopK: d.config.topK, + IncludeMetadata: true, + IncludeValues: false, + }) + if err != nil { + log.Errorf("[Pinecone] Failed to marshal query embedding: %v", err) + return err + } + + return d.client.Post( + "/query", + [][2]string{ + {"Content-Type", "application/json"}, + {"Api-Key", d.config.apiKey}, + }, + requestBody, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + log.Infof("[Pinecone] Query embedding response: %d, %s", statusCode, responseBody) + results, err := d.parseQueryResponse(responseBody, log) + if err != nil { + err = fmt.Errorf("[Pinecone] Failed to parse query response: %v", err) + } + callback(results, ctx, log, err) + }, + d.config.timeout, + ) +} + +func (d *pineconeProvider) parseQueryResponse(responseBody []byte, log wrapper.Log) ([]QueryResult, error) { + if !gjson.GetBytes(responseBody, "matches.0.score").Exists() { + log.Errorf("[Pinecone] No distance found in response body: %s", responseBody) + return nil, errors.New("[Pinecone] No distance found in response body") + } + + if !gjson.GetBytes(responseBody, "matches.0.metadata.question").Exists() { + log.Errorf("[Pinecone] No question found in response body: %s", responseBody) + return nil, errors.New("[Pinecone] No question found in response body") + } + + if !gjson.GetBytes(responseBody, "matches.0.metadata.answer").Exists() { + log.Errorf("[Pinecone] No question found in response body: %s", responseBody) + return nil, errors.New("[Pinecone] No question found in response body") + } + + resultNum := gjson.GetBytes(responseBody, "matches.#").Int() + results := make([]QueryResult, 0, resultNum) + for i := 0; i < int(resultNum); i++ { + result := QueryResult{ + Text: gjson.GetBytes(responseBody, fmt.Sprintf("matches.%d.metadata.question", i)).String(), + Score: gjson.GetBytes(responseBody, fmt.Sprintf("matches.%d.score", i)).Float(), + Answer: gjson.GetBytes(responseBody, fmt.Sprintf("matches.%d.metadata.answer", i)).String(), + } + results = append(results, result) + } + + return results, nil +} diff --git a/plugins/wasm-go/extensions/ai-cache/vector/provider.go b/plugins/wasm-go/extensions/ai-cache/vector/provider.go index dff9ac962e..660197f8e8 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/provider.go @@ -26,6 +26,7 @@ var ( providerTypeChroma: &chromaProviderInitializer{}, providerTypeWeaviate: &weaviateProviderInitializer{}, providerTypeES: &esProviderInitializer{}, + providerTypePinecone: &pineconeProviderInitializer{}, } ) From 440cd8ded6b3b641b822caee9fd2fe10251f7788 Mon Sep 17 00:00:00 2001 From: async Date: Sat, 19 Oct 2024 00:54:33 +0800 Subject: [PATCH 38/71] fix: bugs --- plugins/wasm-go/extensions/ai-proxy/Makefile | 4 ---- plugins/wasm-go/extensions/ai-proxy/go.mod | 2 +- plugins/wasm-go/extensions/ai-proxy/go.sum | 4 ++-- plugins/wasm-go/extensions/ai-proxy/main.go | 3 +-- 4 files changed, 4 insertions(+), 9 deletions(-) delete mode 100644 plugins/wasm-go/extensions/ai-proxy/Makefile diff --git a/plugins/wasm-go/extensions/ai-proxy/Makefile b/plugins/wasm-go/extensions/ai-proxy/Makefile deleted file mode 100644 index e5c7fa8de9..0000000000 --- a/plugins/wasm-go/extensions/ai-proxy/Makefile +++ /dev/null @@ -1,4 +0,0 @@ -.DEFAULT: -build: - tinygo build -o ai-proxy.wasm -scheduler=none -target=wasi -gc=custom -tags='custommalloc nottinygc_finalizer proxy_wasm_version_0_2_100' ./main.go - mv ai-proxy.wasm ../../../../docker-compose-test/ \ No newline at end of file diff --git a/plugins/wasm-go/extensions/ai-proxy/go.mod b/plugins/wasm-go/extensions/ai-proxy/go.mod index a5457b90f8..7fed801fab 100644 --- a/plugins/wasm-go/extensions/ai-proxy/go.mod +++ b/plugins/wasm-go/extensions/ai-proxy/go.mod @@ -10,7 +10,7 @@ require ( github.com/alibaba/higress/plugins/wasm-go v0.0.0 github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f github.com/stretchr/testify v1.8.4 - github.com/tidwall/gjson v1.17.3 + github.com/tidwall/gjson v1.14.3 ) require ( diff --git a/plugins/wasm-go/extensions/ai-proxy/go.sum b/plugins/wasm-go/extensions/ai-proxy/go.sum index b2d63b5f4b..e5b8b79175 100644 --- a/plugins/wasm-go/extensions/ai-proxy/go.sum +++ b/plugins/wasm-go/extensions/ai-proxy/go.sum @@ -13,8 +13,8 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94= -github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw= +github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index 7b19d03fc2..9e0fafe179 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -82,8 +82,7 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf providerConfig := pluginConfig.GetProviderConfig() if apiName == "" && !providerConfig.IsOriginal() { log.Debugf("[onHttpRequestHeader] unsupported path: %s", path.Path) - // _ = util.SendResponse(404, "ai-proxy.unknown_api", util.MimeTypeTextPlain, "API not found: "+path.Path) - log.Debugf("[onHttpRequestHeader] no send response") + _ = util.SendResponse(404, "ai-proxy.unknown_api", util.MimeTypeTextPlain, "API not found: "+path.Path) return types.ActionContinue } ctx.SetContext(ctxKeyApiName, apiName) From 628b74b6ce8826ddee139afe660a543737303401 Mon Sep 17 00:00:00 2001 From: async Date: Sat, 19 Oct 2024 00:56:58 +0800 Subject: [PATCH 39/71] fix: bugs --- plugins/wasm-go/extensions/transformer/go.mod | 2 +- plugins/wasm-go/extensions/transformer/go.sum | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/plugins/wasm-go/extensions/transformer/go.mod b/plugins/wasm-go/extensions/transformer/go.mod index 464974140a..e70583a937 100644 --- a/plugins/wasm-go/extensions/transformer/go.mod +++ b/plugins/wasm-go/extensions/transformer/go.mod @@ -9,7 +9,7 @@ require ( github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f github.com/pkg/errors v0.9.1 github.com/stretchr/testify v1.8.4 - github.com/tidwall/gjson v1.17.3 + github.com/tidwall/gjson v1.17.0 github.com/tidwall/pretty v1.2.1 github.com/tidwall/sjson v1.2.5 github.com/wasilibs/go-re2 v1.6.0 diff --git a/plugins/wasm-go/extensions/transformer/go.sum b/plugins/wasm-go/extensions/transformer/go.sum index 897140b6e4..76246bba99 100644 --- a/plugins/wasm-go/extensions/transformer/go.sum +++ b/plugins/wasm-go/extensions/transformer/go.sum @@ -19,7 +19,6 @@ github.com/tetratelabs/wazero v1.7.2/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXA github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.17.0 h1:/Jocvlh98kcTfpN2+JzGQWQcqrPQwDrVEMApx/M5ZwM= github.com/tidwall/gjson v1.17.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= From 342bd949cc67ab35d527eed33009e3732e575afc Mon Sep 17 00:00:00 2001 From: async Date: Sat, 19 Oct 2024 00:57:23 +0800 Subject: [PATCH 40/71] fix: remove uesless files --- docker-compose-test/Makefile | 95 ------- docker-compose-test/docker-compose.yml | 126 ---------- docker-compose-test/envoy.yaml | 330 ------------------------- 3 files changed, 551 deletions(-) delete mode 100644 docker-compose-test/Makefile delete mode 100644 docker-compose-test/docker-compose.yml delete mode 100644 docker-compose-test/envoy.yaml diff --git a/docker-compose-test/Makefile b/docker-compose-test/Makefile deleted file mode 100644 index da7c64ba43..0000000000 --- a/docker-compose-test/Makefile +++ /dev/null @@ -1,95 +0,0 @@ -HEADER := Content-Type: application/json -WEAVIATE_PORT = 8081 - - -.PHONY: proxy cache docker weaviate-post-collection weaviate-post-obj weaviate-get-obj - -all: proxy cache docker weaviate-post-collection weaviate-post-obj weaviate-get-obj - -docker: - docker compose up -d - -docker-down: - docker compose down - -# 编译 proxy 插件 -proxy: - cd ../plugins/wasm-go/extensions/ai-proxy && \ - tinygo build -o ai-proxy.wasm -scheduler=none -target=wasi -gc=custom -tags='custommalloc nottinygc_finalizer proxy_wasm_version_0_2_100' . && \ - mv ai-proxy.wasm ../../../../docker-compose-test/ - -# 编译 cache 插件 -cache: - cd ../plugins/wasm-go/extensions/ai-cache && \ - tinygo build -o ai-cache.wasm -scheduler=none -target=wasi -gc=custom -tags='custommalloc nottinygc_finalizer proxy_wasm_version_0_2_100' . && \ - mv ai-cache.wasm ../../../../docker-compose-test/ - -# 创建 object -weaviate-post-obj: - curl --request POST --url http://localhost:$(WEAVIATE_PORT)/v1/objects -H "$(HEADER)" --data '{"class": "Higress", "vector": [0.1, 0.2, 0.3], "properties": {"question": "这里是问题3"}}' - -# 获取 schema -weaviate-get-schema: - curl -X GET "http://localhost:$(WEAVIATE_PORT)/v1/schema" -H "$(HEADER)" - -# 创建 collection -weaviate-post-collection: - curl -X POST "http://localhost:$(WEAVIATE_PORT)/v1/schema" -H "$(HEADER)" -d '{"class": "Higress"}' - -# 获取 objs -weaviate-get-obj: - curl -X GET "http://localhost:$(WEAVIATE_PORT)/v1/objects" - -# 获取具体 obj -weaviate-get-obj-id: - curl -X GET "http://localhost:$(WEAVIATE_PORT)/v1/objects/Higress/8e7df58e-3415-4264-9bcb-afbb3c51318b" - -# 删除 obj -weaviate-delete-obj: - curl -X DELETE "http://localhost:$(WEAVIATE_PORT)/v1/objects/Higress/8e7df58e-3415-4264-9bcb-afbb3c51318b" - -# 删除 collection,这里 classname 会自动大写 -weaviate-delete-collection: - curl -X DELETE "http://localhost:$(WEAVIATE_PORT)/v1/schema/Higress" - -QUERY = "{ \ - Get { \ - Higress ( \ - limit: 5 \ - nearVector: { \ - vector: [0.1, 0.2, 0.3] \ - } \ - ) { \ - question \ - _additional { \ - distance \ - } \ - } \ - } \ -}" -# 搜索,默认按照 distance 升序 -# https://weaviate.io/developers/weaviate/config-refs/distances -weaviate-search: - curl -X POST "http://localhost:$(WEAVIATE_PORT)/v1/graphql" -H "$(HEADER)" -d '{"query": ${QUERY}}' - - -# redis client -redis-cli: - docker run -it --network docker-compose-test_wasmtest --rm redis redis-cli -h docker-compose-test-redis-1 - -# redis flushall -redis-flushall: - docker run -it --network docker-compose-test_wasmtest --rm redis redis-cli -h docker-compose-test-redis-1 flushall - -# llm request -llm: - curl -X POST http://localhost:10000/v1/chat/completions -H "Content-Type: application/json" -d '{"model": "gpt-4o-mini", "messages": [{"role": "user", "content": "你好"}]}' - -llm1: - curl -X POST http://localhost:10000/v1/chat/completions -H "Content-Type: application/json" -d '{"model": "gpt-4o-mini", "messages": [{"role": "user", "content": "今天晚上吃什么"}]}' - -llm2: - curl -X POST http://localhost:10000/v1/chat/completions -H "Content-Type: application/json" -d '{"model": "gpt-4o-mini", "messages": [{"role": "user", "content": "今天晚上吃什么?"}]}' - -llm3: - curl -X POST http://localhost:10000/v1/chat/completions -H "Content-Type: application/json" -d '{"model": "gpt-4o-mini", "messages": [{"role": "user", "content": "今天晚上吃什么呢?有无推荐?"}]}' \ No newline at end of file diff --git a/docker-compose-test/docker-compose.yml b/docker-compose-test/docker-compose.yml deleted file mode 100644 index 9d76fbf987..0000000000 --- a/docker-compose-test/docker-compose.yml +++ /dev/null @@ -1,126 +0,0 @@ -version: '3.7' -services: - envoy: - # image: higress-registry.cn-hangzhou.cr.aliyuncs.com/higress/gateway:v1.4.0-rc.1 - image: higress-registry.cn-hangzhou.cr.aliyuncs.com/higress/gateway:1.4.2 - entrypoint: /usr/local/bin/envoy - # 注意这里对wasm开启了debug级别日志,正式部署时则默认info级别 - command: -c /etc/envoy/envoy.yaml --component-log-level wasm:debug - depends_on: - - httpbin - - redis - - weaviate - # - chroma - # - es - networks: - - wasmtest - ports: - - "10000:10000" - - "9901:9901" - volumes: - - ./envoy.yaml:/etc/envoy/envoy.yaml - # 注意默认没有这两个 wasm 的时候,docker 会创建文件夹,这样会出错,需要有 wasm 文件之后 down 然后重新 up - - ./ai-cache.wasm:/etc/envoy/ai-cache.wasm - - ./ai-proxy.wasm:/etc/envoy/ai-proxy.wasm - - # chroma: - # image: chromadb/chroma - # ports: - # - "8001:8000" - # volumes: - # - chroma-data:/chroma/chroma - - redis: - image: redis:latest - networks: - - wasmtest - ports: - - "6379:6379" - - httpbin: - image: kennethreitz/httpbin:latest - networks: - - wasmtest - ports: - - "12345:80" - - # es: - # image: elasticsearch:8.15.0 - # environment: - # - "TZ=Asia/Shanghai" - # - "discovery.type=single-node" - # - "xpack.security.http.ssl.enabled=false" - # - "xpack.license.self_generated.type=trial" - # - "ELASTIC_PASSWORD=123456" - # ports: - # - "9200:9200" - # - "9300:9300" - # networks: - # - wasmtest - - # kibana: - # image: docker.elastic.co/kibana/kibana:8.15.0 - # environment: - # - "TZ=Asia/Shanghai" - # - "ELASTICSEARCH_HOSTS=http://es:9200" - # - "ELASTICSEARCH_URL=http://es:9200" - # - "ELASTICSEARCH_USERNAME=kibana_system" - # - "ELASTICSEARCH_PASSWORD=123456" - # - "xpack.security.enabled=false" - # - "xpack.license.self_generated.type=trial" - # ports: - # - "5601:5601" - # networks: - # - wasmtest - # depends_on: - # - es - - # lobechat: - # # docker hub 如果访问不了,可以改用这个地址:registry.cn-hangzhou.aliyuncs.com/2456868764/lobe-chat:v1.1.3 - # image: lobehub/lobe-chat - # environment: - # - CODE=admin - # - OPENAI_API_KEY=unused - # - OPENAI_PROXY_URL=http://envoy:10000/v1 - # networks: - # - wasmtest - # ports: - # - "3210:3210/tcp" - - weaviate: - command: - - --host - - 0.0.0.0 - - --port - - '8080' - - --scheme - - http - # 高于 1.24.x 的版本,单节点部署有问题 - image: cr.weaviate.io/semitechnologies/weaviate:1.24.1 - ports: - - 8081:8080 - - 50051:50051 - volumes: - - weaviate_data:/var/lib/weaviate - restart: on-failure:0 - networks: - - wasmtest - environment: - QUERY_DEFAULTS_LIMIT: 25 - AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: 'true' - PERSISTENCE_DATA_PATH: '/var/lib/weaviate' - DEFAULT_VECTORIZER_MODULE: 'none' - ENABLE_API_BASED_MODULES: 'true' - CLUSTER_HOSTNAME: 'node1' - - # t2v-transformers: # Set the name of the inference container - # image: cr.weaviate.io/semitechnologies/transformers-inference:sentence-transformers-multi-qa-MiniLM-L6-cos-v1 - # environment: - # ENABLE_CUDA: 0 # Set to 1 to enable -volumes: - weaviate_data: {} - chroma-data: - driver: local - -networks: - wasmtest: {} \ No newline at end of file diff --git a/docker-compose-test/envoy.yaml b/docker-compose-test/envoy.yaml deleted file mode 100644 index 411f3b04fb..0000000000 --- a/docker-compose-test/envoy.yaml +++ /dev/null @@ -1,330 +0,0 @@ -admin: - address: - socket_address: - protocol: TCP - address: 0.0.0.0 - port_value: 9901 -static_resources: - listeners: - - name: listener_0 - address: - socket_address: - protocol: TCP - address: 0.0.0.0 - port_value: 10000 - filter_chains: - - filters: - # httpbin - - name: envoy.filters.network.http_connection_manager - typed_config: - "@type": type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager - scheme_header_transformation: - scheme_to_overwrite: https - stat_prefix: ingress_http - route_config: - name: local_route - virtual_hosts: - - name: local_service - domains: ["*"] - routes: - # - match: - # prefix: "/" - # route: - # cluster: httpbin - - match: - prefix: "/" - route: - cluster: outbound|443||bigmodel.dns - timeout: 300s - - http_filters: - # ai-cache - - name: ai-cache - typed_config: - "@type": type.googleapis.com/udpa.type.v1.TypedStruct - type_url: type.googleapis.com/envoy.extensions.filters.http.wasm.v3.Wasm - value: - config: - name: ai-cache - vm_config: - runtime: envoy.wasm.runtime.v8 - code: - local: - filename: /etc/envoy/ai-cache.wasm - configuration: - "@type": "type.googleapis.com/google.protobuf.StringValue" - value: | - { - "embedding": { - "type": "dashscope", - "serviceName": "dashscope", - "apiKey": "sk-346fadc4ed8448e487ea84b57788d816" - }, - "vector": { - "type": "chroma", - "serviceName": "chroma", - "collectionID": "8f8accc8-b68c-4fb1-84af-0292f9236e8a", - "servicePort": 8001 - }, - "cache": { - "type": "redis", - "serviceName": "redis_cluster", - "timeout": 1000 - } - } - - # 上面的配置中 redis 的配置名字是 redis,而不是 golang tag 中的 redisConfig - # "vector": { - # "type": "dashvector", - # "serviceName": "dashvector", - # "collectionID": "higress_euclidean", - # "serviceDomain": "vrs-cn-g6z3yq2wy0001z.dashvector.cn-hangzhou.aliyuncs.com", - # "apiKey": "sk-nAn4GfZFrbLNhGffVIIq6tdgWNjV7D8A0F7CC5E1011EF9A1EB61E393DC850" - # }, - - # "vectorProvider": { - # "VectorStoreProviderType": "chroma", - # "ChromaServiceName": "chroma", - # "ChromaCollectionID": "0294deb1-8ef5-4582-b21c-75f23093db2c" - # }, - - # "vectorProvider": { - # "VectorStoreProviderType": "elasticsearch", - # "ThresholdRelation": "gte", - # "ESThreshold": 0.7, - # "ESServiceName": "es", - # "ESIndex": "higress", - # "ESUsername": "elastic", - # "ESPassword": "123456" - # }, - - # "vectorProvider": { - # "VectorStoreProviderType": "weaviate", - # "WeaviateServiceName": "weaviate", - # "WeaviateCollection": "Higress", - # "WeaviateThreshold": "0.3" - # }, - - # "vector": { - # "type": "pinecone", - # "PineconeServiceName": "pinecone", - # "PineconeApiEndpoint": "higress-2bdfipe.svc.aped-4627-b74a.pinecone.io", - # "PineconeThreshold": "0.7", - # "ThresholdRelation": "gte", - # "PineconeApiKey": "key" - # }, - # llm-proxy - - name: llm-proxy - typed_config: - "@type": type.googleapis.com/udpa.type.v1.TypedStruct - type_url: type.googleapis.com/envoy.extensions.filters.http.wasm.v3.Wasm - value: - config: - name: llm - vm_config: - runtime: envoy.wasm.runtime.v8 - code: - local: - filename: /etc/envoy/ai-proxy.wasm - configuration: - "@type": "type.googleapis.com/google.protobuf.StringValue" - value: | # 插件配置 - { - "provider": { - "type": "zhipuai", - "apiTokens": [ - "67e93d524df46fca3640df67a7461c04.qOksqKAoWHcv03aV" - ] - } - } - - - - name: envoy.filters.http.router - - clusters: - - name: httpbin - connect_timeout: 30s - type: LOGICAL_DNS - # Comment out the following line to test on v6 networks - dns_lookup_family: V4_ONLY - lb_policy: ROUND_ROBIN - load_assignment: - cluster_name: httpbin - endpoints: - - lb_endpoints: - - endpoint: - address: - socket_address: - address: httpbin - port_value: 80 - # - name: redis_cluster - # connect_timeout: 30s - # type: STRICT_DNS - # lb_policy: ROUND_ROBIN - # load_assignment: - # cluster_name: redis - # endpoints: - # - lb_endpoints: - # - endpoint: - # address: - # socket_address: - # address: 172.17.0.1 - # port_value: 6379 - - name: outbound|6379||redis_cluster - connect_timeout: 1s - type: strict_dns - lb_policy: ROUND_ROBIN - load_assignment: - cluster_name: outbound|6379||redis_cluster - endpoints: - - lb_endpoints: - - endpoint: - address: - socket_address: - address: 172.17.0.1 - port_value: 6379 - typed_extension_protocol_options: - envoy.filters.network.redis_proxy: - "@type": type.googleapis.com/envoy.extensions.filters.network.redis_proxy.v3.RedisProtocolOptions - # chroma - - name: outbound|8001||chroma.dns - connect_timeout: 30s - type: LOGICAL_DNS - dns_lookup_family: V4_ONLY - lb_policy: ROUND_ROBIN - load_assignment: - cluster_name: outbound|8001||chroma.dns - endpoints: - - lb_endpoints: - - endpoint: - address: - socket_address: - address: 172.17.0.1 # 本地 API 服务地址,这里是 docker0 - port_value: 8001 - - # es - - name: outbound|9200||es.dns - connect_timeout: 30s - type: LOGICAL_DNS - dns_lookup_family: V4_ONLY - lb_policy: ROUND_ROBIN - load_assignment: - cluster_name: outbound|9200||es.dns - endpoints: - - lb_endpoints: - - endpoint: - address: - socket_address: - address: 172.17.0.1 # 本地 API 服务地址,这里是 docker0 - port_value: 9200 - # weaviate - - name: outbound|8081||weaviate.dns - connect_timeout: 30s - type: LOGICAL_DNS - dns_lookup_family: V4_ONLY - lb_policy: ROUND_ROBIN - load_assignment: - cluster_name: outbound|8081||weaviate.dns - endpoints: - - lb_endpoints: - - endpoint: - address: - socket_address: - address: 172.17.0.1 # 本地 API 服务地址,这里是 docker0 - port_value: 8081 - - # llm - - name: llm - connect_timeout: 30s - type: LOGICAL_DNS - dns_lookup_family: V4_ONLY - lb_policy: ROUND_ROBIN - load_assignment: - cluster_name: llm - endpoints: - - lb_endpoints: - - endpoint: - address: - socket_address: - address: 172.17.0.1 # 本地 API 服务地址,这里是 docker0 - port_value: 8000 - # dashvector - - name: outbound|443||dashvector.dns - connect_timeout: 30s - type: LOGICAL_DNS - dns_lookup_family: V4_ONLY - lb_policy: ROUND_ROBIN - load_assignment: - cluster_name: outbound|443||dashvector.dns - endpoints: - - lb_endpoints: - - endpoint: - address: - socket_address: - address: vrs-cn-g6z3yq2wy0001z.dashvector.cn-hangzhou.aliyuncs.com - port_value: 443 - transport_socket: - name: envoy.transport_sockets.tls - typed_config: - "@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext - "sni": "vrs-cn-g6z3yq2wy0001z.dashvector.cn-hangzhou.aliyuncs.com" - # dashscope - - name: outbound|443||dashscope.dns - connect_timeout: 30s - type: LOGICAL_DNS - dns_lookup_family: V4_ONLY - lb_policy: ROUND_ROBIN - load_assignment: - cluster_name: outbound|443||dashscope.dns - endpoints: - - lb_endpoints: - - endpoint: - address: - socket_address: - address: dashscope.aliyuncs.com - port_value: 443 - transport_socket: - name: envoy.transport_sockets.tls - typed_config: - "@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext - "sni": "dashscope.aliyuncs.com" - # bigmodel - - name: outbound|443||bigmodel.dns - connect_timeout: 30s - type: LOGICAL_DNS - dns_lookup_family: V4_ONLY - lb_policy: ROUND_ROBIN - load_assignment: - cluster_name: outbound|443||bigmodel.dns - endpoints: - - lb_endpoints: - - endpoint: - address: - socket_address: - address: open.bigmodel.cn - port_value: 443 - transport_socket: - name: envoy.transport_sockets.tls - typed_config: - "@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext - "sni": "open.bigmodel.cn" - # pinecone - - name: outbound|443||pinecone.dns - connect_timeout: 30s - type: LOGICAL_DNS - dns_lookup_family: V4_ONLY - lb_policy: ROUND_ROBIN - load_assignment: - cluster_name: outbound|443||pinecone.dns - endpoints: - - lb_endpoints: - - endpoint: - address: - socket_address: - address: higress-2bdfipe.svc.aped-4627-b74a.pinecone.io - port_value: 443 - transport_socket: - name: envoy.transport_sockets.tls - typed_config: - "@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext - "sni": "higress-2bdfipe.svc.aped-4627-b74a.pinecone.io" \ No newline at end of file From cbeb71b15ef712a1515a82030b4c3bf165d69694 Mon Sep 17 00:00:00 2001 From: async Date: Sat, 19 Oct 2024 00:58:51 +0800 Subject: [PATCH 41/71] fix: remove uesless files --- plugins/wasm-go/extensions/ai-cache/.gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasm-go/extensions/ai-cache/.gitignore b/plugins/wasm-go/extensions/ai-cache/.gitignore index 8a34bf52ad..47db8eedba 100644 --- a/plugins/wasm-go/extensions/ai-cache/.gitignore +++ b/plugins/wasm-go/extensions/ai-cache/.gitignore @@ -1,5 +1,5 @@ # File generated by hgctl. Modify as required. -docker-compose-test/ + * !/.gitignore From 43cfdafe8c4f2658b3c9a44f758fd447ce10bc7a Mon Sep 17 00:00:00 2001 From: async Date: Sat, 19 Oct 2024 10:55:56 +0800 Subject: [PATCH 42/71] feat: qdrant --- .../extensions/ai-cache/vector/pinecone.go | 4 +- .../extensions/ai-cache/vector/provider.go | 2 + .../extensions/ai-cache/vector/qdrant.go | 214 ++++++++++++++++++ 3 files changed, 218 insertions(+), 2 deletions(-) create mode 100644 plugins/wasm-go/extensions/ai-cache/vector/qdrant.go diff --git a/plugins/wasm-go/extensions/ai-cache/vector/pinecone.go b/plugins/wasm-go/extensions/ai-cache/vector/pinecone.go index f040b5b4a9..8d21f7b64e 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/pinecone.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/pinecone.go @@ -184,8 +184,8 @@ func (d *pineconeProvider) parseQueryResponse(responseBody []byte, log wrapper.L } if !gjson.GetBytes(responseBody, "matches.0.metadata.answer").Exists() { - log.Errorf("[Pinecone] No question found in response body: %s", responseBody) - return nil, errors.New("[Pinecone] No question found in response body") + log.Errorf("[Pinecone] No answer found in response body: %s", responseBody) + return nil, errors.New("[Pinecone] No answer found in response body") } resultNum := gjson.GetBytes(responseBody, "matches.#").Int() diff --git a/plugins/wasm-go/extensions/ai-cache/vector/provider.go b/plugins/wasm-go/extensions/ai-cache/vector/provider.go index 660197f8e8..8194d33b30 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/provider.go @@ -13,6 +13,7 @@ const ( providerTypeES = "elasticsearch" providerTypeWeaviate = "weaviate" providerTypePinecone = "pinecone" + providerTypeQdrant = "qdrant" ) type providerInitializer interface { @@ -27,6 +28,7 @@ var ( providerTypeWeaviate: &weaviateProviderInitializer{}, providerTypeES: &esProviderInitializer{}, providerTypePinecone: &pineconeProviderInitializer{}, + providerTypeQdrant: &qdrantProviderInitializer{}, } ) diff --git a/plugins/wasm-go/extensions/ai-cache/vector/qdrant.go b/plugins/wasm-go/extensions/ai-cache/vector/qdrant.go new file mode 100644 index 0000000000..cddea1ec2f --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/vector/qdrant.go @@ -0,0 +1,214 @@ +package vector + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/google/uuid" + "github.com/tidwall/gjson" +) + +type qdrantProviderInitializer struct{} + +const qdrantThreshold = 50 + +func (c *qdrantProviderInitializer) ValidateConfig(config ProviderConfig) error { + if len(config.serviceName) == 0 { + return errors.New("[Qdrant] serviceName is required") + } + if len(config.collectionID) == 0 { + return errors.New("[Qdrant] collectionID is required") + } + return nil +} + +func (c *qdrantProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + return &qdrantProvider{ + config: config, + client: wrapper.NewClusterClient(wrapper.DnsCluster{ + ServiceName: config.serviceName, + Port: config.servicePort, + Domain: config.serviceDomain, + }), + }, nil +} + +type qdrantProvider struct { + config ProviderConfig + client wrapper.HttpClient +} + +func (c *qdrantProvider) GetProviderType() string { + return providerTypeQdrant +} + +func (d *qdrantProvider) GetSimilarityThreshold() float64 { + return qdrantThreshold +} + +type qdrantPayload struct { + Question string `json:"question"` + Answer string `json:"answer"` +} + +type qdrantPoint struct { + ID string `json:"id"` + Vector []float64 `json:"vector"` + Payload qdrantPayload `json:"payload"` +} + +type qdrantInsertRequest struct { + Points []qdrantPoint `json:"points"` +} + +func (d *qdrantProvider) UploadAnswerAndEmbedding( + queryString string, + queryEmb []float64, + queryAnswer string, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error { + // 最少需要填写的参数为 id 和 vector. payload 可选 + // 下面是一个例子 + // { + // "points": [ + // { + // "id": "76874cce-1fb9-4e16-9b0b-f085ac06ed6f", + // "payload": { + // "question": "这里是问题", + // "ansower": "这里是答案" + // }, + // "vector": [ + // 0.9, + // 0.1, + // 0.1 + // ] + // } + // ] + // } + requestBody, err := json.Marshal(qdrantInsertRequest{ + Points: []qdrantPoint{ + { + ID: uuid.New().String(), + Vector: queryEmb, + Payload: qdrantPayload{Question: queryString, Answer: queryAnswer}, + }, + }, + }) + + if err != nil { + log.Errorf("[Qdrant] Failed to marshal upload embedding request body: %v", err) + return err + } + + return d.client.Put( + fmt.Sprintf("/collections/%s/points", d.config.collectionID), + [][2]string{ + {"Content-Type", "application/json"}, + {"api-key", d.config.apiKey}, + }, + requestBody, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + log.Infof("[Qdrant] statusCode:%d, responseBody:%s", statusCode, string(responseBody)) + callback(ctx, log, err) + }, + d.config.timeout, + ) +} + +type qdrantQueryRequest struct { + Vector []float64 `json:"vector"` + Limit int `json:"limit"` + WithPayload bool `json:"with_payload"` +} + +func (d *qdrantProvider) QueryEmbedding( + emb []float64, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error { + // 最少需要填写的参数为 vector 和 limit. with_payload 可选,为了直接得到问题答案,所以这里需要 + // 下面是一个例子 + // { + // "vector": [ + // 0.2, + // 0.1, + // 0.9, + // 0.7 + // ], + // "limit": 1 + // } + requestBody, err := json.Marshal(qdrantQueryRequest{ + Vector: emb, + Limit: d.config.topK, + WithPayload: true, + }) + if err != nil { + log.Errorf("[Qdrant] Failed to marshal query embedding: %v", err) + return err + } + + return d.client.Post( + fmt.Sprintf("/collections/%s/points/search", d.config.collectionID), + [][2]string{ + {"Content-Type", "application/json"}, + {"api-key", d.config.apiKey}, + }, + requestBody, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + log.Infof("[Qdrant] Query embedding response: %d, %s", statusCode, responseBody) + results, err := d.parseQueryResponse(responseBody, log) + if err != nil { + err = fmt.Errorf("[Qdrant] Failed to parse query response: %v", err) + } + callback(results, ctx, log, err) + }, + d.config.timeout, + ) +} + +func (d *qdrantProvider) parseQueryResponse(responseBody []byte, log wrapper.Log) ([]QueryResult, error) { + // 返回的内容例子如下 + // { + // "time": 0.002, + // "status": "ok", + // "result": [ + // { + // "id": 42, + // "version": 3, + // "score": 0.75, + // "payload": { + // "question": "London", + // "answer": "green" + // }, + // "shard_key": "region_1", + // "order_value": 42 + // } + // ] + // } + if !gjson.GetBytes(responseBody, "result.0.score").Exists() { + log.Errorf("[Qdrant] No distance found in response body: %s", responseBody) + return nil, errors.New("[Qdrant] No distance found in response body") + } + + if !gjson.GetBytes(responseBody, "result.0.payload.answer").Exists() { + log.Errorf("[Qdrant] No answer found in response body: %s", responseBody) + return nil, errors.New("[Qdrant] No answer found in response body") + } + + resultNum := gjson.GetBytes(responseBody, "result.#").Int() + results := make([]QueryResult, 0, resultNum) + for i := 0; i < int(resultNum); i++ { + result := QueryResult{ + Text: gjson.GetBytes(responseBody, fmt.Sprintf("result.%d.payload.question", i)).String(), + Score: gjson.GetBytes(responseBody, fmt.Sprintf("result.%d.score", i)).Float(), + Answer: gjson.GetBytes(responseBody, fmt.Sprintf("result.%d.payload.answer", i)).String(), + } + results = append(results, result) + } + + return results, nil +} From 2a4363aab565feaba4950110ca8ab4158e40fa56 Mon Sep 17 00:00:00 2001 From: async Date: Sat, 19 Oct 2024 15:43:36 +0800 Subject: [PATCH 43/71] feat: milvus --- .../extensions/ai-cache/vector/milvus.go | 213 ++++++++++++++++++ .../extensions/ai-cache/vector/provider.go | 2 + .../extensions/ai-cache/vector/qdrant.go | 2 +- 3 files changed, 216 insertions(+), 1 deletion(-) create mode 100644 plugins/wasm-go/extensions/ai-cache/vector/milvus.go diff --git a/plugins/wasm-go/extensions/ai-cache/vector/milvus.go b/plugins/wasm-go/extensions/ai-cache/vector/milvus.go new file mode 100644 index 0000000000..a8ac265151 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/vector/milvus.go @@ -0,0 +1,213 @@ +package vector + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" +) + +type milvusProviderInitializer struct{} + +const milvusThreshold = 0.5 + +func (c *milvusProviderInitializer) ValidateConfig(config ProviderConfig) error { + if len(config.serviceName) == 0 { + return errors.New("[Milvus] serviceName is required") + } + if len(config.collectionID) == 0 { + return errors.New("[Milvus] collectionID is required") + } + return nil +} + +func (c *milvusProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + return &milvusProvider{ + config: config, + client: wrapper.NewClusterClient(wrapper.DnsCluster{ + ServiceName: config.serviceName, + Port: config.servicePort, + Domain: config.serviceDomain, + }), + }, nil +} + +type milvusProvider struct { + config ProviderConfig + client wrapper.HttpClient +} + +func (c *milvusProvider) GetProviderType() string { + return providerTypeMilvus +} + +func (d *milvusProvider) GetSimilarityThreshold() float64 { + return milvusThreshold +} + +type milvusData struct { + ID int `json:"id"` + Vector []float64 `json:"vector"` + Question string `json:"question,omitempty"` + Answer string `json:"answer,omitempty"` +} + +type milvusInsertRequest struct { + CollectionName string `json:"collectionName"` + Data []milvusData `json:"data"` +} + +func (d *milvusProvider) UploadAnswerAndEmbedding( + queryString string, + queryEmb []float64, + queryAnswer string, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error { + // 最少需要填写的参数为 collectionName, data 和 Authorization. question, answer 可选 + // invalid syntax: invalid parameter[expected=Int64][actual=] + // 下面是一个例子 + // { + // "collectionName": "higress", + // "data": [ + // { + // "question": "这里是问题", + // "answer": "这里是答案" + // "vector": [ + // 0.9, + // 0.1, + // 0.1 + // ] + // } + // ] + // } + requestBody, err := json.Marshal(milvusInsertRequest{ + CollectionName: d.config.collectionID, + Data: []milvusData{ + { + ID: 0, + Question: queryString, + Answer: queryAnswer, + Vector: queryEmb, + }, + }, + }) + + if err != nil { + log.Errorf("[Milvus] Failed to marshal upload embedding request body: %v", err) + return err + } + + return d.client.Post( + "/v2/vectordb/entities/insert", + [][2]string{ + {"Content-Type", "application/json"}, + {"Authorization", fmt.Sprintf("Bearer %s", d.config.apiKey)}, + }, + requestBody, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + log.Infof("[Milvus] statusCode:%d, responseBody:%s", statusCode, string(responseBody)) + callback(ctx, log, err) + }, + d.config.timeout, + ) +} + +type milvusQueryRequest struct { + CollectionName string `json:"collectionName"` + Data [][]float64 `json:"data"` + AnnsField string `json:"annsField"` + Limit int `json:"limit"` + OutputFields []string `json:"outputFields"` +} + +func (d *milvusProvider) QueryEmbedding( + emb []float64, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error { + // 最少需要填写的参数为 collectionName, data, annsField. outputFields 为可选参数 + // 下面是一个例子 + // { + // "collectionName": "quick_setup", + // "data": [ + // [ + // 0.3580376395471989, + // "Unknown type", + // 0.18414012509913835, + // "Unknown type", + // 0.9029438446296592 + // ] + // ], + // "annsField": "vector", + // "limit": 3, + // "outputFields": [ + // "color" + // ] + // } + requestBody, err := json.Marshal(milvusQueryRequest{ + CollectionName: d.config.collectionID, + Data: [][]float64{emb}, + AnnsField: "vector", + Limit: d.config.topK, + OutputFields: []string{ + "question", + "answer", + }, + }) + if err != nil { + log.Errorf("[Milvus] Failed to marshal query embedding: %v", err) + return err + } + + return d.client.Post( + "/v2/vectordb/entities/search", + [][2]string{ + {"Content-Type", "application/json"}, + {"Authorization", fmt.Sprintf("Bearer %s", d.config.apiKey)}, + }, + requestBody, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + log.Infof("[Milvus] Query embedding response: %d, %s", statusCode, responseBody) + results, err := d.parseQueryResponse(responseBody, log) + if err != nil { + err = fmt.Errorf("[Milvus] Failed to parse query response: %v", err) + } + callback(results, ctx, log, err) + }, + d.config.timeout, + ) +} + +func (d *milvusProvider) parseQueryResponse(responseBody []byte, log wrapper.Log) ([]QueryResult, error) { + if !gjson.GetBytes(responseBody, "data.0.distance").Exists() { + log.Errorf("[Milvus] No distance found in response body: %s", responseBody) + return nil, errors.New("[Milvus] No distance found in response body") + } + + if !gjson.GetBytes(responseBody, "data.0.question").Exists() { + log.Errorf("[Milvus] No question found in response body: %s", responseBody) + return nil, errors.New("[Milvus] No question found in response body") + } + + if !gjson.GetBytes(responseBody, "data.0.answer").Exists() { + log.Errorf("[Milvus] No answer found in response body: %s", responseBody) + return nil, errors.New("[Milvus] No answer found in response body") + } + + resultNum := gjson.GetBytes(responseBody, "data.#").Int() + results := make([]QueryResult, 0, resultNum) + for i := 0; i < int(resultNum); i++ { + result := QueryResult{ + Text: gjson.GetBytes(responseBody, fmt.Sprintf("data.%d.question", i)).String(), + Score: gjson.GetBytes(responseBody, fmt.Sprintf("data.%d.distance", i)).Float(), + Answer: gjson.GetBytes(responseBody, fmt.Sprintf("data.%d.answer", i)).String(), + } + results = append(results, result) + } + + return results, nil +} diff --git a/plugins/wasm-go/extensions/ai-cache/vector/provider.go b/plugins/wasm-go/extensions/ai-cache/vector/provider.go index 8194d33b30..62875fc88c 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/provider.go @@ -14,6 +14,7 @@ const ( providerTypeWeaviate = "weaviate" providerTypePinecone = "pinecone" providerTypeQdrant = "qdrant" + providerTypeMilvus = "milvus" ) type providerInitializer interface { @@ -29,6 +30,7 @@ var ( providerTypeES: &esProviderInitializer{}, providerTypePinecone: &pineconeProviderInitializer{}, providerTypeQdrant: &qdrantProviderInitializer{}, + providerTypeMilvus: &milvusProviderInitializer{}, } ) diff --git a/plugins/wasm-go/extensions/ai-cache/vector/qdrant.go b/plugins/wasm-go/extensions/ai-cache/vector/qdrant.go index cddea1ec2f..8e6e0a8979 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/qdrant.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/qdrant.go @@ -79,7 +79,7 @@ func (d *qdrantProvider) UploadAnswerAndEmbedding( // "id": "76874cce-1fb9-4e16-9b0b-f085ac06ed6f", // "payload": { // "question": "这里是问题", - // "ansower": "这里是答案" + // "answer": "这里是答案" // }, // "vector": [ // 0.9, From 9603479a2690d5236eadeb5fe69da5a18ffeba58 Mon Sep 17 00:00:00 2001 From: async Date: Sat, 19 Oct 2024 20:12:09 +0800 Subject: [PATCH 44/71] feat: custom threshold --- plugins/wasm-go/extensions/ai-cache/README.md | 3 +- .../extensions/ai-cache/config/config.go | 4 ++ plugins/wasm-go/extensions/ai-cache/core.go | 28 ++++++-------- .../extensions/ai-cache/vector/chroma.go | 6 --- .../extensions/ai-cache/vector/dashvector.go | 11 ------ .../ai-cache/vector/elasticsearch.go | 6 --- .../extensions/ai-cache/vector/milvus.go | 11 +----- .../extensions/ai-cache/vector/pinecone.go | 6 --- .../extensions/ai-cache/vector/provider.go | 37 ++++++++++++++++--- .../extensions/ai-cache/vector/qdrant.go | 6 --- .../extensions/ai-cache/vector/weaviate.go | 6 --- 11 files changed, 50 insertions(+), 74 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-cache/README.md b/plugins/wasm-go/extensions/ai-cache/README.md index fcd0faab05..1894a18993 100644 --- a/plugins/wasm-go/extensions/ai-cache/README.md +++ b/plugins/wasm-go/extensions/ai-cache/README.md @@ -50,7 +50,8 @@ LLM 结果缓存插件,默认配置方式可以直接用于 openai 协议的 | vector.topK | int | optional | 1 | 返回TopK结果,默认为 1 | | vector.timeout | uint32 | optional | 10000 | 请求向量存储服务的超时时间,单位为毫秒。默认值是10000,即10秒 | | vector.collectionID | string | optional | "" | DashVector 向量存储服务 Collection ID | - +| vector.threshold | float64 | optional | 1000 | 向量相似度度量阈值 | +| vector.thresholdRelation | string | optional | lt | 相似度度量方式有 `Cosine`, `DotProduct`, `Euclidean` 等,前两者值越大相似度越高,后者值越小相似度越高。对于 `Cosine` 和 `DotProduct` 选择 `gt`,对于 `Euclidean` 则选择 `lt`。默认为 `lt`,所有条件包括 `lt` (less than,小于)、`lte` (less than or equal to,小等于)、`gt` (greater than,大于)、`gte` (greater than or equal to,大等于) | ## 文本向量化服务(embedding) | Name | Type | Requirement | Default | Description | diff --git a/plugins/wasm-go/extensions/ai-cache/config/config.go b/plugins/wasm-go/extensions/ai-cache/config/config.go index ac9a46bdc5..0a22b3c032 100644 --- a/plugins/wasm-go/extensions/ai-cache/config/config.go +++ b/plugins/wasm-go/extensions/ai-cache/config/config.go @@ -154,6 +154,10 @@ func (c *PluginConfig) GetVectorProvider() vector.Provider { return c.vectorProvider } +func (c *PluginConfig) GetVectorProviderConfig() vector.ProviderConfig { + return c.vectorProviderConfig +} + func (c *PluginConfig) GetCacheProvider() cache.Provider { return c.cacheProvider } diff --git a/plugins/wasm-go/extensions/ai-cache/core.go b/plugins/wasm-go/extensions/ai-cache/core.go index b8b7c49f32..3ffa493e40 100644 --- a/plugins/wasm-go/extensions/ai-cache/core.go +++ b/plugins/wasm-go/extensions/ai-cache/core.go @@ -153,15 +153,9 @@ func handleQueryResults(key string, results []vector.QueryResult, ctx wrapper.Ht mostSimilarData := results[0] log.Debugf("For key: %s, the most similar key found: %s with score: %f", key, mostSimilarData.Text, mostSimilarData.Score) - - simThresholdProvider, ok := config.GetVectorProvider().(vector.SimilarityThresholdProvider) - if !ok { - handleInternalError(nil, "Active vector provider does not implement SimilarityThresholdProvider interface", log) - return - } - - simThreshold := simThresholdProvider.GetSimilarityThreshold() - if mostSimilarData.Score < simThreshold { + simThreshold := config.GetVectorProviderConfig().Threshold + simThresholdRelation := config.GetVectorProviderConfig().ThresholdRelation + if compare(simThresholdRelation, mostSimilarData.Score, simThreshold) { log.Infof("Key accepted: %s with score: %f below threshold", mostSimilarData.Text, mostSimilarData.Score) if mostSimilarData.Answer != "" { // direct return the answer if available @@ -179,7 +173,7 @@ func handleQueryResults(key string, results []vector.QueryResult, ctx wrapper.Ht proxywasm.ResumeHttpRequest() } } else { - log.Infof("Score too high for key: %s with score: %f above threshold", mostSimilarData.Text, mostSimilarData.Score) + log.Infof("Score not meet the threshold %f: %s with score %f", simThreshold, mostSimilarData.Text, mostSimilarData.Score) proxywasm.ResumeHttpRequest() } } @@ -252,20 +246,20 @@ func uploadEmbeddingAndAnswer(ctx wrapper.HttpContext, config config.PluginConfi } // 主要用于相似度/距离/点积判断 -// 相似度度量的是两个向量在方向上的相似程度。相似度越高,两个向量越接近。 +// 余弦相似度度量的是两个向量在方向上的相似程度。相似度越高,两个向量越接近。 // 距离度量的是两个向量在空间上的远近程度。距离越小,两个向量越接近。 // compare 函数根据操作符进行判断并返回结果 -func compare(operator string, value1 float64, value2 float64) (bool, error) { +func compare(operator string, value1 float64, value2 float64) bool { switch operator { case "gt": - return value1 > value2, nil + return value1 > value2 case "gte": - return value1 >= value2, nil + return value1 >= value2 case "lt": - return value1 < value2, nil + return value1 < value2 case "lte": - return value1 <= value2, nil + return value1 <= value2 default: - return false, errors.New("unsupported operator: " + operator) + return false } } diff --git a/plugins/wasm-go/extensions/ai-cache/vector/chroma.go b/plugins/wasm-go/extensions/ai-cache/vector/chroma.go index 7004702c13..acf49bf4f1 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/chroma.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/chroma.go @@ -11,8 +11,6 @@ import ( type chromaProviderInitializer struct{} -const chromaThreshold = 1000 - func (c *chromaProviderInitializer) ValidateConfig(config ProviderConfig) error { if len(config.collectionID) == 0 { return errors.New("[Chroma] collectionID is required") @@ -43,10 +41,6 @@ func (c *ChromaProvider) GetProviderType() string { return providerTypeChroma } -func (d *ChromaProvider) GetSimilarityThreshold() float64 { - return chromaThreshold -} - func (d *ChromaProvider) QueryEmbedding( emb []float64, ctx wrapper.HttpContext, diff --git a/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go b/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go index 8b5848a119..ccafaa4caf 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go @@ -9,10 +9,6 @@ import ( "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" ) -const ( - threshold = 10000 -) - type dashVectorProviderInitializer struct { } @@ -120,9 +116,6 @@ func (d *DvProvider) parseQueryResponse(responseBody []byte) (queryResponse, err return queryResp, nil } -func (d *DvProvider) GetSimThreshold() float64 { - return threshold -} func (d *DvProvider) QueryEmbedding( emb []float64, ctx wrapper.HttpContext, @@ -242,10 +235,6 @@ func (d *DvProvider) UploadEmbedding(queryString string, queryEmb []float64, ctx return err } -func (d *DvProvider) GetSimilarityThreshold() float64 { - return threshold -} - func (d *DvProvider) UploadAnswerAndEmbedding(queryString string, queryEmb []float64, queryAnswer string, ctx wrapper.HttpContext, log wrapper.Log, callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error { url, body, headers, err := d.constructUploadParameters(queryEmb, queryString, queryAnswer) if err != nil { diff --git a/plugins/wasm-go/extensions/ai-cache/vector/elasticsearch.go b/plugins/wasm-go/extensions/ai-cache/vector/elasticsearch.go index 27884aa7ce..8b52852525 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/elasticsearch.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/elasticsearch.go @@ -12,8 +12,6 @@ import ( type esProviderInitializer struct{} -const esThreshold = 1000 - func (c *esProviderInitializer) ValidateConfig(config ProviderConfig) error { if len(config.collectionID) == 0 { return errors.New("[ES] collectionID is required") @@ -44,10 +42,6 @@ func (c *ESProvider) GetProviderType() string { return providerTypeES } -func (d *ESProvider) GetSimilarityThreshold() float64 { - return esThreshold -} - func (d *ESProvider) QueryEmbedding( emb []float64, ctx wrapper.HttpContext, diff --git a/plugins/wasm-go/extensions/ai-cache/vector/milvus.go b/plugins/wasm-go/extensions/ai-cache/vector/milvus.go index a8ac265151..b2043e30a5 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/milvus.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/milvus.go @@ -12,8 +12,6 @@ import ( type milvusProviderInitializer struct{} -const milvusThreshold = 0.5 - func (c *milvusProviderInitializer) ValidateConfig(config ProviderConfig) error { if len(config.serviceName) == 0 { return errors.New("[Milvus] serviceName is required") @@ -44,12 +42,7 @@ func (c *milvusProvider) GetProviderType() string { return providerTypeMilvus } -func (d *milvusProvider) GetSimilarityThreshold() float64 { - return milvusThreshold -} - type milvusData struct { - ID int `json:"id"` Vector []float64 `json:"vector"` Question string `json:"question,omitempty"` Answer string `json:"answer,omitempty"` @@ -68,7 +61,8 @@ func (d *milvusProvider) UploadAnswerAndEmbedding( log wrapper.Log, callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error { // 最少需要填写的参数为 collectionName, data 和 Authorization. question, answer 可选 - // invalid syntax: invalid parameter[expected=Int64][actual=] + // 需要填写 id,否则 v2.4.13-hotfix 提示 invalid syntax: invalid parameter[expected=Int64][actual=] + // 如果不填写 id,一定要在创建 collection 的时候设置 autoId 为 true // 下面是一个例子 // { // "collectionName": "higress", @@ -88,7 +82,6 @@ func (d *milvusProvider) UploadAnswerAndEmbedding( CollectionName: d.config.collectionID, Data: []milvusData{ { - ID: 0, Question: queryString, Answer: queryAnswer, Vector: queryEmb, diff --git a/plugins/wasm-go/extensions/ai-cache/vector/pinecone.go b/plugins/wasm-go/extensions/ai-cache/vector/pinecone.go index 8d21f7b64e..8acb8d7efc 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/pinecone.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/pinecone.go @@ -13,8 +13,6 @@ import ( type pineconeProviderInitializer struct{} -const pineconeThreshold = 1000 - func (c *pineconeProviderInitializer) ValidateConfig(config ProviderConfig) error { if len(config.serviceDomain) == 0 { return errors.New("[Pinecone] serviceDomain is required") @@ -51,10 +49,6 @@ func (c *pineconeProvider) GetProviderType() string { return providerTypePinecone } -func (d *pineconeProvider) GetSimilarityThreshold() float64 { - return pineconeThreshold -} - type pineconeMetadata struct { Question string `json:"question"` Answer string `json:"answer"` diff --git a/plugins/wasm-go/extensions/ai-cache/vector/provider.go b/plugins/wasm-go/extensions/ai-cache/vector/provider.go index 62875fc88c..63e344be22 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/provider.go @@ -81,10 +81,6 @@ type StringQuerier interface { callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error } -type SimilarityThresholdProvider interface { - GetSimilarityThreshold() float64 -} - type ProviderConfig struct { // @Title zh-CN 向量存储服务提供者类型 // @Description zh-CN 向量存储服务提供者类型,例如 DashVector、Milvus @@ -107,9 +103,18 @@ type ProviderConfig struct { // @Title zh-CN 请求超时 // @Description zh-CN 请求向量存储服务的超时时间,单位为毫秒。默认值是10000,即10秒 timeout uint32 - // @Title zh-CN DashVector 向量存储服务 Collection ID - // @Description zh-CN DashVector 向量存储服务 Collection ID + // @Title zh-CN 向量存储服务 Collection ID + // @Description zh-CN 向量存储服务的 Collection ID collectionID string + // @Title zh-CN 相似度度量阈值 + // @Description zh-CN 默认相似度度量阈值,默认为 1000。 + Threshold float64 + // @Title zh-CN 相似度度量比较方式 + // @Description zh-CN 相似度度量比较方式,默认为小于。 + // 相似度度量方式有 Cosine, DotProduct, Euclidean 等,前两者值越大相似度越高,后者值越小相似度越高。 + // 所以需要允许自定义比较方式,对于 Cosine 和 DotProduct 选择 gt,对于 Euclidean 则选择 lt。 + // 默认为 lt,所有条件包括 lt (less than,小于)、lte (less than or equal to,小等于)、gt (greater than,大于)、gte (greater than or equal to,大等于) + ThresholdRelation string // ES 配置 // @Title zh-CN ES 用户名 @@ -142,6 +147,14 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { if c.timeout == 0 { c.timeout = 10000 } + c.Threshold = json.Get("threshold").Float() + if c.Threshold == 0 { + c.Threshold = 1000 + } + c.ThresholdRelation = json.Get("thresholdRelation").String() + if c.ThresholdRelation == "" { + c.ThresholdRelation = "lt" + } // ES c.esUsername = json.Get("esUsername").String() @@ -156,6 +169,9 @@ func (c *ProviderConfig) Validate() error { if !has { return errors.New("unknown vector database service provider type: " + c.typ) } + if !isRelationValid(c.ThresholdRelation) { + return errors.New("invalid thresholdRelation: " + c.ThresholdRelation) + } if err := initializer.ValidateConfig(*c); err != nil { return err } @@ -169,3 +185,12 @@ func CreateProvider(pc ProviderConfig) (Provider, error) { } return initializer.CreateProvider(pc) } + +func isRelationValid(relation string) bool { + for _, r := range []string{"lt", "lte", "gt", "gte"} { + if r == relation { + return true + } + } + return false +} diff --git a/plugins/wasm-go/extensions/ai-cache/vector/qdrant.go b/plugins/wasm-go/extensions/ai-cache/vector/qdrant.go index 8e6e0a8979..187a80e07a 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/qdrant.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/qdrant.go @@ -13,8 +13,6 @@ import ( type qdrantProviderInitializer struct{} -const qdrantThreshold = 50 - func (c *qdrantProviderInitializer) ValidateConfig(config ProviderConfig) error { if len(config.serviceName) == 0 { return errors.New("[Qdrant] serviceName is required") @@ -45,10 +43,6 @@ func (c *qdrantProvider) GetProviderType() string { return providerTypeQdrant } -func (d *qdrantProvider) GetSimilarityThreshold() float64 { - return qdrantThreshold -} - type qdrantPayload struct { Question string `json:"question"` Answer string `json:"answer"` diff --git a/plugins/wasm-go/extensions/ai-cache/vector/weaviate.go b/plugins/wasm-go/extensions/ai-cache/vector/weaviate.go index 25b6908e9a..3289183471 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/weaviate.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/weaviate.go @@ -12,8 +12,6 @@ import ( type weaviateProviderInitializer struct{} -const weaviateThreshold = 0.5 - func (c *weaviateProviderInitializer) ValidateConfig(config ProviderConfig) error { if len(config.collectionID) == 0 { return errors.New("[Weaviate] collectionID is required") @@ -44,10 +42,6 @@ func (c *WeaviateProvider) GetProviderType() string { return providerTypeWeaviate } -func (d *WeaviateProvider) GetSimilarityThreshold() float64 { - return weaviateThreshold -} - func (d *WeaviateProvider) QueryEmbedding( emb []float64, ctx wrapper.HttpContext, From 3d615cc31d954a7b9c8acc3f417e5bdfbf621edb Mon Sep 17 00:00:00 2001 From: async Date: Sat, 19 Oct 2024 20:18:55 +0800 Subject: [PATCH 45/71] feat: custom threshold --- plugins/wasm-go/extensions/ai-cache/README.md | 3 +- .../extensions/ai-cache/config/config.go | 4 +++ plugins/wasm-go/extensions/ai-cache/core.go | 33 +++++++++++++------ .../extensions/ai-cache/vector/dashvector.go | 11 ------- 4 files changed, 29 insertions(+), 22 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-cache/README.md b/plugins/wasm-go/extensions/ai-cache/README.md index fcd0faab05..1894a18993 100644 --- a/plugins/wasm-go/extensions/ai-cache/README.md +++ b/plugins/wasm-go/extensions/ai-cache/README.md @@ -50,7 +50,8 @@ LLM 结果缓存插件,默认配置方式可以直接用于 openai 协议的 | vector.topK | int | optional | 1 | 返回TopK结果,默认为 1 | | vector.timeout | uint32 | optional | 10000 | 请求向量存储服务的超时时间,单位为毫秒。默认值是10000,即10秒 | | vector.collectionID | string | optional | "" | DashVector 向量存储服务 Collection ID | - +| vector.threshold | float64 | optional | 1000 | 向量相似度度量阈值 | +| vector.thresholdRelation | string | optional | lt | 相似度度量方式有 `Cosine`, `DotProduct`, `Euclidean` 等,前两者值越大相似度越高,后者值越小相似度越高。对于 `Cosine` 和 `DotProduct` 选择 `gt`,对于 `Euclidean` 则选择 `lt`。默认为 `lt`,所有条件包括 `lt` (less than,小于)、`lte` (less than or equal to,小等于)、`gt` (greater than,大于)、`gte` (greater than or equal to,大等于) | ## 文本向量化服务(embedding) | Name | Type | Requirement | Default | Description | diff --git a/plugins/wasm-go/extensions/ai-cache/config/config.go b/plugins/wasm-go/extensions/ai-cache/config/config.go index ac9a46bdc5..0a22b3c032 100644 --- a/plugins/wasm-go/extensions/ai-cache/config/config.go +++ b/plugins/wasm-go/extensions/ai-cache/config/config.go @@ -154,6 +154,10 @@ func (c *PluginConfig) GetVectorProvider() vector.Provider { return c.vectorProvider } +func (c *PluginConfig) GetVectorProviderConfig() vector.ProviderConfig { + return c.vectorProviderConfig +} + func (c *PluginConfig) GetCacheProvider() cache.Provider { return c.cacheProvider } diff --git a/plugins/wasm-go/extensions/ai-cache/core.go b/plugins/wasm-go/extensions/ai-cache/core.go index 9b679a6947..3ffa493e40 100644 --- a/plugins/wasm-go/extensions/ai-cache/core.go +++ b/plugins/wasm-go/extensions/ai-cache/core.go @@ -153,15 +153,9 @@ func handleQueryResults(key string, results []vector.QueryResult, ctx wrapper.Ht mostSimilarData := results[0] log.Debugf("For key: %s, the most similar key found: %s with score: %f", key, mostSimilarData.Text, mostSimilarData.Score) - - simThresholdProvider, ok := config.GetVectorProvider().(vector.SimilarityThresholdProvider) - if !ok { - handleInternalError(nil, "Active vector provider does not implement SimilarityThresholdProvider interface", log) - return - } - - simThreshold := simThresholdProvider.GetSimilarityThreshold() - if mostSimilarData.Score < simThreshold { + simThreshold := config.GetVectorProviderConfig().Threshold + simThresholdRelation := config.GetVectorProviderConfig().ThresholdRelation + if compare(simThresholdRelation, mostSimilarData.Score, simThreshold) { log.Infof("Key accepted: %s with score: %f below threshold", mostSimilarData.Text, mostSimilarData.Score) if mostSimilarData.Answer != "" { // direct return the answer if available @@ -179,7 +173,7 @@ func handleQueryResults(key string, results []vector.QueryResult, ctx wrapper.Ht proxywasm.ResumeHttpRequest() } } else { - log.Infof("Score too high for key: %s with score: %f above threshold", mostSimilarData.Text, mostSimilarData.Score) + log.Infof("Score not meet the threshold %f: %s with score %f", simThreshold, mostSimilarData.Text, mostSimilarData.Score) proxywasm.ResumeHttpRequest() } } @@ -250,3 +244,22 @@ func uploadEmbeddingAndAnswer(ctx wrapper.HttpContext, config config.PluginConfi } } } + +// 主要用于相似度/距离/点积判断 +// 余弦相似度度量的是两个向量在方向上的相似程度。相似度越高,两个向量越接近。 +// 距离度量的是两个向量在空间上的远近程度。距离越小,两个向量越接近。 +// compare 函数根据操作符进行判断并返回结果 +func compare(operator string, value1 float64, value2 float64) bool { + switch operator { + case "gt": + return value1 > value2 + case "gte": + return value1 >= value2 + case "lt": + return value1 < value2 + case "lte": + return value1 <= value2 + default: + return false + } +} diff --git a/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go b/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go index 8b5848a119..ccafaa4caf 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go @@ -9,10 +9,6 @@ import ( "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" ) -const ( - threshold = 10000 -) - type dashVectorProviderInitializer struct { } @@ -120,9 +116,6 @@ func (d *DvProvider) parseQueryResponse(responseBody []byte) (queryResponse, err return queryResp, nil } -func (d *DvProvider) GetSimThreshold() float64 { - return threshold -} func (d *DvProvider) QueryEmbedding( emb []float64, ctx wrapper.HttpContext, @@ -242,10 +235,6 @@ func (d *DvProvider) UploadEmbedding(queryString string, queryEmb []float64, ctx return err } -func (d *DvProvider) GetSimilarityThreshold() float64 { - return threshold -} - func (d *DvProvider) UploadAnswerAndEmbedding(queryString string, queryEmb []float64, queryAnswer string, ctx wrapper.HttpContext, log wrapper.Log, callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error { url, body, headers, err := d.constructUploadParameters(queryEmb, queryString, queryAnswer) if err != nil { From 558e75e10b55bfb4a03ca4c03fca163091269e05 Mon Sep 17 00:00:00 2001 From: async Date: Sun, 20 Oct 2024 12:23:13 +0800 Subject: [PATCH 46/71] fix: code format --- .../extensions/ai-cache/vector/chroma.go | 25 +++++++++---------- .../extensions/ai-cache/vector/dashvector.go | 19 +------------- .../ai-cache/vector/elasticsearch.go | 9 +++---- .../extensions/ai-cache/vector/milvus.go | 6 ++--- .../extensions/ai-cache/vector/pinecone.go | 4 +-- .../extensions/ai-cache/vector/qdrant.go | 4 +-- .../extensions/ai-cache/vector/weaviate.go | 4 +-- 7 files changed, 26 insertions(+), 45 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-cache/vector/chroma.go b/plugins/wasm-go/extensions/ai-cache/vector/chroma.go index acf49bf4f1..43ff80013a 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/chroma.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/chroma.go @@ -73,14 +73,14 @@ func (d *ChromaProvider) QueryEmbedding( return err } - d.client.Post( + return d.client.Post( fmt.Sprintf("/api/v1/collections/%s/query", d.config.collectionID), [][2]string{ {"Content-Type", "application/json"}, }, requestBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) { - log.Infof("[Chroma] Query embedding response: %d, %s", statusCode, responseBody) + log.Debugf("[Chroma] Query embedding response: %d, %s", statusCode, responseBody) results, err := d.parseQueryResponse(responseBody, log) if err != nil { err = fmt.Errorf("[Chroma] Failed to parse query response: %v", err) @@ -89,7 +89,6 @@ func (d *ChromaProvider) QueryEmbedding( }, d.config.timeout, ) - return errors.New("[Chroma] QueryEmbedding Not implemented") } func (d *ChromaProvider) UploadAnswerAndEmbedding( @@ -142,7 +141,7 @@ func (d *ChromaProvider) UploadAnswerAndEmbedding( }, requestBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) { - log.Infof("[Chroma] statusCode:%d, responseBody:%s", statusCode, string(responseBody)) + log.Debugf("[Chroma] statusCode:%d, responseBody:%s", statusCode, string(responseBody)) callback(ctx, log, err) }, d.config.timeout, @@ -154,14 +153,14 @@ type chromaEmbedding []float64 type chromaMetadataMap map[string]string type chromaInsertRequest struct { Embeddings []chromaEmbedding `json:"embeddings"` - Metadatas []chromaMetadataMap `json:"metadatas,omitempty"` // Optional metadata map array - Documents []string `json:"documents,omitempty"` // Optional document array + Metadatas []chromaMetadataMap `json:"metadatas,omitempty"` // 可选参数 + Documents []string `json:"documents,omitempty"` // 可选参数 IDs []string `json:"ids"` } type chromaQueryRequest struct { - Where map[string]string `json:"where,omitempty"` // Optional where filter - WhereDocument map[string]string `json:"where_document,omitempty"` // Optional where_document filter + Where map[string]string `json:"where,omitempty"` // 可选参数 + WhereDocument map[string]string `json:"where_document,omitempty"` // 可选参数 QueryEmbeddings []chromaEmbedding `json:"query_embeddings"` Limit int `json:"limit"` Include []string `json:"include"` @@ -170,11 +169,11 @@ type chromaQueryRequest struct { type chromaQueryResponse struct { Ids [][]string `json:"ids"` // 第一维是 batch query,第二维是查询到的多个 ids Distances [][]float64 `json:"distances,omitempty"` // 与 Ids 一一对应 - Metadatas []chromaMetadataMap `json:"metadatas,omitempty"` // Optional, can be null - Embeddings []chromaEmbedding `json:"embeddings,omitempty"` // Optional, can be null + Metadatas []chromaMetadataMap `json:"metadatas,omitempty"` // 可选参数 + Embeddings []chromaEmbedding `json:"embeddings,omitempty"` // 可选参数 Documents [][]string `json:"documents,omitempty"` // 与 Ids 一一对应 - Uris []string `json:"uris,omitempty"` // Optional, can be null - Data []interface{} `json:"data,omitempty"` // Optional, can be null + Uris []string `json:"uris,omitempty"` // 可选参数 + Data []interface{} `json:"data,omitempty"` // 可选参数 Included []string `json:"included"` } @@ -185,7 +184,7 @@ func (d *ChromaProvider) parseQueryResponse(responseBody []byte, log wrapper.Log return nil, err } - log.Infof("[Chroma] queryResp Ids len: %d", len(queryResp.Ids)) + log.Debugf("[Chroma] queryResp Ids len: %d", len(queryResp.Ids)) if len(queryResp.Ids) == 1 && len(queryResp.Ids[0]) == 0 { return nil, errors.New("no query results found in response") } diff --git a/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go b/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go index ccafaa4caf..493bd2ab04 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go @@ -48,21 +48,6 @@ func (d *DvProvider) GetProviderType() string { return providerTypeDashVector } -// type embeddingRequest struct { -// Model string `json:"model"` -// Input input `json:"input"` -// Parameters params `json:"parameters"` -// } - -// type params struct { -// TextType string `json:"text_type"` -// } - -// type input struct { -// Texts []string `json:"texts"` -// } - -// queryResponse 定义查询响应的结构 type queryResponse struct { Code int `json:"code"` RequestID string `json:"request_id"` @@ -70,17 +55,15 @@ type queryResponse struct { Output []result `json:"output"` } -// queryRequest 定义查询请求的结构 type queryRequest struct { Vector []float64 `json:"vector"` TopK int `json:"topk"` IncludeVector bool `json:"include_vector"` } -// result 定义查询结果的结构 type result struct { ID string `json:"id"` - Vector []float64 `json:"vector,omitempty"` // omitempty 使得如果 vector 是空,它将不会被序列化 + Vector []float64 `json:"vector,omitempty"` // 如果 vector 是空,vecotr 字段将不会被序列化 Fields map[string]interface{} `json:"fields"` Score float64 `json:"score"` } diff --git a/plugins/wasm-go/extensions/ai-cache/vector/elasticsearch.go b/plugins/wasm-go/extensions/ai-cache/vector/elasticsearch.go index 8b52852525..1e0d19d4a3 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/elasticsearch.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/elasticsearch.go @@ -71,7 +71,7 @@ func (d *ESProvider) QueryEmbedding( }, requestBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) { - log.Infof("[ES] Query embedding response: %d, %s", statusCode, responseBody) + log.Debugf("[ES] Query embedding response: %d, %s", statusCode, responseBody) results, err := d.parseQueryResponse(responseBody, log) if err != nil { err = fmt.Errorf("[ES] Failed to parse query response: %v", err) @@ -82,7 +82,7 @@ func (d *ESProvider) QueryEmbedding( ) } -// 编码 ES 身份认证字符串 +// base64 编码 ES 身份认证字符串 func (d *ESProvider) getCredentials() string { credentials := fmt.Sprintf("%s:%s", d.config.esUsername, d.config.esPassword) encodedCredentials := base64.StdEncoding.EncodeToString([]byte(credentials)) @@ -125,7 +125,7 @@ func (d *ESProvider) UploadAnswerAndEmbedding( }, requestBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) { - log.Infof("[ES] statusCode:%d, responseBody:%s", statusCode, string(responseBody)) + log.Debugf("[ES] statusCode:%d, responseBody:%s", statusCode, string(responseBody)) callback(ctx, log, err) }, d.config.timeout, @@ -154,7 +154,6 @@ type esQueryRequest struct { Size int `json:"size"` } -// esQueryResponse represents the search result structure. type esQueryResponse struct { Took int `json:"took"` TimedOut bool `json:"timed_out"` @@ -179,7 +178,7 @@ func (d *ESProvider) parseQueryResponse(responseBody []byte, log wrapper.Log) ([ if err != nil { return []QueryResult{}, err } - log.Infof("[ES] queryResp Hits len: %d", len(queryResp.Hits.Hits)) + log.Debugf("[ES] queryResp Hits len: %d", len(queryResp.Hits.Hits)) if len(queryResp.Hits.Hits) == 0 { return nil, errors.New("no query results found in response") } diff --git a/plugins/wasm-go/extensions/ai-cache/vector/milvus.go b/plugins/wasm-go/extensions/ai-cache/vector/milvus.go index b2043e30a5..f057c33b88 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/milvus.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/milvus.go @@ -62,7 +62,7 @@ func (d *milvusProvider) UploadAnswerAndEmbedding( callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error { // 最少需要填写的参数为 collectionName, data 和 Authorization. question, answer 可选 // 需要填写 id,否则 v2.4.13-hotfix 提示 invalid syntax: invalid parameter[expected=Int64][actual=] - // 如果不填写 id,一定要在创建 collection 的时候设置 autoId 为 true + // 如果不填写 id,要在创建 collection 的时候设置 autoId 为 true // 下面是一个例子 // { // "collectionName": "higress", @@ -102,7 +102,7 @@ func (d *milvusProvider) UploadAnswerAndEmbedding( }, requestBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) { - log.Infof("[Milvus] statusCode:%d, responseBody:%s", statusCode, string(responseBody)) + log.Debugf("[Milvus] statusCode:%d, responseBody:%s", statusCode, string(responseBody)) callback(ctx, log, err) }, d.config.timeout, @@ -164,7 +164,7 @@ func (d *milvusProvider) QueryEmbedding( }, requestBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) { - log.Infof("[Milvus] Query embedding response: %d, %s", statusCode, responseBody) + log.Debugf("[Milvus] Query embedding response: %d, %s", statusCode, responseBody) results, err := d.parseQueryResponse(responseBody, log) if err != nil { err = fmt.Errorf("[Milvus] Failed to parse query response: %v", err) diff --git a/plugins/wasm-go/extensions/ai-cache/vector/pinecone.go b/plugins/wasm-go/extensions/ai-cache/vector/pinecone.go index 8acb8d7efc..51b561c66a 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/pinecone.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/pinecone.go @@ -107,7 +107,7 @@ func (d *pineconeProvider) UploadAnswerAndEmbedding( }, requestBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) { - log.Infof("[Pinecone] statusCode:%d, responseBody:%s", statusCode, string(responseBody)) + log.Debugf("[Pinecone] statusCode:%d, responseBody:%s", statusCode, string(responseBody)) callback(ctx, log, err) }, d.config.timeout, @@ -155,7 +155,7 @@ func (d *pineconeProvider) QueryEmbedding( }, requestBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) { - log.Infof("[Pinecone] Query embedding response: %d, %s", statusCode, responseBody) + log.Debugf("[Pinecone] Query embedding response: %d, %s", statusCode, responseBody) results, err := d.parseQueryResponse(responseBody, log) if err != nil { err = fmt.Errorf("[Pinecone] Failed to parse query response: %v", err) diff --git a/plugins/wasm-go/extensions/ai-cache/vector/qdrant.go b/plugins/wasm-go/extensions/ai-cache/vector/qdrant.go index 187a80e07a..8f402373c0 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/qdrant.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/qdrant.go @@ -106,7 +106,7 @@ func (d *qdrantProvider) UploadAnswerAndEmbedding( }, requestBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) { - log.Infof("[Qdrant] statusCode:%d, responseBody:%s", statusCode, string(responseBody)) + log.Debugf("[Qdrant] statusCode:%d, responseBody:%s", statusCode, string(responseBody)) callback(ctx, log, err) }, d.config.timeout, @@ -153,7 +153,7 @@ func (d *qdrantProvider) QueryEmbedding( }, requestBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) { - log.Infof("[Qdrant] Query embedding response: %d, %s", statusCode, responseBody) + log.Debugf("[Qdrant] Query embedding response: %d, %s", statusCode, responseBody) results, err := d.parseQueryResponse(responseBody, log) if err != nil { err = fmt.Errorf("[Qdrant] Failed to parse query response: %v", err) diff --git a/plugins/wasm-go/extensions/ai-cache/vector/weaviate.go b/plugins/wasm-go/extensions/ai-cache/vector/weaviate.go index 3289183471..61fd355f53 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/weaviate.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/weaviate.go @@ -91,7 +91,7 @@ func (d *WeaviateProvider) QueryEmbedding( }, requestBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) { - log.Infof("[Weaviate] Query embedding response: %d, %s", statusCode, responseBody) + log.Debugf("[Weaviate] Query embedding response: %d, %s", statusCode, responseBody) results, err := d.parseQueryResponse(responseBody, log) if err != nil { err = fmt.Errorf("[Weaviate] Failed to parse query response: %v", err) @@ -131,7 +131,7 @@ func (d *WeaviateProvider) UploadAnswerAndEmbedding( }, requestBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) { - log.Infof("[Weaviate] statusCode: %d, responseBody: %s", statusCode, string(responseBody)) + log.Debugf("[Weaviate] statusCode: %d, responseBody: %s", statusCode, string(responseBody)) callback(ctx, log, err) }, d.config.timeout, From 2cfcda6f0100d5211b5f53cb3db2da8ee9f5f5bb Mon Sep 17 00:00:00 2001 From: suchun Date: Sun, 20 Oct 2024 13:48:10 +0000 Subject: [PATCH 47/71] add ai cache test --- .../e2e/conformance/tests/go-wasm-ai-cache.go | 115 ++++++++++++++++++ .../conformance/tests/go-wasm-ai-cache.yaml | 102 ++++++++++++++++ 2 files changed, 217 insertions(+) create mode 100644 test/e2e/conformance/tests/go-wasm-ai-cache.go create mode 100644 test/e2e/conformance/tests/go-wasm-ai-cache.yaml diff --git a/test/e2e/conformance/tests/go-wasm-ai-cache.go b/test/e2e/conformance/tests/go-wasm-ai-cache.go new file mode 100644 index 0000000000..dc71a6efd4 --- /dev/null +++ b/test/e2e/conformance/tests/go-wasm-ai-cache.go @@ -0,0 +1,115 @@ +// Copyright (c) 2022 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "testing" + + "github.com/alibaba/higress/test/e2e/conformance/utils/http" + "github.com/alibaba/higress/test/e2e/conformance/utils/suite" +) + +func init() { + Register(WasmPluginsAiCache) +} + +var WasmPluginsAiCache = suite.ConformanceTest{ + ShortName: "WasmPluginAiCache", + Description: "The Ingress in the higress-conformance-infra namespace test the ai-cache WASM plugin.", + Features: []suite.SupportedFeature{suite.WASMGoConformanceFeature}, + Manifests: []string{"tests/go-wasm-ai-cache.yaml"}, + Test: func(t *testing.T, suite *suite.ConformanceTestSuite) { + testcases := []http.Assertion{ + { + Meta: http.AssertionMeta{ + TestCaseName: "case 1: openai", + TargetBackend: "infra-backend-v1", + TargetNamespace: "higress-conformance-infra", + }, + Request: http.AssertionRequest{ + ActualRequest: http.Request{ + Host: "openai.ai.com", + Path: "/v1/chat/completions", + Method:"POST", + ContentType: http.ContentTypeApplicationJson, + Body: []byte(`{ + "model": "gpt-3", + "messages": [{"role":"user","content":"hi"}]}`), + }, + ExpectedRequest: &http.ExpectedRequest{ + Request: http.Request{ + Host: "api.openai.com", + Path: "/v1/chat/completions", + Method: "POST", + ContentType: http.ContentTypeApplicationJson, + Body: []byte(`{ + "model": "gpt-3", + "messages": [{"role":"user","content":"hi"}], + "max_tokens": 123, + "temperature": 0.66}`), + }, + }, + }, + Response: http.AssertionResponse{ + ExpectedResponse: http.Response{ + StatusCode: 200, + }, + }, + }, + { + Meta: http.AssertionMeta{ + TestCaseName: "case 2: qwen", + TargetBackend: "infra-backend-v1", + TargetNamespace: "higress-conformance-infra", + }, + Request: http.AssertionRequest{ + ActualRequest: http.Request{ + Host: "qwen.ai.com", + Path: "/v1/chat/completions", + Method:"POST", + ContentType: http.ContentTypeApplicationJson, + Body: []byte(`{ + "model": "qwen-long", + "input": {"messages": [{"role":"user","content":"hi"}]}, + "parameters": {"max_tokens": 321, "temperature": 0.7}}`), + }, + ExpectedRequest: &http.ExpectedRequest{ + Request: http.Request{ + Host: "dashscope.aliyuncs.com", + Path: "/api/v1/services/aigc/text-generation/generation", + Method: "POST", + ContentType: http.ContentTypeApplicationJson, + Body: []byte(`{ + "model": "qwen-long", + "input": {"messages": [{"role":"user","content":"hi"}]}, + "parameters": {"max_tokens": 321, "temperature": 0.66}}`), + }, + }, + }, + Response: http.AssertionResponse{ + ExpectedResponse: http.Response{ + StatusCode: 500, + }, + }, + }, + + } + t.Run("WasmPlugins ai-cache", func(t *testing.T) { + for _, testcase := range testcases { + http.MakeRequestAndExpectEventuallyConsistentResponse(t, suite.RoundTripper, suite.TimeoutConfig, suite.GatewayAddress, testcase) + } + }) + }, +} diff --git a/test/e2e/conformance/tests/go-wasm-ai-cache.yaml b/test/e2e/conformance/tests/go-wasm-ai-cache.yaml new file mode 100644 index 0000000000..53db5d3e8b --- /dev/null +++ b/test/e2e/conformance/tests/go-wasm-ai-cache.yaml @@ -0,0 +1,102 @@ +# Copyright (c) 2022 Alibaba Group Holding Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +apiVersion: networking.k8s.io/v1 +kind: Ingress +metadata: + annotations: + name: wasmplugin-ai-cache-openai + namespace: higress-conformance-infra +spec: + ingressClassName: higress + rules: + - host: "openai.ai.com" + http: + paths: + - pathType: Prefix + path: "/" + backend: + service: + name: infra-backend-v1 + port: + number: 8080 +--- +apiVersion: networking.k8s.io/v1 +kind: Ingress +metadata: + annotations: + name: wasmplugin-ai-cache-qwen + namespace: higress-conformance-infra +spec: + ingressClassName: higress + rules: + - host: "qwen.ai.com" + http: + paths: + - pathType: Prefix + path: "/" + backend: + service: + name: infra-backend-v1 + port: + number: 8080 +--- +apiVersion: extensions.higress.io/v1alpha1 +kind: WasmPlugin +metadata: + name: ai-cache + namespace: higress-system +spec: + priority: 400 + matchRules: + - config: + embedding: + type: "dashscope" + serviceName: "qwen" + apiKey: "{{Github.QwenApiKey}}" + timeout: 12000 + vector: + type: "dashvector" + serviceName: "dashvector" + collectionID: "{{Github.DashVectorCollectionID}}" + serviceDomain: "{{Github.DashVectorServiceDomain}}" + apiKey: "{{Github.DashVectorApiKey}}" + timeout: 12000 + cache: + + ingress: + - higress-conformance-infra/wasmplugin-ai-cache-openai + - higress-conformance-infra/wasmplugin-ai-cache-qwen + url: file:///opt/plugins/wasm-go/extensions/ai-cache/plugin.wasm +--- +apiVersion: extensions.higress.io/v1alpha1 +kind: WasmPlugin +metadata: + name: ai-proxy + namespace: higress-system +spec: + priority: 201 + matchRules: + - config: + provider: + type: "qwen" + qwenEnableCompatible: true + apiTokens: + - "{{Github.QwenKey}}" + timeout: 1200000 + modelMapping: + "*": "qwen-long" + ingress: + - higress-conformance-infra/wasmplugin-ai-cache-openai + - higress-conformance-infra/wasmplugin-ai-cache-qwen + url: oci://higress-registry.cn-hangzhou.cr.aliyuncs.com/plugins/ai-proxy:1.0.0 From 4caf9be83baba360ff94458a86e750e27f753bfb Mon Sep 17 00:00:00 2001 From: suchun Date: Sun, 20 Oct 2024 18:46:10 +0000 Subject: [PATCH 48/71] update test --- .../e2e/conformance/tests/go-wasm-ai-cache.go | 59 ++++--------------- .../conformance/tests/go-wasm-ai-cache.yaml | 15 ++--- 2 files changed, 18 insertions(+), 56 deletions(-) diff --git a/test/e2e/conformance/tests/go-wasm-ai-cache.go b/test/e2e/conformance/tests/go-wasm-ai-cache.go index dc71a6efd4..30ac248916 100644 --- a/test/e2e/conformance/tests/go-wasm-ai-cache.go +++ b/test/e2e/conformance/tests/go-wasm-ai-cache.go @@ -34,77 +34,38 @@ var WasmPluginsAiCache = suite.ConformanceTest{ testcases := []http.Assertion{ { Meta: http.AssertionMeta{ - TestCaseName: "case 1: openai", + TestCaseName: "case 1: basic", TargetBackend: "infra-backend-v1", TargetNamespace: "higress-conformance-infra", }, Request: http.AssertionRequest{ ActualRequest: http.Request{ - Host: "openai.ai.com", - Path: "/v1/chat/completions", - Method:"POST", - ContentType: http.ContentTypeApplicationJson, - Body: []byte(`{ - "model": "gpt-3", - "messages": [{"role":"user","content":"hi"}]}`), - }, - ExpectedRequest: &http.ExpectedRequest{ - Request: http.Request{ - Host: "api.openai.com", - Path: "/v1/chat/completions", - Method: "POST", - ContentType: http.ContentTypeApplicationJson, - Body: []byte(`{ - "model": "gpt-3", - "messages": [{"role":"user","content":"hi"}], - "max_tokens": 123, - "temperature": 0.66}`), - }, - }, - }, - Response: http.AssertionResponse{ - ExpectedResponse: http.Response{ - StatusCode: 200, - }, - }, - }, - { - Meta: http.AssertionMeta{ - TestCaseName: "case 2: qwen", - TargetBackend: "infra-backend-v1", - TargetNamespace: "higress-conformance-infra", - }, - Request: http.AssertionRequest{ - ActualRequest: http.Request{ - Host: "qwen.ai.com", - Path: "/v1/chat/completions", - Method:"POST", - ContentType: http.ContentTypeApplicationJson, + Host: "dashscope.aliyuncs.com", + Path: "/v1/chat/completions", + Method: "POST", + ContentType: http.ContentTypeApplicationJson, Body: []byte(`{ "model": "qwen-long", - "input": {"messages": [{"role":"user","content":"hi"}]}, - "parameters": {"max_tokens": 321, "temperature": 0.7}}`), + "messages": [{"role":"user","content":"hi"}]}`), }, ExpectedRequest: &http.ExpectedRequest{ Request: http.Request{ Host: "dashscope.aliyuncs.com", - Path: "/api/v1/services/aigc/text-generation/generation", + Path: "/compatible-mode/v1/chat/completions", Method: "POST", ContentType: http.ContentTypeApplicationJson, Body: []byte(`{ - "model": "qwen-long", - "input": {"messages": [{"role":"user","content":"hi"}]}, - "parameters": {"max_tokens": 321, "temperature": 0.66}}`), + "model": "qwen-long", + "messages": [{"role":"user","content":"hi"}]}`), }, }, }, Response: http.AssertionResponse{ ExpectedResponse: http.Response{ - StatusCode: 500, + StatusCode: 200, }, }, }, - } t.Run("WasmPlugins ai-cache", func(t *testing.T) { for _, testcase := range testcases { diff --git a/test/e2e/conformance/tests/go-wasm-ai-cache.yaml b/test/e2e/conformance/tests/go-wasm-ai-cache.yaml index 53db5d3e8b..c7d6b0c46b 100644 --- a/test/e2e/conformance/tests/go-wasm-ai-cache.yaml +++ b/test/e2e/conformance/tests/go-wasm-ai-cache.yaml @@ -20,7 +20,7 @@ metadata: spec: ingressClassName: higress rules: - - host: "openai.ai.com" + - host: "dashscope.aliyuncs.com" http: paths: - pathType: Prefix @@ -63,21 +63,22 @@ spec: embedding: type: "dashscope" serviceName: "qwen" - apiKey: "{{Github.QwenApiKey}}" + apiKey: "{{secret.qwenApiKey}}" timeout: 12000 vector: type: "dashvector" serviceName: "dashvector" - collectionID: "{{Github.DashVectorCollectionID}}" - serviceDomain: "{{Github.DashVectorServiceDomain}}" - apiKey: "{{Github.DashVectorApiKey}}" + collectionID: "{{secret.collectionID}}" + serviceDomain: "{{secret.serviceDomain}}" + apiKey: "{{secret.apiKey}}" timeout: 12000 cache: ingress: - higress-conformance-infra/wasmplugin-ai-cache-openai - higress-conformance-infra/wasmplugin-ai-cache-qwen - url: file:///opt/plugins/wasm-go/extensions/ai-cache/plugin.wasm + # url: file:///opt/plugins/wasm-go/extensions/ai-cache/plugin.wasm + url: oci://registry.cn-shanghai.aliyuncs.com/suchunsv/higress_ai:1.18 --- apiVersion: extensions.higress.io/v1alpha1 kind: WasmPlugin @@ -92,7 +93,7 @@ spec: type: "qwen" qwenEnableCompatible: true apiTokens: - - "{{Github.QwenKey}}" + - "{{secret.qwenApiKey}}" timeout: 1200000 modelMapping: "*": "qwen-long" From d04d78a63453d6cc5b80da2d3fc86036484658aa Mon Sep 17 00:00:00 2001 From: suchun Date: Thu, 24 Oct 2024 00:08:45 +0000 Subject: [PATCH 49/71] fix bugs --- plugins/wasm-go/extensions/ai-cache/README.md | 2 +- .../extensions/ai-cache/config/config.go | 12 +++- plugins/wasm-go/extensions/ai-cache/core.go | 60 +++++++++---------- plugins/wasm-go/extensions/ai-cache/main.go | 40 ++++++------- plugins/wasm-go/extensions/ai-cache/util.go | 29 +++++---- 5 files changed, 78 insertions(+), 65 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-cache/README.md b/plugins/wasm-go/extensions/ai-cache/README.md index 1894a18993..a788bc9257 100644 --- a/plugins/wasm-go/extensions/ai-cache/README.md +++ b/plugins/wasm-go/extensions/ai-cache/README.md @@ -34,7 +34,7 @@ LLM 结果缓存插件,默认配置方式可以直接用于 openai 协议的 | vector.type | string | optional | "" | 向量存储服务提供者类型,例如 DashVector | | embedding.type | string | optional | "" | 请求文本向量化服务类型,例如 DashScope | | cache.type | string | optional | "" | 缓存服务类型,例如 redis | -| cacheKeyStrategy | string | optional | "lastQuestion" | 决定如何根据历史问题生成缓存键的策略。可选值: "lastQuestion" (使用最后一个问题), "allQuestions" (拼接所有问题) 或 "disable" (禁用缓存) | +| cacheKeyStrategy | string | optional | "lastQuestion" | 决定如何根据历史问题生成缓存键的策略。可选值: "lastQuestion" (使用最后一个问题), "allQuestions" (拼接所有问题) 或 "disabled" (禁用缓存) | | enableSemanticCache | bool | optional | true | 是否启用语义化缓存, 若不启用,则使用字符串匹配的方式来查找缓存,此时需要配置cache服务 | 以下是vector、embedding、cache的具体配置说明,注意若不配置embedding或cache服务,则可忽略以下相应配置中的 `required` 字段。 diff --git a/plugins/wasm-go/extensions/ai-cache/config/config.go b/plugins/wasm-go/extensions/ai-cache/config/config.go index 0a22b3c032..87e27452db 100644 --- a/plugins/wasm-go/extensions/ai-cache/config/config.go +++ b/plugins/wasm-go/extensions/ai-cache/config/config.go @@ -10,6 +10,12 @@ import ( "github.com/tidwall/gjson" ) +const ( + CACHE_KEY_STRATEGY_LAST_QUESTION = "lastQuestion" + CACHE_KEY_STRATEGY_ALL_QUESTIONS = "allQuestions" + CACHE_KEY_STRATEGY_DISABLED = "disabled" +) + type PluginConfig struct { // @Title zh-CN 返回 HTTP 响应的模版 // @Description zh-CN 用 %s 标记需要被 cache value 替换的部分 @@ -36,7 +42,7 @@ type PluginConfig struct { EnableSemanticCache bool // @Title zh-CN 缓存键策略 - // @Description zh-CN 决定如何生成缓存键的策略。可选值: "lastQuestion" (使用最后一个问题), "allQuestions" (拼接所有问题) 或 "disable" (禁用缓存) + // @Description zh-CN 决定如何生成缓存键的策略。可选值: "lastQuestion" (使用最后一个问题), "allQuestions" (拼接所有问题) 或 "disabled" (禁用缓存) CacheKeyStrategy string } @@ -107,7 +113,9 @@ func (c *PluginConfig) Validate() error { } // 验证 CacheKeyStrategy 的值 - if c.CacheKeyStrategy != "lastQuestion" && c.CacheKeyStrategy != "allQuestions" && c.CacheKeyStrategy != "disable" { + if c.CacheKeyStrategy != CACHE_KEY_STRATEGY_LAST_QUESTION && + c.CacheKeyStrategy != CACHE_KEY_STRATEGY_ALL_QUESTIONS && + c.CacheKeyStrategy != CACHE_KEY_STRATEGY_DISABLED { return fmt.Errorf("invalid CacheKeyStrategy: %s", c.CacheKeyStrategy) } // 如果启用了语义化缓存,确保必要的组件已配置 diff --git a/plugins/wasm-go/extensions/ai-cache/core.go b/plugins/wasm-go/extensions/ai-cache/core.go index 3ffa493e40..d3e9c81b5f 100644 --- a/plugins/wasm-go/extensions/ai-cache/core.go +++ b/plugins/wasm-go/extensions/ai-cache/core.go @@ -12,18 +12,18 @@ import ( ) // CheckCacheForKey checks if the key is in the cache, or triggers similarity search if not found. -func CheckCacheForKey(key string, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, stream bool, useSimilaritySearch bool) error { - activeCacheProvider := config.GetCacheProvider() +func CheckCacheForKey(key string, ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log, stream bool, useSimilaritySearch bool) error { + activeCacheProvider := c.GetCacheProvider() if activeCacheProvider == nil { log.Debug("No cache provider configured, performing similarity search") - return performSimilaritySearch(key, ctx, config, log, key, stream) + return performSimilaritySearch(key, ctx, c, log, key, stream) } queryKey := activeCacheProvider.GetCacheKeyPrefix() + key log.Debugf("Querying cache with key: %s", queryKey) err := activeCacheProvider.Get(queryKey, func(response resp.Value) { - handleCacheResponse(key, response, ctx, log, stream, config, useSimilaritySearch) + handleCacheResponse(key, response, ctx, log, stream, c, useSimilaritySearch) }) if err != nil { @@ -35,10 +35,10 @@ func CheckCacheForKey(key string, ctx wrapper.HttpContext, config config.PluginC } // handleCacheResponse processes cache response and handles cache hits and misses. -func handleCacheResponse(key string, response resp.Value, ctx wrapper.HttpContext, log wrapper.Log, stream bool, config config.PluginConfig, useSimilaritySearch bool) { +func handleCacheResponse(key string, response resp.Value, ctx wrapper.HttpContext, log wrapper.Log, stream bool, c config.PluginConfig, useSimilaritySearch bool) { if err := response.Error(); err == nil && !response.IsNull() { log.Infof("Cache hit for key: %s", key) - processCacheHit(key, response.String(), stream, ctx, config, log) + processCacheHit(key, response.String(), stream, ctx, c, log) return } @@ -47,8 +47,8 @@ func handleCacheResponse(key string, response resp.Value, ctx wrapper.HttpContex log.Errorf("Error retrieving key: %s from cache, error: %v", key, err) } - if useSimilaritySearch && config.EnableSemanticCache { - if err := performSimilaritySearch(key, ctx, config, log, key, stream); err != nil { + if useSimilaritySearch && c.EnableSemanticCache { + if err := performSimilaritySearch(key, ctx, c, log, key, stream); err != nil { log.Errorf("Failed to perform similarity search for key: %s, error: %v", key, err) proxywasm.ResumeHttpRequest() } @@ -58,7 +58,7 @@ func handleCacheResponse(key string, response resp.Value, ctx wrapper.HttpContex } // processCacheHit handles a successful cache hit. -func processCacheHit(key string, response string, stream bool, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log) { +func processCacheHit(key string, response string, stream bool, ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log) { if stream { log.Debug("streaming response is not supported for cache hit yet") stream = false @@ -75,49 +75,49 @@ func processCacheHit(key string, response string, stream bool, ctx wrapper.HttpC ctx.SetContext(CACHE_KEY_CONTEXT_KEY, nil) // proxywasm.SendHttpResponseWithDetail(200, "ai-cache.hit", [][2]string{{"content-type", "application/json; charset=utf-8"}}, []byte(fmt.Sprintf(config.ResponseTemplate, escapedResponse)), -1) - proxywasm.SendHttpResponseWithDetail(200, "ai-cache.hit", [][2]string{{"content-type", "application/json; charset=utf-8"}}, []byte(fmt.Sprintf(config.ResponseTemplate, response)), -1) + proxywasm.SendHttpResponseWithDetail(200, "ai-cache.hit", [][2]string{{"content-type", "application/json; charset=utf-8"}}, []byte(fmt.Sprintf(c.ResponseTemplate, response)), -1) } // performSimilaritySearch determines the appropriate similarity search method to use. -func performSimilaritySearch(key string, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, queryString string, stream bool) error { - activeVectorProvider := config.GetVectorProvider() +func performSimilaritySearch(key string, ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log, queryString string, stream bool) error { + activeVectorProvider := c.GetVectorProvider() if activeVectorProvider == nil { return errors.New("no vector provider configured for similarity search") } // Check if the active vector provider implements the StringQuerier interface. if _, ok := activeVectorProvider.(vector.StringQuerier); ok { - return performStringQuery(key, queryString, ctx, config, log, stream) + return performStringQuery(key, queryString, ctx, c, log, stream) } // Check if the active vector provider implements the EmbeddingQuerier interface. if _, ok := activeVectorProvider.(vector.EmbeddingQuerier); ok { - return performEmbeddingQuery(key, ctx, config, log, stream) + return performEmbeddingQuery(key, ctx, c, log, stream) } return errors.New("no suitable querier or embedding provider available for similarity search") } // performStringQuery executes the string-based similarity search. -func performStringQuery(key string, queryString string, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, stream bool) error { - stringQuerier, ok := config.GetVectorProvider().(vector.StringQuerier) +func performStringQuery(key string, queryString string, ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log, stream bool) error { + stringQuerier, ok := c.GetVectorProvider().(vector.StringQuerier) if !ok { return logAndReturnError(log, "active vector provider does not implement StringQuerier interface") } return stringQuerier.QueryString(queryString, ctx, log, func(results []vector.QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error) { - handleQueryResults(key, results, ctx, log, stream, config, err) + handleQueryResults(key, results, ctx, log, stream, c, err) }) } // performEmbeddingQuery executes the embedding-based similarity search. -func performEmbeddingQuery(key string, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, stream bool) error { - embeddingQuerier, ok := config.GetVectorProvider().(vector.EmbeddingQuerier) +func performEmbeddingQuery(key string, ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log, stream bool) error { + embeddingQuerier, ok := c.GetVectorProvider().(vector.EmbeddingQuerier) if !ok { return logAndReturnError(log, "active vector provider does not implement EmbeddingQuerier interface") } - activeEmbeddingProvider := config.GetEmbeddingProvider() + activeEmbeddingProvider := c.GetEmbeddingProvider() if activeEmbeddingProvider == nil { return logAndReturnError(log, "no embedding provider configured for similarity search") } @@ -130,7 +130,7 @@ func performEmbeddingQuery(key string, ctx wrapper.HttpContext, config config.Pl ctx.SetContext(CACHE_KEY_EMBEDDING_KEY, textEmbedding) err = embeddingQuerier.QueryEmbedding(textEmbedding, ctx, log, func(results []vector.QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error) { - handleQueryResults(key, results, ctx, log, stream, config, err) + handleQueryResults(key, results, ctx, log, stream, c, err) }) if err != nil { handleInternalError(err, fmt.Sprintf("Error querying vector database for key: %s", key), log) @@ -139,7 +139,7 @@ func performEmbeddingQuery(key string, ctx wrapper.HttpContext, config config.Pl } // handleQueryResults processes the results of similarity search and determines next actions. -func handleQueryResults(key string, results []vector.QueryResult, ctx wrapper.HttpContext, log wrapper.Log, stream bool, config config.PluginConfig, err error) { +func handleQueryResults(key string, results []vector.QueryResult, ctx wrapper.HttpContext, log wrapper.Log, stream bool, c config.PluginConfig, err error) { if err != nil { handleInternalError(err, fmt.Sprintf("Error querying vector database for key: %s", key), log) return @@ -153,13 +153,13 @@ func handleQueryResults(key string, results []vector.QueryResult, ctx wrapper.Ht mostSimilarData := results[0] log.Debugf("For key: %s, the most similar key found: %s with score: %f", key, mostSimilarData.Text, mostSimilarData.Score) - simThreshold := config.GetVectorProviderConfig().Threshold - simThresholdRelation := config.GetVectorProviderConfig().ThresholdRelation + simThreshold := c.GetVectorProviderConfig().Threshold + simThresholdRelation := c.GetVectorProviderConfig().ThresholdRelation if compare(simThresholdRelation, mostSimilarData.Score, simThreshold) { log.Infof("Key accepted: %s with score: %f below threshold", mostSimilarData.Text, mostSimilarData.Score) if mostSimilarData.Answer != "" { // direct return the answer if available - processCacheHit(key, mostSimilarData.Answer, stream, ctx, config, log) + processCacheHit(key, mostSimilarData.Answer, stream, ctx, c, log) } else { // // otherwise, continue to check cache for the most similar key // err = CheckCacheForKey(mostSimilarData.Text, ctx, config, log, stream, false) @@ -169,7 +169,7 @@ func handleQueryResults(key string, results []vector.QueryResult, ctx wrapper.Ht // } // Otherwise, do not check the cache, directly return - log.Warnf("No cache hit for key: %s, however, no answer found in vector database", mostSimilarData.Text) + log.Warnf("Cache hit for key: %s, but no corresponding answer found in the vector database", mostSimilarData.Text) proxywasm.ResumeHttpRequest() } } else { @@ -196,8 +196,8 @@ func handleInternalError(err error, message string, log wrapper.Log) { } // Caches the response value -func cacheResponse(ctx wrapper.HttpContext, config config.PluginConfig, key string, value string, log wrapper.Log) { - activeCacheProvider := config.GetCacheProvider() +func cacheResponse(ctx wrapper.HttpContext, c config.PluginConfig, key string, value string, log wrapper.Log) { + activeCacheProvider := c.GetCacheProvider() if activeCacheProvider != nil { queryKey := activeCacheProvider.GetCacheKeyPrefix() + key log.Infof("[onHttpResponseBody] setting cache to redis, key: %s, value: %s", queryKey, value) @@ -206,7 +206,7 @@ func cacheResponse(ctx wrapper.HttpContext, config config.PluginConfig, key stri } // Handles embedding upload if available -func uploadEmbeddingAndAnswer(ctx wrapper.HttpContext, config config.PluginConfig, key string, value string, log wrapper.Log) { +func uploadEmbeddingAndAnswer(ctx wrapper.HttpContext, c config.PluginConfig, key string, value string, log wrapper.Log) { embedding := ctx.GetContext(CACHE_KEY_EMBEDDING_KEY) if embedding == nil { return @@ -218,7 +218,7 @@ func uploadEmbeddingAndAnswer(ctx wrapper.HttpContext, config config.PluginConfi return } - activeVectorProvider := config.GetVectorProvider() + activeVectorProvider := c.GetVectorProvider() if activeVectorProvider == nil { log.Debug("[onHttpResponseBody] no vector provider configured for uploading embedding") return diff --git a/plugins/wasm-go/extensions/ai-cache/main.go b/plugins/wasm-go/extensions/ai-cache/main.go index 5025ab0e70..eea6bb05ee 100644 --- a/plugins/wasm-go/extensions/ai-cache/main.go +++ b/plugins/wasm-go/extensions/ai-cache/main.go @@ -34,23 +34,23 @@ func main() { ) } -func parseConfig(json gjson.Result, config *config.PluginConfig, log wrapper.Log) error { +func parseConfig(json gjson.Result, c *config.PluginConfig, log wrapper.Log) error { // config.EmbeddingProviderConfig.FromJson(json.Get("embeddingProvider")) // config.VectorDatabaseProviderConfig.FromJson(json.Get("vectorBaseProvider")) // config.RedisConfig.FromJson(json.Get("redis")) - config.FromJson(json) - if err := config.Validate(); err != nil { + c.FromJson(json) + if err := c.Validate(); err != nil { return err } // 注意,在 parseConfig 阶段初始化 client 会出错,比如 docker compose 中的 redis 就无法使用 - if err := config.Complete(log); err != nil { + if err := c.Complete(log); err != nil { log.Errorf("complete config failed: %v", err) return err } return nil } -func onHttpRequestHeaders(ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log) types.Action { +func onHttpRequestHeaders(ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log) types.Action { // 这段代码是为了测试,在 parseConfig 阶段初始化 client 会出错,比如 docker compose 中的 redis 就无法使用 // 但是在 onHttpRequestHeaders 中可以连接到 redis、 // 修复需要修改 envoy @@ -82,7 +82,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config config.PluginConfig, l return types.HeaderStopIteration } -func onHttpRequestBody(ctx wrapper.HttpContext, config config.PluginConfig, body []byte, log wrapper.Log) types.Action { +func onHttpRequestBody(ctx wrapper.HttpContext, c config.PluginConfig, body []byte, log wrapper.Log) types.Action { bodyJson := gjson.ParseBytes(body) // TODO: It may be necessary to support stream mode determination for different LLM providers. @@ -93,9 +93,9 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config config.PluginConfig, body } var key string - if config.CacheKeyStrategy == "lastQuestion" { + if c.CacheKeyStrategy == config.CACHE_KEY_STRATEGY_LAST_QUESTION { key = bodyJson.Get("messages.@reverse.0.content").String() - } else if config.CacheKeyStrategy == "allQuestions" { + } else if c.CacheKeyStrategy == config.CACHE_KEY_STRATEGY_ALL_QUESTIONS { // Retrieve all user messages and concatenate them messages := bodyJson.Get("messages").Array() var userMessages []string @@ -104,13 +104,13 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config config.PluginConfig, body userMessages = append(userMessages, msg.Get("content").String()) } } - key = strings.Join(userMessages, " ") - } else if config.CacheKeyStrategy == "disable" { + key = strings.Join(userMessages, "\n") + } else if c.CacheKeyStrategy == config.CACHE_KEY_STRATEGY_DISABLED { log.Debugf("[onHttpRequestBody] cache key strategy is disabled") ctx.DontReadRequestBody() return types.ActionContinue } else { - log.Warnf("[onHttpRequestBody] unknown cache key strategy: %s", config.CacheKeyStrategy) + log.Warnf("[onHttpRequestBody] unknown cache key strategy: %s", c.CacheKeyStrategy) ctx.DontReadRequestBody() return types.ActionContinue } @@ -123,7 +123,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config config.PluginConfig, body return types.ActionContinue } - if err := CheckCacheForKey(key, ctx, config, log, stream, true); err != nil { + if err := CheckCacheForKey(key, ctx, c, log, stream, true); err != nil { log.Errorf("check cache for key: %s failed, error: %v", key, err) return types.ActionContinue } @@ -131,7 +131,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config config.PluginConfig, body return types.ActionPause } -func onHttpResponseHeaders(ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log) types.Action { +func onHttpResponseHeaders(ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log) types.Action { contentType, _ := proxywasm.GetHttpResponseHeader("content-type") if strings.Contains(contentType, "text/event-stream") { ctx.SetContext(STREAM_CONTEXT_KEY, struct{}{}) @@ -139,8 +139,8 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config config.PluginConfig, return types.ActionContinue } -func onHttpResponseBody(ctx wrapper.HttpContext, config config.PluginConfig, chunk []byte, isLastChunk bool, log wrapper.Log) []byte { - log.Debugf("[onHttpResponseBody] escaped chunk: %q", string(chunk)) +func onHttpResponseBody(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, isLastChunk bool, log wrapper.Log) []byte { + log.Debugf("[onHttpResponseBody] chunk: %q", string(chunk)) log.Debugf("[onHttpResponseBody] isLastChunk: %v", isLastChunk) if ctx.GetContext(TOOL_CALLS_CONTEXT_KEY) != nil { @@ -154,7 +154,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config config.PluginConfig, chu } if !isLastChunk { - handlePartialChunk(ctx, config, chunk, log) + handlePartialChunk(ctx, c, chunk, log) return chunk } @@ -163,9 +163,9 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config config.PluginConfig, chu var err error if len(chunk) > 0 { - value, err = processNonEmptyChunk(ctx, config, chunk, log) + value, err = processNonEmptyChunk(ctx, c, chunk, log) } else { - value, err = processEmptyChunk(ctx, config, chunk, log) + value, err = processEmptyChunk(ctx, c, chunk, log) } if err != nil { @@ -173,10 +173,10 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config config.PluginConfig, chu return chunk } // Cache the final value - cacheResponse(ctx, config, key.(string), value, log) + cacheResponse(ctx, c, key.(string), value, log) // Handle embedding upload if available - uploadEmbeddingAndAnswer(ctx, config, key.(string), value, log) + uploadEmbeddingAndAnswer(ctx, c, key.(string), value, log) return chunk } diff --git a/plugins/wasm-go/extensions/ai-cache/util.go b/plugins/wasm-go/extensions/ai-cache/util.go index 64f3d6dce9..590d5213d4 100644 --- a/plugins/wasm-go/extensions/ai-cache/util.go +++ b/plugins/wasm-go/extensions/ai-cache/util.go @@ -9,7 +9,7 @@ import ( "github.com/tidwall/gjson" ) -func processSSEMessage(ctx wrapper.HttpContext, config config.PluginConfig, sseMessage string, log wrapper.Log) string { +func processSSEMessage(ctx wrapper.HttpContext, c config.PluginConfig, sseMessage string, log wrapper.Log) string { subMessages := strings.Split(sseMessage, "\n") var message string for _, msg := range subMessages { @@ -26,8 +26,8 @@ func processSSEMessage(ctx wrapper.HttpContext, config config.PluginConfig, sseM // skip the prefix "data:" bodyJson := message[5:] // Extract values from JSON fields - responseBody := gjson.Get(bodyJson, config.CacheStreamValueFrom) - toolCalls := gjson.Get(bodyJson, config.CacheToolCallsFrom) + responseBody := gjson.Get(bodyJson, c.CacheStreamValueFrom) + toolCalls := gjson.Get(bodyJson, c.CacheToolCallsFrom) if toolCalls.Exists() { // TODO: Temporarily store the tool_calls value in the context for processing @@ -59,7 +59,7 @@ func processSSEMessage(ctx wrapper.HttpContext, config config.PluginConfig, sseM } // Handles partial chunks of data when the full response is not received yet. -func handlePartialChunk(ctx wrapper.HttpContext, config config.PluginConfig, chunk []byte, log wrapper.Log) { +func handlePartialChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) { stream := ctx.GetContext(STREAM_CONTEXT_KEY) if stream == nil { @@ -74,7 +74,7 @@ func handlePartialChunk(ctx wrapper.HttpContext, config config.PluginConfig, chu partialMessage := appendPartialMessage(ctx, chunk) messages := strings.Split(string(partialMessage), "\n\n") for _, msg := range messages[:len(messages)-1] { - processSSEMessage(ctx, config, msg, log) + processSSEMessage(ctx, c, msg, log) } savePartialMessage(ctx, partialMessage, messages) } @@ -91,6 +91,11 @@ func appendPartialMessage(ctx wrapper.HttpContext, chunk []byte) []byte { // Saves the remaining partial message chunk func savePartialMessage(ctx wrapper.HttpContext, partialMessage []byte, messages []string) { + if len(messages) == 0 { + ctx.SetContext(PARTIAL_MESSAGE_CONTEXT_KEY, nil) + return + } + if !strings.HasSuffix(string(partialMessage), "\n\n") { ctx.SetContext(PARTIAL_MESSAGE_CONTEXT_KEY, []byte(messages[len(messages)-1])) } else { @@ -98,21 +103,21 @@ func savePartialMessage(ctx wrapper.HttpContext, partialMessage []byte, messages } } -// Processes the final chunk and returns the parsed value or an error -func processNonEmptyChunk(ctx wrapper.HttpContext, config config.PluginConfig, chunk []byte, log wrapper.Log) (string, error) { +// Processes a non-empty data chunk and returns the parsed value or an error +func processNonEmptyChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) (string, error) { stream := ctx.GetContext(STREAM_CONTEXT_KEY) var value string if stream == nil { body := appendFinalBody(ctx, chunk) bodyJson := gjson.ParseBytes(body) - value = bodyJson.Get(config.CacheValueFrom).String() + value = bodyJson.Get(c.CacheValueFrom).String() if value == "" { return "", fmt.Errorf("failed to parse value from response body: %s", body) } } else { - value, err := processFinalStreamMessage(ctx, config, log, chunk) + value, err := processFinalStreamMessage(ctx, c, log, chunk) if err != nil { return "", err } @@ -122,7 +127,7 @@ func processNonEmptyChunk(ctx wrapper.HttpContext, config config.PluginConfig, c return value, nil } -func processEmptyChunk(ctx wrapper.HttpContext, config config.PluginConfig, chunk []byte, log wrapper.Log) (string, error) { +func processEmptyChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) (string, error) { tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) if tempContentI == nil { return string(chunk), nil @@ -144,7 +149,7 @@ func appendFinalBody(ctx wrapper.HttpContext, chunk []byte) []byte { } // Processes the final SSE message chunk -func processFinalStreamMessage(ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log, chunk []byte) (string, error) { +func processFinalStreamMessage(ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log, chunk []byte) (string, error) { var lastMessage []byte partialMessageI := ctx.GetContext(PARTIAL_MESSAGE_CONTEXT_KEY) @@ -160,5 +165,5 @@ func processFinalStreamMessage(ctx wrapper.HttpContext, config config.PluginConf } lastMessage = lastMessage[:len(lastMessage)-2] // Remove the last \n\n - return processSSEMessage(ctx, config, string(lastMessage), log), nil + return processSSEMessage(ctx, c, string(lastMessage), log), nil } From 81bde6dd2e199c91ff1fd31a7503aec3ee8884c0 Mon Sep 17 00:00:00 2001 From: async Date: Thu, 24 Oct 2024 09:37:42 +0800 Subject: [PATCH 50/71] update --- plugins/wasm-go/extensions/ai-cache/vector/provider.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/plugins/wasm-go/extensions/ai-cache/vector/provider.go b/plugins/wasm-go/extensions/ai-cache/vector/provider.go index cd99ebaf3f..afbfc5e4e7 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/provider.go @@ -100,6 +100,15 @@ type ProviderConfig struct { // @Title zh-CN DashVector 向量存储服务 Collection ID // @Description zh-CN DashVector 向量存储服务 Collection ID collectionID string + // @Title zh-CN 相似度度量阈值 + // @Description zh-CN 默认相似度度量阈值,默认为 1000。 + Threshold float64 + // @Title zh-CN 相似度度量比较方式 + // @Description zh-CN 相似度度量比较方式,默认为小于。 + // 相似度度量方式有 Cosine, DotProduct, Euclidean 等,前两者值越大相似度越高,后者值越小相似度越高。 + // 所以需要允许自定义比较方式,对于 Cosine 和 DotProduct 选择 gt,对于 Euclidean 则选择 lt。 + // 默认为 lt,所有条件包括 lt (less than,小于)、lte (less than or equal to,小等于)、gt (greater than,大于)、gte (greater than or equal to,大等于) + ThresholdRelation string } func (c *ProviderConfig) GetProviderType() string { From ea34f4a478bc7d7d212416e32029f4b04358ef79 Mon Sep 17 00:00:00 2001 From: async Date: Thu, 24 Oct 2024 10:12:53 +0800 Subject: [PATCH 51/71] fix: bugs --- plugins/wasm-go/extensions/ai-cache/vector/provider.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/plugins/wasm-go/extensions/ai-cache/vector/provider.go b/plugins/wasm-go/extensions/ai-cache/vector/provider.go index afbfc5e4e7..98ee13c42d 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/provider.go @@ -134,6 +134,14 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { if c.timeout == 0 { c.timeout = 10000 } + c.Threshold = json.Get("threshold").Float() + if c.Threshold == 0 { + c.Threshold = 1000 + } + c.ThresholdRelation = json.Get("thresholdRelation").String() + if c.ThresholdRelation == "" { + c.ThresholdRelation = "lt" + } } func (c *ProviderConfig) Validate() error { From f5b50fd35af851a9a791440ffbd42568048d123f Mon Sep 17 00:00:00 2001 From: suchun Date: Thu, 24 Oct 2024 08:38:33 +0000 Subject: [PATCH 52/71] add support for skip-cache --- plugins/wasm-go/extensions/ai-cache/main.go | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-cache/main.go b/plugins/wasm-go/extensions/ai-cache/main.go index 5be78c01db..a0e56964da 100644 --- a/plugins/wasm-go/extensions/ai-cache/main.go +++ b/plugins/wasm-go/extensions/ai-cache/main.go @@ -20,6 +20,7 @@ const ( PARTIAL_MESSAGE_CONTEXT_KEY = "partialMessage" TOOL_CALLS_CONTEXT_KEY = "toolCalls" STREAM_CONTEXT_KEY = "stream" + SKIP_CACHE_HEADER = "skip-cache" ) func main() { @@ -50,11 +51,10 @@ func parseConfig(json gjson.Result, c *config.PluginConfig, log wrapper.Log) err return nil } - -func onHttpRequestHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action { - skipCache, _ := proxywasm.GetHttpRequestHeader(SkipCacheHeader) +func onHttpRequestHeaders(ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log) types.Action { + skipCache, _ := proxywasm.GetHttpRequestHeader(SKIP_CACHE_HEADER) if skipCache == "on" { - ctx.SetContext(SkipCacheHeader, struct{}{}) + ctx.SetContext(SKIP_CACHE_HEADER, struct{}{}) ctx.DontReadRequestBody() return types.ActionContinue } @@ -123,9 +123,8 @@ func onHttpRequestBody(ctx wrapper.HttpContext, c config.PluginConfig, body []by return types.ActionPause } - -func onHttpResponseHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action { - skipCache := ctx.GetContext(SkipCacheHeader) +func onHttpResponseHeaders(ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log) types.Action { + skipCache := ctx.GetContext(SKIP_CACHE_HEADER) if skipCache != nil { ctx.DontReadResponseBody() return types.ActionContinue From a1fe701a51416c29a1c11dbf04300d37b0a3018e Mon Sep 17 00:00:00 2001 From: suchun Date: Thu, 24 Oct 2024 08:42:34 +0000 Subject: [PATCH 53/71] update README.md and change to FQDNCluster --- plugins/wasm-go/extensions/ai-cache/README.md | 14 +++++++------- .../extensions/ai-cache/vector/dashvector.go | 8 ++++---- .../wasm-go/extensions/ai-cache/vector/provider.go | 2 +- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-cache/README.md b/plugins/wasm-go/extensions/ai-cache/README.md index d1bda6ee33..df656babd0 100644 --- a/plugins/wasm-go/extensions/ai-cache/README.md +++ b/plugins/wasm-go/extensions/ai-cache/README.md @@ -36,8 +36,8 @@ LLM 结果缓存插件,默认配置方式可以直接用于 openai 协议的 | Name | Type | Requirement | Default | Description | | --- | --- | --- | --- | --- | -| vector.type | string | optional | "" | 向量存储服务提供者类型,例如 DashVector | -| embedding.type | string | optional | "" | 请求文本向量化服务类型,例如 DashScope | +| vector.type | string | optional | "" | 向量存储服务提供者类型,例如 dashvector | +| embedding.type | string | optional | "" | 请求文本向量化服务类型,例如 dashscope | | cache.type | string | optional | "" | 缓存服务类型,例如 redis | | cacheKeyStrategy | string | optional | "lastQuestion" | 决定如何根据历史问题生成缓存键的策略。可选值: "lastQuestion" (使用最后一个问题), "allQuestions" (拼接所有问题) 或 "disabled" (禁用缓存) | | enableSemanticCache | bool | optional | true | 是否启用语义化缓存, 若不启用,则使用字符串匹配的方式来查找缓存,此时需要配置cache服务 | @@ -47,23 +47,23 @@ LLM 结果缓存插件,默认配置方式可以直接用于 openai 协议的 ## 向量数据库服务(vector) | Name | Type | Requirement | Default | Description | | --- | --- | --- | --- | --- | -| vector.type | string | required | "" | 向量存储服务提供者类型,例如 DashVector | +| vector.type | string | required | "" | 向量存储服务提供者类型,例如 dashvector | | vector.serviceName | string | required | "" | 向量存储服务名称 | | vector.serviceDomain | string | required | "" | 向量存储服务域名 | | vector.servicePort | int64 | optional | 443 | 向量存储服务端口 | | vector.apiKey | string | optional | "" | 向量存储服务 API Key | | vector.topK | int | optional | 1 | 返回TopK结果,默认为 1 | | vector.timeout | uint32 | optional | 10000 | 请求向量存储服务的超时时间,单位为毫秒。默认值是10000,即10秒 | -| vector.collectionID | string | optional | "" | DashVector 向量存储服务 Collection ID | +| vector.collectionID | string | optional | "" | dashvector 向量存储服务 Collection ID | | vector.threshold | float64 | optional | 1000 | 向量相似度度量阈值 | | vector.thresholdRelation | string | optional | lt | 相似度度量方式有 `Cosine`, `DotProduct`, `Euclidean` 等,前两者值越大相似度越高,后者值越小相似度越高。对于 `Cosine` 和 `DotProduct` 选择 `gt`,对于 `Euclidean` 则选择 `lt`。默认为 `lt`,所有条件包括 `lt` (less than,小于)、`lte` (less than or equal to,小等于)、`gt` (greater than,大于)、`gte` (greater than or equal to,大等于) | ## 文本向量化服务(embedding) | Name | Type | Requirement | Default | Description | | --- | --- | --- | --- | --- | -| embedding.type | string | required | "" | 请求文本向量化服务类型,例如 DashScope | +| embedding.type | string | required | "" | 请求文本向量化服务类型,例如 dashscope | | embedding.serviceName | string | required | "" | 请求文本向量化服务名称 | -| embedding.serviceDomain | string | optional | "" | 请求文本向量化服务域名 | +| embedding.serviceHost | string | optional | "" | 请求文本向量化服务域名 | | embedding.servicePort | int64 | optional | 443 | 请求文本向量化服务端口 | | embedding.apiKey | string | optional | "" | 请求文本向量化服务的 API Key | | embedding.timeout | uint32 | optional | 10000 | 请求文本向量化服务的超时时间,单位为毫秒。默认值是10000,即10秒 | @@ -74,7 +74,7 @@ LLM 结果缓存插件,默认配置方式可以直接用于 openai 协议的 | cache.type | string | required | "" | 缓存服务类型,例如 redis | | --- | --- | --- | --- | --- | | cache.serviceName | string | required | "" | 缓存服务名称 | -| cache.serviceDomain | string | required | "" | 缓存服务域名 | +| cache.serviceHost | string | required | "" | 缓存服务域名 | | cache.servicePort | int64 | optional | 6379 | 缓存服务端口 | | cache.username | string | optional | "" | 缓存服务用户名 | | cache.password | string | optional | "" | 缓存服务密码 | diff --git a/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go b/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go index ccafaa4caf..3be568801f 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go @@ -31,10 +31,10 @@ func (d *dashVectorProviderInitializer) ValidateConfig(config ProviderConfig) er func (d *dashVectorProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { return &DvProvider{ config: config, - client: wrapper.NewClusterClient(wrapper.DnsCluster{ - ServiceName: config.serviceName, - Port: config.servicePort, - Domain: config.serviceDomain, + client: wrapper.NewClusterClient(wrapper.FQDNCluster{ + FQDN: config.serviceName, + Host: config.serviceDomain, + Port: int64(config.servicePort), }), }, nil } diff --git a/plugins/wasm-go/extensions/ai-cache/vector/provider.go b/plugins/wasm-go/extensions/ai-cache/vector/provider.go index 98ee13c42d..3a2ef2bed4 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/provider.go @@ -84,7 +84,7 @@ type ProviderConfig struct { serviceName string // @Title zh-CN 向量存储服务域名 // @Description zh-CN 向量存储服务域名 - serviceDomain string + serviceHost string // @Title zh-CN 向量存储服务端口 // @Description zh-CN 向量存储服务端口 servicePort int64 From 730d951272bd588eba8a16f95571520c546a63ee Mon Sep 17 00:00:00 2001 From: suchun Date: Thu, 24 Oct 2024 08:49:02 +0000 Subject: [PATCH 54/71] change to FQDNCluster --- plugins/wasm-go/extensions/ai-cache/README.md | 2 +- .../extensions/ai-cache/embedding/dashscope.go | 12 ++++++------ .../extensions/ai-cache/embedding/provider.go | 4 ++-- .../wasm-go/extensions/ai-cache/vector/dashvector.go | 8 ++++---- .../wasm-go/extensions/ai-cache/vector/provider.go | 12 ++++++------ 5 files changed, 19 insertions(+), 19 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-cache/README.md b/plugins/wasm-go/extensions/ai-cache/README.md index df656babd0..35f8da0099 100644 --- a/plugins/wasm-go/extensions/ai-cache/README.md +++ b/plugins/wasm-go/extensions/ai-cache/README.md @@ -49,7 +49,7 @@ LLM 结果缓存插件,默认配置方式可以直接用于 openai 协议的 | --- | --- | --- | --- | --- | | vector.type | string | required | "" | 向量存储服务提供者类型,例如 dashvector | | vector.serviceName | string | required | "" | 向量存储服务名称 | -| vector.serviceDomain | string | required | "" | 向量存储服务域名 | +| vector.serviceHost | string | required | "" | 向量存储服务域名 | | vector.servicePort | int64 | optional | 443 | 向量存储服务端口 | | vector.apiKey | string | optional | "" | 向量存储服务 API Key | | vector.topK | int | optional | 1 | 返回TopK结果,默认为 1 | diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go b/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go index eef81ea64c..ba65d3a3b5 100644 --- a/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go +++ b/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go @@ -31,15 +31,15 @@ func (d *dashScopeProviderInitializer) CreateProvider(c ProviderConfig) (Provide if c.servicePort == 0 { c.servicePort = DASHSCOPE_PORT } - if c.serviceDomain == "" { - c.serviceDomain = DASHSCOPE_DOMAIN + if c.serviceHost == "" { + c.serviceHost = DASHSCOPE_DOMAIN } return &DSProvider{ config: c, - client: wrapper.NewClusterClient(wrapper.DnsCluster{ - ServiceName: c.serviceName, - Port: c.servicePort, - Domain: c.serviceDomain, + client: wrapper.NewClusterClient(wrapper.FQDNCluster{ + FQDN: c.serviceName, + Host: c.serviceHost, + Port: int64(c.servicePort), }), }, nil } diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/provider.go b/plugins/wasm-go/extensions/ai-cache/embedding/provider.go index b7748f2cc4..909edf129c 100644 --- a/plugins/wasm-go/extensions/ai-cache/embedding/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/embedding/provider.go @@ -31,7 +31,7 @@ type ProviderConfig struct { serviceName string // @Title zh-CN 文本特征提取服务域名 // @Description zh-CN 文本特征提取服务域名 - serviceDomain string + serviceHost string // @Title zh-CN 文本特征提取服务端口 // @Description zh-CN 文本特征提取服务端口 servicePort int64 @@ -49,7 +49,7 @@ type ProviderConfig struct { func (c *ProviderConfig) FromJson(json gjson.Result) { c.typ = json.Get("type").String() c.serviceName = json.Get("serviceName").String() - c.serviceDomain = json.Get("serviceDomain").String() + c.serviceHost = json.Get("serviceHost").String() c.servicePort = json.Get("servicePort").Int() c.apiKey = json.Get("apiKey").String() c.timeout = uint32(json.Get("timeout").Int()) diff --git a/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go b/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go index 3be568801f..7bdb0a76d0 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go @@ -22,8 +22,8 @@ func (d *dashVectorProviderInitializer) ValidateConfig(config ProviderConfig) er if len(config.serviceName) == 0 { return errors.New("[DashVector] serviceName is required") } - if len(config.serviceDomain) == 0 { - return errors.New("[DashVector] serviceDomain is required") + if len(config.serviceHost) == 0 { + return errors.New("[DashVector] serviceHost is required") } return nil } @@ -33,7 +33,7 @@ func (d *dashVectorProviderInitializer) CreateProvider(config ProviderConfig) (P config: config, client: wrapper.NewClusterClient(wrapper.FQDNCluster{ FQDN: config.serviceName, - Host: config.serviceDomain, + Host: config.serviceHost, Port: int64(config.servicePort), }), }, nil @@ -45,7 +45,7 @@ type DvProvider struct { } func (d *DvProvider) GetProviderType() string { - return providerTypeDashVector + return PROVIDER_TYPE_DASH_VECTOR } // type embeddingRequest struct { diff --git a/plugins/wasm-go/extensions/ai-cache/vector/provider.go b/plugins/wasm-go/extensions/ai-cache/vector/provider.go index 3a2ef2bed4..a04123a166 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/provider.go @@ -8,8 +8,8 @@ import ( ) const ( - providerTypeDashVector = "dashvector" - providerTypeChroma = "chroma" + PROVIDER_TYPE_DASH_VECTOR = "dashvector" + PROVIDER_TYPE_CHROMA = "chroma" ) type providerInitializer interface { @@ -19,8 +19,8 @@ type providerInitializer interface { var ( providerInitializers = map[string]providerInitializer{ - providerTypeDashVector: &dashVectorProviderInitializer{}, - // providerTypeChroma: &chromaProviderInitializer{}, + PROVIDER_TYPE_DASH_VECTOR: &dashVectorProviderInitializer{}, + // PROVIDER_TYPE_CHROMA: &chromaProviderInitializer{}, } ) @@ -77,7 +77,7 @@ type SimilarityThresholdProvider interface { type ProviderConfig struct { // @Title zh-CN 向量存储服务提供者类型 - // @Description zh-CN 向量存储服务提供者类型,例如 DashVector、Milvus + // @Description zh-CN 向量存储服务提供者类型,例如 dashvector、chroma typ string // @Title zh-CN 向量存储服务名称 // @Description zh-CN 向量存储服务名称 @@ -119,7 +119,7 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { c.typ = json.Get("type").String() // DashVector c.serviceName = json.Get("serviceName").String() - c.serviceDomain = json.Get("serviceDomain").String() + c.serviceHost = json.Get("serviceHost").String() c.servicePort = int64(json.Get("servicePort").Int()) if c.servicePort == 0 { c.servicePort = 443 From 335c04c0a78f83b0c9aa3df823cd9f8c147bf9b6 Mon Sep 17 00:00:00 2001 From: suchun Date: Fri, 25 Oct 2024 23:36:32 +0000 Subject: [PATCH 55/71] provide support for the legacy configuration --- plugins/wasm-go/extensions/ai-cache/README.md | 25 +++- .../extensions/ai-cache/cache/provider.go | 5 + .../extensions/ai-cache/config/config.go | 46 ++++-- plugins/wasm-go/extensions/ai-cache/core.go | 45 +++--- .../ai-cache/embedding/dashscope.go | 2 +- plugins/wasm-go/extensions/ai-cache/main.go | 134 +++++++++++++++--- 6 files changed, 194 insertions(+), 63 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-cache/README.md b/plugins/wasm-go/extensions/ai-cache/README.md index 35f8da0099..ca7d0f6265 100644 --- a/plugins/wasm-go/extensions/ai-cache/README.md +++ b/plugins/wasm-go/extensions/ai-cache/README.md @@ -42,7 +42,13 @@ LLM 结果缓存插件,默认配置方式可以直接用于 openai 协议的 | cacheKeyStrategy | string | optional | "lastQuestion" | 决定如何根据历史问题生成缓存键的策略。可选值: "lastQuestion" (使用最后一个问题), "allQuestions" (拼接所有问题) 或 "disabled" (禁用缓存) | | enableSemanticCache | bool | optional | true | 是否启用语义化缓存, 若不启用,则使用字符串匹配的方式来查找缓存,此时需要配置cache服务 | -以下是vector、embedding、cache的具体配置说明,注意若不配置embedding或cache服务,则可忽略以下相应配置中的 `required` 字段。 +根据是否需要启用语义缓存,可以只配置组件的组合为: +1. `cache`: 仅启用字符串匹配缓存 +3. `vector (+ embedding)`: 启用语义化缓存, 其中若 `vector` 未提供字符串表征服务,则需要自行配置 `embedding` 服务 +2. `vector (+ embedding) + cache`: 启用语义化缓存并用缓存服务存储LLM响应以加速 + +注意若不配置相关组件,则可以忽略相应组件的`required`字段。 + ## 向量数据库服务(vector) | Name | Type | Requirement | Default | Description | @@ -99,24 +105,32 @@ LLM 结果缓存插件,默认配置方式可以直接用于 openai 协议的 ```yaml embedding: type: dashscope - serviceName: [Your Service Name] + serviceName: my_dashscope.dns apiKey: [Your Key] vector: type: dashvector - serviceName: [Your Service Name] + serviceName: my_dashvector.dns collectionID: [Your Collection ID] serviceDomain: [Your domain] apiKey: [Your key] cache: type: redis - serviceName: [Your Service Name] + serviceName: my_redis.dns servicePort: 6379 timeout: 100 ``` +旧版本配置兼容 +```yaml +redis: + serviceName: my_redis.dns + servicePort: 6379 + timeout: 100 +``` + ## 进阶用法 当前默认的缓存 key 是基于 GJSON PATH 的表达式:`messages.@reverse.0.content` 提取,含义是把 messages 数组反转后取第一项的 content; @@ -128,3 +142,6 @@ GJSON PATH 支持条件判断语法,例如希望取最后一个 role 为 user 更多用法可以参考[官方文档](https://github.com/tidwall/gjson/blob/master/SYNTAX.md),可以使用 [GJSON Playground](https://gjson.dev/) 进行语法测试。 +## 常见问题 + +1. 如果返回的错误为 `error status returned by host: bad argument`,请检查`serviceName`是否正确包含了服务的类型后缀(.dns等)。 \ No newline at end of file diff --git a/plugins/wasm-go/extensions/ai-cache/cache/provider.go b/plugins/wasm-go/extensions/ai-cache/cache/provider.go index 9cdeaac262..e368b2412c 100644 --- a/plugins/wasm-go/extensions/ai-cache/cache/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/cache/provider.go @@ -90,6 +90,11 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { } +func (c *ProviderConfig) ConvertLegacyJson(json gjson.Result) { + c.FromJson(json) + c.typ = "redis" +} + func (c *ProviderConfig) Validate() error { if c.typ == "" { return errors.New("cache service type is required") diff --git a/plugins/wasm-go/extensions/ai-cache/config/config.go b/plugins/wasm-go/extensions/ai-cache/config/config.go index 87e27452db..9c05587409 100644 --- a/plugins/wasm-go/extensions/ai-cache/config/config.go +++ b/plugins/wasm-go/extensions/ai-cache/config/config.go @@ -50,10 +50,14 @@ func (c *PluginConfig) FromJson(json gjson.Result) { c.vectorProviderConfig.FromJson(json.Get("vector")) c.embeddingProviderConfig.FromJson(json.Get("embedding")) c.cacheProviderConfig.FromJson(json.Get("cache")) + if json.Get("redis").Exists() { + // compatible with legacy config + c.cacheProviderConfig.ConvertLegacyJson(json.Get("redis")) + } c.CacheKeyStrategy = json.Get("cacheKeyStrategy").String() if c.CacheKeyStrategy == "" { - c.CacheKeyStrategy = "lastQuestion" // 设置默认值 + c.CacheKeyStrategy = "lastQuestion" // set default value } // c.CacheKeyFrom = json.Get("cacheKeyFrom").String() // if c.CacheKeyFrom == "" { @@ -85,7 +89,7 @@ func (c *PluginConfig) FromJson(json gjson.Result) { if json.Get("enableSemanticCache").Exists() { c.EnableSemanticCache = json.Get("enableSemanticCache").Bool() } else { - c.EnableSemanticCache = true // 设置默认值为 true + c.EnableSemanticCache = true // set default value to true } } @@ -107,29 +111,34 @@ func (c *PluginConfig) Validate() error { } } - // vector 和 embedding 不能同时为空 - if c.vectorProviderConfig.GetProviderType() == "" && c.embeddingProviderConfig.GetProviderType() == "" { - return fmt.Errorf("vector and embedding provider cannot be both empty") + // cache, vector, and embedding cannot all be empty + if c.vectorProviderConfig.GetProviderType() == "" && + c.embeddingProviderConfig.GetProviderType() == "" && + c.cacheProviderConfig.GetProviderType() == "" { + return fmt.Errorf("vector, embedding and cache provider cannot be all empty") } - // 验证 CacheKeyStrategy 的值 + // Validate the value of CacheKeyStrategy if c.CacheKeyStrategy != CACHE_KEY_STRATEGY_LAST_QUESTION && c.CacheKeyStrategy != CACHE_KEY_STRATEGY_ALL_QUESTIONS && c.CacheKeyStrategy != CACHE_KEY_STRATEGY_DISABLED { return fmt.Errorf("invalid CacheKeyStrategy: %s", c.CacheKeyStrategy) } - // 如果启用了语义化缓存,确保必要的组件已配置 - if c.EnableSemanticCache { - if c.embeddingProviderConfig.GetProviderType() == "" { - return fmt.Errorf("semantic cache is enabled but embedding provider is not configured") - } - } + + // If semantic cache is enabled, ensure necessary components are configured + // if c.EnableSemanticCache { + // if c.embeddingProviderConfig.GetProviderType() == "" { + // return fmt.Errorf("semantic cache is enabled but embedding provider is not configured") + // } + // // if only configure cache, just warn the user + // } return nil } func (c *PluginConfig) Complete(log wrapper.Log) error { var err error if c.embeddingProviderConfig.GetProviderType() != "" { + log.Debugf("embedding provider is set to %s", c.embeddingProviderConfig.GetProviderType()) c.embeddingProvider, err = embedding.CreateProvider(c.embeddingProviderConfig) if err != nil { return err @@ -139,6 +148,7 @@ func (c *PluginConfig) Complete(log wrapper.Log) error { c.embeddingProvider = nil } if c.cacheProviderConfig.GetProviderType() != "" { + log.Debugf("cache provider is set to %s", c.cacheProviderConfig.GetProviderType()) c.cacheProvider, err = cache.CreateProvider(c.cacheProviderConfig) if err != nil { return err @@ -147,9 +157,15 @@ func (c *PluginConfig) Complete(log wrapper.Log) error { log.Info("cache provider is not configured") c.cacheProvider = nil } - c.vectorProvider, err = vector.CreateProvider(c.vectorProviderConfig) - if err != nil { - return err + if c.vectorProviderConfig.GetProviderType() != "" { + log.Debugf("vector provider is set to %s", c.vectorProviderConfig.GetProviderType()) + c.vectorProvider, err = vector.CreateProvider(c.vectorProviderConfig) + if err != nil { + return err + } + } else { + log.Info("vector provider is not configured") + c.vectorProvider = nil } return nil } diff --git a/plugins/wasm-go/extensions/ai-cache/core.go b/plugins/wasm-go/extensions/ai-cache/core.go index d3e9c81b5f..b4c7d33e6e 100644 --- a/plugins/wasm-go/extensions/ai-cache/core.go +++ b/plugins/wasm-go/extensions/ai-cache/core.go @@ -59,23 +59,15 @@ func handleCacheResponse(key string, response resp.Value, ctx wrapper.HttpContex // processCacheHit handles a successful cache hit. func processCacheHit(key string, response string, stream bool, ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log) { - if stream { - log.Debug("streaming response is not supported for cache hit yet") - stream = false - } - // escapedResponse, err := json.Marshal(response) - // log.Debugf("Cached response for key %s: %s", key, escapedResponse) - - // if err != nil { - // handleInternalError(err, "Failed to marshal cached response", log) - // return - // } log.Debugf("Cached response for key %s: %s", key, response) ctx.SetContext(CACHE_KEY_CONTEXT_KEY, nil) - // proxywasm.SendHttpResponseWithDetail(200, "ai-cache.hit", [][2]string{{"content-type", "application/json; charset=utf-8"}}, []byte(fmt.Sprintf(config.ResponseTemplate, escapedResponse)), -1) - proxywasm.SendHttpResponseWithDetail(200, "ai-cache.hit", [][2]string{{"content-type", "application/json; charset=utf-8"}}, []byte(fmt.Sprintf(c.ResponseTemplate, response)), -1) + if stream { + proxywasm.SendHttpResponseWithDetail(200, "ai-cache.hit", [][2]string{{"content-type", "application/json; charset=utf-8"}}, []byte(fmt.Sprintf(c.StreamResponseTemplate, response)), -1) + } else { + proxywasm.SendHttpResponseWithDetail(200, "ai-cache.hit", [][2]string{{"content-type", "application/json; charset=utf-8"}}, []byte(fmt.Sprintf(c.ResponseTemplate, response)), -1) + } } // performSimilaritySearch determines the appropriate similarity search method to use. @@ -87,11 +79,13 @@ func performSimilaritySearch(key string, ctx wrapper.HttpContext, c config.Plugi // Check if the active vector provider implements the StringQuerier interface. if _, ok := activeVectorProvider.(vector.StringQuerier); ok { + log.Debugf("[%s] [performSimilaritySearch] active vector provider implements StringQuerier interface, performing string query", PLUGIN_NAME) return performStringQuery(key, queryString, ctx, c, log, stream) } // Check if the active vector provider implements the EmbeddingQuerier interface. if _, ok := activeVectorProvider.(vector.EmbeddingQuerier); ok { + log.Debugf("[%s] [performSimilaritySearch] active vector provider implements EmbeddingQuerier interface, performing embedding query", PLUGIN_NAME) return performEmbeddingQuery(key, ctx, c, log, stream) } @@ -123,6 +117,7 @@ func performEmbeddingQuery(key string, ctx wrapper.HttpContext, c config.PluginC } return activeEmbeddingProvider.GetEmbedding(key, ctx, log, func(textEmbedding []float64, err error) { + log.Debugf("[%s] [performEmbeddingQuery] GetEmbedding success, length of embedding: %d, error: %v", PLUGIN_NAME, len(textEmbedding), err) if err != nil { handleInternalError(err, fmt.Sprintf("Error getting embedding for key: %s", key), log) return @@ -146,17 +141,17 @@ func handleQueryResults(key string, results []vector.QueryResult, ctx wrapper.Ht } if len(results) == 0 { - log.Warnf("No similar keys found for key: %s", key) + log.Warnf("[%s] [handleQueryResults] No similar keys found for key: %s", PLUGIN_NAME, key) proxywasm.ResumeHttpRequest() return } mostSimilarData := results[0] - log.Debugf("For key: %s, the most similar key found: %s with score: %f", key, mostSimilarData.Text, mostSimilarData.Score) + log.Debugf("[%s] [handleQueryResults] For key: %s, the most similar key found: %s with score: %f", PLUGIN_NAME, key, mostSimilarData.Text, mostSimilarData.Score) simThreshold := c.GetVectorProviderConfig().Threshold simThresholdRelation := c.GetVectorProviderConfig().ThresholdRelation if compare(simThresholdRelation, mostSimilarData.Score, simThreshold) { - log.Infof("Key accepted: %s with score: %f below threshold", mostSimilarData.Text, mostSimilarData.Score) + log.Infof("[%s] [handleQueryResults] Key accepted: %s with score: %f below threshold", PLUGIN_NAME, mostSimilarData.Text, mostSimilarData.Score) if mostSimilarData.Answer != "" { // direct return the answer if available processCacheHit(key, mostSimilarData.Answer, stream, ctx, c, log) @@ -169,11 +164,11 @@ func handleQueryResults(key string, results []vector.QueryResult, ctx wrapper.Ht // } // Otherwise, do not check the cache, directly return - log.Warnf("Cache hit for key: %s, but no corresponding answer found in the vector database", mostSimilarData.Text) + log.Warnf("[%s] [handleQueryResults] Cache hit for key: %s, but no corresponding answer found in the vector database", PLUGIN_NAME, mostSimilarData.Text) proxywasm.ResumeHttpRequest() } } else { - log.Infof("Score not meet the threshold %f: %s with score %f", simThreshold, mostSimilarData.Text, mostSimilarData.Score) + log.Infof("[%s] [handleQueryResults] Score not meet the threshold %f: %s with score %f", PLUGIN_NAME, simThreshold, mostSimilarData.Text, mostSimilarData.Score) proxywasm.ResumeHttpRequest() } } @@ -200,7 +195,7 @@ func cacheResponse(ctx wrapper.HttpContext, c config.PluginConfig, key string, v activeCacheProvider := c.GetCacheProvider() if activeCacheProvider != nil { queryKey := activeCacheProvider.GetCacheKeyPrefix() + key - log.Infof("[onHttpResponseBody] setting cache to redis, key: %s, value: %s", queryKey, value) + log.Infof("[%s] [cacheResponse] setting cache to redis, key: %s, value: %s", PLUGIN_NAME, queryKey, value) _ = activeCacheProvider.Set(queryKey, value, nil) } } @@ -214,22 +209,22 @@ func uploadEmbeddingAndAnswer(ctx wrapper.HttpContext, c config.PluginConfig, ke emb, ok := embedding.([]float64) if !ok { - log.Errorf("[onHttpResponseBody] embedding is not of expected type []float64") + log.Errorf("[%s] [uploadEmbeddingAndAnswer] embedding is not of expected type []float64", PLUGIN_NAME) return } activeVectorProvider := c.GetVectorProvider() if activeVectorProvider == nil { - log.Debug("[onHttpResponseBody] no vector provider configured for uploading embedding") + log.Debugf("[%s] [uploadEmbeddingAndAnswer] no vector provider configured for uploading embedding", PLUGIN_NAME) return } // Attempt to upload answer embedding first if ansEmbUploader, ok := activeVectorProvider.(vector.AnswerAndEmbeddingUploader); ok { - log.Infof("[onHttpResponseBody] uploading answer embedding for key: %s", key) + log.Infof("[%s] [uploadEmbeddingAndAnswer] uploading answer embedding for key: %s", PLUGIN_NAME, key) err := ansEmbUploader.UploadAnswerAndEmbedding(key, emb, value, ctx, log, nil) if err != nil { - log.Warnf("[onHttpResponseBody] failed to upload answer embedding for key: %s, error: %v", key, err) + log.Warnf("[%s] [uploadEmbeddingAndAnswer] failed to upload answer embedding for key: %s, error: %v", PLUGIN_NAME, key, err) } else { return // If successful, return early } @@ -237,10 +232,10 @@ func uploadEmbeddingAndAnswer(ctx wrapper.HttpContext, c config.PluginConfig, ke // If answer embedding upload fails, attempt normal embedding upload if embUploader, ok := activeVectorProvider.(vector.EmbeddingUploader); ok { - log.Infof("[onHttpResponseBody] uploading embedding for key: %s", key) + log.Infof("[%s] [uploadEmbeddingAndAnswer] uploading embedding for key: %s", PLUGIN_NAME, key) err := embUploader.UploadEmbedding(key, emb, ctx, log, nil) if err != nil { - log.Warnf("[onHttpResponseBody] failed to upload embedding for key: %s, error: %v", key, err) + log.Warnf("[uploadEmbeddingAndAnswer] failed to upload embedding for key: %s, error: %v",key, err) } } } diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go b/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go index ba65d3a3b5..35c897cce5 100644 --- a/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go +++ b/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go @@ -13,7 +13,7 @@ import ( const ( DASHSCOPE_DOMAIN = "dashscope.aliyuncs.com" DASHSCOPE_PORT = 443 - DASHSCOPE_DEFAULT_MODEL_NAME = "text-embedding-v1" + DASHSCOPE_DEFAULT_MODEL_NAME = "text-embedding-v2" DASHSCOPE_ENDPOINT = "/api/v1/services/embeddings/text-embedding/text-embedding" ) diff --git a/plugins/wasm-go/extensions/ai-cache/main.go b/plugins/wasm-go/extensions/ai-cache/main.go index a0e56964da..3851188b96 100644 --- a/plugins/wasm-go/extensions/ai-cache/main.go +++ b/plugins/wasm-go/extensions/ai-cache/main.go @@ -43,7 +43,7 @@ func parseConfig(json gjson.Result, c *config.PluginConfig, log wrapper.Log) err if err := c.Validate(); err != nil { return err } - // 注意,在 parseConfig 阶段初始化 client 会出错,比如 docker compose 中的 redis 就无法使用 + // Note that initializing the client during the parseConfig phase may cause errors, such as Redis not being usable in Docker Compose. if err := c.Complete(log); err != nil { log.Errorf("complete config failed: %v", err) return err @@ -136,38 +136,136 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, c config.PluginConfig, log w return types.ActionContinue } +// func onHttpResponseBody(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, isLastChunk bool, log wrapper.Log) []byte { +// log.Debugf("[onHttpResponseBody] chunk: %q", string(chunk)) +// log.Debugf("[onHttpResponseBody] isLastChunk: %v", isLastChunk) + +// if ctx.GetContext(TOOL_CALLS_CONTEXT_KEY) != nil { +// return chunk +// } + +// key := ctx.GetContext(CACHE_KEY_CONTEXT_KEY) +// if key == nil { +// log.Debug("[onHttpResponseBody] key is nil, bypass caching") +// return chunk +// } + +// if !isLastChunk { +// handlePartialChunk(ctx, c, chunk, log) +// return chunk +// } + +// // Handle last chunk +// var value string +// var err error + +// if len(chunk) > 0 { +// value, err = processNonEmptyChunk(ctx, c, chunk, log) +// } else { +// value, err = processEmptyChunk(ctx, c, chunk, log) +// } + +// if err != nil { +// log.Warnf("[onHttpResponseBody] failed to process chunk: %v", err) +// return chunk +// } +// // Cache the final value +// cacheResponse(ctx, c, key.(string), value, log) + +// // Handle embedding upload if available +// uploadEmbeddingAndAnswer(ctx, c, key.(string), value, log) + +// return chunk +// } + func onHttpResponseBody(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, isLastChunk bool, log wrapper.Log) []byte { - log.Debugf("[onHttpResponseBody] chunk: %q", string(chunk)) - log.Debugf("[onHttpResponseBody] isLastChunk: %v", isLastChunk) + if string(chunk) == "data: [DONE]" { + return nil + } if ctx.GetContext(TOOL_CALLS_CONTEXT_KEY) != nil { + // we should not cache tool call result return chunk } - key := ctx.GetContext(CACHE_KEY_CONTEXT_KEY) if key == nil { - log.Debug("[onHttpResponseBody] key is nil, bypass caching") return chunk } - if !isLastChunk { - handlePartialChunk(ctx, c, chunk, log) + stream := ctx.GetContext(STREAM_CONTEXT_KEY) + if stream == nil { + tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) + if tempContentI == nil { + ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, chunk) + return chunk + } + tempContent := tempContentI.([]byte) + tempContent = append(tempContent, chunk...) + ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, tempContent) + } else { + var partialMessage []byte + partialMessageI := ctx.GetContext(PARTIAL_MESSAGE_CONTEXT_KEY) + if partialMessageI != nil { + partialMessage = append(partialMessageI.([]byte), chunk...) + } else { + partialMessage = chunk + } + messages := strings.Split(string(partialMessage), "\n\n") + for i, msg := range messages { + if i < len(messages)-1 { + // process complete message + processSSEMessage(ctx, c, msg, log) + } + } + if !strings.HasSuffix(string(partialMessage), "\n\n") { + ctx.SetContext(PARTIAL_MESSAGE_CONTEXT_KEY, []byte(messages[len(messages)-1])) + } else { + ctx.SetContext(PARTIAL_MESSAGE_CONTEXT_KEY, nil) + } + } return chunk } - - // Handle last chunk + // last chunk + stream := ctx.GetContext(STREAM_CONTEXT_KEY) var value string - var err error + if stream == nil { + var body []byte + tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) + if tempContentI != nil { + body = append(tempContentI.([]byte), chunk...) + } else { + body = chunk + } + bodyJson := gjson.ParseBytes(body) - if len(chunk) > 0 { - value, err = processNonEmptyChunk(ctx, c, chunk, log) + value = bodyJson.Get(c.CacheValueFrom).String() + if value == "" { + log.Warnf("parse value from response body failded, body:%s", body) + return chunk + } } else { - value, err = processEmptyChunk(ctx, c, chunk, log) - } - - if err != nil { - log.Warnf("[onHttpResponseBody] failed to process chunk: %v", err) - return chunk + if len(chunk) > 0 { + var lastMessage []byte + partialMessageI := ctx.GetContext(PARTIAL_MESSAGE_CONTEXT_KEY) + if partialMessageI != nil { + lastMessage = append(partialMessageI.([]byte), chunk...) + } else { + lastMessage = chunk + } + if !strings.HasSuffix(string(lastMessage), "\n\n") { + log.Warnf("invalid lastMessage:%s", lastMessage) + return chunk + } + // remove the last \n\n + lastMessage = lastMessage[:len(lastMessage)-2] + value = processSSEMessage(ctx, c, string(lastMessage), log) + } else { + tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) + if tempContentI == nil { + return chunk + } + value = tempContentI.(string) + } } // Cache the final value cacheResponse(ctx, c, key.(string), value, log) From 59bddf6d92682d638be427a3fc9c511768587021 Mon Sep 17 00:00:00 2001 From: suchun Date: Sat, 26 Oct 2024 00:17:41 +0000 Subject: [PATCH 56/71] simplify resp func, add func name when debug --- plugins/wasm-go/extensions/ai-cache/core.go | 55 +++--- plugins/wasm-go/extensions/ai-cache/main.go | 125 +------------ plugins/wasm-go/extensions/ai-cache/util.go | 197 +++++++++----------- 3 files changed, 119 insertions(+), 258 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-cache/core.go b/plugins/wasm-go/extensions/ai-cache/core.go index b4c7d33e6e..49286a28ef 100644 --- a/plugins/wasm-go/extensions/ai-cache/core.go +++ b/plugins/wasm-go/extensions/ai-cache/core.go @@ -15,19 +15,19 @@ import ( func CheckCacheForKey(key string, ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log, stream bool, useSimilaritySearch bool) error { activeCacheProvider := c.GetCacheProvider() if activeCacheProvider == nil { - log.Debug("No cache provider configured, performing similarity search") + log.Debugf("[%s] [CheckCacheForKey] no cache provider configured, performing similarity search", PLUGIN_NAME) return performSimilaritySearch(key, ctx, c, log, key, stream) } queryKey := activeCacheProvider.GetCacheKeyPrefix() + key - log.Debugf("Querying cache with key: %s", queryKey) + log.Debugf("[%s] [CheckCacheForKey] querying cache with key: %s", PLUGIN_NAME, queryKey) err := activeCacheProvider.Get(queryKey, func(response resp.Value) { handleCacheResponse(key, response, ctx, log, stream, c, useSimilaritySearch) }) if err != nil { - log.Errorf("Failed to retrieve key: %s from cache, error: %v", key, err) + log.Errorf("[%s] [CheckCacheForKey] failed to retrieve key: %s from cache, error: %v", PLUGIN_NAME, key, err) return err } @@ -37,19 +37,19 @@ func CheckCacheForKey(key string, ctx wrapper.HttpContext, c config.PluginConfig // handleCacheResponse processes cache response and handles cache hits and misses. func handleCacheResponse(key string, response resp.Value, ctx wrapper.HttpContext, log wrapper.Log, stream bool, c config.PluginConfig, useSimilaritySearch bool) { if err := response.Error(); err == nil && !response.IsNull() { - log.Infof("Cache hit for key: %s", key) + log.Infof("[%s] cache hit for key: %s", PLUGIN_NAME, key) processCacheHit(key, response.String(), stream, ctx, c, log) return } - log.Infof("Cache miss for key: %s", key) + log.Infof("[%s] [handleCacheResponse] cache miss for key: %s", PLUGIN_NAME, key) if err := response.Error(); err != nil { - log.Errorf("Error retrieving key: %s from cache, error: %v", key, err) + log.Errorf("[%s] [handleCacheResponse] error retrieving key: %s from cache, error: %v", PLUGIN_NAME, key, err) } if useSimilaritySearch && c.EnableSemanticCache { if err := performSimilaritySearch(key, ctx, c, log, key, stream); err != nil { - log.Errorf("Failed to perform similarity search for key: %s, error: %v", key, err) + log.Errorf("[%s] [handleCacheResponse] failed to perform similarity search for key: %s, error: %v", PLUGIN_NAME, key, err) proxywasm.ResumeHttpRequest() } } else { @@ -59,7 +59,7 @@ func handleCacheResponse(key string, response resp.Value, ctx wrapper.HttpContex // processCacheHit handles a successful cache hit. func processCacheHit(key string, response string, stream bool, ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log) { - log.Debugf("Cached response for key %s: %s", key, response) + log.Debugf("[%s] [processCacheHit] cached response for key %s: %s", PLUGIN_NAME, key, response) ctx.SetContext(CACHE_KEY_CONTEXT_KEY, nil) @@ -74,7 +74,7 @@ func processCacheHit(key string, response string, stream bool, ctx wrapper.HttpC func performSimilaritySearch(key string, ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log, queryString string, stream bool) error { activeVectorProvider := c.GetVectorProvider() if activeVectorProvider == nil { - return errors.New("no vector provider configured for similarity search") + return logAndReturnError(log, "[performSimilaritySearch] no vector provider configured for similarity search") } // Check if the active vector provider implements the StringQuerier interface. @@ -89,14 +89,14 @@ func performSimilaritySearch(key string, ctx wrapper.HttpContext, c config.Plugi return performEmbeddingQuery(key, ctx, c, log, stream) } - return errors.New("no suitable querier or embedding provider available for similarity search") + return logAndReturnError(log, "[performSimilaritySearch] no suitable querier or embedding provider available for similarity search") } // performStringQuery executes the string-based similarity search. func performStringQuery(key string, queryString string, ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log, stream bool) error { stringQuerier, ok := c.GetVectorProvider().(vector.StringQuerier) if !ok { - return logAndReturnError(log, "active vector provider does not implement StringQuerier interface") + return logAndReturnError(log, "[performStringQuery] active vector provider does not implement StringQuerier interface") } return stringQuerier.QueryString(queryString, ctx, log, func(results []vector.QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error) { @@ -108,18 +108,18 @@ func performStringQuery(key string, queryString string, ctx wrapper.HttpContext, func performEmbeddingQuery(key string, ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log, stream bool) error { embeddingQuerier, ok := c.GetVectorProvider().(vector.EmbeddingQuerier) if !ok { - return logAndReturnError(log, "active vector provider does not implement EmbeddingQuerier interface") + return logAndReturnError(log, fmt.Sprintf("[performEmbeddingQuery] active vector provider does not implement EmbeddingQuerier interface")) } activeEmbeddingProvider := c.GetEmbeddingProvider() if activeEmbeddingProvider == nil { - return logAndReturnError(log, "no embedding provider configured for similarity search") + return logAndReturnError(log, fmt.Sprintf("[performEmbeddingQuery] no embedding provider configured for similarity search")) } return activeEmbeddingProvider.GetEmbedding(key, ctx, log, func(textEmbedding []float64, err error) { log.Debugf("[%s] [performEmbeddingQuery] GetEmbedding success, length of embedding: %d, error: %v", PLUGIN_NAME, len(textEmbedding), err) if err != nil { - handleInternalError(err, fmt.Sprintf("Error getting embedding for key: %s", key), log) + handleInternalError(err, fmt.Sprintf("[%s] [performEmbeddingQuery] error getting embedding for key: %s", PLUGIN_NAME, key), log) return } ctx.SetContext(CACHE_KEY_EMBEDDING_KEY, textEmbedding) @@ -128,7 +128,7 @@ func performEmbeddingQuery(key string, ctx wrapper.HttpContext, c config.PluginC handleQueryResults(key, results, ctx, log, stream, c, err) }) if err != nil { - handleInternalError(err, fmt.Sprintf("Error querying vector database for key: %s", key), log) + handleInternalError(err, fmt.Sprintf("[%s] [performEmbeddingQuery] error querying vector database for key: %s", PLUGIN_NAME, key), log) } }) } @@ -136,22 +136,22 @@ func performEmbeddingQuery(key string, ctx wrapper.HttpContext, c config.PluginC // handleQueryResults processes the results of similarity search and determines next actions. func handleQueryResults(key string, results []vector.QueryResult, ctx wrapper.HttpContext, log wrapper.Log, stream bool, c config.PluginConfig, err error) { if err != nil { - handleInternalError(err, fmt.Sprintf("Error querying vector database for key: %s", key), log) + handleInternalError(err, fmt.Sprintf("[%s] [handleQueryResults] error querying vector database for key: %s", PLUGIN_NAME, key), log) return } if len(results) == 0 { - log.Warnf("[%s] [handleQueryResults] No similar keys found for key: %s", PLUGIN_NAME, key) + log.Warnf("[%s] [handleQueryResults] no similar keys found for key: %s", PLUGIN_NAME, key) proxywasm.ResumeHttpRequest() return } mostSimilarData := results[0] - log.Debugf("[%s] [handleQueryResults] For key: %s, the most similar key found: %s with score: %f", PLUGIN_NAME, key, mostSimilarData.Text, mostSimilarData.Score) + log.Debugf("[%s] [handleQueryResults] for key: %s, the most similar key found: %s with score: %f", PLUGIN_NAME, key, mostSimilarData.Text, mostSimilarData.Score) simThreshold := c.GetVectorProviderConfig().Threshold simThresholdRelation := c.GetVectorProviderConfig().ThresholdRelation if compare(simThresholdRelation, mostSimilarData.Score, simThreshold) { - log.Infof("[%s] [handleQueryResults] Key accepted: %s with score: %f below threshold", PLUGIN_NAME, mostSimilarData.Text, mostSimilarData.Score) + log.Infof("[%s] key accepted: %s with score: %f", PLUGIN_NAME, mostSimilarData.Text, mostSimilarData.Score) if mostSimilarData.Answer != "" { // direct return the answer if available processCacheHit(key, mostSimilarData.Answer, stream, ctx, c, log) @@ -164,17 +164,18 @@ func handleQueryResults(key string, results []vector.QueryResult, ctx wrapper.Ht // } // Otherwise, do not check the cache, directly return - log.Warnf("[%s] [handleQueryResults] Cache hit for key: %s, but no corresponding answer found in the vector database", PLUGIN_NAME, mostSimilarData.Text) + log.Infof("[%s] cache hit for key: %s, but no corresponding answer found in the vector database", PLUGIN_NAME, mostSimilarData.Text) proxywasm.ResumeHttpRequest() } } else { - log.Infof("[%s] [handleQueryResults] Score not meet the threshold %f: %s with score %f", PLUGIN_NAME, simThreshold, mostSimilarData.Text, mostSimilarData.Score) + log.Infof("[%s] score not meet the threshold %f: %s with score %f", PLUGIN_NAME, simThreshold, mostSimilarData.Text, mostSimilarData.Score) proxywasm.ResumeHttpRequest() } } // logAndReturnError logs an error and returns it. func logAndReturnError(log wrapper.Log, message string) error { + message = fmt.Sprintf("[%s] %s", PLUGIN_NAME, message) log.Errorf(message) return errors.New(message) } @@ -182,9 +183,9 @@ func logAndReturnError(log wrapper.Log, message string) error { // handleInternalError logs an error and resumes the HTTP request. func handleInternalError(err error, message string, log wrapper.Log) { if err != nil { - log.Errorf("%s: %v", message, err) + log.Errorf("[%s] [handleInternalError] %s: %v", PLUGIN_NAME, message, err) } else { - log.Errorf(message) + log.Errorf("[%s] [handleInternalError] %s", PLUGIN_NAME, message) } // proxywasm.SendHttpResponse(500, [][2]string{{"content-type", "text/plain"}}, []byte("Internal Server Error"), -1) proxywasm.ResumeHttpRequest() @@ -195,7 +196,7 @@ func cacheResponse(ctx wrapper.HttpContext, c config.PluginConfig, key string, v activeCacheProvider := c.GetCacheProvider() if activeCacheProvider != nil { queryKey := activeCacheProvider.GetCacheKeyPrefix() + key - log.Infof("[%s] [cacheResponse] setting cache to redis, key: %s, value: %s", PLUGIN_NAME, queryKey, value) + log.Infof("[%s] setting cache to redis, key: %s, value: %s", PLUGIN_NAME, queryKey, value) _ = activeCacheProvider.Set(queryKey, value, nil) } } @@ -221,7 +222,7 @@ func uploadEmbeddingAndAnswer(ctx wrapper.HttpContext, c config.PluginConfig, ke // Attempt to upload answer embedding first if ansEmbUploader, ok := activeVectorProvider.(vector.AnswerAndEmbeddingUploader); ok { - log.Infof("[%s] [uploadEmbeddingAndAnswer] uploading answer embedding for key: %s", PLUGIN_NAME, key) + log.Infof("[%s] uploading answer embedding for key: %s", PLUGIN_NAME, key) err := ansEmbUploader.UploadAnswerAndEmbedding(key, emb, value, ctx, log, nil) if err != nil { log.Warnf("[%s] [uploadEmbeddingAndAnswer] failed to upload answer embedding for key: %s, error: %v", PLUGIN_NAME, key, err) @@ -232,10 +233,10 @@ func uploadEmbeddingAndAnswer(ctx wrapper.HttpContext, c config.PluginConfig, ke // If answer embedding upload fails, attempt normal embedding upload if embUploader, ok := activeVectorProvider.(vector.EmbeddingUploader); ok { - log.Infof("[%s] [uploadEmbeddingAndAnswer] uploading embedding for key: %s", PLUGIN_NAME, key) + log.Infof("[%s] uploading embedding for key: %s", PLUGIN_NAME, key) err := embUploader.UploadEmbedding(key, emb, ctx, log, nil) if err != nil { - log.Warnf("[uploadEmbeddingAndAnswer] failed to upload embedding for key: %s, error: %v",key, err) + log.Warnf("[%s] [uploadEmbeddingAndAnswer] failed to upload embedding for key: %s, error: %v", PLUGIN_NAME, key, err) } } } diff --git a/plugins/wasm-go/extensions/ai-cache/main.go b/plugins/wasm-go/extensions/ai-cache/main.go index 3851188b96..85234d2df8 100644 --- a/plugins/wasm-go/extensions/ai-cache/main.go +++ b/plugins/wasm-go/extensions/ai-cache/main.go @@ -116,7 +116,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, c config.PluginConfig, body []by } if err := CheckCacheForKey(key, ctx, c, log, stream, true); err != nil { - log.Errorf("check cache for key: %s failed, error: %v", key, err) + log.Errorf("[onHttpRequestBody] check cache for key: %s failed, error: %v", key, err) return types.ActionContinue } @@ -136,142 +136,33 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, c config.PluginConfig, log w return types.ActionContinue } -// func onHttpResponseBody(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, isLastChunk bool, log wrapper.Log) []byte { -// log.Debugf("[onHttpResponseBody] chunk: %q", string(chunk)) -// log.Debugf("[onHttpResponseBody] isLastChunk: %v", isLastChunk) - -// if ctx.GetContext(TOOL_CALLS_CONTEXT_KEY) != nil { -// return chunk -// } - -// key := ctx.GetContext(CACHE_KEY_CONTEXT_KEY) -// if key == nil { -// log.Debug("[onHttpResponseBody] key is nil, bypass caching") -// return chunk -// } - -// if !isLastChunk { -// handlePartialChunk(ctx, c, chunk, log) -// return chunk -// } - -// // Handle last chunk -// var value string -// var err error - -// if len(chunk) > 0 { -// value, err = processNonEmptyChunk(ctx, c, chunk, log) -// } else { -// value, err = processEmptyChunk(ctx, c, chunk, log) -// } - -// if err != nil { -// log.Warnf("[onHttpResponseBody] failed to process chunk: %v", err) -// return chunk -// } -// // Cache the final value -// cacheResponse(ctx, c, key.(string), value, log) - -// // Handle embedding upload if available -// uploadEmbeddingAndAnswer(ctx, c, key.(string), value, log) - -// return chunk -// } - func onHttpResponseBody(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, isLastChunk bool, log wrapper.Log) []byte { if string(chunk) == "data: [DONE]" { return nil } if ctx.GetContext(TOOL_CALLS_CONTEXT_KEY) != nil { - // we should not cache tool call result return chunk } + key := ctx.GetContext(CACHE_KEY_CONTEXT_KEY) if key == nil { return chunk } + if !isLastChunk { - stream := ctx.GetContext(STREAM_CONTEXT_KEY) - if stream == nil { - tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) - if tempContentI == nil { - ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, chunk) - return chunk - } - tempContent := tempContentI.([]byte) - tempContent = append(tempContent, chunk...) - ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, tempContent) - } else { - var partialMessage []byte - partialMessageI := ctx.GetContext(PARTIAL_MESSAGE_CONTEXT_KEY) - if partialMessageI != nil { - partialMessage = append(partialMessageI.([]byte), chunk...) - } else { - partialMessage = chunk - } - messages := strings.Split(string(partialMessage), "\n\n") - for i, msg := range messages { - if i < len(messages)-1 { - // process complete message - processSSEMessage(ctx, c, msg, log) - } - } - if !strings.HasSuffix(string(partialMessage), "\n\n") { - ctx.SetContext(PARTIAL_MESSAGE_CONTEXT_KEY, []byte(messages[len(messages)-1])) - } else { - ctx.SetContext(PARTIAL_MESSAGE_CONTEXT_KEY, nil) - } - } - return chunk + return handleNonLastChunk(ctx, c, chunk, log) } - // last chunk + stream := ctx.GetContext(STREAM_CONTEXT_KEY) var value string if stream == nil { - var body []byte - tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) - if tempContentI != nil { - body = append(tempContentI.([]byte), chunk...) - } else { - body = chunk - } - bodyJson := gjson.ParseBytes(body) - - value = bodyJson.Get(c.CacheValueFrom).String() - if value == "" { - log.Warnf("parse value from response body failded, body:%s", body) - return chunk - } + value = processNonStreamLastChunk(ctx, c, chunk, log) } else { - if len(chunk) > 0 { - var lastMessage []byte - partialMessageI := ctx.GetContext(PARTIAL_MESSAGE_CONTEXT_KEY) - if partialMessageI != nil { - lastMessage = append(partialMessageI.([]byte), chunk...) - } else { - lastMessage = chunk - } - if !strings.HasSuffix(string(lastMessage), "\n\n") { - log.Warnf("invalid lastMessage:%s", lastMessage) - return chunk - } - // remove the last \n\n - lastMessage = lastMessage[:len(lastMessage)-2] - value = processSSEMessage(ctx, c, string(lastMessage), log) - } else { - tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) - if tempContentI == nil { - return chunk - } - value = tempContentI.(string) - } + value = processStreamLastChunk(ctx, c, chunk, log) } - // Cache the final value - cacheResponse(ctx, c, key.(string), value, log) - // Handle embedding upload if available + cacheResponse(ctx, c, key.(string), value, log) uploadEmbeddingAndAnswer(ctx, c, key.(string), value, log) - return chunk } diff --git a/plugins/wasm-go/extensions/ai-cache/util.go b/plugins/wasm-go/extensions/ai-cache/util.go index 590d5213d4..ba26442d94 100644 --- a/plugins/wasm-go/extensions/ai-cache/util.go +++ b/plugins/wasm-go/extensions/ai-cache/util.go @@ -1,7 +1,6 @@ package main import ( - "fmt" "strings" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/config" @@ -9,6 +8,87 @@ import ( "github.com/tidwall/gjson" ) +func handleNonLastChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) []byte { + stream := ctx.GetContext(STREAM_CONTEXT_KEY) + if stream == nil { + return handleNonStreamChunk(ctx, c, chunk, log) + } + return handleStreamChunk(ctx, c, chunk, log) +} + +func handleNonStreamChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) []byte { + tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) + if tempContentI == nil { + ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, chunk) + return chunk + } + tempContent := tempContentI.([]byte) + tempContent = append(tempContent, chunk...) + ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, tempContent) + return chunk +} + +func handleStreamChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) []byte { + var partialMessage []byte + partialMessageI := ctx.GetContext(PARTIAL_MESSAGE_CONTEXT_KEY) + if partialMessageI != nil { + partialMessage = append(partialMessageI.([]byte), chunk...) + } else { + partialMessage = chunk + } + messages := strings.Split(string(partialMessage), "\n\n") + for i, msg := range messages { + if i < len(messages)-1 { + processSSEMessage(ctx, c, msg, log) + } + } + if !strings.HasSuffix(string(partialMessage), "\n\n") { + ctx.SetContext(PARTIAL_MESSAGE_CONTEXT_KEY, []byte(messages[len(messages)-1])) + } else { + ctx.SetContext(PARTIAL_MESSAGE_CONTEXT_KEY, nil) + } + return chunk +} + +func processNonStreamLastChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) string { + var body []byte + tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) + if tempContentI != nil { + body = append(tempContentI.([]byte), chunk...) + } else { + body = chunk + } + bodyJson := gjson.ParseBytes(body) + value := bodyJson.Get(c.CacheValueFrom).String() + if value == "" { + log.Warnf("[%s] [processNonStreamLastChunk] parse value from response body failed, body:%s", PLUGIN_NAME, body) + } + return value +} + +func processStreamLastChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) string { + if len(chunk) > 0 { + var lastMessage []byte + partialMessageI := ctx.GetContext(PARTIAL_MESSAGE_CONTEXT_KEY) + if partialMessageI != nil { + lastMessage = append(partialMessageI.([]byte), chunk...) + } else { + lastMessage = chunk + } + if !strings.HasSuffix(string(lastMessage), "\n\n") { + log.Warnf("[%s] [processStreamLastChunk] invalid lastMessage:%s", PLUGIN_NAME, lastMessage) + return "" + } + lastMessage = lastMessage[:len(lastMessage)-2] + return processSSEMessage(ctx, c, string(lastMessage), log) + } + tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) + if tempContentI == nil { + return "" + } + return tempContentI.(string) +} + func processSSEMessage(ctx wrapper.HttpContext, c config.PluginConfig, sseMessage string, log wrapper.Log) string { subMessages := strings.Split(sseMessage, "\n") var message string @@ -19,7 +99,7 @@ func processSSEMessage(ctx wrapper.HttpContext, c config.PluginConfig, sseMessag } } if len(message) < 6 { - log.Warnf("invalid message: %s", message) + log.Warnf("[%s] [processSSEMessage] invalid message: %s", PLUGIN_NAME, message) return "" } @@ -37,7 +117,7 @@ func processSSEMessage(ctx wrapper.HttpContext, c config.PluginConfig, sseMessag // Check if the ResponseBody field exists if !responseBody.Exists() { // Return an empty string if we cannot extract the content - log.Warnf("cannot extract content from message: %s", message) + log.Warnf("[%s] [processSSEMessage] cannot extract content from message: %s", PLUGIN_NAME, message) return "" } else { tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) @@ -55,115 +135,4 @@ func processSSEMessage(ctx wrapper.HttpContext, c config.PluginConfig, sseMessag ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, content) return content } - -} - -// Handles partial chunks of data when the full response is not received yet. -func handlePartialChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) { - stream := ctx.GetContext(STREAM_CONTEXT_KEY) - - if stream == nil { - tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) - if tempContentI == nil { - ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, chunk) - } else { - tempContent := append(tempContentI.([]byte), chunk...) - ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, tempContent) - } - } else { - partialMessage := appendPartialMessage(ctx, chunk) - messages := strings.Split(string(partialMessage), "\n\n") - for _, msg := range messages[:len(messages)-1] { - processSSEMessage(ctx, c, msg, log) - } - savePartialMessage(ctx, partialMessage, messages) - } -} - -// Appends the partial message chunks -func appendPartialMessage(ctx wrapper.HttpContext, chunk []byte) []byte { - partialMessageI := ctx.GetContext(PARTIAL_MESSAGE_CONTEXT_KEY) - if partialMessageI != nil { - return append(partialMessageI.([]byte), chunk...) - } - return chunk -} - -// Saves the remaining partial message chunk -func savePartialMessage(ctx wrapper.HttpContext, partialMessage []byte, messages []string) { - if len(messages) == 0 { - ctx.SetContext(PARTIAL_MESSAGE_CONTEXT_KEY, nil) - return - } - - if !strings.HasSuffix(string(partialMessage), "\n\n") { - ctx.SetContext(PARTIAL_MESSAGE_CONTEXT_KEY, []byte(messages[len(messages)-1])) - } else { - ctx.SetContext(PARTIAL_MESSAGE_CONTEXT_KEY, nil) - } -} - -// Processes a non-empty data chunk and returns the parsed value or an error -func processNonEmptyChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) (string, error) { - stream := ctx.GetContext(STREAM_CONTEXT_KEY) - var value string - - if stream == nil { - body := appendFinalBody(ctx, chunk) - bodyJson := gjson.ParseBytes(body) - value = bodyJson.Get(c.CacheValueFrom).String() - - if value == "" { - return "", fmt.Errorf("failed to parse value from response body: %s", body) - } - } else { - value, err := processFinalStreamMessage(ctx, c, log, chunk) - if err != nil { - return "", err - } - return value, nil - } - - return value, nil -} - -func processEmptyChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) (string, error) { - tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) - if tempContentI == nil { - return string(chunk), nil - } - value, ok := tempContentI.([]byte) - if !ok { - return "", fmt.Errorf("invalid type for tempContentI") - } - return string(value), nil -} - -// Appends the final body chunk to the existing body content -func appendFinalBody(ctx wrapper.HttpContext, chunk []byte) []byte { - tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) - if tempContentI != nil { - return append(tempContentI.([]byte), chunk...) - } - return chunk -} - -// Processes the final SSE message chunk -func processFinalStreamMessage(ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log, chunk []byte) (string, error) { - var lastMessage []byte - partialMessageI := ctx.GetContext(PARTIAL_MESSAGE_CONTEXT_KEY) - - if partialMessageI != nil { - lastMessage = append(partialMessageI.([]byte), chunk...) - } else { - lastMessage = chunk - } - - if !strings.HasSuffix(string(lastMessage), "\n\n") { - log.Warnf("[onHttpResponseBody] invalid lastMessage: %s", lastMessage) - return "", fmt.Errorf("invalid lastMessage format") - } - - lastMessage = lastMessage[:len(lastMessage)-2] // Remove the last \n\n - return processSSEMessage(ctx, c, string(lastMessage), log), nil } From 36f0d7758731ec877feb2f232919dc1c3171e69c Mon Sep 17 00:00:00 2001 From: suchun Date: Sat, 26 Oct 2024 00:22:16 +0000 Subject: [PATCH 57/71] change *.typ to * --- plugins/wasm-go/extensions/ai-cache/README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-cache/README.md b/plugins/wasm-go/extensions/ai-cache/README.md index ca7d0f6265..edbd0562cc 100644 --- a/plugins/wasm-go/extensions/ai-cache/README.md +++ b/plugins/wasm-go/extensions/ai-cache/README.md @@ -36,9 +36,9 @@ LLM 结果缓存插件,默认配置方式可以直接用于 openai 协议的 | Name | Type | Requirement | Default | Description | | --- | --- | --- | --- | --- | -| vector.type | string | optional | "" | 向量存储服务提供者类型,例如 dashvector | -| embedding.type | string | optional | "" | 请求文本向量化服务类型,例如 dashscope | -| cache.type | string | optional | "" | 缓存服务类型,例如 redis | +| vector | string | optional | "" | 向量存储服务提供者类型,例如 dashvector | +| embedding | string | optional | "" | 请求文本向量化服务类型,例如 dashscope | +| cache | string | optional | "" | 缓存服务类型,例如 redis | | cacheKeyStrategy | string | optional | "lastQuestion" | 决定如何根据历史问题生成缓存键的策略。可选值: "lastQuestion" (使用最后一个问题), "allQuestions" (拼接所有问题) 或 "disabled" (禁用缓存) | | enableSemanticCache | bool | optional | true | 是否启用语义化缓存, 若不启用,则使用字符串匹配的方式来查找缓存,此时需要配置cache服务 | From 009a1b10923cb6d93a1d00a2a9254a62ad74a5a0 Mon Sep 17 00:00:00 2001 From: suchun Date: Sat, 26 Oct 2024 01:13:26 +0000 Subject: [PATCH 58/71] add support for legacy config --- .../extensions/ai-cache/config/config.go | 54 ++++++++++++++++--- plugins/wasm-go/extensions/ai-cache/main.go | 9 ++-- 2 files changed, 51 insertions(+), 12 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-cache/config/config.go b/plugins/wasm-go/extensions/ai-cache/config/config.go index 9c05587409..a8d9246abe 100644 --- a/plugins/wasm-go/extensions/ai-cache/config/config.go +++ b/plugins/wasm-go/extensions/ai-cache/config/config.go @@ -32,7 +32,7 @@ type PluginConfig struct { vectorProviderConfig vector.ProviderConfig cacheProviderConfig cache.ProviderConfig - // CacheKeyFrom string + CacheKeyFrom string CacheValueFrom string CacheStreamValueFrom string CacheToolCallsFrom string @@ -46,7 +46,8 @@ type PluginConfig struct { CacheKeyStrategy string } -func (c *PluginConfig) FromJson(json gjson.Result) { +func (c *PluginConfig) FromJson(json gjson.Result, log wrapper.Log) { + c.vectorProviderConfig.FromJson(json.Get("vector")) c.embeddingProviderConfig.FromJson(json.Get("embedding")) c.cacheProviderConfig.FromJson(json.Get("cache")) @@ -57,12 +58,12 @@ func (c *PluginConfig) FromJson(json gjson.Result) { c.CacheKeyStrategy = json.Get("cacheKeyStrategy").String() if c.CacheKeyStrategy == "" { - c.CacheKeyStrategy = "lastQuestion" // set default value + c.CacheKeyStrategy = CACHE_KEY_STRATEGY_LAST_QUESTION // set default value + } + c.CacheKeyFrom = json.Get("cacheKeyFrom").String() + if c.CacheKeyFrom == "" { + c.CacheKeyFrom = "messages.@reverse.0.content" } - // c.CacheKeyFrom = json.Get("cacheKeyFrom").String() - // if c.CacheKeyFrom == "" { - // c.CacheKeyFrom = "messages.@reverse.0.content" - // } c.CacheValueFrom = json.Get("cacheValueFrom").String() if c.CacheValueFrom == "" { c.CacheValueFrom = "choices.0.message.content" @@ -85,12 +86,14 @@ func (c *PluginConfig) FromJson(json gjson.Result) { c.ResponseTemplate = `{"id":"from-cache","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` } - // 默认值为 true if json.Get("enableSemanticCache").Exists() { c.EnableSemanticCache = json.Get("enableSemanticCache").Bool() } else { c.EnableSemanticCache = true // set default value to true } + + // compatible with legacy config + convertLegacyMapFields(c, json, log) } func (c *PluginConfig) Validate() error { @@ -185,3 +188,38 @@ func (c *PluginConfig) GetVectorProviderConfig() vector.ProviderConfig { func (c *PluginConfig) GetCacheProvider() cache.Provider { return c.cacheProvider } + +func convertLegacyMapFields(c *PluginConfig, json gjson.Result, log wrapper.Log) { + keyMap := map[string]string{ + "cacheKeyFrom.requestBody": "cacheKeyFrom", + "cacheValueFrom.requestBody": "cacheValueFrom", + "cacheStreamValueFrom.requestBody": "cacheStreamValueFrom", + "returnResponseTemplate": "responseTemplate", + "returnStreamResponseTemplate": "streamResponseTemplate", + } + + for oldKey, newKey := range keyMap { + if json.Get(oldKey).Exists() { + log.Debugf("[convertLegacyMapFields] mapping %s to %s", oldKey, newKey) + setField(c, newKey, json.Get(oldKey).String(), log) + } else { + log.Debugf("[convertLegacyMapFields] %s not exists", oldKey) + } + } +} + +func setField(c *PluginConfig, fieldName string, value string, log wrapper.Log) { + switch fieldName { + case "cacheKeyFrom": + c.CacheKeyFrom = value + case "cacheValueFrom": + c.CacheValueFrom = value + case "cacheStreamValueFrom": + c.CacheStreamValueFrom = value + case "responseTemplate": + c.ResponseTemplate = value + case "streamResponseTemplate": + c.StreamResponseTemplate = value + } + log.Debugf("[setField] set %s to %s", fieldName, value) +} diff --git a/plugins/wasm-go/extensions/ai-cache/main.go b/plugins/wasm-go/extensions/ai-cache/main.go index 85234d2df8..137932bc3e 100644 --- a/plugins/wasm-go/extensions/ai-cache/main.go +++ b/plugins/wasm-go/extensions/ai-cache/main.go @@ -39,7 +39,7 @@ func parseConfig(json gjson.Result, c *config.PluginConfig, log wrapper.Log) err // config.EmbeddingProviderConfig.FromJson(json.Get("embeddingProvider")) // config.VectorDatabaseProviderConfig.FromJson(json.Get("vectorBaseProvider")) // config.RedisConfig.FromJson(json.Get("redis")) - c.FromJson(json) + c.FromJson(json, log) if err := c.Validate(); err != nil { return err } @@ -86,9 +86,10 @@ func onHttpRequestBody(ctx wrapper.HttpContext, c config.PluginConfig, body []by var key string if c.CacheKeyStrategy == config.CACHE_KEY_STRATEGY_LAST_QUESTION { - key = bodyJson.Get("messages.@reverse.0.content").String() + log.Debugf("[onHttpRequestBody] cache key strategy is last question, cache key from: %s", c.CacheKeyFrom) + key = bodyJson.Get(c.CacheKeyFrom).String() } else if c.CacheKeyStrategy == config.CACHE_KEY_STRATEGY_ALL_QUESTIONS { - // Retrieve all user messages and concatenate them + log.Debugf("[onHttpRequestBody] cache key strategy is all questions, cache key from: messages") messages := bodyJson.Get("messages").Array() var userMessages []string for _, msg := range messages { @@ -98,7 +99,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, c config.PluginConfig, body []by } key = strings.Join(userMessages, "\n") } else if c.CacheKeyStrategy == config.CACHE_KEY_STRATEGY_DISABLED { - log.Debugf("[onHttpRequestBody] cache key strategy is disabled") + log.Info("[onHttpRequestBody] cache key strategy is disabled") ctx.DontReadRequestBody() return types.ActionContinue } else { From 4515f43cafc9dd3b3e30a735b2784ab62e18132d Mon Sep 17 00:00:00 2001 From: suchun Date: Sat, 26 Oct 2024 01:31:39 +0000 Subject: [PATCH 59/71] update content_type in stream resp --- plugins/wasm-go/extensions/ai-cache/config/config.go | 10 +++++----- plugins/wasm-go/extensions/ai-cache/core.go | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-cache/config/config.go b/plugins/wasm-go/extensions/ai-cache/config/config.go index a8d9246abe..51e056ceba 100644 --- a/plugins/wasm-go/extensions/ai-cache/config/config.go +++ b/plugins/wasm-go/extensions/ai-cache/config/config.go @@ -191,11 +191,11 @@ func (c *PluginConfig) GetCacheProvider() cache.Provider { func convertLegacyMapFields(c *PluginConfig, json gjson.Result, log wrapper.Log) { keyMap := map[string]string{ - "cacheKeyFrom.requestBody": "cacheKeyFrom", - "cacheValueFrom.requestBody": "cacheValueFrom", - "cacheStreamValueFrom.requestBody": "cacheStreamValueFrom", - "returnResponseTemplate": "responseTemplate", - "returnStreamResponseTemplate": "streamResponseTemplate", + "`cacheKeyFrom.requestBody`": "cacheKeyFrom", + "`cacheValueFrom.requestBody`": "cacheValueFrom", + "`cacheStreamValueFrom.requestBody`": "cacheStreamValueFrom", + "returnResponseTemplate": "responseTemplate", + "returnStreamResponseTemplate": "streamResponseTemplate", } for oldKey, newKey := range keyMap { diff --git a/plugins/wasm-go/extensions/ai-cache/core.go b/plugins/wasm-go/extensions/ai-cache/core.go index 49286a28ef..9612fde481 100644 --- a/plugins/wasm-go/extensions/ai-cache/core.go +++ b/plugins/wasm-go/extensions/ai-cache/core.go @@ -64,7 +64,7 @@ func processCacheHit(key string, response string, stream bool, ctx wrapper.HttpC ctx.SetContext(CACHE_KEY_CONTEXT_KEY, nil) if stream { - proxywasm.SendHttpResponseWithDetail(200, "ai-cache.hit", [][2]string{{"content-type", "application/json; charset=utf-8"}}, []byte(fmt.Sprintf(c.StreamResponseTemplate, response)), -1) + proxywasm.SendHttpResponseWithDetail(200, "ai-cache.hit", [][2]string{{"content-type", "text/event-stream; charset=utf-8"}}, []byte(fmt.Sprintf(c.StreamResponseTemplate, response)), -1) } else { proxywasm.SendHttpResponseWithDetail(200, "ai-cache.hit", [][2]string{{"content-type", "application/json; charset=utf-8"}}, []byte(fmt.Sprintf(c.ResponseTemplate, response)), -1) } From c0482807f66069a6b670251ec444679bbd63c7b8 Mon Sep 17 00:00:00 2001 From: suchun Date: Sat, 26 Oct 2024 10:25:36 +0000 Subject: [PATCH 60/71] fix bugs --- plugins/wasm-go/extensions/ai-cache/README.md | 2 +- .../extensions/ai-cache/cache/provider.go | 7 ++- .../extensions/ai-cache/config/config.go | 2 +- plugins/wasm-go/extensions/ai-cache/core.go | 33 +++++++++----- plugins/wasm-go/extensions/ai-cache/main.go | 30 ++++++++++--- plugins/wasm-go/extensions/ai-cache/util.go | 43 ++++++++++--------- 6 files changed, 76 insertions(+), 41 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-cache/README.md b/plugins/wasm-go/extensions/ai-cache/README.md index edbd0562cc..ca91bdf5a1 100644 --- a/plugins/wasm-go/extensions/ai-cache/README.md +++ b/plugins/wasm-go/extensions/ai-cache/README.md @@ -86,7 +86,7 @@ LLM 结果缓存插件,默认配置方式可以直接用于 openai 协议的 | cache.password | string | optional | "" | 缓存服务密码 | | cache.timeout | uint32 | optional | 10000 | 缓存服务的超时时间,单位为毫秒。默认值是10000,即10秒 | | cache.cacheTTL | int | optional | 0 | 缓存过期时间,单位为秒。默认值是 0,即 永不过期| -| cacheKeyPrefix | string | optional | "higressAiCache:" | 缓存 Key 的前缀,默认值为 "higressAiCache:" | +| cacheKeyPrefix | string | optional | "higress-ai-cache:" | 缓存 Key 的前缀,默认值为 "higress-ai-cache:" | ## 其他配置 diff --git a/plugins/wasm-go/extensions/ai-cache/cache/provider.go b/plugins/wasm-go/extensions/ai-cache/cache/provider.go index e368b2412c..1238d21570 100644 --- a/plugins/wasm-go/extensions/ai-cache/cache/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/cache/provider.go @@ -9,7 +9,7 @@ import ( const ( PROVIDER_TYPE_REDIS = "redis" - DEFAULT_CACHE_PREFIX = "higressAiCache:" + DEFAULT_CACHE_PREFIX = "higress-ai-cache:" ) type providerInitializer interface { @@ -91,8 +91,11 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { } func (c *ProviderConfig) ConvertLegacyJson(json gjson.Result) { - c.FromJson(json) + c.FromJson(json.Get("redis")) c.typ = "redis" + if json.Get("cacheTTL").Exists() { + c.cacheTTL = int(json.Get("cacheTTL").Int()) + } } func (c *ProviderConfig) Validate() error { diff --git a/plugins/wasm-go/extensions/ai-cache/config/config.go b/plugins/wasm-go/extensions/ai-cache/config/config.go index 51e056ceba..77648e80ab 100644 --- a/plugins/wasm-go/extensions/ai-cache/config/config.go +++ b/plugins/wasm-go/extensions/ai-cache/config/config.go @@ -53,7 +53,7 @@ func (c *PluginConfig) FromJson(json gjson.Result, log wrapper.Log) { c.cacheProviderConfig.FromJson(json.Get("cache")) if json.Get("redis").Exists() { // compatible with legacy config - c.cacheProviderConfig.ConvertLegacyJson(json.Get("redis")) + c.cacheProviderConfig.ConvertLegacyJson(json) } c.CacheKeyStrategy = json.Get("cacheKeyStrategy").String() diff --git a/plugins/wasm-go/extensions/ai-cache/core.go b/plugins/wasm-go/extensions/ai-cache/core.go index 9612fde481..507dd8e1cb 100644 --- a/plugins/wasm-go/extensions/ai-cache/core.go +++ b/plugins/wasm-go/extensions/ai-cache/core.go @@ -3,6 +3,7 @@ package main import ( "errors" "fmt" + "strings" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/config" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/vector" @@ -59,8 +60,17 @@ func handleCacheResponse(key string, response resp.Value, ctx wrapper.HttpContex // processCacheHit handles a successful cache hit. func processCacheHit(key string, response string, stream bool, ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log) { + if response == "" || response == " " { + log.Warnf("[%s] [processCacheHit] cached response for key %s is empty", PLUGIN_NAME, key) + proxywasm.ResumeHttpRequest() + return + } + log.Debugf("[%s] [processCacheHit] cached response for key %s: %s", PLUGIN_NAME, key, response) + // Replace newline characters in the response with escaped characters to ensure consistent formatting + response = strings.ReplaceAll(response, "\n", "\\n") + ctx.SetContext(CACHE_KEY_CONTEXT_KEY, nil) if stream { @@ -154,18 +164,16 @@ func handleQueryResults(key string, results []vector.QueryResult, ctx wrapper.Ht log.Infof("[%s] key accepted: %s with score: %f", PLUGIN_NAME, mostSimilarData.Text, mostSimilarData.Score) if mostSimilarData.Answer != "" { // direct return the answer if available + cacheResponse(ctx, c, key, mostSimilarData.Answer, log) processCacheHit(key, mostSimilarData.Answer, stream, ctx, c, log) } else { - // // otherwise, continue to check cache for the most similar key - // err = CheckCacheForKey(mostSimilarData.Text, ctx, config, log, stream, false) - // if err != nil { - // log.Errorf("check cache for key: %s failed, error: %v", mostSimilarData.Text, err) - // proxywasm.ResumeHttpRequest() - // } - - // Otherwise, do not check the cache, directly return - log.Infof("[%s] cache hit for key: %s, but no corresponding answer found in the vector database", PLUGIN_NAME, mostSimilarData.Text) - proxywasm.ResumeHttpRequest() + if c.GetCacheProvider() != nil { + CheckCacheForKey(mostSimilarData.Text, ctx, c, log, stream, false) + } else { + // Otherwise, do not check the cache, directly return + log.Infof("[%s] cache hit for key: %s, but no corresponding answer found in the vector database", PLUGIN_NAME, mostSimilarData.Text) + proxywasm.ResumeHttpRequest() + } } } else { log.Infof("[%s] score not meet the threshold %f: %s with score %f", PLUGIN_NAME, simThreshold, mostSimilarData.Text, mostSimilarData.Score) @@ -193,6 +201,11 @@ func handleInternalError(err error, message string, log wrapper.Log) { // Caches the response value func cacheResponse(ctx wrapper.HttpContext, c config.PluginConfig, key string, value string, log wrapper.Log) { + if value == "" || value == " " { + log.Warnf("[%s] [cacheResponse] cached value for key %s is empty", PLUGIN_NAME, key) + return + } + activeCacheProvider := c.GetCacheProvider() if activeCacheProvider != nil { queryKey := activeCacheProvider.GetCacheKeyPrefix() + key diff --git a/plugins/wasm-go/extensions/ai-cache/main.go b/plugins/wasm-go/extensions/ai-cache/main.go index 137932bc3e..de36738b07 100644 --- a/plugins/wasm-go/extensions/ai-cache/main.go +++ b/plugins/wasm-go/extensions/ai-cache/main.go @@ -20,7 +20,8 @@ const ( PARTIAL_MESSAGE_CONTEXT_KEY = "partialMessage" TOOL_CALLS_CONTEXT_KEY = "toolCalls" STREAM_CONTEXT_KEY = "stream" - SKIP_CACHE_HEADER = "skip-cache" + SKIP_CACHE_HEADER = "x-higress-skip-ai-cache" + ERROR_PARTIAL_MESSAGE_KEY = "errorPartialMessage" ) func main() { @@ -134,14 +135,18 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, c config.PluginConfig, log w if strings.Contains(contentType, "text/event-stream") { ctx.SetContext(STREAM_CONTEXT_KEY, struct{}{}) } + return types.ActionContinue } func onHttpResponseBody(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, isLastChunk bool, log wrapper.Log) []byte { - if string(chunk) == "data: [DONE]" { - return nil + if ctx.GetContext(ERROR_PARTIAL_MESSAGE_KEY) != nil { + if isLastChunk { + // If the last chunk is an error, clear the error flag + ctx.SetContext(ERROR_PARTIAL_MESSAGE_KEY, nil) + } + return chunk } - if ctx.GetContext(TOOL_CALLS_CONTEXT_KEY) != nil { return chunk } @@ -152,15 +157,26 @@ func onHttpResponseBody(ctx wrapper.HttpContext, c config.PluginConfig, chunk [] } if !isLastChunk { - return handleNonLastChunk(ctx, c, chunk, log) + if err := handleNonLastChunk(ctx, c, chunk, log); err != nil { + log.Errorf("[onHttpResponseBody] handle non last chunk failed, error: %v", err) + // Set an empty struct in the context to indicate an error in processing the partial message + ctx.SetContext(ERROR_PARTIAL_MESSAGE_KEY, struct{}{}) + } + return chunk } stream := ctx.GetContext(STREAM_CONTEXT_KEY) var value string + var err error if stream == nil { - value = processNonStreamLastChunk(ctx, c, chunk, log) + value, err = processNonStreamLastChunk(ctx, c, chunk, log) } else { - value = processStreamLastChunk(ctx, c, chunk, log) + value, err = processStreamLastChunk(ctx, c, chunk, log) + } + + if err != nil { + log.Errorf("[onHttpResponseBody] process last chunk failed, error: %v", err) + return chunk } cacheResponse(ctx, c, key.(string), value, log) diff --git a/plugins/wasm-go/extensions/ai-cache/util.go b/plugins/wasm-go/extensions/ai-cache/util.go index ba26442d94..bca8efa26c 100644 --- a/plugins/wasm-go/extensions/ai-cache/util.go +++ b/plugins/wasm-go/extensions/ai-cache/util.go @@ -8,27 +8,30 @@ import ( "github.com/tidwall/gjson" ) -func handleNonLastChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) []byte { +func handleNonLastChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) error { stream := ctx.GetContext(STREAM_CONTEXT_KEY) + err := error(nil) if stream == nil { - return handleNonStreamChunk(ctx, c, chunk, log) + err = handleNonStreamChunk(ctx, c, chunk, log) + } else { + err = handleStreamChunk(ctx, c, chunk, log) } - return handleStreamChunk(ctx, c, chunk, log) + return err } -func handleNonStreamChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) []byte { +func handleNonStreamChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) error { tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) if tempContentI == nil { ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, chunk) - return chunk + return nil } tempContent := tempContentI.([]byte) tempContent = append(tempContent, chunk...) ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, tempContent) - return chunk + return nil } -func handleStreamChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) []byte { +func handleStreamChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) error { var partialMessage []byte partialMessageI := ctx.GetContext(PARTIAL_MESSAGE_CONTEXT_KEY) if partialMessageI != nil { @@ -47,10 +50,10 @@ func handleStreamChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []b } else { ctx.SetContext(PARTIAL_MESSAGE_CONTEXT_KEY, nil) } - return chunk + return nil } -func processNonStreamLastChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) string { +func processNonStreamLastChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) (string, error) { var body []byte tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) if tempContentI != nil { @@ -63,10 +66,10 @@ func processNonStreamLastChunk(ctx wrapper.HttpContext, c config.PluginConfig, c if value == "" { log.Warnf("[%s] [processNonStreamLastChunk] parse value from response body failed, body:%s", PLUGIN_NAME, body) } - return value + return value, nil } -func processStreamLastChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) string { +func processStreamLastChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) (string, error) { if len(chunk) > 0 { var lastMessage []byte partialMessageI := ctx.GetContext(PARTIAL_MESSAGE_CONTEXT_KEY) @@ -77,30 +80,30 @@ func processStreamLastChunk(ctx wrapper.HttpContext, c config.PluginConfig, chun } if !strings.HasSuffix(string(lastMessage), "\n\n") { log.Warnf("[%s] [processStreamLastChunk] invalid lastMessage:%s", PLUGIN_NAME, lastMessage) - return "" + return "", nil } lastMessage = lastMessage[:len(lastMessage)-2] return processSSEMessage(ctx, c, string(lastMessage), log) } tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) if tempContentI == nil { - return "" + return "", nil } - return tempContentI.(string) + return tempContentI.(string), nil } -func processSSEMessage(ctx wrapper.HttpContext, c config.PluginConfig, sseMessage string, log wrapper.Log) string { +func processSSEMessage(ctx wrapper.HttpContext, c config.PluginConfig, sseMessage string, log wrapper.Log) (string, error) { subMessages := strings.Split(sseMessage, "\n") var message string for _, msg := range subMessages { - if strings.HasPrefix(msg, "data: ") { + if strings.HasPrefix(msg, "data:") { message = msg break } } if len(message) < 6 { log.Warnf("[%s] [processSSEMessage] invalid message: %s", PLUGIN_NAME, message) - return "" + return "", nil } // skip the prefix "data:" @@ -118,7 +121,7 @@ func processSSEMessage(ctx wrapper.HttpContext, c config.PluginConfig, sseMessag if !responseBody.Exists() { // Return an empty string if we cannot extract the content log.Warnf("[%s] [processSSEMessage] cannot extract content from message: %s", PLUGIN_NAME, message) - return "" + return "", nil } else { tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) @@ -126,13 +129,13 @@ func processSSEMessage(ctx wrapper.HttpContext, c config.PluginConfig, sseMessag if tempContentI == nil { content := responseBody.String() ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, content) - return content + return content, nil } // Update the content in the cache appendMsg := responseBody.String() content := tempContentI.(string) + appendMsg ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, content) - return content + return content, nil } } From 0ec24f3bfb21b0fc2a40442581b5a4168bf7ab95 Mon Sep 17 00:00:00 2001 From: suchun Date: Sat, 26 Oct 2024 11:15:38 +0000 Subject: [PATCH 61/71] add support for legacy configuration --- plugins/wasm-go/extensions/ai-cache/config/config.go | 10 +++++----- plugins/wasm-go/extensions/ai-cache/util.go | 7 +++++-- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-cache/config/config.go b/plugins/wasm-go/extensions/ai-cache/config/config.go index 77648e80ab..4bd6e2a18f 100644 --- a/plugins/wasm-go/extensions/ai-cache/config/config.go +++ b/plugins/wasm-go/extensions/ai-cache/config/config.go @@ -191,11 +191,11 @@ func (c *PluginConfig) GetCacheProvider() cache.Provider { func convertLegacyMapFields(c *PluginConfig, json gjson.Result, log wrapper.Log) { keyMap := map[string]string{ - "`cacheKeyFrom.requestBody`": "cacheKeyFrom", - "`cacheValueFrom.requestBody`": "cacheValueFrom", - "`cacheStreamValueFrom.requestBody`": "cacheStreamValueFrom", - "returnResponseTemplate": "responseTemplate", - "returnStreamResponseTemplate": "streamResponseTemplate", + "cacheKeyFrom.requestBody": "cacheKeyFrom", + "cacheValueFrom.requestBody": "cacheValueFrom", + "cacheStreamValueFrom.requestBody": "cacheStreamValueFrom", + "returnResponseTemplate": "responseTemplate", + "returnStreamResponseTemplate": "streamResponseTemplate", } for oldKey, newKey := range keyMap { diff --git a/plugins/wasm-go/extensions/ai-cache/util.go b/plugins/wasm-go/extensions/ai-cache/util.go index bca8efa26c..9cd0b184bc 100644 --- a/plugins/wasm-go/extensions/ai-cache/util.go +++ b/plugins/wasm-go/extensions/ai-cache/util.go @@ -1,6 +1,7 @@ package main import ( + "fmt" "strings" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/config" @@ -34,6 +35,8 @@ func handleNonStreamChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk func handleStreamChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) error { var partialMessage []byte partialMessageI := ctx.GetContext(PARTIAL_MESSAGE_CONTEXT_KEY) + log.Debugf("[%s] [handleStreamChunk] chunk: %s", PLUGIN_NAME, chunk) + log.Debugf("[%s] [handleStreamChunk] partialMessageI: %v", PLUGIN_NAME, partialMessageI) if partialMessageI != nil { partialMessage = append(partialMessageI.([]byte), chunk...) } else { @@ -120,8 +123,8 @@ func processSSEMessage(ctx wrapper.HttpContext, c config.PluginConfig, sseMessag // Check if the ResponseBody field exists if !responseBody.Exists() { // Return an empty string if we cannot extract the content - log.Warnf("[%s] [processSSEMessage] cannot extract content from message: %s", PLUGIN_NAME, message) - return "", nil + // log.Warnf("[%s] [processSSEMessage] cannot extract content from message: %s", PLUGIN_NAME, message) + return "", fmt.Errorf("[%s] [processSSEMessage] cannot extract content from message: %s", PLUGIN_NAME, message) } else { tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) From a658bfe2efb4c5af6bbd9b1f72e248e558e83ad8 Mon Sep 17 00:00:00 2001 From: suchun Date: Sat, 26 Oct 2024 12:23:15 +0000 Subject: [PATCH 62/71] fix bugs --- .../extensions/ai-cache/config/config.go | 4 +-- plugins/wasm-go/extensions/ai-cache/core.go | 14 ++++++----- plugins/wasm-go/extensions/ai-cache/main.go | 6 ++--- plugins/wasm-go/extensions/ai-cache/util.go | 25 +++++++++++-------- 4 files changed, 27 insertions(+), 22 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-cache/config/config.go b/plugins/wasm-go/extensions/ai-cache/config/config.go index 4bd6e2a18f..d69e19d3c0 100644 --- a/plugins/wasm-go/extensions/ai-cache/config/config.go +++ b/plugins/wasm-go/extensions/ai-cache/config/config.go @@ -79,11 +79,11 @@ func (c *PluginConfig) FromJson(json gjson.Result, log wrapper.Log) { c.StreamResponseTemplate = json.Get("streamResponseTemplate").String() if c.StreamResponseTemplate == "" { - c.StreamResponseTemplate = `data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + "\n\ndata:[DONE]\n\n" + c.StreamResponseTemplate = `data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":%s},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + "\n\ndata:[DONE]\n\n" } c.ResponseTemplate = json.Get("responseTemplate").String() if c.ResponseTemplate == "" { - c.ResponseTemplate = `{"id":"from-cache","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + c.ResponseTemplate = `{"id":"from-cache","choices":[{"index":0,"message":{"role":"assistant","content":%s},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` } if json.Get("enableSemanticCache").Exists() { diff --git a/plugins/wasm-go/extensions/ai-cache/core.go b/plugins/wasm-go/extensions/ai-cache/core.go index 507dd8e1cb..b3f4e59278 100644 --- a/plugins/wasm-go/extensions/ai-cache/core.go +++ b/plugins/wasm-go/extensions/ai-cache/core.go @@ -3,6 +3,7 @@ package main import ( "errors" "fmt" + "strconv" "strings" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/config" @@ -60,7 +61,7 @@ func handleCacheResponse(key string, response resp.Value, ctx wrapper.HttpContex // processCacheHit handles a successful cache hit. func processCacheHit(key string, response string, stream bool, ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log) { - if response == "" || response == " " { + if strings.TrimSpace(response) == "" { log.Warnf("[%s] [processCacheHit] cached response for key %s is empty", PLUGIN_NAME, key) proxywasm.ResumeHttpRequest() return @@ -69,14 +70,15 @@ func processCacheHit(key string, response string, stream bool, ctx wrapper.HttpC log.Debugf("[%s] [processCacheHit] cached response for key %s: %s", PLUGIN_NAME, key, response) // Replace newline characters in the response with escaped characters to ensure consistent formatting - response = strings.ReplaceAll(response, "\n", "\\n") + // response = strings.ReplaceAll(response, "\n", "\\n") + escapedResponse := strconv.Quote(response) ctx.SetContext(CACHE_KEY_CONTEXT_KEY, nil) if stream { - proxywasm.SendHttpResponseWithDetail(200, "ai-cache.hit", [][2]string{{"content-type", "text/event-stream; charset=utf-8"}}, []byte(fmt.Sprintf(c.StreamResponseTemplate, response)), -1) + proxywasm.SendHttpResponseWithDetail(200, "ai-cache.hit", [][2]string{{"content-type", "text/event-stream; charset=utf-8"}}, []byte(fmt.Sprintf(c.StreamResponseTemplate, escapedResponse)), -1) } else { - proxywasm.SendHttpResponseWithDetail(200, "ai-cache.hit", [][2]string{{"content-type", "application/json; charset=utf-8"}}, []byte(fmt.Sprintf(c.ResponseTemplate, response)), -1) + proxywasm.SendHttpResponseWithDetail(200, "ai-cache.hit", [][2]string{{"content-type", "application/json; charset=utf-8"}}, []byte(fmt.Sprintf(c.ResponseTemplate, escapedResponse)), -1) } } @@ -201,7 +203,7 @@ func handleInternalError(err error, message string, log wrapper.Log) { // Caches the response value func cacheResponse(ctx wrapper.HttpContext, c config.PluginConfig, key string, value string, log wrapper.Log) { - if value == "" || value == " " { + if strings.TrimSpace(value) == "" { log.Warnf("[%s] [cacheResponse] cached value for key %s is empty", PLUGIN_NAME, key) return } @@ -209,8 +211,8 @@ func cacheResponse(ctx wrapper.HttpContext, c config.PluginConfig, key string, v activeCacheProvider := c.GetCacheProvider() if activeCacheProvider != nil { queryKey := activeCacheProvider.GetCacheKeyPrefix() + key - log.Infof("[%s] setting cache to redis, key: %s, value: %s", PLUGIN_NAME, queryKey, value) _ = activeCacheProvider.Set(queryKey, value, nil) + log.Debugf("[%s] [cacheResponse] cache set success, key: %s, length of value: %d", PLUGIN_NAME, queryKey, len(value)) } } diff --git a/plugins/wasm-go/extensions/ai-cache/main.go b/plugins/wasm-go/extensions/ai-cache/main.go index de36738b07..16a2d1e52a 100644 --- a/plugins/wasm-go/extensions/ai-cache/main.go +++ b/plugins/wasm-go/extensions/ai-cache/main.go @@ -141,18 +141,16 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, c config.PluginConfig, log w func onHttpResponseBody(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, isLastChunk bool, log wrapper.Log) []byte { if ctx.GetContext(ERROR_PARTIAL_MESSAGE_KEY) != nil { - if isLastChunk { - // If the last chunk is an error, clear the error flag - ctx.SetContext(ERROR_PARTIAL_MESSAGE_KEY, nil) - } return chunk } + if ctx.GetContext(TOOL_CALLS_CONTEXT_KEY) != nil { return chunk } key := ctx.GetContext(CACHE_KEY_CONTEXT_KEY) if key == nil { + log.Debug("[onHttpResponseBody] key is nil, skip cache") return chunk } diff --git a/plugins/wasm-go/extensions/ai-cache/util.go b/plugins/wasm-go/extensions/ai-cache/util.go index 9cd0b184bc..1c00adb631 100644 --- a/plugins/wasm-go/extensions/ai-cache/util.go +++ b/plugins/wasm-go/extensions/ai-cache/util.go @@ -45,7 +45,10 @@ func handleStreamChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []b messages := strings.Split(string(partialMessage), "\n\n") for i, msg := range messages { if i < len(messages)-1 { - processSSEMessage(ctx, c, msg, log) + _, err := processSSEMessage(ctx, c, msg, log) + if err != nil { + return fmt.Errorf("[%s] [handleStreamChunk] processSSEMessage failed, error: %v", PLUGIN_NAME, err) + } } } if !strings.HasSuffix(string(partialMessage), "\n\n") { @@ -66,8 +69,8 @@ func processNonStreamLastChunk(ctx wrapper.HttpContext, c config.PluginConfig, c } bodyJson := gjson.ParseBytes(body) value := bodyJson.Get(c.CacheValueFrom).String() - if value == "" { - log.Warnf("[%s] [processNonStreamLastChunk] parse value from response body failed, body:%s", PLUGIN_NAME, body) + if strings.TrimSpace(value) == "" { + return "", fmt.Errorf("[%s] [processNonStreamLastChunk] parse value from response body failed, body:%s", PLUGIN_NAME, body) } return value, nil } @@ -82,11 +85,14 @@ func processStreamLastChunk(ctx wrapper.HttpContext, c config.PluginConfig, chun lastMessage = chunk } if !strings.HasSuffix(string(lastMessage), "\n\n") { - log.Warnf("[%s] [processStreamLastChunk] invalid lastMessage:%s", PLUGIN_NAME, lastMessage) - return "", nil + return "", fmt.Errorf("[%s] [processStreamLastChunk] invalid lastMessage:%s", PLUGIN_NAME, lastMessage) } lastMessage = lastMessage[:len(lastMessage)-2] - return processSSEMessage(ctx, c, string(lastMessage), log) + value, err := processSSEMessage(ctx, c, string(lastMessage), log) + if err != nil { + return "", fmt.Errorf("[%s] [processStreamLastChunk] processSSEMessage failed, error: %v", PLUGIN_NAME, err) + } + return value, nil } tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) if tempContentI == nil { @@ -105,8 +111,7 @@ func processSSEMessage(ctx wrapper.HttpContext, c config.PluginConfig, sseMessag } } if len(message) < 6 { - log.Warnf("[%s] [processSSEMessage] invalid message: %s", PLUGIN_NAME, message) - return "", nil + return "", fmt.Errorf("[%s] [processSSEMessage] invalid message: %s", PLUGIN_NAME, message) } // skip the prefix "data:" @@ -123,8 +128,8 @@ func processSSEMessage(ctx wrapper.HttpContext, c config.PluginConfig, sseMessag // Check if the ResponseBody field exists if !responseBody.Exists() { // Return an empty string if we cannot extract the content - // log.Warnf("[%s] [processSSEMessage] cannot extract content from message: %s", PLUGIN_NAME, message) - return "", fmt.Errorf("[%s] [processSSEMessage] cannot extract content from message: %s", PLUGIN_NAME, message) + log.Warnf("[%s] [processSSEMessage] cannot extract content from message: %s", PLUGIN_NAME, message) + return "", nil } else { tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) From a19914414624e062b75f8282fffecfbc29fb0c5f Mon Sep 17 00:00:00 2001 From: suchun Date: Sat, 26 Oct 2024 21:52:34 +0000 Subject: [PATCH 63/71] handle the data: [DONE] and return in escaped string --- .../wasm-go/extensions/ai-cache/config/config.go | 4 ++-- plugins/wasm-go/extensions/ai-cache/core.go | 2 +- plugins/wasm-go/extensions/ai-cache/main.go | 2 ++ plugins/wasm-go/extensions/ai-cache/util.go | 16 +++++++++++----- 4 files changed, 16 insertions(+), 8 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-cache/config/config.go b/plugins/wasm-go/extensions/ai-cache/config/config.go index d69e19d3c0..4bd6e2a18f 100644 --- a/plugins/wasm-go/extensions/ai-cache/config/config.go +++ b/plugins/wasm-go/extensions/ai-cache/config/config.go @@ -79,11 +79,11 @@ func (c *PluginConfig) FromJson(json gjson.Result, log wrapper.Log) { c.StreamResponseTemplate = json.Get("streamResponseTemplate").String() if c.StreamResponseTemplate == "" { - c.StreamResponseTemplate = `data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":%s},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + "\n\ndata:[DONE]\n\n" + c.StreamResponseTemplate = `data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + "\n\ndata:[DONE]\n\n" } c.ResponseTemplate = json.Get("responseTemplate").String() if c.ResponseTemplate == "" { - c.ResponseTemplate = `{"id":"from-cache","choices":[{"index":0,"message":{"role":"assistant","content":%s},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + c.ResponseTemplate = `{"id":"from-cache","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` } if json.Get("enableSemanticCache").Exists() { diff --git a/plugins/wasm-go/extensions/ai-cache/core.go b/plugins/wasm-go/extensions/ai-cache/core.go index b3f4e59278..c70e386ff9 100644 --- a/plugins/wasm-go/extensions/ai-cache/core.go +++ b/plugins/wasm-go/extensions/ai-cache/core.go @@ -71,7 +71,7 @@ func processCacheHit(key string, response string, stream bool, ctx wrapper.HttpC // Replace newline characters in the response with escaped characters to ensure consistent formatting // response = strings.ReplaceAll(response, "\n", "\\n") - escapedResponse := strconv.Quote(response) + escapedResponse := strings.Trim(strconv.Quote(response), "\"") ctx.SetContext(CACHE_KEY_CONTEXT_KEY, nil) diff --git a/plugins/wasm-go/extensions/ai-cache/main.go b/plugins/wasm-go/extensions/ai-cache/main.go index 16a2d1e52a..cd91f9a848 100644 --- a/plugins/wasm-go/extensions/ai-cache/main.go +++ b/plugins/wasm-go/extensions/ai-cache/main.go @@ -140,6 +140,8 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, c config.PluginConfig, log w } func onHttpResponseBody(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, isLastChunk bool, log wrapper.Log) []byte { + log.Debugf("[onHttpResponseBody] is last chunk: %v", isLastChunk) + log.Debugf("[onHttpResponseBody] chunk: %s", string(chunk)) if ctx.GetContext(ERROR_PARTIAL_MESSAGE_KEY) != nil { return chunk } diff --git a/plugins/wasm-go/extensions/ai-cache/util.go b/plugins/wasm-go/extensions/ai-cache/util.go index 1c00adb631..b36801b49a 100644 --- a/plugins/wasm-go/extensions/ai-cache/util.go +++ b/plugins/wasm-go/extensions/ai-cache/util.go @@ -35,8 +35,7 @@ func handleNonStreamChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk func handleStreamChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) error { var partialMessage []byte partialMessageI := ctx.GetContext(PARTIAL_MESSAGE_CONTEXT_KEY) - log.Debugf("[%s] [handleStreamChunk] chunk: %s", PLUGIN_NAME, chunk) - log.Debugf("[%s] [handleStreamChunk] partialMessageI: %v", PLUGIN_NAME, partialMessageI) + log.Debugf("[%s] [handleStreamChunk] cache content: %v", PLUGIN_NAME, ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY)) if partialMessageI != nil { partialMessage = append(partialMessageI.([]byte), chunk...) } else { @@ -116,6 +115,11 @@ func processSSEMessage(ctx wrapper.HttpContext, c config.PluginConfig, sseMessag // skip the prefix "data:" bodyJson := message[5:] + + if strings.TrimSpace(bodyJson) == "[DONE]" { + return "", nil + } + // Extract values from JSON fields responseBody := gjson.Get(bodyJson, c.CacheStreamValueFrom) toolCalls := gjson.Get(bodyJson, c.CacheToolCallsFrom) @@ -127,9 +131,11 @@ func processSSEMessage(ctx wrapper.HttpContext, c config.PluginConfig, sseMessag // Check if the ResponseBody field exists if !responseBody.Exists() { - // Return an empty string if we cannot extract the content - log.Warnf("[%s] [processSSEMessage] cannot extract content from message: %s", PLUGIN_NAME, message) - return "", nil + if ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) != nil { + log.Debugf("[%s] [processSSEMessage] unable to extract content from message; cache content is not nil: %s", PLUGIN_NAME, message) + return "", nil + } + return "", fmt.Errorf("[%s] [processSSEMessage] unable to extract content from message; cache content is nil: %s", PLUGIN_NAME, message) } else { tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) From 77f05d64ba503ba09cb53dcd0ab6cb5a8de59a4e Mon Sep 17 00:00:00 2001 From: suchun Date: Sat, 26 Oct 2024 22:12:43 +0000 Subject: [PATCH 64/71] dont read resp when ERROR_PARTIAL_MESSAGE_KEY not nil --- plugins/wasm-go/extensions/ai-cache/core.go | 3 +-- plugins/wasm-go/extensions/ai-cache/main.go | 8 +++++--- plugins/wasm-go/extensions/ai-cache/util.go | 16 ++++++++-------- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-cache/core.go b/plugins/wasm-go/extensions/ai-cache/core.go index c70e386ff9..19a9b2b856 100644 --- a/plugins/wasm-go/extensions/ai-cache/core.go +++ b/plugins/wasm-go/extensions/ai-cache/core.go @@ -69,8 +69,7 @@ func processCacheHit(key string, response string, stream bool, ctx wrapper.HttpC log.Debugf("[%s] [processCacheHit] cached response for key %s: %s", PLUGIN_NAME, key, response) - // Replace newline characters in the response with escaped characters to ensure consistent formatting - // response = strings.ReplaceAll(response, "\n", "\\n") + // Escape the response to ensure consistent formatting escapedResponse := strings.Trim(strconv.Quote(response), "\"") ctx.SetContext(CACHE_KEY_CONTEXT_KEY, nil) diff --git a/plugins/wasm-go/extensions/ai-cache/main.go b/plugins/wasm-go/extensions/ai-cache/main.go index cd91f9a848..1aca29f0ec 100644 --- a/plugins/wasm-go/extensions/ai-cache/main.go +++ b/plugins/wasm-go/extensions/ai-cache/main.go @@ -136,15 +136,17 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, c config.PluginConfig, log w ctx.SetContext(STREAM_CONTEXT_KEY, struct{}{}) } + if ctx.GetContext(ERROR_PARTIAL_MESSAGE_KEY) != nil { + ctx.DontReadResponseBody() + return types.ActionContinue + } + return types.ActionContinue } func onHttpResponseBody(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, isLastChunk bool, log wrapper.Log) []byte { log.Debugf("[onHttpResponseBody] is last chunk: %v", isLastChunk) log.Debugf("[onHttpResponseBody] chunk: %s", string(chunk)) - if ctx.GetContext(ERROR_PARTIAL_MESSAGE_KEY) != nil { - return chunk - } if ctx.GetContext(TOOL_CALLS_CONTEXT_KEY) != nil { return chunk diff --git a/plugins/wasm-go/extensions/ai-cache/util.go b/plugins/wasm-go/extensions/ai-cache/util.go index b36801b49a..983dfbb25a 100644 --- a/plugins/wasm-go/extensions/ai-cache/util.go +++ b/plugins/wasm-go/extensions/ai-cache/util.go @@ -35,7 +35,7 @@ func handleNonStreamChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk func handleStreamChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) error { var partialMessage []byte partialMessageI := ctx.GetContext(PARTIAL_MESSAGE_CONTEXT_KEY) - log.Debugf("[%s] [handleStreamChunk] cache content: %v", PLUGIN_NAME, ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY)) + log.Debugf("[handleStreamChunk] cache content: %v", ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY)) if partialMessageI != nil { partialMessage = append(partialMessageI.([]byte), chunk...) } else { @@ -46,7 +46,7 @@ func handleStreamChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []b if i < len(messages)-1 { _, err := processSSEMessage(ctx, c, msg, log) if err != nil { - return fmt.Errorf("[%s] [handleStreamChunk] processSSEMessage failed, error: %v", PLUGIN_NAME, err) + return fmt.Errorf("[handleStreamChunk] processSSEMessage failed, error: %v", err) } } } @@ -69,7 +69,7 @@ func processNonStreamLastChunk(ctx wrapper.HttpContext, c config.PluginConfig, c bodyJson := gjson.ParseBytes(body) value := bodyJson.Get(c.CacheValueFrom).String() if strings.TrimSpace(value) == "" { - return "", fmt.Errorf("[%s] [processNonStreamLastChunk] parse value from response body failed, body:%s", PLUGIN_NAME, body) + return "", fmt.Errorf("[processNonStreamLastChunk] parse value from response body failed, body:%s", body) } return value, nil } @@ -84,12 +84,12 @@ func processStreamLastChunk(ctx wrapper.HttpContext, c config.PluginConfig, chun lastMessage = chunk } if !strings.HasSuffix(string(lastMessage), "\n\n") { - return "", fmt.Errorf("[%s] [processStreamLastChunk] invalid lastMessage:%s", PLUGIN_NAME, lastMessage) + return "", fmt.Errorf("[processStreamLastChunk] invalid lastMessage:%s", lastMessage) } lastMessage = lastMessage[:len(lastMessage)-2] value, err := processSSEMessage(ctx, c, string(lastMessage), log) if err != nil { - return "", fmt.Errorf("[%s] [processStreamLastChunk] processSSEMessage failed, error: %v", PLUGIN_NAME, err) + return "", fmt.Errorf("[processStreamLastChunk] processSSEMessage failed, error: %v", err) } return value, nil } @@ -110,7 +110,7 @@ func processSSEMessage(ctx wrapper.HttpContext, c config.PluginConfig, sseMessag } } if len(message) < 6 { - return "", fmt.Errorf("[%s] [processSSEMessage] invalid message: %s", PLUGIN_NAME, message) + return "", fmt.Errorf("[processSSEMessage] invalid message: %s", message) } // skip the prefix "data:" @@ -132,10 +132,10 @@ func processSSEMessage(ctx wrapper.HttpContext, c config.PluginConfig, sseMessag // Check if the ResponseBody field exists if !responseBody.Exists() { if ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) != nil { - log.Debugf("[%s] [processSSEMessage] unable to extract content from message; cache content is not nil: %s", PLUGIN_NAME, message) + log.Debugf("[processSSEMessage] unable to extract content from message; cache content is not nil: %s", message) return "", nil } - return "", fmt.Errorf("[%s] [processSSEMessage] unable to extract content from message; cache content is nil: %s", PLUGIN_NAME, message) + return "", fmt.Errorf("[processSSEMessage] unable to extract content from message; cache content is nil: %s", message) } else { tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) From 28c629c1a006c1a1ffc88736196a3618efc426d3 Mon Sep 17 00:00:00 2001 From: Kent Dong Date: Sun, 27 Oct 2024 11:40:17 +0800 Subject: [PATCH 65/71] Update redis_wrapper.go --- plugins/wasm-go/pkg/wrapper/redis_wrapper.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasm-go/pkg/wrapper/redis_wrapper.go b/plugins/wasm-go/pkg/wrapper/redis_wrapper.go index 13ba7c0d98..c619c3e191 100644 --- a/plugins/wasm-go/pkg/wrapper/redis_wrapper.go +++ b/plugins/wasm-go/pkg/wrapper/redis_wrapper.go @@ -237,9 +237,9 @@ func (c RedisClusterClient[C]) SetEx(key string, value interface{}, ttl int, cal args := make([]interface{}, 0) args = append(args, "set") args = append(args, key) + args = append(args, value) args = append(args, "ex") args = append(args, ttl) - args = append(args, value) return RedisCall(c.cluster, respString(args), callback) } From d9ce358ac9bd009e30ca9567c2384758feb9cea1 Mon Sep 17 00:00:00 2001 From: async Date: Tue, 29 Oct 2024 11:20:06 +0800 Subject: [PATCH 66/71] merge --- .../extensions/ai-cache/vector/chroma.go | 2 +- .../ai-cache/vector/elasticsearch.go | 2 +- .../extensions/ai-cache/vector/milvus.go | 10 +++---- .../extensions/ai-cache/vector/pinecone.go | 2 +- .../extensions/ai-cache/vector/provider.go | 28 +++++++++---------- .../extensions/ai-cache/vector/qdrant.go | 2 +- .../extensions/ai-cache/vector/weaviate.go | 2 +- 7 files changed, 24 insertions(+), 24 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-cache/vector/chroma.go b/plugins/wasm-go/extensions/ai-cache/vector/chroma.go index 1dd46ac790..a15e72bd40 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/chroma.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/chroma.go @@ -38,7 +38,7 @@ type ChromaProvider struct { } func (c *ChromaProvider) GetProviderType() string { - return providerTypeChroma + return PROVIDER_TYPE_CHROMA } func (d *ChromaProvider) QueryEmbedding( diff --git a/plugins/wasm-go/extensions/ai-cache/vector/elasticsearch.go b/plugins/wasm-go/extensions/ai-cache/vector/elasticsearch.go index 989a34cce0..7a3b9d5b3e 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/elasticsearch.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/elasticsearch.go @@ -39,7 +39,7 @@ type ESProvider struct { } func (c *ESProvider) GetProviderType() string { - return providerTypeES + return PROVIDER_TYPE_ES } func (d *ESProvider) QueryEmbedding( diff --git a/plugins/wasm-go/extensions/ai-cache/vector/milvus.go b/plugins/wasm-go/extensions/ai-cache/vector/milvus.go index f057c33b88..7e5ee205f4 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/milvus.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/milvus.go @@ -25,10 +25,10 @@ func (c *milvusProviderInitializer) ValidateConfig(config ProviderConfig) error func (c *milvusProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { return &milvusProvider{ config: config, - client: wrapper.NewClusterClient(wrapper.DnsCluster{ - ServiceName: config.serviceName, - Port: config.servicePort, - Domain: config.serviceDomain, + client: wrapper.NewClusterClient(wrapper.FQDNCluster{ + FQDN: config.serviceName, + Host: config.serviceHost, + Port: int64(config.servicePort), }), }, nil } @@ -39,7 +39,7 @@ type milvusProvider struct { } func (c *milvusProvider) GetProviderType() string { - return providerTypeMilvus + return PROVIDER_TYPE_MILVUS } type milvusData struct { diff --git a/plugins/wasm-go/extensions/ai-cache/vector/pinecone.go b/plugins/wasm-go/extensions/ai-cache/vector/pinecone.go index 7beb14be3e..f661ee0725 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/pinecone.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/pinecone.go @@ -46,7 +46,7 @@ type pineconeProvider struct { } func (c *pineconeProvider) GetProviderType() string { - return providerTypePinecone + return PROVIDER_TYPE_PINECONE } type pineconeMetadata struct { diff --git a/plugins/wasm-go/extensions/ai-cache/vector/provider.go b/plugins/wasm-go/extensions/ai-cache/vector/provider.go index 08f949eaac..44c0290c8a 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/provider.go @@ -8,13 +8,13 @@ import ( ) const ( - providerTypeDashVector = "dashvector" - providerTypeChroma = "chroma" - providerTypeES = "elasticsearch" - providerTypeWeaviate = "weaviate" - providerTypePinecone = "pinecone" - providerTypeQdrant = "qdrant" - providerTypeMilvus = "milvus" + PROVIDER_TYPE_DASH_VECTOR = "dashvector" + PROVIDER_TYPE_CHROMA = "chroma" + PROVIDER_TYPE_ES = "elasticsearch" + PROVIDER_TYPE_WEAVIATE = "weaviate" + PROVIDER_TYPE_PINECONE = "pinecone" + PROVIDER_TYPE_QDRANT = "qdrant" + PROVIDER_TYPE_MILVUS = "milvus" ) type providerInitializer interface { @@ -24,13 +24,13 @@ type providerInitializer interface { var ( providerInitializers = map[string]providerInitializer{ - providerTypeDashVector: &dashVectorProviderInitializer{}, - providerTypeChroma: &chromaProviderInitializer{}, - providerTypeWeaviate: &weaviateProviderInitializer{}, - providerTypeES: &esProviderInitializer{}, - providerTypePinecone: &pineconeProviderInitializer{}, - providerTypeQdrant: &qdrantProviderInitializer{}, - providerTypeMilvus: &milvusProviderInitializer{}, + PROVIDER_TYPE_DASH_VECTOR: &dashVectorProviderInitializer{}, + PROVIDER_TYPE_CHROMA: &chromaProviderInitializer{}, + PROVIDER_TYPE_ES: &weaviateProviderInitializer{}, + PROVIDER_TYPE_WEAVIATE: &esProviderInitializer{}, + PROVIDER_TYPE_PINECONE: &pineconeProviderInitializer{}, + PROVIDER_TYPE_QDRANT: &qdrantProviderInitializer{}, + PROVIDER_TYPE_MILVUS: &milvusProviderInitializer{}, } ) diff --git a/plugins/wasm-go/extensions/ai-cache/vector/qdrant.go b/plugins/wasm-go/extensions/ai-cache/vector/qdrant.go index 2a583f2491..3355d0d9ab 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/qdrant.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/qdrant.go @@ -40,7 +40,7 @@ type qdrantProvider struct { } func (c *qdrantProvider) GetProviderType() string { - return providerTypeQdrant + return PROVIDER_TYPE_QDRANT } type qdrantPayload struct { diff --git a/plugins/wasm-go/extensions/ai-cache/vector/weaviate.go b/plugins/wasm-go/extensions/ai-cache/vector/weaviate.go index e057d3dfb8..f4b06d5d4b 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/weaviate.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/weaviate.go @@ -39,7 +39,7 @@ type WeaviateProvider struct { } func (c *WeaviateProvider) GetProviderType() string { - return providerTypeWeaviate + return PROVIDER_TYPE_WEAVIATE } func (d *WeaviateProvider) QueryEmbedding( From 4a9555754c90bf9d20b87eb46e94524ee2aee845 Mon Sep 17 00:00:00 2001 From: async Date: Tue, 29 Oct 2024 12:19:15 +0800 Subject: [PATCH 67/71] update: README.md --- plugins/wasm-go/extensions/ai-cache/README.md | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/plugins/wasm-go/extensions/ai-cache/README.md b/plugins/wasm-go/extensions/ai-cache/README.md index ca91bdf5a1..6150b4ebae 100644 --- a/plugins/wasm-go/extensions/ai-cache/README.md +++ b/plugins/wasm-go/extensions/ai-cache/README.md @@ -99,6 +99,31 @@ LLM 结果缓存插件,默认配置方式可以直接用于 openai 协议的 | responseTemplate | string | optional | `{"id":"ai-cache.hit","choices":[{"index":0,"message":{"role":"assistant","content":%s},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` | 返回 HTTP 响应的模版,用 %s 标记需要被 cache value 替换的部分 | | streamResponseTemplate | string | optional | `data:{"id":"ai-cache.hit","choices":[{"index":0,"delta":{"role":"assistant","content":%s},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}\n\ndata:[DONE]\n\n` | 返回流式 HTTP 响应的模版,用 %s 标记需要被 cache value 替换的部分 | +# 向量数据库提供商特有配置 +## Chroma +Chroma 所对应的 `vector.type` 为 `chroma`。它并无特有的配置字段。 + +## DashVector +DashVector 所对应的 `vector.type` 为 `dashvector`。它并无特有的配置字段。 + +## ElasticSearch +ElasticSearch 所对应的 `vector.type` 为 `elasticsearch`。它特有的配置字段如下: +| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | +|-------------------|----------|----------|--------|-------------------------------------------------------------------------------| +| `vector.esUsername` | string | 非必填 | - | ElasticSearch 用户名 | +| `vector.esPassword` | object | 非必填 | - | ElasticSearch 密码 | + +## Milvus +Milvus 所对应的 `vector.type` 为 `milvus`。它并无特有的配置字段。 + +## Pinecone +Pinecone 所对应的 `vector.type` 为 `pinecone`。它并无特有的配置字段。 + +## Qdrant +Qdrant 所对应的 `vector.type` 为 `qdrant`。它并无特有的配置字段。 + +## Weaviate +Weaviate 所对应的 `vector.type` 为 `weaviate`。它并无特有的配置字段。 ## 配置示例 ### 基础配置 From 902d810af03982c39e978f39378fb7ab1bd54367 Mon Sep 17 00:00:00 2001 From: async Date: Tue, 29 Oct 2024 12:33:44 +0800 Subject: [PATCH 68/71] fix: READMME.md --- plugins/wasm-go/extensions/ai-cache/README.md | 73 ------------------- 1 file changed, 73 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-cache/README.md b/plugins/wasm-go/extensions/ai-cache/README.md index e867d65f71..6150b4ebae 100644 --- a/plugins/wasm-go/extensions/ai-cache/README.md +++ b/plugins/wasm-go/extensions/ai-cache/README.md @@ -124,79 +124,6 @@ Qdrant 所对应的 `vector.type` 为 `qdrant`。它并无特有的配置字段 ## Weaviate Weaviate 所对应的 `vector.type` 为 `weaviate`。它并无特有的配置字段。 -配置分为 3 个部分:向量数据库(vector);文本向量化接口(embedding);缓存数据库(cache),同时也提供了细粒度的 LLM 请求/响应提取参数配置等。 - -## 配置说明 - -本插件同时支持基于向量数据库的语义化缓存和基于字符串匹配的缓存方法,如果同时配置了向量数据库和缓存数据库,优先使用向量数据库。 - -*Note*: 向量数据库(vector) 和 缓存数据库(cache) 不能同时为空,否则本插件无法提供缓存服务。 - -| Name | Type | Requirement | Default | Description | -| --- | --- | --- | --- | --- | -| vector | string | optional | "" | 向量存储服务提供者类型,例如 dashvector | -| embedding | string | optional | "" | 请求文本向量化服务类型,例如 dashscope | -| cache | string | optional | "" | 缓存服务类型,例如 redis | -| cacheKeyStrategy | string | optional | "lastQuestion" | 决定如何根据历史问题生成缓存键的策略。可选值: "lastQuestion" (使用最后一个问题), "allQuestions" (拼接所有问题) 或 "disabled" (禁用缓存) | -| enableSemanticCache | bool | optional | true | 是否启用语义化缓存, 若不启用,则使用字符串匹配的方式来查找缓存,此时需要配置cache服务 | - -根据是否需要启用语义缓存,可以只配置组件的组合为: -1. `cache`: 仅启用字符串匹配缓存 -3. `vector (+ embedding)`: 启用语义化缓存, 其中若 `vector` 未提供字符串表征服务,则需要自行配置 `embedding` 服务 -2. `vector (+ embedding) + cache`: 启用语义化缓存并用缓存服务存储LLM响应以加速 - -注意若不配置相关组件,则可以忽略相应组件的`required`字段。 - - -## 向量数据库服务(vector) -| Name | Type | Requirement | Default | Description | -| --- | --- | --- | --- | --- | -| vector.type | string | required | "" | 向量存储服务提供者类型,例如 dashvector | -| vector.serviceName | string | required | "" | 向量存储服务名称 | -| vector.serviceHost | string | required | "" | 向量存储服务域名 | -| vector.servicePort | int64 | optional | 443 | 向量存储服务端口 | -| vector.apiKey | string | optional | "" | 向量存储服务 API Key | -| vector.topK | int | optional | 1 | 返回TopK结果,默认为 1 | -| vector.timeout | uint32 | optional | 10000 | 请求向量存储服务的超时时间,单位为毫秒。默认值是10000,即10秒 | -| vector.collectionID | string | optional | "" | dashvector 向量存储服务 Collection ID | -| vector.threshold | float64 | optional | 1000 | 向量相似度度量阈值 | -| vector.thresholdRelation | string | optional | lt | 相似度度量方式有 `Cosine`, `DotProduct`, `Euclidean` 等,前两者值越大相似度越高,后者值越小相似度越高。对于 `Cosine` 和 `DotProduct` 选择 `gt`,对于 `Euclidean` 则选择 `lt`。默认为 `lt`,所有条件包括 `lt` (less than,小于)、`lte` (less than or equal to,小等于)、`gt` (greater than,大于)、`gte` (greater than or equal to,大等于) | - -## 文本向量化服务(embedding) -| Name | Type | Requirement | Default | Description | -| --- | --- | --- | --- | --- | -| embedding.type | string | required | "" | 请求文本向量化服务类型,例如 dashscope | -| embedding.serviceName | string | required | "" | 请求文本向量化服务名称 | -| embedding.serviceHost | string | optional | "" | 请求文本向量化服务域名 | -| embedding.servicePort | int64 | optional | 443 | 请求文本向量化服务端口 | -| embedding.apiKey | string | optional | "" | 请求文本向量化服务的 API Key | -| embedding.timeout | uint32 | optional | 10000 | 请求文本向量化服务的超时时间,单位为毫秒。默认值是10000,即10秒 | -| embedding.model | string | optional | "" | 请求文本向量化服务的模型名称 | - - -## 缓存服务(cache) -| cache.type | string | required | "" | 缓存服务类型,例如 redis | -| --- | --- | --- | --- | --- | -| cache.serviceName | string | required | "" | 缓存服务名称 | -| cache.serviceHost | string | required | "" | 缓存服务域名 | -| cache.servicePort | int64 | optional | 6379 | 缓存服务端口 | -| cache.username | string | optional | "" | 缓存服务用户名 | -| cache.password | string | optional | "" | 缓存服务密码 | -| cache.timeout | uint32 | optional | 10000 | 缓存服务的超时时间,单位为毫秒。默认值是10000,即10秒 | -| cache.cacheTTL | int | optional | 0 | 缓存过期时间,单位为秒。默认值是 0,即 永不过期| -| cacheKeyPrefix | string | optional | "higress-ai-cache:" | 缓存 Key 的前缀,默认值为 "higress-ai-cache:" | - - -## 其他配置 -| Name | Type | Requirement | Default | Description | -| --- | --- | --- | --- | --- | -| cacheKeyFrom | string | optional | "messages.@reverse.0.content" | 从请求 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 | -| cacheValueFrom | string | optional | "choices.0.message.content" | 从响应 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 | -| cacheStreamValueFrom | string | optional | "choices.0.delta.content" | 从流式响应 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 | -| cacheToolCallsFrom | string | optional | "choices.0.delta.content.tool_calls" | 从请求 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 | -| responseTemplate | string | optional | `{"id":"ai-cache.hit","choices":[{"index":0,"message":{"role":"assistant","content":%s},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` | 返回 HTTP 响应的模版,用 %s 标记需要被 cache value 替换的部分 | -| streamResponseTemplate | string | optional | `data:{"id":"ai-cache.hit","choices":[{"index":0,"delta":{"role":"assistant","content":%s},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}\n\ndata:[DONE]\n\n` | 返回流式 HTTP 响应的模版,用 %s 标记需要被 cache value 替换的部分 | - ## 配置示例 ### 基础配置 From a1a7eefdd848136ec12e5b1ca2fe745e5719ccc9 Mon Sep 17 00:00:00 2001 From: EnableAsync <43645467+EnableAsync@users.noreply.github.com> Date: Tue, 29 Oct 2024 13:06:02 +0800 Subject: [PATCH 69/71] Update README.md --- plugins/wasm-go/extensions/ai-cache/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-cache/README.md b/plugins/wasm-go/extensions/ai-cache/README.md index 6150b4ebae..8a1ab495d5 100644 --- a/plugins/wasm-go/extensions/ai-cache/README.md +++ b/plugins/wasm-go/extensions/ai-cache/README.md @@ -111,7 +111,7 @@ ElasticSearch 所对应的 `vector.type` 为 `elasticsearch`。它特有的配 | 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | |-------------------|----------|----------|--------|-------------------------------------------------------------------------------| | `vector.esUsername` | string | 非必填 | - | ElasticSearch 用户名 | -| `vector.esPassword` | object | 非必填 | - | ElasticSearch 密码 | +| `vector.esPassword` | string | 非必填 | - | ElasticSearch 密码 | ## Milvus Milvus 所对应的 `vector.type` 为 `milvus`。它并无特有的配置字段。 @@ -169,4 +169,4 @@ GJSON PATH 支持条件判断语法,例如希望取最后一个 role 为 user ## 常见问题 -1. 如果返回的错误为 `error status returned by host: bad argument`,请检查`serviceName`是否正确包含了服务的类型后缀(.dns等)。 \ No newline at end of file +1. 如果返回的错误为 `error status returned by host: bad argument`,请检查`serviceName`是否正确包含了服务的类型后缀(.dns等)。 From 6a782a40ecd2d3a8eb0beef6c32f382d4d526c41 Mon Sep 17 00:00:00 2001 From: async Date: Mon, 18 Nov 2024 00:18:33 +0800 Subject: [PATCH 70/71] update --- plugins/wasm-go/extensions/ai-cache/README.md | 18 ++++++++----- .../extensions/ai-cache/embedding/weaviate.go | 27 ------------------- .../ai-cache/vector/elasticsearch.go | 13 ++++++--- .../extensions/ai-cache/vector/pinecone.go | 5 +--- .../extensions/ai-cache/vector/provider.go | 4 +-- .../extensions/ai-cache/vector/weaviate.go | 2 ++ 6 files changed, 26 insertions(+), 43 deletions(-) delete mode 100644 plugins/wasm-go/extensions/ai-cache/embedding/weaviate.go diff --git a/plugins/wasm-go/extensions/ai-cache/README.md b/plugins/wasm-go/extensions/ai-cache/README.md index 8a1ab495d5..3d44563f44 100644 --- a/plugins/wasm-go/extensions/ai-cache/README.md +++ b/plugins/wasm-go/extensions/ai-cache/README.md @@ -101,29 +101,35 @@ LLM 结果缓存插件,默认配置方式可以直接用于 openai 协议的 # 向量数据库提供商特有配置 ## Chroma -Chroma 所对应的 `vector.type` 为 `chroma`。它并无特有的配置字段。 +Chroma 所对应的 `vector.type` 为 `chroma`。它并无特有的配置字段。需要提前创建 Collection。 ## DashVector -DashVector 所对应的 `vector.type` 为 `dashvector`。它并无特有的配置字段。 +DashVector 所对应的 `vector.type` 为 `dashvector`。它并无特有的配置字段。需要提前创建 Collection。 ## ElasticSearch -ElasticSearch 所对应的 `vector.type` 为 `elasticsearch`。它特有的配置字段如下: +ElasticSearch 所对应的 `vector.type` 为 `elasticsearch`。需要提前创建 Index 并填入在 `vector.collectionID` 中。当前依赖于 [KNN](https://www.elastic.co/guide/en/elasticsearch/reference/current/knn-search.html) 方法,请保证 ES 版本支持 `KNN`,当前已在 `8.16` 版本测试。 +它特有的配置字段如下: | 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | |-------------------|----------|----------|--------|-------------------------------------------------------------------------------| | `vector.esUsername` | string | 非必填 | - | ElasticSearch 用户名 | | `vector.esPassword` | string | 非必填 | - | ElasticSearch 密码 | +`vector.esUsername` 和 `vector.esPassword` 用于 Basic 认证。同时也支持 Api Key 认证,当填写了 `vector.apiKey` 时,则启用 Api Key 认证,如果使用 SaaS 版本需要填写 `encoded` 的值。 + ## Milvus -Milvus 所对应的 `vector.type` 为 `milvus`。它并无特有的配置字段。 +Milvus 所对应的 `vector.type` 为 `milvus`。它并无特有的配置字段。需要提前创建 Collection。 ## Pinecone -Pinecone 所对应的 `vector.type` 为 `pinecone`。它并无特有的配置字段。 +Pinecone 所对应的 `vector.type` 为 `pinecone`。它并无特有的配置字段。需要提前创建 Index,并填写 Index 访问域名至 `serviceHost`。 +Pinecone 中的 `Namespace` 参数通过插件的 `vector.collectionID` 进行配置。 ## Qdrant -Qdrant 所对应的 `vector.type` 为 `qdrant`。它并无特有的配置字段。 +Qdrant 所对应的 `vector.type` 为 `qdrant`。它并无特有的配置字段。需要提前创建 Collection。 ## Weaviate Weaviate 所对应的 `vector.type` 为 `weaviate`。它并无特有的配置字段。 +需要提前创建 Collection。需要注意的是 Weaviate 会设置首字母自动大写,在填写配置 `collectionID` 的时候需要将首字母设置为大写。 +如果使用 SaaS 需要填写 `serviceHost` 参数。 ## 配置示例 ### 基础配置 diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/weaviate.go b/plugins/wasm-go/extensions/ai-cache/embedding/weaviate.go deleted file mode 100644 index b26d9cea8d..0000000000 --- a/plugins/wasm-go/extensions/ai-cache/embedding/weaviate.go +++ /dev/null @@ -1,27 +0,0 @@ -package embedding - -// import ( -// "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" -// ) - -// const ( -// weaviateURL = "172.17.0.1:8081" -// ) - -// type weaviateProviderInitializer struct { -// } - -// func (d *weaviateProviderInitializer) ValidateConfig(config ProviderConfig) error { -// return nil -// } - -// func (d *weaviateProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { -// return &DSProvider{ -// config: config, -// client: wrapper.NewClusterClient(wrapper.DnsCluster{ -// ServiceName: config.ServiceName, -// Port: dashScopePort, -// Domain: dashScopeDomain, -// }), -// }, nil -// } diff --git a/plugins/wasm-go/extensions/ai-cache/vector/elasticsearch.go b/plugins/wasm-go/extensions/ai-cache/vector/elasticsearch.go index 7a3b9d5b3e..263bdd2850 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/elasticsearch.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/elasticsearch.go @@ -82,11 +82,16 @@ func (d *ESProvider) QueryEmbedding( ) } -// base64 编码 ES 身份认证字符串 +// base64 编码 ES 身份认证字符串或使用 Apikey func (d *ESProvider) getCredentials() string { - credentials := fmt.Sprintf("%s:%s", d.config.esUsername, d.config.esPassword) - encodedCredentials := base64.StdEncoding.EncodeToString([]byte(credentials)) - return fmt.Sprintf("Basic %s", encodedCredentials) + if len(d.config.apiKey) != 0 { + return fmt.Sprintf("ApiKey %s", d.config.apiKey) + } else { + credentials := fmt.Sprintf("%s:%s", d.config.esUsername, d.config.esPassword) + encodedCredentials := base64.StdEncoding.EncodeToString([]byte(credentials)) + return fmt.Sprintf("Basic %s", encodedCredentials) + } + } func (d *ESProvider) UploadAnswerAndEmbedding( diff --git a/plugins/wasm-go/extensions/ai-cache/vector/pinecone.go b/plugins/wasm-go/extensions/ai-cache/vector/pinecone.go index f661ee0725..9f490a5a49 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/pinecone.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/pinecone.go @@ -15,7 +15,7 @@ type pineconeProviderInitializer struct{} func (c *pineconeProviderInitializer) ValidateConfig(config ProviderConfig) error { if len(config.serviceHost) == 0 { - return errors.New("[Pinecone] serviceDomain is required") + return errors.New("[Pinecone] serviceHost is required") } if len(config.serviceName) == 0 { return errors.New("[Pinecone] serviceName is required") @@ -23,9 +23,6 @@ func (c *pineconeProviderInitializer) ValidateConfig(config ProviderConfig) erro if len(config.apiKey) == 0 { return errors.New("[Pinecone] apiKey is required") } - if len(config.collectionID) == 0 { - return errors.New("[Pinecone] collectionID is required") - } return nil } diff --git a/plugins/wasm-go/extensions/ai-cache/vector/provider.go b/plugins/wasm-go/extensions/ai-cache/vector/provider.go index 44c0290c8a..cbaff3691b 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/provider.go @@ -26,8 +26,8 @@ var ( providerInitializers = map[string]providerInitializer{ PROVIDER_TYPE_DASH_VECTOR: &dashVectorProviderInitializer{}, PROVIDER_TYPE_CHROMA: &chromaProviderInitializer{}, - PROVIDER_TYPE_ES: &weaviateProviderInitializer{}, - PROVIDER_TYPE_WEAVIATE: &esProviderInitializer{}, + PROVIDER_TYPE_ES: &esProviderInitializer{}, + PROVIDER_TYPE_WEAVIATE: &weaviateProviderInitializer{}, PROVIDER_TYPE_PINECONE: &pineconeProviderInitializer{}, PROVIDER_TYPE_QDRANT: &qdrantProviderInitializer{}, PROVIDER_TYPE_MILVUS: &milvusProviderInitializer{}, diff --git a/plugins/wasm-go/extensions/ai-cache/vector/weaviate.go b/plugins/wasm-go/extensions/ai-cache/vector/weaviate.go index f4b06d5d4b..668e2d7bc4 100644 --- a/plugins/wasm-go/extensions/ai-cache/vector/weaviate.go +++ b/plugins/wasm-go/extensions/ai-cache/vector/weaviate.go @@ -88,6 +88,7 @@ func (d *WeaviateProvider) QueryEmbedding( "/v1/graphql", [][2]string{ {"Content-Type", "application/json"}, + {"Authorization", fmt.Sprintf("Bearer %s", d.config.apiKey)}, }, requestBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) { @@ -128,6 +129,7 @@ func (d *WeaviateProvider) UploadAnswerAndEmbedding( "/v1/objects", [][2]string{ {"Content-Type", "application/json"}, + {"Authorization", fmt.Sprintf("Bearer %s", d.config.apiKey)}, }, requestBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) { From 014d3eac2ce596531812b8edc8a1c01da01fcd65 Mon Sep 17 00:00:00 2001 From: async Date: Tue, 19 Nov 2024 23:01:09 +0800 Subject: [PATCH 71/71] update --- plugins/wasm-go/extensions/ai-cache/README.md | 28 ++++++++++++------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-cache/README.md b/plugins/wasm-go/extensions/ai-cache/README.md index 3d44563f44..b6e59e8e16 100644 --- a/plugins/wasm-go/extensions/ai-cache/README.md +++ b/plugins/wasm-go/extensions/ai-cache/README.md @@ -60,7 +60,7 @@ LLM 结果缓存插件,默认配置方式可以直接用于 openai 协议的 | vector.apiKey | string | optional | "" | 向量存储服务 API Key | | vector.topK | int | optional | 1 | 返回TopK结果,默认为 1 | | vector.timeout | uint32 | optional | 10000 | 请求向量存储服务的超时时间,单位为毫秒。默认值是10000,即10秒 | -| vector.collectionID | string | optional | "" | dashvector 向量存储服务 Collection ID | +| vector.collectionID | string | optional | "" | 向量存储服务 Collection ID | | vector.threshold | float64 | optional | 1000 | 向量相似度度量阈值 | | vector.thresholdRelation | string | optional | lt | 相似度度量方式有 `Cosine`, `DotProduct`, `Euclidean` 等,前两者值越大相似度越高,后者值越小相似度越高。对于 `Cosine` 和 `DotProduct` 选择 `gt`,对于 `Euclidean` 则选择 `lt`。默认为 `lt`,所有条件包括 `lt` (less than,小于)、`lte` (less than or equal to,小等于)、`gt` (greater than,大于)、`gte` (greater than or equal to,大等于) | @@ -101,35 +101,43 @@ LLM 结果缓存插件,默认配置方式可以直接用于 openai 协议的 # 向量数据库提供商特有配置 ## Chroma -Chroma 所对应的 `vector.type` 为 `chroma`。它并无特有的配置字段。需要提前创建 Collection。 +Chroma 所对应的 `vector.type` 为 `chroma`。它并无特有的配置字段。需要提前创建 Collection,并填写 Collection ID 至配置项 `vector.collectionID`,一个 Collection ID 的示例为 `52bbb8b3-724c-477b-a4ce-d5b578214612`。 ## DashVector -DashVector 所对应的 `vector.type` 为 `dashvector`。它并无特有的配置字段。需要提前创建 Collection。 +DashVector 所对应的 `vector.type` 为 `dashvector`。它并无特有的配置字段。需要提前创建 Collection,并填写 `Collection 名称` 至配置项 `vector.collectionID`。 ## ElasticSearch -ElasticSearch 所对应的 `vector.type` 为 `elasticsearch`。需要提前创建 Index 并填入在 `vector.collectionID` 中。当前依赖于 [KNN](https://www.elastic.co/guide/en/elasticsearch/reference/current/knn-search.html) 方法,请保证 ES 版本支持 `KNN`,当前已在 `8.16` 版本测试。 +ElasticSearch 所对应的 `vector.type` 为 `elasticsearch`。需要提前创建 Index 并填写 Index Name 至配置项 `vector.collectionID` 。 + +当前依赖于 [KNN](https://www.elastic.co/guide/en/elasticsearch/reference/current/knn-search.html) 方法,请保证 ES 版本支持 `KNN`,当前已在 `8.16` 版本测试。 + 它特有的配置字段如下: | 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | |-------------------|----------|----------|--------|-------------------------------------------------------------------------------| | `vector.esUsername` | string | 非必填 | - | ElasticSearch 用户名 | | `vector.esPassword` | string | 非必填 | - | ElasticSearch 密码 | + `vector.esUsername` 和 `vector.esPassword` 用于 Basic 认证。同时也支持 Api Key 认证,当填写了 `vector.apiKey` 时,则启用 Api Key 认证,如果使用 SaaS 版本需要填写 `encoded` 的值。 ## Milvus -Milvus 所对应的 `vector.type` 为 `milvus`。它并无特有的配置字段。需要提前创建 Collection。 +Milvus 所对应的 `vector.type` 为 `milvus`。它并无特有的配置字段。需要提前创建 Collection,并填写 Collection Name 至配置项 `vector.collectionID`。 ## Pinecone -Pinecone 所对应的 `vector.type` 为 `pinecone`。它并无特有的配置字段。需要提前创建 Index,并填写 Index 访问域名至 `serviceHost`。 -Pinecone 中的 `Namespace` 参数通过插件的 `vector.collectionID` 进行配置。 +Pinecone 所对应的 `vector.type` 为 `pinecone`。它并无特有的配置字段。需要提前创建 Index,并填写 Index 访问域名至 `vector.serviceHost`。 + +Pinecone 中的 `Namespace` 参数通过插件的 `vector.collectionID` 进行配置,如果不填写 `vector.collectionID`,则默认为 Default Namespace。 ## Qdrant -Qdrant 所对应的 `vector.type` 为 `qdrant`。它并无特有的配置字段。需要提前创建 Collection。 +Qdrant 所对应的 `vector.type` 为 `qdrant`。它并无特有的配置字段。需要提前创建 Collection,并填写 Collection Name 至配置项 `vector.collectionID`。 ## Weaviate Weaviate 所对应的 `vector.type` 为 `weaviate`。它并无特有的配置字段。 -需要提前创建 Collection。需要注意的是 Weaviate 会设置首字母自动大写,在填写配置 `collectionID` 的时候需要将首字母设置为大写。 -如果使用 SaaS 需要填写 `serviceHost` 参数。 +需要提前创建 Collection,并填写 Collection Name 至配置项 `vector.collectionID`。 + +需要注意的是 Weaviate 会设置首字母自动大写,在填写配置 `collectionID` 的时候需要将首字母设置为大写。 + +如果使用 SaaS 需要填写 `vector.serviceHost` 参数。 ## 配置示例 ### 基础配置