diff --git a/CODEOWNERS b/CODEOWNERS index fab96a6ec9..3d36c596c3 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -3,7 +3,7 @@ /istio @SpecialYang @johnlanni /pkg @SpecialYang @johnlanni @CH3CHO /plugins @johnlanni @WeixinX @CH3CHO -/plugins/wasm-rust @007gzs +/plugins/wasm-rust @007gzs @jizhuozhi /registry @NameHaibinZhang @2456868764 @johnlanni /test @Xunzhuo @2456868764 @CH3CHO /tools @johnlanni @Xunzhuo @2456868764 diff --git a/CONTRIBUTING_EN.md b/CONTRIBUTING_EN.md index 25539cc314..30b3410838 100644 --- a/CONTRIBUTING_EN.md +++ b/CONTRIBUTING_EN.md @@ -1,6 +1,6 @@ # Contributing to Higress -It is warmly welcomed if you have interest to hack on Higress. First, we encourage this kind of willing very much. And here is a list of contributing guide for you. +Your interest in contributing to Higress is warmly welcomed. First, we encourage this kind of willing very much. And here is a list of contributing guide for you. [[中文贡献文档](./CONTRIBUTING_CN.md)] diff --git a/README.md b/README.md index 761f8591a8..c93bf3b15e 100644 --- a/README.md +++ b/README.md @@ -59,7 +59,7 @@ docker run -d --rm --name higress-ai -v ${PWD}:/data \ - 8080 端口:网关 HTTP 协议入口 - 8443 端口:网关 HTTPS 协议入口 -**Higress 的所有 Docker 镜像都一直使用自己独享的仓库,不受 Docker Hub 境内不可访问的影响** +**Higress 的所有 Docker 镜像都一直使用自己独享的仓库,不受 Docker Hub 境内访问受限的影响** K8s 下使用 Helm 部署等其他安装方式可以参考官网 [Quick Start 文档](https://higress.cn/docs/latest/user/quickstart/)。 diff --git a/VERSION b/VERSION index 0ac852dded..f3b15f3f8e 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -v2.0.1 +v2.0.2 diff --git a/helm/core/Chart.yaml b/helm/core/Chart.yaml index 1dca35d253..04d287a95f 100644 --- a/helm/core/Chart.yaml +++ b/helm/core/Chart.yaml @@ -1,5 +1,5 @@ apiVersion: v2 -appVersion: 2.0.1 +appVersion: 2.0.2 description: Helm chart for deploying higress gateways icon: https://higress.io/img/higress_logo_small.png home: http://higress.io/ @@ -10,4 +10,4 @@ name: higress-core sources: - http://github.com/alibaba/higress type: application -version: 2.0.1 +version: 2.0.2 diff --git a/helm/core/templates/configmap.yaml b/helm/core/templates/configmap.yaml index b7814f5bf7..a915604d15 100644 --- a/helm/core/templates/configmap.yaml +++ b/helm/core/templates/configmap.yaml @@ -116,6 +116,12 @@ data: {{- $existingData = index $existingConfig.data "higress" | default "{}" | fromYaml }} {{- end }} {{- $newData := dict }} + {{- if hasKey .Values "upstream" }} + {{- $_ := set $newData "upstream" .Values.upstream }} + {{- end }} + {{- if hasKey .Values "downstream" }} + {{- $_ := set $newData "downstream" .Values.downstream }} + {{- end }} {{- if and (hasKey .Values "tracing") .Values.tracing.enable }} {{- $_ := set $newData "tracing" .Values.tracing }} {{- end }} diff --git a/helm/core/templates/controller-clusterrole.yaml b/helm/core/templates/controller-clusterrole.yaml index 8c05467e72..1f0f574555 100644 --- a/helm/core/templates/controller-clusterrole.yaml +++ b/helm/core/templates/controller-clusterrole.yaml @@ -129,3 +129,10 @@ rules: - apiGroups: ["networking.internal.knative.dev"] resources: ["ingresses/status"] verbs: ["get","patch","update"] + # gateway api need + - apiGroups: ["apps"] + verbs: [ "get", "watch", "list", "update", "patch", "create", "delete" ] + resources: [ "deployments" ] + - apiGroups: [""] + verbs: [ "get", "watch", "list", "update", "patch", "create", "delete" ] + resources: [ "serviceaccounts"] diff --git a/helm/core/templates/controller-deployment.yaml b/helm/core/templates/controller-deployment.yaml index 98bfe8f6d8..dda26d2433 100644 --- a/helm/core/templates/controller-deployment.yaml +++ b/helm/core/templates/controller-deployment.yaml @@ -69,6 +69,12 @@ spec: fieldPath: spec.serviceAccountName - name: DOMAIN_SUFFIX value: {{ .Values.global.proxy.clusterDomain }} + - name: GATEWAY_NAME + value: {{ include "gateway.name" . }} + - name: PILOT_ENABLE_GATEWAY_API + value: "{{ .Values.global.enableGatewayAPI }}" + - name: PILOT_ENABLE_ALPHA_GATEWAY_API + value: "{{ .Values.global.enableGatewayAPI }}" {{- if .Values.controller.env }} {{- range $key, $val := .Values.controller.env }} - name: {{ $key }} @@ -215,14 +221,14 @@ spec: - name: HIGRESS_ENABLE_ISTIO_API value: "true" {{- end }} - {{- if .Values.global.enableGatewayAPI }} - name: PILOT_ENABLE_GATEWAY_API - value: "true" + value: "false" + - name: PILOT_ENABLE_ALPHA_GATEWAY_API + value: "false" - name: PILOT_ENABLE_GATEWAY_API_STATUS - value: "true" + value: "false" - name: PILOT_ENABLE_GATEWAY_API_DEPLOYMENT_CONTROLLER value: "false" - {{- end }} {{- if not .Values.global.enableHigressIstio }} - name: CUSTOM_CA_CERT_NAME value: "higress-ca-root-cert" diff --git a/helm/core/values.yaml b/helm/core/values.yaml index 39582f748d..fb03a2f85f 100644 --- a/helm/core/values.yaml +++ b/helm/core/values.yaml @@ -684,3 +684,19 @@ tracing: # zipkin: # service: "" # port: 9411 + +# Downstream config settings +downstream: + idleTimeout: 180 + maxRequestHeadersKb: 60 + connectionBufferLimits: 32768 + http2: + maxConcurrentStreams: 100 + initialStreamWindowSize: 65535 + initialConnectionWindowSize: 1048576 + routeTimeout: 0 + +# Upstream config settings +upstream: + idleTimeout: 10 + connectionBufferLimits: 10485760 diff --git a/helm/higress/Chart.lock b/helm/higress/Chart.lock index e2397ea62a..9c71c9f1c8 100644 --- a/helm/higress/Chart.lock +++ b/helm/higress/Chart.lock @@ -1,9 +1,9 @@ dependencies: - name: higress-core repository: file://../core - version: 2.0.1 + version: 2.0.2 - name: higress-console repository: https://higress.io/helm-charts/ version: 1.4.4 -digest: sha256:6e4d77c31c834a404a728ec5a8379dd5df27a7e9b998a08e6524dc6534b07c1d -generated: "2024-10-09T20:07:21.857942+08:00" +digest: sha256:a424449caa01a71798c7fec9769ef97be7658354c028a3cede4790e4b6094532 +generated: "2024-10-28T18:50:27.528097+08:00" diff --git a/helm/higress/Chart.yaml b/helm/higress/Chart.yaml index 7d683863ba..51fe7f7de2 100644 --- a/helm/higress/Chart.yaml +++ b/helm/higress/Chart.yaml @@ -1,5 +1,5 @@ apiVersion: v2 -appVersion: 2.0.1 +appVersion: 2.0.2 description: Helm chart for deploying Higress gateways icon: https://higress.io/img/higress_logo_small.png home: http://higress.io/ @@ -12,9 +12,9 @@ sources: dependencies: - name: higress-core repository: "file://../core" - version: 2.0.1 + version: 2.0.2 - name: higress-console repository: "https://higress.io/helm-charts/" version: 1.4.4 type: application -version: 2.0.1 +version: 2.0.2 diff --git a/pkg/config/envs.go b/pkg/config/envs.go index 38ff64bf26..99a67edc65 100644 --- a/pkg/config/envs.go +++ b/pkg/config/envs.go @@ -19,6 +19,7 @@ import "istio.io/pkg/env" var ( PodNamespace = env.RegisterStringVar("POD_NAMESPACE", "higress-system", "").Get() PodName = env.RegisterStringVar("POD_NAME", "", "").Get() + GatewayName = env.RegisterStringVar("GATEWAY_NAME", "higress-gateway", "").Get() // Revision is the value of the Istio control plane revision, e.g. "canary", // and is the value used by the "istio.io/rev" label. Revision = env.Register("REVISION", "", "").Get() diff --git a/pkg/ingress/config/ingress_config.go b/pkg/ingress/config/ingress_config.go index e41181676a..6030ae50af 100644 --- a/pkg/ingress/config/ingress_config.go +++ b/pkg/ingress/config/ingress_config.go @@ -34,6 +34,7 @@ import ( extensions "istio.io/api/extensions/v1alpha1" networking "istio.io/api/networking/v1alpha3" istiotype "istio.io/api/type/v1beta1" + "istio.io/istio/pilot/pkg/features" istiomodel "istio.io/istio/pilot/pkg/model" "istio.io/istio/pilot/pkg/util/protoconv" "istio.io/istio/pkg/cluster" @@ -235,8 +236,9 @@ func (m *IngressConfig) AddLocalCluster(options common.Options) { ingressController = ingressv1.NewController(m.localKubeClient, m.localKubeClient, options, secretController) } m.remoteIngressControllers[options.ClusterId] = ingressController - - m.remoteGatewayControllers[options.ClusterId] = gateway.NewController(m.localKubeClient, options) + if features.EnableGatewayAPI { + m.remoteGatewayControllers[options.ClusterId] = gateway.NewController(m.localKubeClient, options) + } } func (m *IngressConfig) List(typ config.GroupVersionKind, namespace string) []config.Config { @@ -719,9 +721,9 @@ func (m *IngressConfig) convertDestinationRule(configs []common.WrapperConfig) [ } else if dr.DestinationRule.TrafficPolicy != nil { portTrafficPolicy := destinationRuleWrapper.DestinationRule.TrafficPolicy.PortLevelSettings[0] portUpdated := false - for _, portTrafficPolicy := range dr.DestinationRule.TrafficPolicy.PortLevelSettings { - if portTrafficPolicy.Port.Number == portTrafficPolicy.Port.Number { - portTrafficPolicy.Tls = portTrafficPolicy.Tls + for _, policy := range dr.DestinationRule.TrafficPolicy.PortLevelSettings { + if policy.Port.Number == portTrafficPolicy.Port.Number { + policy.Tls = portTrafficPolicy.Tls portUpdated = true break } diff --git a/pkg/ingress/kube/gateway/istio/conversion.go b/pkg/ingress/kube/gateway/istio/conversion.go index 83c817a717..4cf5dee9e5 100644 --- a/pkg/ingress/kube/gateway/istio/conversion.go +++ b/pkg/ingress/kube/gateway/istio/conversion.go @@ -25,6 +25,7 @@ import ( "strings" higressconfig "github.com/alibaba/higress/pkg/config" + "github.com/alibaba/higress/pkg/ingress/kube/util" istio "istio.io/api/networking/v1alpha3" "istio.io/istio/pilot/pkg/features" "istio.io/istio/pilot/pkg/model" @@ -1880,7 +1881,7 @@ func extractGatewayServices(r GatewayResources, kgw *k8s.GatewaySpec, obj config if len(name) > 0 { return []string{fmt.Sprintf("%s.%s.svc.%v", name, obj.Namespace, r.Domain)}, false, nil } - return []string{}, true, nil + return []string{fmt.Sprintf("%s.%s.svc.%s", higressconfig.GatewayName, higressconfig.PodNamespace, util.GetDomainSuffix())}, true, nil } gatewayServices := []string{} skippedAddresses := []string{} diff --git a/pkg/ingress/kube/gateway/istio/conversion_test.go b/pkg/ingress/kube/gateway/istio/conversion_test.go index 986aecbc69..28f8d26962 100644 --- a/pkg/ingress/kube/gateway/istio/conversion_test.go +++ b/pkg/ingress/kube/gateway/istio/conversion_test.go @@ -919,7 +919,7 @@ func TestExtractGatewayServices(t *testing.T) { Namespace: "default", }, }, - gatewayServices: []string{}, + gatewayServices: []string{"higress-gateway.higress-system.svc.cluster.local"}, useDefaultService: true, }, { @@ -1039,7 +1039,7 @@ func TestExtractGatewayServices(t *testing.T) { Namespace: "default", }, }, - gatewayServices: []string{}, + gatewayServices: []string{"higress-gateway.higress-system.svc.cluster.local"}, useDefaultService: true, }, } diff --git a/pkg/ingress/translation/translation.go b/pkg/ingress/translation/translation.go index 1545a1749a..bb95d57682 100644 --- a/pkg/ingress/translation/translation.go +++ b/pkg/ingress/translation/translation.go @@ -187,10 +187,9 @@ func (m *IngressTranslation) List(typ config.GroupVersionKind, namespace string) higressConfig = append(higressConfig, ingressConfig...) if m.kingressConfig != nil { kingressConfig := m.kingressConfig.List(typ, namespace) - if kingressConfig == nil { - return nil + if kingressConfig != nil { + higressConfig = append(higressConfig, kingressConfig...) } - higressConfig = append(higressConfig, kingressConfig...) } return higressConfig } diff --git a/plugins/wasm-cpp/extensions/model_router/plugin.cc b/plugins/wasm-cpp/extensions/model_router/plugin.cc index 457864d268..66a90973ff 100644 --- a/plugins/wasm-cpp/extensions/model_router/plugin.cc +++ b/plugins/wasm-cpp/extensions/model_router/plugin.cc @@ -101,7 +101,7 @@ bool PluginRootContext::configure(size_t configuration_size) { configuration_data->view())); return false; } - if (!parseAuthRuleConfig(result.value())) { + if (!parseRuleConfig(result.value())) { LOG_WARN(absl::StrCat("cannot parse plugin configuration JSON string: ", configuration_data->view())); return false; diff --git a/plugins/wasm-cpp/extensions/model_router/plugin_test.cc b/plugins/wasm-cpp/extensions/model_router/plugin_test.cc index 9ce5998051..dc351ecdc8 100644 --- a/plugins/wasm-cpp/extensions/model_router/plugin_test.cc +++ b/plugins/wasm-cpp/extensions/model_router/plugin_test.cc @@ -40,6 +40,11 @@ class MockContext : public proxy_wasm::ContextBase { MOCK_METHOD(WasmResult, getHeaderMapValue, (WasmHeaderMapType /* type */, std::string_view /* key */, std::string_view* /*result */)); + MOCK_METHOD(WasmResult, replaceHeaderMapValue, + (WasmHeaderMapType /* type */, std::string_view /* key */, + std::string_view /* value */)); + MOCK_METHOD(WasmResult, removeHeaderMapValue, + (WasmHeaderMapType /* type */, std::string_view /* key */)); MOCK_METHOD(WasmResult, addHeaderMapValue, (WasmHeaderMapType, std::string_view, std::string_view)); MOCK_METHOD(WasmResult, getProperty, (std::string_view, std::string*)); @@ -87,6 +92,16 @@ class ModelRouterTest : public ::testing::Test { } return WasmResult::Ok; }); + ON_CALL(*mock_context_, + replaceHeaderMapValue(WasmHeaderMapType::RequestHeaders, testing::_, + testing::_)) + .WillByDefault([&](WasmHeaderMapType, std::string_view key, + std::string_view value) { return WasmResult::Ok; }); + ON_CALL(*mock_context_, + removeHeaderMapValue(WasmHeaderMapType::RequestHeaders, testing::_)) + .WillByDefault([&](WasmHeaderMapType, std::string_view key) { + return WasmResult::Ok; + }); ON_CALL(*mock_context_, addHeaderMapValue(WasmHeaderMapType::RequestHeaders, testing::_, testing::_)) .WillByDefault([&](WasmHeaderMapType, std::string_view header, @@ -128,12 +143,45 @@ TEST_F(ModelRouterTest, RewriteModelAndHeader) { return WasmResult::Ok; }); - EXPECT_CALL( - *mock_context_, - addHeaderMapValue(testing::_, std::string_view("x-higress-llm-provider"), - std::string_view("qwen"))); + EXPECT_CALL(*mock_context_, + replaceHeaderMapValue(testing::_, + std::string_view("x-higress-llm-provider"), + std::string_view("qwen"))); + + body_.set(request_json); + EXPECT_EQ(context_->onRequestHeaders(0, false), + FilterHeadersStatus::StopIteration); + EXPECT_EQ(context_->onRequestBody(28, true), FilterDataStatus::Continue); +} + +TEST_F(ModelRouterTest, RouteLevelRewriteModelAndHeader) { + std::string configuration = R"( +{ + "_rules_": [ + { + "_match_route_": ["route-a"], + "enable": true + } +]})"; + + config_.set(configuration); + EXPECT_TRUE(root_context_->configure(configuration.size())); + + std::string request_json = R"({"model": "qwen/qwen-long"})"; + EXPECT_CALL(*mock_context_, + setBuffer(testing::_, testing::_, testing::_, testing::_)) + .WillOnce([&](WasmBufferType, size_t, size_t, std::string_view body) { + EXPECT_EQ(body, R"({"model":"qwen-long"})"); + return WasmResult::Ok; + }); + + EXPECT_CALL(*mock_context_, + replaceHeaderMapValue(testing::_, + std::string_view("x-higress-llm-provider"), + std::string_view("qwen"))); body_.set(request_json); + route_name_ = "route-a"; EXPECT_EQ(context_->onRequestHeaders(0, false), FilterHeadersStatus::StopIteration); EXPECT_EQ(context_->onRequestBody(28, true), FilterDataStatus::Continue); diff --git a/plugins/wasm-go/extensions/ai-cache/.gitignore b/plugins/wasm-go/extensions/ai-cache/.gitignore index 47db8eedba..8a34bf52ad 100644 --- a/plugins/wasm-go/extensions/ai-cache/.gitignore +++ b/plugins/wasm-go/extensions/ai-cache/.gitignore @@ -1,5 +1,5 @@ # File generated by hgctl. Modify as required. - +docker-compose-test/ * !/.gitignore diff --git a/plugins/wasm-go/extensions/ai-cache/README.md b/plugins/wasm-go/extensions/ai-cache/README.md index 1de252f12c..ca91bdf5a1 100644 --- a/plugins/wasm-go/extensions/ai-cache/README.md +++ b/plugins/wasm-go/extensions/ai-cache/README.md @@ -1,9 +1,15 @@ +## 简介 --- title: AI 缓存 keywords: [higress,ai cache] description: AI 缓存插件配置参考 --- +**Note** + +> 需要数据面的proxy wasm版本大于等于0.2.100 +> 编译时,需要带上版本的tag,例如:`tinygo build -o main.wasm -scheduler=none -target=wasi -gc=custom -tags="custommalloc nottinygc_finalizer proxy_wasm_version_0_2_100" ./` +> ## 功能说明 @@ -19,33 +25,113 @@ LLM 结果缓存插件,默认配置方式可以直接用于 openai 协议的 插件执行阶段:`认证阶段` 插件执行优先级:`10` +## 配置说明 +配置分为 3 个部分:向量数据库(vector);文本向量化接口(embedding);缓存数据库(cache),同时也提供了细粒度的 LLM 请求/响应提取参数配置等。 + ## 配置说明 -| Name | Type | Requirement | Default | Description | -| -------- | -------- | -------- | -------- | -------- | -| cacheKeyFrom.requestBody | string | optional | "messages.@reverse.0.content" | 从请求 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 | -| cacheValueFrom.responseBody | string | optional | "choices.0.message.content" | 从响应 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 | -| cacheStreamValueFrom.responseBody | string | optional | "choices.0.delta.content" | 从流式响应 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 | -| cacheKeyPrefix | string | optional | "higress-ai-cache:" | Redis缓存Key的前缀 | -| cacheTTL | integer | optional | 0 | 缓存的过期时间,单位是秒,默认值为0,即永不过期 | -| redis.serviceName | string | requried | - | redis 服务名称,带服务类型的完整 FQDN 名称,例如 my-redis.dns、redis.my-ns.svc.cluster.local | -| redis.servicePort | integer | optional | 6379 | redis 服务端口 | -| redis.timeout | integer | optional | 1000 | 请求 redis 的超时时间,单位为毫秒 | -| redis.username | string | optional | - | 登陆 redis 的用户名 | -| redis.password | string | optional | - | 登陆 redis 的密码 | -| returnResponseTemplate | string | optional | `{"id":"from-cache","choices":[%s],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` | 返回 HTTP 响应的模版,用 %s 标记需要被 cache value 替换的部分 | -| returnStreamResponseTemplate | string | optional | `data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}\n\ndata:[DONE]\n\n` | 返回流式 HTTP 响应的模版,用 %s 标记需要被 cache value 替换的部分 | +本插件同时支持基于向量数据库的语义化缓存和基于字符串匹配的缓存方法,如果同时配置了向量数据库和缓存数据库,优先使用向量数据库。 + +*Note*: 向量数据库(vector) 和 缓存数据库(cache) 不能同时为空,否则本插件无法提供缓存服务。 + +| Name | Type | Requirement | Default | Description | +| --- | --- | --- | --- | --- | +| vector | string | optional | "" | 向量存储服务提供者类型,例如 dashvector | +| embedding | string | optional | "" | 请求文本向量化服务类型,例如 dashscope | +| cache | string | optional | "" | 缓存服务类型,例如 redis | +| cacheKeyStrategy | string | optional | "lastQuestion" | 决定如何根据历史问题生成缓存键的策略。可选值: "lastQuestion" (使用最后一个问题), "allQuestions" (拼接所有问题) 或 "disabled" (禁用缓存) | +| enableSemanticCache | bool | optional | true | 是否启用语义化缓存, 若不启用,则使用字符串匹配的方式来查找缓存,此时需要配置cache服务 | + +根据是否需要启用语义缓存,可以只配置组件的组合为: +1. `cache`: 仅启用字符串匹配缓存 +3. `vector (+ embedding)`: 启用语义化缓存, 其中若 `vector` 未提供字符串表征服务,则需要自行配置 `embedding` 服务 +2. `vector (+ embedding) + cache`: 启用语义化缓存并用缓存服务存储LLM响应以加速 + +注意若不配置相关组件,则可以忽略相应组件的`required`字段。 + + +## 向量数据库服务(vector) +| Name | Type | Requirement | Default | Description | +| --- | --- | --- | --- | --- | +| vector.type | string | required | "" | 向量存储服务提供者类型,例如 dashvector | +| vector.serviceName | string | required | "" | 向量存储服务名称 | +| vector.serviceHost | string | required | "" | 向量存储服务域名 | +| vector.servicePort | int64 | optional | 443 | 向量存储服务端口 | +| vector.apiKey | string | optional | "" | 向量存储服务 API Key | +| vector.topK | int | optional | 1 | 返回TopK结果,默认为 1 | +| vector.timeout | uint32 | optional | 10000 | 请求向量存储服务的超时时间,单位为毫秒。默认值是10000,即10秒 | +| vector.collectionID | string | optional | "" | dashvector 向量存储服务 Collection ID | +| vector.threshold | float64 | optional | 1000 | 向量相似度度量阈值 | +| vector.thresholdRelation | string | optional | lt | 相似度度量方式有 `Cosine`, `DotProduct`, `Euclidean` 等,前两者值越大相似度越高,后者值越小相似度越高。对于 `Cosine` 和 `DotProduct` 选择 `gt`,对于 `Euclidean` 则选择 `lt`。默认为 `lt`,所有条件包括 `lt` (less than,小于)、`lte` (less than or equal to,小等于)、`gt` (greater than,大于)、`gte` (greater than or equal to,大等于) | + +## 文本向量化服务(embedding) +| Name | Type | Requirement | Default | Description | +| --- | --- | --- | --- | --- | +| embedding.type | string | required | "" | 请求文本向量化服务类型,例如 dashscope | +| embedding.serviceName | string | required | "" | 请求文本向量化服务名称 | +| embedding.serviceHost | string | optional | "" | 请求文本向量化服务域名 | +| embedding.servicePort | int64 | optional | 443 | 请求文本向量化服务端口 | +| embedding.apiKey | string | optional | "" | 请求文本向量化服务的 API Key | +| embedding.timeout | uint32 | optional | 10000 | 请求文本向量化服务的超时时间,单位为毫秒。默认值是10000,即10秒 | +| embedding.model | string | optional | "" | 请求文本向量化服务的模型名称 | + + +## 缓存服务(cache) +| cache.type | string | required | "" | 缓存服务类型,例如 redis | +| --- | --- | --- | --- | --- | +| cache.serviceName | string | required | "" | 缓存服务名称 | +| cache.serviceHost | string | required | "" | 缓存服务域名 | +| cache.servicePort | int64 | optional | 6379 | 缓存服务端口 | +| cache.username | string | optional | "" | 缓存服务用户名 | +| cache.password | string | optional | "" | 缓存服务密码 | +| cache.timeout | uint32 | optional | 10000 | 缓存服务的超时时间,单位为毫秒。默认值是10000,即10秒 | +| cache.cacheTTL | int | optional | 0 | 缓存过期时间,单位为秒。默认值是 0,即 永不过期| +| cacheKeyPrefix | string | optional | "higress-ai-cache:" | 缓存 Key 的前缀,默认值为 "higress-ai-cache:" | + + +## 其他配置 +| Name | Type | Requirement | Default | Description | +| --- | --- | --- | --- | --- | +| cacheKeyFrom | string | optional | "messages.@reverse.0.content" | 从请求 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 | +| cacheValueFrom | string | optional | "choices.0.message.content" | 从响应 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 | +| cacheStreamValueFrom | string | optional | "choices.0.delta.content" | 从流式响应 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 | +| cacheToolCallsFrom | string | optional | "choices.0.delta.content.tool_calls" | 从请求 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 | +| responseTemplate | string | optional | `{"id":"ai-cache.hit","choices":[{"index":0,"message":{"role":"assistant","content":%s},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` | 返回 HTTP 响应的模版,用 %s 标记需要被 cache value 替换的部分 | +| streamResponseTemplate | string | optional | `data:{"id":"ai-cache.hit","choices":[{"index":0,"delta":{"role":"assistant","content":%s},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}\n\ndata:[DONE]\n\n` | 返回流式 HTTP 响应的模版,用 %s 标记需要被 cache value 替换的部分 | + ## 配置示例 +### 基础配置 +```yaml +embedding: + type: dashscope + serviceName: my_dashscope.dns + apiKey: [Your Key] + +vector: + type: dashvector + serviceName: my_dashvector.dns + collectionID: [Your Collection ID] + serviceDomain: [Your domain] + apiKey: [Your key] + +cache: + type: redis + serviceName: my_redis.dns + servicePort: 6379 + timeout: 100 +``` + +旧版本配置兼容 ```yaml redis: - serviceName: my-redis.dns - timeout: 2000 + serviceName: my_redis.dns + servicePort: 6379 + timeout: 100 ``` ## 进阶用法 - 当前默认的缓存 key 是基于 GJSON PATH 的表达式:`messages.@reverse.0.content` 提取,含义是把 messages 数组反转后取第一项的 content; GJSON PATH 支持条件判断语法,例如希望取最后一个 role 为 user 的 content 作为 key,可以写成: `messages.@reverse.#(role=="user").content`; @@ -55,3 +141,7 @@ GJSON PATH 支持条件判断语法,例如希望取最后一个 role 为 user 还可以支持管道语法,例如希望取到数第二个 role 为 user 的 content 作为 key,可以写成:`messages.@reverse.#(role=="user")#.content|1`。 更多用法可以参考[官方文档](https://github.com/tidwall/gjson/blob/master/SYNTAX.md),可以使用 [GJSON Playground](https://gjson.dev/) 进行语法测试。 + +## 常见问题 + +1. 如果返回的错误为 `error status returned by host: bad argument`,请检查`serviceName`是否正确包含了服务的类型后缀(.dns等)。 \ No newline at end of file diff --git a/plugins/wasm-go/extensions/ai-cache/cache/provider.go b/plugins/wasm-go/extensions/ai-cache/cache/provider.go new file mode 100644 index 0000000000..1238d21570 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/cache/provider.go @@ -0,0 +1,135 @@ +package cache + +import ( + "errors" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" +) + +const ( + PROVIDER_TYPE_REDIS = "redis" + DEFAULT_CACHE_PREFIX = "higress-ai-cache:" +) + +type providerInitializer interface { + ValidateConfig(ProviderConfig) error + CreateProvider(ProviderConfig) (Provider, error) +} + +var ( + providerInitializers = map[string]providerInitializer{ + PROVIDER_TYPE_REDIS: &redisProviderInitializer{}, + } +) + +type ProviderConfig struct { + // @Title zh-CN redis 缓存服务提供者类型 + // @Description zh-CN 缓存服务提供者类型,例如 redis + typ string + // @Title zh-CN redis 缓存服务名称 + // @Description zh-CN 缓存服务名称 + serviceName string + // @Title zh-CN redis 缓存服务端口 + // @Description zh-CN 缓存服务端口,默认值为6379 + servicePort int + // @Title zh-CN redis 缓存服务地址 + // @Description zh-CN Cache 缓存服务地址,非必填 + serviceHost string + // @Title zh-CN 缓存服务用户名 + // @Description zh-CN 缓存服务用户名,非必填 + username string + // @Title zh-CN 缓存服务密码 + // @Description zh-CN 缓存服务密码,非必填 + password string + // @Title zh-CN 请求超时 + // @Description zh-CN 请求缓存服务的超时时间,单位为毫秒。默认值是10000,即10秒 + timeout uint32 + // @Title zh-CN 缓存过期时间 + // @Description zh-CN 缓存过期时间,单位为秒。默认值是0,即永不过期 + cacheTTL int + // @Title 缓存 Key 前缀 + // @Description 缓存 Key 的前缀,默认值为 "higressAiCache:" + cacheKeyPrefix string +} + +func (c *ProviderConfig) GetProviderType() string { + return c.typ +} + +func (c *ProviderConfig) FromJson(json gjson.Result) { + c.typ = json.Get("type").String() + c.serviceName = json.Get("serviceName").String() + c.servicePort = int(json.Get("servicePort").Int()) + if !json.Get("servicePort").Exists() { + c.servicePort = 6379 + } + c.serviceHost = json.Get("serviceHost").String() + c.username = json.Get("username").String() + if !json.Get("username").Exists() { + c.username = "" + } + c.password = json.Get("password").String() + if !json.Get("password").Exists() { + c.password = "" + } + c.timeout = uint32(json.Get("timeout").Int()) + if !json.Get("timeout").Exists() { + c.timeout = 10000 + } + c.cacheTTL = int(json.Get("cacheTTL").Int()) + if !json.Get("cacheTTL").Exists() { + c.cacheTTL = 0 + // c.cacheTTL = 3600000 + } + if json.Get("cacheKeyPrefix").Exists() { + c.cacheKeyPrefix = json.Get("cacheKeyPrefix").String() + } else { + c.cacheKeyPrefix = DEFAULT_CACHE_PREFIX + } + +} + +func (c *ProviderConfig) ConvertLegacyJson(json gjson.Result) { + c.FromJson(json.Get("redis")) + c.typ = "redis" + if json.Get("cacheTTL").Exists() { + c.cacheTTL = int(json.Get("cacheTTL").Int()) + } +} + +func (c *ProviderConfig) Validate() error { + if c.typ == "" { + return errors.New("cache service type is required") + } + if c.serviceName == "" { + return errors.New("cache service name is required") + } + if c.cacheTTL < 0 { + return errors.New("cache TTL must be greater than or equal to 0") + } + initializer, has := providerInitializers[c.typ] + if !has { + return errors.New("unknown cache service provider type: " + c.typ) + } + if err := initializer.ValidateConfig(*c); err != nil { + return err + } + return nil +} + +func CreateProvider(pc ProviderConfig) (Provider, error) { + initializer, has := providerInitializers[pc.typ] + if !has { + return nil, errors.New("unknown provider type: " + pc.typ) + } + return initializer.CreateProvider(pc) +} + +type Provider interface { + GetProviderType() string + Init(username string, password string, timeout uint32) error + Get(key string, cb wrapper.RedisResponseCallback) error + Set(key string, value string, cb wrapper.RedisResponseCallback) error + GetCacheKeyPrefix() string +} diff --git a/plugins/wasm-go/extensions/ai-cache/cache/redis.go b/plugins/wasm-go/extensions/ai-cache/cache/redis.go new file mode 100644 index 0000000000..4cb69744e1 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/cache/redis.go @@ -0,0 +1,58 @@ +package cache + +import ( + "errors" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" +) + +type redisProviderInitializer struct { +} + +func (r *redisProviderInitializer) ValidateConfig(cf ProviderConfig) error { + if len(cf.serviceName) == 0 { + return errors.New("cache service name is required") + } + return nil +} + +func (r *redisProviderInitializer) CreateProvider(cf ProviderConfig) (Provider, error) { + rp := redisProvider{ + config: cf, + client: wrapper.NewRedisClusterClient(wrapper.FQDNCluster{ + FQDN: cf.serviceName, + Host: cf.serviceHost, + Port: int64(cf.servicePort)}), + } + err := rp.Init(cf.username, cf.password, cf.timeout) + return &rp, err +} + +type redisProvider struct { + config ProviderConfig + client wrapper.RedisClient +} + +func (rp *redisProvider) GetProviderType() string { + return PROVIDER_TYPE_REDIS +} + +func (rp *redisProvider) Init(username string, password string, timeout uint32) error { + return rp.client.Init(rp.config.username, rp.config.password, int64(rp.config.timeout)) +} + +func (rp *redisProvider) Get(key string, cb wrapper.RedisResponseCallback) error { + return rp.client.Get(key, cb) +} + +func (rp *redisProvider) Set(key string, value string, cb wrapper.RedisResponseCallback) error { + if rp.config.cacheTTL == 0 { + return rp.client.Set(key, value, cb) + } else { + return rp.client.SetEx(key, value, rp.config.cacheTTL, cb) + } +} + +func (rp *redisProvider) GetCacheKeyPrefix() string { + return rp.config.cacheKeyPrefix +} diff --git a/plugins/wasm-go/extensions/ai-cache/config/config.go b/plugins/wasm-go/extensions/ai-cache/config/config.go new file mode 100644 index 0000000000..4bd6e2a18f --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/config/config.go @@ -0,0 +1,225 @@ +package config + +import ( + "fmt" + + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/cache" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/embedding" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/vector" + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" +) + +const ( + CACHE_KEY_STRATEGY_LAST_QUESTION = "lastQuestion" + CACHE_KEY_STRATEGY_ALL_QUESTIONS = "allQuestions" + CACHE_KEY_STRATEGY_DISABLED = "disabled" +) + +type PluginConfig struct { + // @Title zh-CN 返回 HTTP 响应的模版 + // @Description zh-CN 用 %s 标记需要被 cache value 替换的部分 + ResponseTemplate string + // @Title zh-CN 返回流式 HTTP 响应的模版 + // @Description zh-CN 用 %s 标记需要被 cache value 替换的部分 + StreamResponseTemplate string + + cacheProvider cache.Provider + embeddingProvider embedding.Provider + vectorProvider vector.Provider + + embeddingProviderConfig embedding.ProviderConfig + vectorProviderConfig vector.ProviderConfig + cacheProviderConfig cache.ProviderConfig + + CacheKeyFrom string + CacheValueFrom string + CacheStreamValueFrom string + CacheToolCallsFrom string + + // @Title zh-CN 启用语义化缓存 + // @Description zh-CN 控制是否启用语义化缓存功能。true 表示启用,false 表示禁用。 + EnableSemanticCache bool + + // @Title zh-CN 缓存键策略 + // @Description zh-CN 决定如何生成缓存键的策略。可选值: "lastQuestion" (使用最后一个问题), "allQuestions" (拼接所有问题) 或 "disabled" (禁用缓存) + CacheKeyStrategy string +} + +func (c *PluginConfig) FromJson(json gjson.Result, log wrapper.Log) { + + c.vectorProviderConfig.FromJson(json.Get("vector")) + c.embeddingProviderConfig.FromJson(json.Get("embedding")) + c.cacheProviderConfig.FromJson(json.Get("cache")) + if json.Get("redis").Exists() { + // compatible with legacy config + c.cacheProviderConfig.ConvertLegacyJson(json) + } + + c.CacheKeyStrategy = json.Get("cacheKeyStrategy").String() + if c.CacheKeyStrategy == "" { + c.CacheKeyStrategy = CACHE_KEY_STRATEGY_LAST_QUESTION // set default value + } + c.CacheKeyFrom = json.Get("cacheKeyFrom").String() + if c.CacheKeyFrom == "" { + c.CacheKeyFrom = "messages.@reverse.0.content" + } + c.CacheValueFrom = json.Get("cacheValueFrom").String() + if c.CacheValueFrom == "" { + c.CacheValueFrom = "choices.0.message.content" + } + c.CacheStreamValueFrom = json.Get("cacheStreamValueFrom").String() + if c.CacheStreamValueFrom == "" { + c.CacheStreamValueFrom = "choices.0.delta.content" + } + c.CacheToolCallsFrom = json.Get("cacheToolCallsFrom").String() + if c.CacheToolCallsFrom == "" { + c.CacheToolCallsFrom = "choices.0.delta.content.tool_calls" + } + + c.StreamResponseTemplate = json.Get("streamResponseTemplate").String() + if c.StreamResponseTemplate == "" { + c.StreamResponseTemplate = `data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + "\n\ndata:[DONE]\n\n" + } + c.ResponseTemplate = json.Get("responseTemplate").String() + if c.ResponseTemplate == "" { + c.ResponseTemplate = `{"id":"from-cache","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + } + + if json.Get("enableSemanticCache").Exists() { + c.EnableSemanticCache = json.Get("enableSemanticCache").Bool() + } else { + c.EnableSemanticCache = true // set default value to true + } + + // compatible with legacy config + convertLegacyMapFields(c, json, log) +} + +func (c *PluginConfig) Validate() error { + // if cache provider is configured, validate it + if c.cacheProviderConfig.GetProviderType() != "" { + if err := c.cacheProviderConfig.Validate(); err != nil { + return err + } + } + if c.embeddingProviderConfig.GetProviderType() != "" { + if err := c.embeddingProviderConfig.Validate(); err != nil { + return err + } + } + if c.vectorProviderConfig.GetProviderType() != "" { + if err := c.vectorProviderConfig.Validate(); err != nil { + return err + } + } + + // cache, vector, and embedding cannot all be empty + if c.vectorProviderConfig.GetProviderType() == "" && + c.embeddingProviderConfig.GetProviderType() == "" && + c.cacheProviderConfig.GetProviderType() == "" { + return fmt.Errorf("vector, embedding and cache provider cannot be all empty") + } + + // Validate the value of CacheKeyStrategy + if c.CacheKeyStrategy != CACHE_KEY_STRATEGY_LAST_QUESTION && + c.CacheKeyStrategy != CACHE_KEY_STRATEGY_ALL_QUESTIONS && + c.CacheKeyStrategy != CACHE_KEY_STRATEGY_DISABLED { + return fmt.Errorf("invalid CacheKeyStrategy: %s", c.CacheKeyStrategy) + } + + // If semantic cache is enabled, ensure necessary components are configured + // if c.EnableSemanticCache { + // if c.embeddingProviderConfig.GetProviderType() == "" { + // return fmt.Errorf("semantic cache is enabled but embedding provider is not configured") + // } + // // if only configure cache, just warn the user + // } + return nil +} + +func (c *PluginConfig) Complete(log wrapper.Log) error { + var err error + if c.embeddingProviderConfig.GetProviderType() != "" { + log.Debugf("embedding provider is set to %s", c.embeddingProviderConfig.GetProviderType()) + c.embeddingProvider, err = embedding.CreateProvider(c.embeddingProviderConfig) + if err != nil { + return err + } + } else { + log.Info("embedding provider is not configured") + c.embeddingProvider = nil + } + if c.cacheProviderConfig.GetProviderType() != "" { + log.Debugf("cache provider is set to %s", c.cacheProviderConfig.GetProviderType()) + c.cacheProvider, err = cache.CreateProvider(c.cacheProviderConfig) + if err != nil { + return err + } + } else { + log.Info("cache provider is not configured") + c.cacheProvider = nil + } + if c.vectorProviderConfig.GetProviderType() != "" { + log.Debugf("vector provider is set to %s", c.vectorProviderConfig.GetProviderType()) + c.vectorProvider, err = vector.CreateProvider(c.vectorProviderConfig) + if err != nil { + return err + } + } else { + log.Info("vector provider is not configured") + c.vectorProvider = nil + } + return nil +} + +func (c *PluginConfig) GetEmbeddingProvider() embedding.Provider { + return c.embeddingProvider +} + +func (c *PluginConfig) GetVectorProvider() vector.Provider { + return c.vectorProvider +} + +func (c *PluginConfig) GetVectorProviderConfig() vector.ProviderConfig { + return c.vectorProviderConfig +} + +func (c *PluginConfig) GetCacheProvider() cache.Provider { + return c.cacheProvider +} + +func convertLegacyMapFields(c *PluginConfig, json gjson.Result, log wrapper.Log) { + keyMap := map[string]string{ + "cacheKeyFrom.requestBody": "cacheKeyFrom", + "cacheValueFrom.requestBody": "cacheValueFrom", + "cacheStreamValueFrom.requestBody": "cacheStreamValueFrom", + "returnResponseTemplate": "responseTemplate", + "returnStreamResponseTemplate": "streamResponseTemplate", + } + + for oldKey, newKey := range keyMap { + if json.Get(oldKey).Exists() { + log.Debugf("[convertLegacyMapFields] mapping %s to %s", oldKey, newKey) + setField(c, newKey, json.Get(oldKey).String(), log) + } else { + log.Debugf("[convertLegacyMapFields] %s not exists", oldKey) + } + } +} + +func setField(c *PluginConfig, fieldName string, value string, log wrapper.Log) { + switch fieldName { + case "cacheKeyFrom": + c.CacheKeyFrom = value + case "cacheValueFrom": + c.CacheValueFrom = value + case "cacheStreamValueFrom": + c.CacheStreamValueFrom = value + case "responseTemplate": + c.ResponseTemplate = value + case "streamResponseTemplate": + c.StreamResponseTemplate = value + } + log.Debugf("[setField] set %s to %s", fieldName, value) +} diff --git a/plugins/wasm-go/extensions/ai-cache/core.go b/plugins/wasm-go/extensions/ai-cache/core.go new file mode 100644 index 0000000000..19a9b2b856 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/core.go @@ -0,0 +1,275 @@ +package main + +import ( + "errors" + "fmt" + "strconv" + "strings" + + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/config" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/vector" + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" + "github.com/tidwall/resp" +) + +// CheckCacheForKey checks if the key is in the cache, or triggers similarity search if not found. +func CheckCacheForKey(key string, ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log, stream bool, useSimilaritySearch bool) error { + activeCacheProvider := c.GetCacheProvider() + if activeCacheProvider == nil { + log.Debugf("[%s] [CheckCacheForKey] no cache provider configured, performing similarity search", PLUGIN_NAME) + return performSimilaritySearch(key, ctx, c, log, key, stream) + } + + queryKey := activeCacheProvider.GetCacheKeyPrefix() + key + log.Debugf("[%s] [CheckCacheForKey] querying cache with key: %s", PLUGIN_NAME, queryKey) + + err := activeCacheProvider.Get(queryKey, func(response resp.Value) { + handleCacheResponse(key, response, ctx, log, stream, c, useSimilaritySearch) + }) + + if err != nil { + log.Errorf("[%s] [CheckCacheForKey] failed to retrieve key: %s from cache, error: %v", PLUGIN_NAME, key, err) + return err + } + + return nil +} + +// handleCacheResponse processes cache response and handles cache hits and misses. +func handleCacheResponse(key string, response resp.Value, ctx wrapper.HttpContext, log wrapper.Log, stream bool, c config.PluginConfig, useSimilaritySearch bool) { + if err := response.Error(); err == nil && !response.IsNull() { + log.Infof("[%s] cache hit for key: %s", PLUGIN_NAME, key) + processCacheHit(key, response.String(), stream, ctx, c, log) + return + } + + log.Infof("[%s] [handleCacheResponse] cache miss for key: %s", PLUGIN_NAME, key) + if err := response.Error(); err != nil { + log.Errorf("[%s] [handleCacheResponse] error retrieving key: %s from cache, error: %v", PLUGIN_NAME, key, err) + } + + if useSimilaritySearch && c.EnableSemanticCache { + if err := performSimilaritySearch(key, ctx, c, log, key, stream); err != nil { + log.Errorf("[%s] [handleCacheResponse] failed to perform similarity search for key: %s, error: %v", PLUGIN_NAME, key, err) + proxywasm.ResumeHttpRequest() + } + } else { + proxywasm.ResumeHttpRequest() + } +} + +// processCacheHit handles a successful cache hit. +func processCacheHit(key string, response string, stream bool, ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log) { + if strings.TrimSpace(response) == "" { + log.Warnf("[%s] [processCacheHit] cached response for key %s is empty", PLUGIN_NAME, key) + proxywasm.ResumeHttpRequest() + return + } + + log.Debugf("[%s] [processCacheHit] cached response for key %s: %s", PLUGIN_NAME, key, response) + + // Escape the response to ensure consistent formatting + escapedResponse := strings.Trim(strconv.Quote(response), "\"") + + ctx.SetContext(CACHE_KEY_CONTEXT_KEY, nil) + + if stream { + proxywasm.SendHttpResponseWithDetail(200, "ai-cache.hit", [][2]string{{"content-type", "text/event-stream; charset=utf-8"}}, []byte(fmt.Sprintf(c.StreamResponseTemplate, escapedResponse)), -1) + } else { + proxywasm.SendHttpResponseWithDetail(200, "ai-cache.hit", [][2]string{{"content-type", "application/json; charset=utf-8"}}, []byte(fmt.Sprintf(c.ResponseTemplate, escapedResponse)), -1) + } +} + +// performSimilaritySearch determines the appropriate similarity search method to use. +func performSimilaritySearch(key string, ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log, queryString string, stream bool) error { + activeVectorProvider := c.GetVectorProvider() + if activeVectorProvider == nil { + return logAndReturnError(log, "[performSimilaritySearch] no vector provider configured for similarity search") + } + + // Check if the active vector provider implements the StringQuerier interface. + if _, ok := activeVectorProvider.(vector.StringQuerier); ok { + log.Debugf("[%s] [performSimilaritySearch] active vector provider implements StringQuerier interface, performing string query", PLUGIN_NAME) + return performStringQuery(key, queryString, ctx, c, log, stream) + } + + // Check if the active vector provider implements the EmbeddingQuerier interface. + if _, ok := activeVectorProvider.(vector.EmbeddingQuerier); ok { + log.Debugf("[%s] [performSimilaritySearch] active vector provider implements EmbeddingQuerier interface, performing embedding query", PLUGIN_NAME) + return performEmbeddingQuery(key, ctx, c, log, stream) + } + + return logAndReturnError(log, "[performSimilaritySearch] no suitable querier or embedding provider available for similarity search") +} + +// performStringQuery executes the string-based similarity search. +func performStringQuery(key string, queryString string, ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log, stream bool) error { + stringQuerier, ok := c.GetVectorProvider().(vector.StringQuerier) + if !ok { + return logAndReturnError(log, "[performStringQuery] active vector provider does not implement StringQuerier interface") + } + + return stringQuerier.QueryString(queryString, ctx, log, func(results []vector.QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error) { + handleQueryResults(key, results, ctx, log, stream, c, err) + }) +} + +// performEmbeddingQuery executes the embedding-based similarity search. +func performEmbeddingQuery(key string, ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log, stream bool) error { + embeddingQuerier, ok := c.GetVectorProvider().(vector.EmbeddingQuerier) + if !ok { + return logAndReturnError(log, fmt.Sprintf("[performEmbeddingQuery] active vector provider does not implement EmbeddingQuerier interface")) + } + + activeEmbeddingProvider := c.GetEmbeddingProvider() + if activeEmbeddingProvider == nil { + return logAndReturnError(log, fmt.Sprintf("[performEmbeddingQuery] no embedding provider configured for similarity search")) + } + + return activeEmbeddingProvider.GetEmbedding(key, ctx, log, func(textEmbedding []float64, err error) { + log.Debugf("[%s] [performEmbeddingQuery] GetEmbedding success, length of embedding: %d, error: %v", PLUGIN_NAME, len(textEmbedding), err) + if err != nil { + handleInternalError(err, fmt.Sprintf("[%s] [performEmbeddingQuery] error getting embedding for key: %s", PLUGIN_NAME, key), log) + return + } + ctx.SetContext(CACHE_KEY_EMBEDDING_KEY, textEmbedding) + + err = embeddingQuerier.QueryEmbedding(textEmbedding, ctx, log, func(results []vector.QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error) { + handleQueryResults(key, results, ctx, log, stream, c, err) + }) + if err != nil { + handleInternalError(err, fmt.Sprintf("[%s] [performEmbeddingQuery] error querying vector database for key: %s", PLUGIN_NAME, key), log) + } + }) +} + +// handleQueryResults processes the results of similarity search and determines next actions. +func handleQueryResults(key string, results []vector.QueryResult, ctx wrapper.HttpContext, log wrapper.Log, stream bool, c config.PluginConfig, err error) { + if err != nil { + handleInternalError(err, fmt.Sprintf("[%s] [handleQueryResults] error querying vector database for key: %s", PLUGIN_NAME, key), log) + return + } + + if len(results) == 0 { + log.Warnf("[%s] [handleQueryResults] no similar keys found for key: %s", PLUGIN_NAME, key) + proxywasm.ResumeHttpRequest() + return + } + + mostSimilarData := results[0] + log.Debugf("[%s] [handleQueryResults] for key: %s, the most similar key found: %s with score: %f", PLUGIN_NAME, key, mostSimilarData.Text, mostSimilarData.Score) + simThreshold := c.GetVectorProviderConfig().Threshold + simThresholdRelation := c.GetVectorProviderConfig().ThresholdRelation + if compare(simThresholdRelation, mostSimilarData.Score, simThreshold) { + log.Infof("[%s] key accepted: %s with score: %f", PLUGIN_NAME, mostSimilarData.Text, mostSimilarData.Score) + if mostSimilarData.Answer != "" { + // direct return the answer if available + cacheResponse(ctx, c, key, mostSimilarData.Answer, log) + processCacheHit(key, mostSimilarData.Answer, stream, ctx, c, log) + } else { + if c.GetCacheProvider() != nil { + CheckCacheForKey(mostSimilarData.Text, ctx, c, log, stream, false) + } else { + // Otherwise, do not check the cache, directly return + log.Infof("[%s] cache hit for key: %s, but no corresponding answer found in the vector database", PLUGIN_NAME, mostSimilarData.Text) + proxywasm.ResumeHttpRequest() + } + } + } else { + log.Infof("[%s] score not meet the threshold %f: %s with score %f", PLUGIN_NAME, simThreshold, mostSimilarData.Text, mostSimilarData.Score) + proxywasm.ResumeHttpRequest() + } +} + +// logAndReturnError logs an error and returns it. +func logAndReturnError(log wrapper.Log, message string) error { + message = fmt.Sprintf("[%s] %s", PLUGIN_NAME, message) + log.Errorf(message) + return errors.New(message) +} + +// handleInternalError logs an error and resumes the HTTP request. +func handleInternalError(err error, message string, log wrapper.Log) { + if err != nil { + log.Errorf("[%s] [handleInternalError] %s: %v", PLUGIN_NAME, message, err) + } else { + log.Errorf("[%s] [handleInternalError] %s", PLUGIN_NAME, message) + } + // proxywasm.SendHttpResponse(500, [][2]string{{"content-type", "text/plain"}}, []byte("Internal Server Error"), -1) + proxywasm.ResumeHttpRequest() +} + +// Caches the response value +func cacheResponse(ctx wrapper.HttpContext, c config.PluginConfig, key string, value string, log wrapper.Log) { + if strings.TrimSpace(value) == "" { + log.Warnf("[%s] [cacheResponse] cached value for key %s is empty", PLUGIN_NAME, key) + return + } + + activeCacheProvider := c.GetCacheProvider() + if activeCacheProvider != nil { + queryKey := activeCacheProvider.GetCacheKeyPrefix() + key + _ = activeCacheProvider.Set(queryKey, value, nil) + log.Debugf("[%s] [cacheResponse] cache set success, key: %s, length of value: %d", PLUGIN_NAME, queryKey, len(value)) + } +} + +// Handles embedding upload if available +func uploadEmbeddingAndAnswer(ctx wrapper.HttpContext, c config.PluginConfig, key string, value string, log wrapper.Log) { + embedding := ctx.GetContext(CACHE_KEY_EMBEDDING_KEY) + if embedding == nil { + return + } + + emb, ok := embedding.([]float64) + if !ok { + log.Errorf("[%s] [uploadEmbeddingAndAnswer] embedding is not of expected type []float64", PLUGIN_NAME) + return + } + + activeVectorProvider := c.GetVectorProvider() + if activeVectorProvider == nil { + log.Debugf("[%s] [uploadEmbeddingAndAnswer] no vector provider configured for uploading embedding", PLUGIN_NAME) + return + } + + // Attempt to upload answer embedding first + if ansEmbUploader, ok := activeVectorProvider.(vector.AnswerAndEmbeddingUploader); ok { + log.Infof("[%s] uploading answer embedding for key: %s", PLUGIN_NAME, key) + err := ansEmbUploader.UploadAnswerAndEmbedding(key, emb, value, ctx, log, nil) + if err != nil { + log.Warnf("[%s] [uploadEmbeddingAndAnswer] failed to upload answer embedding for key: %s, error: %v", PLUGIN_NAME, key, err) + } else { + return // If successful, return early + } + } + + // If answer embedding upload fails, attempt normal embedding upload + if embUploader, ok := activeVectorProvider.(vector.EmbeddingUploader); ok { + log.Infof("[%s] uploading embedding for key: %s", PLUGIN_NAME, key) + err := embUploader.UploadEmbedding(key, emb, ctx, log, nil) + if err != nil { + log.Warnf("[%s] [uploadEmbeddingAndAnswer] failed to upload embedding for key: %s, error: %v", PLUGIN_NAME, key, err) + } + } +} + +// 主要用于相似度/距离/点积判断 +// 余弦相似度度量的是两个向量在方向上的相似程度。相似度越高,两个向量越接近。 +// 距离度量的是两个向量在空间上的远近程度。距离越小,两个向量越接近。 +// compare 函数根据操作符进行判断并返回结果 +func compare(operator string, value1 float64, value2 float64) bool { + switch operator { + case "gt": + return value1 > value2 + case "gte": + return value1 >= value2 + case "lt": + return value1 < value2 + case "lte": + return value1 <= value2 + default: + return false + } +} diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go b/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go new file mode 100644 index 0000000000..35c897cce5 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go @@ -0,0 +1,187 @@ +package embedding + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "strconv" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" +) + +const ( + DASHSCOPE_DOMAIN = "dashscope.aliyuncs.com" + DASHSCOPE_PORT = 443 + DASHSCOPE_DEFAULT_MODEL_NAME = "text-embedding-v2" + DASHSCOPE_ENDPOINT = "/api/v1/services/embeddings/text-embedding/text-embedding" +) + +type dashScopeProviderInitializer struct { +} + +func (d *dashScopeProviderInitializer) ValidateConfig(config ProviderConfig) error { + if config.apiKey == "" { + return errors.New("[DashScope] apiKey is required") + } + return nil +} + +func (d *dashScopeProviderInitializer) CreateProvider(c ProviderConfig) (Provider, error) { + if c.servicePort == 0 { + c.servicePort = DASHSCOPE_PORT + } + if c.serviceHost == "" { + c.serviceHost = DASHSCOPE_DOMAIN + } + return &DSProvider{ + config: c, + client: wrapper.NewClusterClient(wrapper.FQDNCluster{ + FQDN: c.serviceName, + Host: c.serviceHost, + Port: int64(c.servicePort), + }), + }, nil +} + +func (d *DSProvider) GetProviderType() string { + return PROVIDER_TYPE_DASHSCOPE +} + +type Embedding struct { + Embedding []float64 `json:"embedding"` + TextIndex int `json:"text_index"` +} + +type Input struct { + Texts []string `json:"texts"` +} + +type Params struct { + TextType string `json:"text_type"` +} + +type Response struct { + RequestID string `json:"request_id"` + Output Output `json:"output"` + Usage Usage `json:"usage"` +} + +type Output struct { + Embeddings []Embedding `json:"embeddings"` +} + +type Usage struct { + TotalTokens int `json:"total_tokens"` +} + +type EmbeddingRequest struct { + Model string `json:"model"` + Input Input `json:"input"` + Parameters Params `json:"parameters"` +} + +type Document struct { + Vector []float64 `json:"vector"` + Fields map[string]string `json:"fields"` +} + +type DSProvider struct { + config ProviderConfig + client wrapper.HttpClient +} + +func (d *DSProvider) constructParameters(texts []string, log wrapper.Log) (string, [][2]string, []byte, error) { + + model := d.config.model + + if model == "" { + model = DASHSCOPE_DEFAULT_MODEL_NAME + } + data := EmbeddingRequest{ + Model: model, + Input: Input{ + Texts: texts, + }, + Parameters: Params{ + TextType: "query", + }, + } + + requestBody, err := json.Marshal(data) + if err != nil { + log.Errorf("failed to marshal request data: %v", err) + return "", nil, nil, err + } + + if d.config.apiKey == "" { + err := errors.New("dashScopeKey is empty") + log.Errorf("failed to construct headers: %v", err) + return "", nil, nil, err + } + + headers := [][2]string{ + {"Authorization", "Bearer " + d.config.apiKey}, + {"Content-Type", "application/json"}, + } + + return DASHSCOPE_ENDPOINT, headers, requestBody, err +} + +type Result struct { + ID string `json:"id"` + Vector []float64 `json:"vector,omitempty"` + Fields map[string]interface{} `json:"fields"` + Score float64 `json:"score"` +} + +func (d *DSProvider) parseTextEmbedding(responseBody []byte) (*Response, error) { + var resp Response + err := json.Unmarshal(responseBody, &resp) + if err != nil { + return nil, err + } + return &resp, nil +} + +func (d *DSProvider) GetEmbedding( + queryString string, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(emb []float64, err error)) error { + embUrl, embHeaders, embRequestBody, err := d.constructParameters([]string{queryString}, log) + if err != nil { + log.Errorf("failed to construct parameters: %v", err) + return err + } + + var resp *Response + err = d.client.Post(embUrl, embHeaders, embRequestBody, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + + if statusCode != http.StatusOK { + err = errors.New("failed to get embedding due to status code: " + strconv.Itoa(statusCode)) + callback(nil, err) + return + } + + log.Debugf("get embedding response: %d, %s", statusCode, responseBody) + + resp, err = d.parseTextEmbedding(responseBody) + if err != nil { + err = fmt.Errorf("failed to parse response: %v", err) + callback(nil, err) + return + } + + if len(resp.Output.Embeddings) == 0 { + err = errors.New("no embedding found in response") + callback(nil, err) + return + } + + callback(resp.Output.Embeddings[0].Embedding, nil) + + }, d.config.timeout) + return err +} diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/provider.go b/plugins/wasm-go/extensions/ai-cache/embedding/provider.go new file mode 100644 index 0000000000..909edf129c --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/embedding/provider.go @@ -0,0 +1,101 @@ +package embedding + +import ( + "errors" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" +) + +const ( + PROVIDER_TYPE_DASHSCOPE = "dashscope" +) + +type providerInitializer interface { + ValidateConfig(ProviderConfig) error + CreateProvider(ProviderConfig) (Provider, error) +} + +var ( + providerInitializers = map[string]providerInitializer{ + PROVIDER_TYPE_DASHSCOPE: &dashScopeProviderInitializer{}, + } +) + +type ProviderConfig struct { + // @Title zh-CN 文本特征提取服务提供者类型 + // @Description zh-CN 文本特征提取服务提供者类型,例如 DashScope + typ string + // @Title zh-CN DashScope 文本特征提取服务名称 + // @Description zh-CN 文本特征提取服务名称 + serviceName string + // @Title zh-CN 文本特征提取服务域名 + // @Description zh-CN 文本特征提取服务域名 + serviceHost string + // @Title zh-CN 文本特征提取服务端口 + // @Description zh-CN 文本特征提取服务端口 + servicePort int64 + // @Title zh-CN 文本特征提取服务 API Key + // @Description zh-CN 文本特征提取服务 API Key + apiKey string + // @Title zh-CN 文本特征提取服务超时时间 + // @Description zh-CN 文本特征提取服务超时时间 + timeout uint32 + // @Title zh-CN 文本特征提取服务使用的模型 + // @Description zh-CN 用于文本特征提取的模型名称, 在 DashScope 中默认为 "text-embedding-v1" + model string +} + +func (c *ProviderConfig) FromJson(json gjson.Result) { + c.typ = json.Get("type").String() + c.serviceName = json.Get("serviceName").String() + c.serviceHost = json.Get("serviceHost").String() + c.servicePort = json.Get("servicePort").Int() + c.apiKey = json.Get("apiKey").String() + c.timeout = uint32(json.Get("timeout").Int()) + c.model = json.Get("model").String() + if c.timeout == 0 { + c.timeout = 10000 + } +} + +func (c *ProviderConfig) Validate() error { + if c.serviceName == "" { + return errors.New("embedding service name is required") + } + if c.apiKey == "" { + return errors.New("embedding service API key is required") + } + if c.typ == "" { + return errors.New("embedding service type is required") + } + initializer, has := providerInitializers[c.typ] + if !has { + return errors.New("unknown embedding service provider type: " + c.typ) + } + if err := initializer.ValidateConfig(*c); err != nil { + return err + } + return nil +} + +func (c *ProviderConfig) GetProviderType() string { + return c.typ +} + +func CreateProvider(pc ProviderConfig) (Provider, error) { + initializer, has := providerInitializers[pc.typ] + if !has { + return nil, errors.New("unknown provider type: " + pc.typ) + } + return initializer.CreateProvider(pc) +} + +type Provider interface { + GetProviderType() string + GetEmbedding( + queryString string, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(emb []float64, err error)) error +} diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/weaviate.go b/plugins/wasm-go/extensions/ai-cache/embedding/weaviate.go new file mode 100644 index 0000000000..b26d9cea8d --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/embedding/weaviate.go @@ -0,0 +1,27 @@ +package embedding + +// import ( +// "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" +// ) + +// const ( +// weaviateURL = "172.17.0.1:8081" +// ) + +// type weaviateProviderInitializer struct { +// } + +// func (d *weaviateProviderInitializer) ValidateConfig(config ProviderConfig) error { +// return nil +// } + +// func (d *weaviateProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { +// return &DSProvider{ +// config: config, +// client: wrapper.NewClusterClient(wrapper.DnsCluster{ +// ServiceName: config.ServiceName, +// Port: dashScopePort, +// Domain: dashScopeDomain, +// }), +// }, nil +// } diff --git a/plugins/wasm-go/extensions/ai-cache/go.mod b/plugins/wasm-go/extensions/ai-cache/go.mod index c9630cfb8a..e4aae265e0 100644 --- a/plugins/wasm-go/extensions/ai-cache/go.mod +++ b/plugins/wasm-go/extensions/ai-cache/go.mod @@ -7,17 +7,18 @@ go 1.19 replace github.com/alibaba/higress/plugins/wasm-go => ../.. require ( - github.com/alibaba/higress/plugins/wasm-go v1.3.6-0.20240528060522-53bccf89f441 + github.com/alibaba/higress/plugins/wasm-go v1.4.2 github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f - github.com/tidwall/gjson v1.14.3 + github.com/tidwall/gjson v1.17.3 github.com/tidwall/resp v0.1.1 - github.com/tidwall/sjson v1.2.5 +// github.com/weaviate/weaviate-go-client/v4 v4.15.1 ) require ( - github.com/google/uuid v1.3.0 // indirect + github.com/google/uuid v1.6.0 // indirect github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect github.com/magefile/mage v1.14.0 // indirect + github.com/stretchr/testify v1.9.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect ) diff --git a/plugins/wasm-go/extensions/ai-cache/go.sum b/plugins/wasm-go/extensions/ai-cache/go.sum index 8246b4de5e..7ada0c8b70 100644 --- a/plugins/wasm-go/extensions/ai-cache/go.sum +++ b/plugins/wasm-go/extensions/ai-cache/go.sum @@ -1,24 +1,21 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= -github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew= -github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc h1:t2AT8zb6N/59Y78lyRWedVoVWHNRSCBh0oWCC+bluTQ= -github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= +github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg= github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo= github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw= -github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94= +github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE= github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0= -github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= -github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/plugins/wasm-go/extensions/ai-cache/main.go b/plugins/wasm-go/extensions/ai-cache/main.go index 7886d5698f..1aca29f0ec 100644 --- a/plugins/wasm-go/extensions/ai-cache/main.go +++ b/plugins/wasm-go/extensions/ai-cache/main.go @@ -1,33 +1,33 @@ -// File generated by hgctl. Modify as required. -// See: https://higress.io/zh-cn/docs/user/wasm-go#2-%E7%BC%96%E5%86%99-maingo-%E6%96%87%E4%BB%B6 - +// 这个文件中主要将OnHttpRequestHeaders、OnHttpRequestBody、OnHttpResponseHeaders、OnHttpResponseBody这四个函数实现 +// 其中的缓存思路调用cache.go中的逻辑,然后cache.go中的逻辑会调用textEmbeddingProvider和vectorStoreProvider中的逻辑(实例) package main import ( - "errors" - "fmt" "strings" + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/config" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/tidwall/gjson" - "github.com/tidwall/resp" ) const ( - CacheKeyContextKey = "cacheKey" - CacheContentContextKey = "cacheContent" - PartialMessageContextKey = "partialMessage" - ToolCallsContextKey = "toolCalls" - StreamContextKey = "stream" - DefaultCacheKeyPrefix = "higress-ai-cache:" - SkipCacheHeader = "x-higress-skip-ai-cache" + PLUGIN_NAME = "ai-cache" + CACHE_KEY_CONTEXT_KEY = "cacheKey" + CACHE_KEY_EMBEDDING_KEY = "cacheKeyEmbedding" + CACHE_CONTENT_CONTEXT_KEY = "cacheContent" + PARTIAL_MESSAGE_CONTEXT_KEY = "partialMessage" + TOOL_CALLS_CONTEXT_KEY = "toolCalls" + STREAM_CONTEXT_KEY = "stream" + SKIP_CACHE_HEADER = "x-higress-skip-ai-cache" + ERROR_PARTIAL_MESSAGE_KEY = "errorPartialMessage" ) func main() { + // CreateClient() wrapper.SetCtx( - "ai-cache", + PLUGIN_NAME, wrapper.ParseConfigBy(parseConfig), wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders), wrapper.ProcessRequestBodyBy(onHttpRequestBody), @@ -36,146 +36,26 @@ func main() { ) } -// @Name ai-cache -// @Category protocol -// @Phase AUTHN -// @Priority 10 -// @Title zh-CN AI Cache -// @Description zh-CN 大模型结果缓存 -// @IconUrl -// @Version 0.1.0 -// -// @Contact.name johnlanni -// @Contact.url -// @Contact.email -// -// @Example -// redis: -// serviceName: my-redis.dns -// timeout: 2000 -// cacheKeyFrom: -// requestBody: "messages.@reverse.0.content" -// cacheValueFrom: -// responseBody: "choices.0.message.content" -// cacheStreamValueFrom: -// responseBody: "choices.0.delta.content" -// returnResponseTemplate: | -// {"id":"from-cache","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}} -// returnStreamResponseTemplate: | -// data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}} -// -// data:[DONE] -// -// @End - -type RedisInfo struct { - // @Title zh-CN redis 服务名称 - // @Description zh-CN 带服务类型的完整 FQDN 名称,例如 my-redis.dns、redis.my-ns.svc.cluster.local - ServiceName string `required:"true" yaml:"serviceName" json:"serviceName"` - // @Title zh-CN redis 服务端口 - // @Description zh-CN 默认值为6379 - ServicePort int `required:"false" yaml:"servicePort" json:"servicePort"` - // @Title zh-CN 用户名 - // @Description zh-CN 登陆 redis 的用户名,非必填 - Username string `required:"false" yaml:"username" json:"username"` - // @Title zh-CN 密码 - // @Description zh-CN 登陆 redis 的密码,非必填,可以只填密码 - Password string `required:"false" yaml:"password" json:"password"` - // @Title zh-CN 请求超时 - // @Description zh-CN 请求 redis 的超时时间,单位为毫秒。默认值是1000,即1秒 - Timeout int `required:"false" yaml:"timeout" json:"timeout"` -} - -type KVExtractor struct { - // @Title zh-CN 从请求 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 - RequestBody string `required:"false" yaml:"requestBody" json:"requestBody"` - // @Title zh-CN 从响应 Body 中基于 [GJSON PATH](https://github.com/tidwall/gjson/blob/master/SYNTAX.md) 语法提取字符串 - ResponseBody string `required:"false" yaml:"responseBody" json:"responseBody"` -} - -type PluginConfig struct { - // @Title zh-CN Redis 地址信息 - // @Description zh-CN 用于存储缓存结果的 Redis 地址 - RedisInfo RedisInfo `required:"true" yaml:"redis" json:"redis"` - // @Title zh-CN 缓存 key 的来源 - // @Description zh-CN 往 redis 里存时,使用的 key 的提取方式 - CacheKeyFrom KVExtractor `required:"true" yaml:"cacheKeyFrom" json:"cacheKeyFrom"` - // @Title zh-CN 缓存 value 的来源 - // @Description zh-CN 往 redis 里存时,使用的 value 的提取方式 - CacheValueFrom KVExtractor `required:"true" yaml:"cacheValueFrom" json:"cacheValueFrom"` - // @Title zh-CN 流式响应下,缓存 value 的来源 - // @Description zh-CN 往 redis 里存时,使用的 value 的提取方式 - CacheStreamValueFrom KVExtractor `required:"true" yaml:"cacheStreamValueFrom" json:"cacheStreamValueFrom"` - // @Title zh-CN 返回 HTTP 响应的模版 - // @Description zh-CN 用 %s 标记需要被 cache value 替换的部分 - ReturnResponseTemplate string `required:"true" yaml:"returnResponseTemplate" json:"returnResponseTemplate"` - // @Title zh-CN 返回流式 HTTP 响应的模版 - // @Description zh-CN 用 %s 标记需要被 cache value 替换的部分 - ReturnStreamResponseTemplate string `required:"true" yaml:"returnStreamResponseTemplate" json:"returnStreamResponseTemplate"` - // @Title zh-CN 缓存的过期时间 - // @Description zh-CN 单位是秒,默认值为0,即永不过期 - CacheTTL int `required:"false" yaml:"cacheTTL" json:"cacheTTL"` - // @Title zh-CN Redis缓存Key的前缀 - // @Description zh-CN 默认值是"higress-ai-cache:" - CacheKeyPrefix string `required:"false" yaml:"cacheKeyPrefix" json:"cacheKeyPrefix"` - redisClient wrapper.RedisClient `yaml:"-" json:"-"` -} - -func parseConfig(json gjson.Result, c *PluginConfig, log wrapper.Log) error { - c.RedisInfo.ServiceName = json.Get("redis.serviceName").String() - if c.RedisInfo.ServiceName == "" { - return errors.New("redis service name must not by empty") - } - c.RedisInfo.ServicePort = int(json.Get("redis.servicePort").Int()) - if c.RedisInfo.ServicePort == 0 { - if strings.HasSuffix(c.RedisInfo.ServiceName, ".static") { - // use default logic port which is 80 for static service - c.RedisInfo.ServicePort = 80 - } else { - c.RedisInfo.ServicePort = 6379 - } - } - c.RedisInfo.Username = json.Get("redis.username").String() - c.RedisInfo.Password = json.Get("redis.password").String() - c.RedisInfo.Timeout = int(json.Get("redis.timeout").Int()) - if c.RedisInfo.Timeout == 0 { - c.RedisInfo.Timeout = 1000 - } - c.CacheKeyFrom.RequestBody = json.Get("cacheKeyFrom.requestBody").String() - if c.CacheKeyFrom.RequestBody == "" { - c.CacheKeyFrom.RequestBody = "messages.@reverse.0.content" - } - c.CacheValueFrom.ResponseBody = json.Get("cacheValueFrom.responseBody").String() - if c.CacheValueFrom.ResponseBody == "" { - c.CacheValueFrom.ResponseBody = "choices.0.message.content" - } - c.CacheStreamValueFrom.ResponseBody = json.Get("cacheStreamValueFrom.responseBody").String() - if c.CacheStreamValueFrom.ResponseBody == "" { - c.CacheStreamValueFrom.ResponseBody = "choices.0.delta.content" - } - c.ReturnResponseTemplate = json.Get("returnResponseTemplate").String() - if c.ReturnResponseTemplate == "" { - c.ReturnResponseTemplate = `{"id":"from-cache","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` - } - c.ReturnStreamResponseTemplate = json.Get("returnStreamResponseTemplate").String() - if c.ReturnStreamResponseTemplate == "" { - c.ReturnStreamResponseTemplate = `data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + "\n\ndata:[DONE]\n\n" +func parseConfig(json gjson.Result, c *config.PluginConfig, log wrapper.Log) error { + // config.EmbeddingProviderConfig.FromJson(json.Get("embeddingProvider")) + // config.VectorDatabaseProviderConfig.FromJson(json.Get("vectorBaseProvider")) + // config.RedisConfig.FromJson(json.Get("redis")) + c.FromJson(json, log) + if err := c.Validate(); err != nil { + return err } - c.CacheKeyPrefix = json.Get("cacheKeyPrefix").String() - if c.CacheKeyPrefix == "" { - c.CacheKeyPrefix = DefaultCacheKeyPrefix + // Note that initializing the client during the parseConfig phase may cause errors, such as Redis not being usable in Docker Compose. + if err := c.Complete(log); err != nil { + log.Errorf("complete config failed: %v", err) + return err } - c.redisClient = wrapper.NewRedisClusterClient(wrapper.FQDNCluster{ - FQDN: c.RedisInfo.ServiceName, - Port: int64(c.RedisInfo.ServicePort), - }) - return c.redisClient.Init(c.RedisInfo.Username, c.RedisInfo.Password, int64(c.RedisInfo.Timeout)) + return nil } -func onHttpRequestHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action { - skipCache, _ := proxywasm.GetHttpRequestHeader(SkipCacheHeader) +func onHttpRequestHeaders(ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log) types.Action { + skipCache, _ := proxywasm.GetHttpRequestHeader(SKIP_CACHE_HEADER) if skipCache == "on" { - ctx.SetContext(SkipCacheHeader, struct{}{}) + ctx.SetContext(SKIP_CACHE_HEADER, struct{}{}) ctx.DontReadRequestBody() return types.ActionContinue } @@ -185,199 +65,123 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrap return types.ActionContinue } if !strings.Contains(contentType, "application/json") { - log.Warnf("content is not json, can't process:%s", contentType) + log.Warnf("content is not json, can't process: %s", contentType) ctx.DontReadRequestBody() return types.ActionContinue } - proxywasm.RemoveHttpRequestHeader("Accept-Encoding") + _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") // The request has a body and requires delaying the header transmission until a cache miss occurs, // at which point the header should be sent. return types.HeaderStopIteration } -func TrimQuote(source string) string { - return strings.Trim(source, `"`) -} +func onHttpRequestBody(ctx wrapper.HttpContext, c config.PluginConfig, body []byte, log wrapper.Log) types.Action { -func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte, log wrapper.Log) types.Action { bodyJson := gjson.ParseBytes(body) // TODO: It may be necessary to support stream mode determination for different LLM providers. stream := false if bodyJson.Get("stream").Bool() { stream = true - ctx.SetContext(StreamContextKey, struct{}{}) - } else if ctx.GetContext(StreamContextKey) != nil { - stream = true + ctx.SetContext(STREAM_CONTEXT_KEY, struct{}{}) + } + + var key string + if c.CacheKeyStrategy == config.CACHE_KEY_STRATEGY_LAST_QUESTION { + log.Debugf("[onHttpRequestBody] cache key strategy is last question, cache key from: %s", c.CacheKeyFrom) + key = bodyJson.Get(c.CacheKeyFrom).String() + } else if c.CacheKeyStrategy == config.CACHE_KEY_STRATEGY_ALL_QUESTIONS { + log.Debugf("[onHttpRequestBody] cache key strategy is all questions, cache key from: messages") + messages := bodyJson.Get("messages").Array() + var userMessages []string + for _, msg := range messages { + if msg.Get("role").String() == "user" { + userMessages = append(userMessages, msg.Get("content").String()) + } + } + key = strings.Join(userMessages, "\n") + } else if c.CacheKeyStrategy == config.CACHE_KEY_STRATEGY_DISABLED { + log.Info("[onHttpRequestBody] cache key strategy is disabled") + ctx.DontReadRequestBody() + return types.ActionContinue + } else { + log.Warnf("[onHttpRequestBody] unknown cache key strategy: %s", c.CacheKeyStrategy) + ctx.DontReadRequestBody() + return types.ActionContinue } - key := TrimQuote(bodyJson.Get(config.CacheKeyFrom.RequestBody).Raw) + + ctx.SetContext(CACHE_KEY_CONTEXT_KEY, key) + log.Debugf("[onHttpRequestBody] key: %s", key) if key == "" { - log.Debug("parse key from request body failed") + log.Debug("[onHttpRequestBody] parse key from request body failed") + ctx.DontReadResponseBody() return types.ActionContinue } - ctx.SetContext(CacheKeyContextKey, key) - err := config.redisClient.Get(config.CacheKeyPrefix+key, func(response resp.Value) { - if err := response.Error(); err != nil { - log.Errorf("redis get key:%s failed, err:%v", key, err) - proxywasm.ResumeHttpRequest() - return - } - if response.IsNull() { - log.Debugf("cache miss, key:%s", key) - proxywasm.ResumeHttpRequest() - return - } - log.Debugf("cache hit, key:%s", key) - ctx.SetContext(CacheKeyContextKey, nil) - if !stream { - proxywasm.SendHttpResponseWithDetail(200, "ai-cache.hit", [][2]string{{"content-type", "application/json; charset=utf-8"}}, []byte(fmt.Sprintf(config.ReturnResponseTemplate, response.String())), -1) - } else { - proxywasm.SendHttpResponseWithDetail(200, "ai-cache.hit", [][2]string{{"content-type", "text/event-stream; charset=utf-8"}}, []byte(fmt.Sprintf(config.ReturnStreamResponseTemplate, response.String())), -1) - } - }) - if err != nil { - log.Error("redis access failed") + + if err := CheckCacheForKey(key, ctx, c, log, stream, true); err != nil { + log.Errorf("[onHttpRequestBody] check cache for key: %s failed, error: %v", key, err) return types.ActionContinue } - return types.ActionPause -} -func processSSEMessage(ctx wrapper.HttpContext, config PluginConfig, sseMessage string, log wrapper.Log) string { - subMessages := strings.Split(sseMessage, "\n") - var message string - for _, msg := range subMessages { - if strings.HasPrefix(msg, "data:") { - message = msg - break - } - } - if len(message) < 6 { - log.Errorf("invalid message:%s", message) - return "" - } - // skip the prefix "data:" - bodyJson := message[5:] - if gjson.Get(bodyJson, config.CacheStreamValueFrom.ResponseBody).Exists() { - tempContentI := ctx.GetContext(CacheContentContextKey) - if tempContentI == nil { - content := TrimQuote(gjson.Get(bodyJson, config.CacheStreamValueFrom.ResponseBody).Raw) - ctx.SetContext(CacheContentContextKey, content) - return content - } - append := TrimQuote(gjson.Get(bodyJson, config.CacheStreamValueFrom.ResponseBody).Raw) - content := tempContentI.(string) + append - ctx.SetContext(CacheContentContextKey, content) - return content - } else if gjson.Get(bodyJson, "choices.0.delta.content.tool_calls").Exists() { - // TODO: compatible with other providers - ctx.SetContext(ToolCallsContextKey, struct{}{}) - return "" - } - log.Debugf("unknown message:%s", bodyJson) - return "" + return types.ActionPause } -func onHttpResponseHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action { - skipCache := ctx.GetContext(SkipCacheHeader) +func onHttpResponseHeaders(ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log) types.Action { + skipCache := ctx.GetContext(SKIP_CACHE_HEADER) if skipCache != nil { ctx.DontReadResponseBody() return types.ActionContinue } contentType, _ := proxywasm.GetHttpResponseHeader("content-type") if strings.Contains(contentType, "text/event-stream") { - ctx.SetContext(StreamContextKey, struct{}{}) + ctx.SetContext(STREAM_CONTEXT_KEY, struct{}{}) } + + if ctx.GetContext(ERROR_PARTIAL_MESSAGE_KEY) != nil { + ctx.DontReadResponseBody() + return types.ActionContinue + } + return types.ActionContinue } -func onHttpResponseBody(ctx wrapper.HttpContext, config PluginConfig, chunk []byte, isLastChunk bool, log wrapper.Log) []byte { - if ctx.GetContext(ToolCallsContextKey) != nil { - // we should not cache tool call result +func onHttpResponseBody(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, isLastChunk bool, log wrapper.Log) []byte { + log.Debugf("[onHttpResponseBody] is last chunk: %v", isLastChunk) + log.Debugf("[onHttpResponseBody] chunk: %s", string(chunk)) + + if ctx.GetContext(TOOL_CALLS_CONTEXT_KEY) != nil { return chunk } - keyI := ctx.GetContext(CacheKeyContextKey) - if keyI == nil { + + key := ctx.GetContext(CACHE_KEY_CONTEXT_KEY) + if key == nil { + log.Debug("[onHttpResponseBody] key is nil, skip cache") return chunk } + if !isLastChunk { - stream := ctx.GetContext(StreamContextKey) - if stream == nil { - tempContentI := ctx.GetContext(CacheContentContextKey) - if tempContentI == nil { - ctx.SetContext(CacheContentContextKey, chunk) - return chunk - } - tempContent := tempContentI.([]byte) - tempContent = append(tempContent, chunk...) - ctx.SetContext(CacheContentContextKey, tempContent) - } else { - var partialMessage []byte - partialMessageI := ctx.GetContext(PartialMessageContextKey) - if partialMessageI != nil { - partialMessage = append(partialMessageI.([]byte), chunk...) - } else { - partialMessage = chunk - } - messages := strings.Split(string(partialMessage), "\n\n") - for i, msg := range messages { - if i < len(messages)-1 { - // process complete message - processSSEMessage(ctx, config, msg, log) - } - } - if !strings.HasSuffix(string(partialMessage), "\n\n") { - ctx.SetContext(PartialMessageContextKey, []byte(messages[len(messages)-1])) - } else { - ctx.SetContext(PartialMessageContextKey, nil) - } + if err := handleNonLastChunk(ctx, c, chunk, log); err != nil { + log.Errorf("[onHttpResponseBody] handle non last chunk failed, error: %v", err) + // Set an empty struct in the context to indicate an error in processing the partial message + ctx.SetContext(ERROR_PARTIAL_MESSAGE_KEY, struct{}{}) } return chunk } - // last chunk - key := keyI.(string) - stream := ctx.GetContext(StreamContextKey) + + stream := ctx.GetContext(STREAM_CONTEXT_KEY) var value string + var err error if stream == nil { - var body []byte - tempContentI := ctx.GetContext(CacheContentContextKey) - if tempContentI != nil { - body = append(tempContentI.([]byte), chunk...) - } else { - body = chunk - } - bodyJson := gjson.ParseBytes(body) - - value = TrimQuote(bodyJson.Get(config.CacheValueFrom.ResponseBody).Raw) - if value == "" { - log.Warnf("parse value from response body failded, body:%s", body) - return chunk - } + value, err = processNonStreamLastChunk(ctx, c, chunk, log) } else { - if len(chunk) > 0 { - var lastMessage []byte - partialMessageI := ctx.GetContext(PartialMessageContextKey) - if partialMessageI != nil { - lastMessage = append(partialMessageI.([]byte), chunk...) - } else { - lastMessage = chunk - } - if !strings.HasSuffix(string(lastMessage), "\n\n") { - log.Warnf("invalid lastMessage:%s", lastMessage) - return chunk - } - // remove the last \n\n - lastMessage = lastMessage[:len(lastMessage)-2] - value = processSSEMessage(ctx, config, string(lastMessage), log) - } else { - tempContentI := ctx.GetContext(CacheContentContextKey) - if tempContentI == nil { - return chunk - } - value = tempContentI.(string) - } + value, err = processStreamLastChunk(ctx, c, chunk, log) } - config.redisClient.Set(config.CacheKeyPrefix+key, value, nil) - if config.CacheTTL != 0 { - config.redisClient.Expire(config.CacheKeyPrefix+key, config.CacheTTL, nil) + + if err != nil { + log.Errorf("[onHttpResponseBody] process last chunk failed, error: %v", err) + return chunk } + + cacheResponse(ctx, c, key.(string), value, log) + uploadEmbeddingAndAnswer(ctx, c, key.(string), value, log) return chunk } diff --git a/plugins/wasm-go/extensions/ai-cache/util.go b/plugins/wasm-go/extensions/ai-cache/util.go new file mode 100644 index 0000000000..983dfbb25a --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/util.go @@ -0,0 +1,155 @@ +package main + +import ( + "fmt" + "strings" + + "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/config" + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" +) + +func handleNonLastChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) error { + stream := ctx.GetContext(STREAM_CONTEXT_KEY) + err := error(nil) + if stream == nil { + err = handleNonStreamChunk(ctx, c, chunk, log) + } else { + err = handleStreamChunk(ctx, c, chunk, log) + } + return err +} + +func handleNonStreamChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) error { + tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) + if tempContentI == nil { + ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, chunk) + return nil + } + tempContent := tempContentI.([]byte) + tempContent = append(tempContent, chunk...) + ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, tempContent) + return nil +} + +func handleStreamChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) error { + var partialMessage []byte + partialMessageI := ctx.GetContext(PARTIAL_MESSAGE_CONTEXT_KEY) + log.Debugf("[handleStreamChunk] cache content: %v", ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY)) + if partialMessageI != nil { + partialMessage = append(partialMessageI.([]byte), chunk...) + } else { + partialMessage = chunk + } + messages := strings.Split(string(partialMessage), "\n\n") + for i, msg := range messages { + if i < len(messages)-1 { + _, err := processSSEMessage(ctx, c, msg, log) + if err != nil { + return fmt.Errorf("[handleStreamChunk] processSSEMessage failed, error: %v", err) + } + } + } + if !strings.HasSuffix(string(partialMessage), "\n\n") { + ctx.SetContext(PARTIAL_MESSAGE_CONTEXT_KEY, []byte(messages[len(messages)-1])) + } else { + ctx.SetContext(PARTIAL_MESSAGE_CONTEXT_KEY, nil) + } + return nil +} + +func processNonStreamLastChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) (string, error) { + var body []byte + tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) + if tempContentI != nil { + body = append(tempContentI.([]byte), chunk...) + } else { + body = chunk + } + bodyJson := gjson.ParseBytes(body) + value := bodyJson.Get(c.CacheValueFrom).String() + if strings.TrimSpace(value) == "" { + return "", fmt.Errorf("[processNonStreamLastChunk] parse value from response body failed, body:%s", body) + } + return value, nil +} + +func processStreamLastChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) (string, error) { + if len(chunk) > 0 { + var lastMessage []byte + partialMessageI := ctx.GetContext(PARTIAL_MESSAGE_CONTEXT_KEY) + if partialMessageI != nil { + lastMessage = append(partialMessageI.([]byte), chunk...) + } else { + lastMessage = chunk + } + if !strings.HasSuffix(string(lastMessage), "\n\n") { + return "", fmt.Errorf("[processStreamLastChunk] invalid lastMessage:%s", lastMessage) + } + lastMessage = lastMessage[:len(lastMessage)-2] + value, err := processSSEMessage(ctx, c, string(lastMessage), log) + if err != nil { + return "", fmt.Errorf("[processStreamLastChunk] processSSEMessage failed, error: %v", err) + } + return value, nil + } + tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) + if tempContentI == nil { + return "", nil + } + return tempContentI.(string), nil +} + +func processSSEMessage(ctx wrapper.HttpContext, c config.PluginConfig, sseMessage string, log wrapper.Log) (string, error) { + subMessages := strings.Split(sseMessage, "\n") + var message string + for _, msg := range subMessages { + if strings.HasPrefix(msg, "data:") { + message = msg + break + } + } + if len(message) < 6 { + return "", fmt.Errorf("[processSSEMessage] invalid message: %s", message) + } + + // skip the prefix "data:" + bodyJson := message[5:] + + if strings.TrimSpace(bodyJson) == "[DONE]" { + return "", nil + } + + // Extract values from JSON fields + responseBody := gjson.Get(bodyJson, c.CacheStreamValueFrom) + toolCalls := gjson.Get(bodyJson, c.CacheToolCallsFrom) + + if toolCalls.Exists() { + // TODO: Temporarily store the tool_calls value in the context for processing + ctx.SetContext(TOOL_CALLS_CONTEXT_KEY, toolCalls.String()) + } + + // Check if the ResponseBody field exists + if !responseBody.Exists() { + if ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) != nil { + log.Debugf("[processSSEMessage] unable to extract content from message; cache content is not nil: %s", message) + return "", nil + } + return "", fmt.Errorf("[processSSEMessage] unable to extract content from message; cache content is nil: %s", message) + } else { + tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) + + // If there is no content in the cache, initialize and set the content + if tempContentI == nil { + content := responseBody.String() + ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, content) + return content, nil + } + + // Update the content in the cache + appendMsg := responseBody.String() + content := tempContentI.(string) + appendMsg + ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, content) + return content, nil + } +} diff --git a/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go b/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go new file mode 100644 index 0000000000..7bdb0a76d0 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/vector/dashvector.go @@ -0,0 +1,256 @@ +package vector + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" +) + +type dashVectorProviderInitializer struct { +} + +func (d *dashVectorProviderInitializer) ValidateConfig(config ProviderConfig) error { + if len(config.apiKey) == 0 { + return errors.New("[DashVector] apiKey is required") + } + if len(config.collectionID) == 0 { + return errors.New("[DashVector] collectionID is required") + } + if len(config.serviceName) == 0 { + return errors.New("[DashVector] serviceName is required") + } + if len(config.serviceHost) == 0 { + return errors.New("[DashVector] serviceHost is required") + } + return nil +} + +func (d *dashVectorProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) { + return &DvProvider{ + config: config, + client: wrapper.NewClusterClient(wrapper.FQDNCluster{ + FQDN: config.serviceName, + Host: config.serviceHost, + Port: int64(config.servicePort), + }), + }, nil +} + +type DvProvider struct { + config ProviderConfig + client wrapper.HttpClient +} + +func (d *DvProvider) GetProviderType() string { + return PROVIDER_TYPE_DASH_VECTOR +} + +// type embeddingRequest struct { +// Model string `json:"model"` +// Input input `json:"input"` +// Parameters params `json:"parameters"` +// } + +// type params struct { +// TextType string `json:"text_type"` +// } + +// type input struct { +// Texts []string `json:"texts"` +// } + +// queryResponse 定义查询响应的结构 +type queryResponse struct { + Code int `json:"code"` + RequestID string `json:"request_id"` + Message string `json:"message"` + Output []result `json:"output"` +} + +// queryRequest 定义查询请求的结构 +type queryRequest struct { + Vector []float64 `json:"vector"` + TopK int `json:"topk"` + IncludeVector bool `json:"include_vector"` +} + +// result 定义查询结果的结构 +type result struct { + ID string `json:"id"` + Vector []float64 `json:"vector,omitempty"` // omitempty 使得如果 vector 是空,它将不会被序列化 + Fields map[string]interface{} `json:"fields"` + Score float64 `json:"score"` +} + +func (d *DvProvider) constructEmbeddingQueryParameters(vector []float64) (string, []byte, [][2]string, error) { + url := fmt.Sprintf("/v1/collections/%s/query", d.config.collectionID) + + requestData := queryRequest{ + Vector: vector, + TopK: d.config.topK, + IncludeVector: false, + } + + requestBody, err := json.Marshal(requestData) + if err != nil { + return "", nil, nil, err + } + + header := [][2]string{ + {"Content-Type", "application/json"}, + {"dashvector-auth-token", d.config.apiKey}, + } + + return url, requestBody, header, nil +} + +func (d *DvProvider) parseQueryResponse(responseBody []byte) (queryResponse, error) { + var queryResp queryResponse + err := json.Unmarshal(responseBody, &queryResp) + if err != nil { + return queryResponse{}, err + } + return queryResp, nil +} + +func (d *DvProvider) QueryEmbedding( + emb []float64, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error { + url, body, headers, err := d.constructEmbeddingQueryParameters(emb) + log.Debugf("url:%s, body:%s, headers:%v", url, string(body), headers) + if err != nil { + err = fmt.Errorf("failed to construct embedding query parameters: %v", err) + return err + } + + err = d.client.Post(url, headers, body, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + err = nil + if statusCode != http.StatusOK { + err = fmt.Errorf("failed to query embedding: %d", statusCode) + callback(nil, ctx, log, err) + return + } + log.Debugf("query embedding response: %d, %s", statusCode, responseBody) + results, err := d.ParseQueryResponse(responseBody, ctx, log) + if err != nil { + err = fmt.Errorf("failed to parse query response: %v", err) + } + callback(results, ctx, log, err) + }, + d.config.timeout) + if err != nil { + err = fmt.Errorf("failed to query embedding: %v", err) + } + return err +} + +func getStringValue(fields map[string]interface{}, key string) string { + if val, ok := fields[key]; ok { + return val.(string) + } + return "" +} + +func (d *DvProvider) ParseQueryResponse(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]QueryResult, error) { + resp, err := d.parseQueryResponse(responseBody) + if err != nil { + return nil, err + } + + if len(resp.Output) == 0 { + return nil, errors.New("no query results found in response") + } + + results := make([]QueryResult, 0, len(resp.Output)) + + for _, output := range resp.Output { + result := QueryResult{ + Text: getStringValue(output.Fields, "query"), + Embedding: output.Vector, + Score: output.Score, + Answer: getStringValue(output.Fields, "answer"), + } + results = append(results, result) + } + + return results, nil +} + +type document struct { + Vector []float64 `json:"vector"` + Fields map[string]string `json:"fields"` +} + +type insertRequest struct { + Docs []document `json:"docs"` +} + +func (d *DvProvider) constructUploadParameters(emb []float64, queryString string, answer string) (string, []byte, [][2]string, error) { + url := "/v1/collections/" + d.config.collectionID + "/docs" + + doc := document{ + Vector: emb, + Fields: map[string]string{ + "query": queryString, + "answer": answer, + }, + } + + requestBody, err := json.Marshal(insertRequest{Docs: []document{doc}}) + if err != nil { + return "", nil, nil, err + } + + header := [][2]string{ + {"Content-Type", "application/json"}, + {"dashvector-auth-token", d.config.apiKey}, + } + + return url, requestBody, header, err +} + +func (d *DvProvider) UploadEmbedding(queryString string, queryEmb []float64, ctx wrapper.HttpContext, log wrapper.Log, callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error { + url, body, headers, err := d.constructUploadParameters(queryEmb, queryString, "") + if err != nil { + return err + } + err = d.client.Post( + url, + headers, + body, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + log.Debugf("statusCode:%d, responseBody:%s", statusCode, string(responseBody)) + if statusCode != http.StatusOK { + err = fmt.Errorf("failed to upload embedding: %d", statusCode) + } + callback(ctx, log, err) + }, + d.config.timeout) + return err +} + +func (d *DvProvider) UploadAnswerAndEmbedding(queryString string, queryEmb []float64, queryAnswer string, ctx wrapper.HttpContext, log wrapper.Log, callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error { + url, body, headers, err := d.constructUploadParameters(queryEmb, queryString, queryAnswer) + if err != nil { + return err + } + err = d.client.Post( + url, + headers, + body, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + log.Debugf("statusCode:%d, responseBody:%s", statusCode, string(responseBody)) + if statusCode != http.StatusOK { + err = fmt.Errorf("failed to upload embedding: %d", statusCode) + } + callback(ctx, log, err) + }, + d.config.timeout) + return err +} diff --git a/plugins/wasm-go/extensions/ai-cache/vector/provider.go b/plugins/wasm-go/extensions/ai-cache/vector/provider.go new file mode 100644 index 0000000000..a04123a166 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/vector/provider.go @@ -0,0 +1,167 @@ +package vector + +import ( + "errors" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" +) + +const ( + PROVIDER_TYPE_DASH_VECTOR = "dashvector" + PROVIDER_TYPE_CHROMA = "chroma" +) + +type providerInitializer interface { + ValidateConfig(ProviderConfig) error + CreateProvider(ProviderConfig) (Provider, error) +} + +var ( + providerInitializers = map[string]providerInitializer{ + PROVIDER_TYPE_DASH_VECTOR: &dashVectorProviderInitializer{}, + // PROVIDER_TYPE_CHROMA: &chromaProviderInitializer{}, + } +) + +// QueryResult 定义通用的查询结果的结构体 +type QueryResult struct { + Text string // 相似的文本 + Embedding []float64 // 相似文本的向量 + Score float64 // 文本的向量相似度或距离等度量 + Answer string // 相似文本对应的LLM生成的回答 +} + +type Provider interface { + GetProviderType() string +} + +type EmbeddingQuerier interface { + QueryEmbedding( + emb []float64, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error +} + +type EmbeddingUploader interface { + UploadEmbedding( + queryString string, + queryEmb []float64, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error +} + +type AnswerAndEmbeddingUploader interface { + UploadAnswerAndEmbedding( + queryString string, + queryEmb []float64, + answer string, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error +} + +type StringQuerier interface { + QueryString( + queryString string, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error +} + +type SimilarityThresholdProvider interface { + GetSimilarityThreshold() float64 +} + +type ProviderConfig struct { + // @Title zh-CN 向量存储服务提供者类型 + // @Description zh-CN 向量存储服务提供者类型,例如 dashvector、chroma + typ string + // @Title zh-CN 向量存储服务名称 + // @Description zh-CN 向量存储服务名称 + serviceName string + // @Title zh-CN 向量存储服务域名 + // @Description zh-CN 向量存储服务域名 + serviceHost string + // @Title zh-CN 向量存储服务端口 + // @Description zh-CN 向量存储服务端口 + servicePort int64 + // @Title zh-CN 向量存储服务 API Key + // @Description zh-CN 向量存储服务 API Key + apiKey string + // @Title zh-CN 返回TopK结果 + // @Description zh-CN 返回TopK结果,默认为 1 + topK int + // @Title zh-CN 请求超时 + // @Description zh-CN 请求向量存储服务的超时时间,单位为毫秒。默认值是10000,即10秒 + timeout uint32 + // @Title zh-CN DashVector 向量存储服务 Collection ID + // @Description zh-CN DashVector 向量存储服务 Collection ID + collectionID string + // @Title zh-CN 相似度度量阈值 + // @Description zh-CN 默认相似度度量阈值,默认为 1000。 + Threshold float64 + // @Title zh-CN 相似度度量比较方式 + // @Description zh-CN 相似度度量比较方式,默认为小于。 + // 相似度度量方式有 Cosine, DotProduct, Euclidean 等,前两者值越大相似度越高,后者值越小相似度越高。 + // 所以需要允许自定义比较方式,对于 Cosine 和 DotProduct 选择 gt,对于 Euclidean 则选择 lt。 + // 默认为 lt,所有条件包括 lt (less than,小于)、lte (less than or equal to,小等于)、gt (greater than,大于)、gte (greater than or equal to,大等于) + ThresholdRelation string +} + +func (c *ProviderConfig) GetProviderType() string { + return c.typ +} + +func (c *ProviderConfig) FromJson(json gjson.Result) { + c.typ = json.Get("type").String() + // DashVector + c.serviceName = json.Get("serviceName").String() + c.serviceHost = json.Get("serviceHost").String() + c.servicePort = int64(json.Get("servicePort").Int()) + if c.servicePort == 0 { + c.servicePort = 443 + } + c.apiKey = json.Get("apiKey").String() + c.collectionID = json.Get("collectionID").String() + c.topK = int(json.Get("topK").Int()) + if c.topK == 0 { + c.topK = 1 + } + c.timeout = uint32(json.Get("timeout").Int()) + if c.timeout == 0 { + c.timeout = 10000 + } + c.Threshold = json.Get("threshold").Float() + if c.Threshold == 0 { + c.Threshold = 1000 + } + c.ThresholdRelation = json.Get("thresholdRelation").String() + if c.ThresholdRelation == "" { + c.ThresholdRelation = "lt" + } +} + +func (c *ProviderConfig) Validate() error { + if c.typ == "" { + return errors.New("vector database service is required") + } + initializer, has := providerInitializers[c.typ] + if !has { + return errors.New("unknown vector database service provider type: " + c.typ) + } + if err := initializer.ValidateConfig(*c); err != nil { + return err + } + return nil +} + +func CreateProvider(pc ProviderConfig) (Provider, error) { + initializer, has := providerInitializers[pc.typ] + if !has { + return nil, errors.New("unknown provider type: " + pc.typ) + } + return initializer.CreateProvider(pc) +} diff --git a/plugins/wasm-go/extensions/ai-proxy/Makefile b/plugins/wasm-go/extensions/ai-proxy/Makefile new file mode 100644 index 0000000000..e5c7fa8de9 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/Makefile @@ -0,0 +1,4 @@ +.DEFAULT: +build: + tinygo build -o ai-proxy.wasm -scheduler=none -target=wasi -gc=custom -tags='custommalloc nottinygc_finalizer proxy_wasm_version_0_2_100' ./main.go + mv ai-proxy.wasm ../../../../docker-compose-test/ \ No newline at end of file diff --git a/plugins/wasm-go/extensions/ai-proxy/go.mod b/plugins/wasm-go/extensions/ai-proxy/go.mod index 7fed801fab..a5457b90f8 100644 --- a/plugins/wasm-go/extensions/ai-proxy/go.mod +++ b/plugins/wasm-go/extensions/ai-proxy/go.mod @@ -10,7 +10,7 @@ require ( github.com/alibaba/higress/plugins/wasm-go v0.0.0 github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f github.com/stretchr/testify v1.8.4 - github.com/tidwall/gjson v1.14.3 + github.com/tidwall/gjson v1.17.3 ) require ( diff --git a/plugins/wasm-go/extensions/ai-proxy/go.sum b/plugins/wasm-go/extensions/ai-proxy/go.sum index e5b8b79175..b2d63b5f4b 100644 --- a/plugins/wasm-go/extensions/ai-proxy/go.sum +++ b/plugins/wasm-go/extensions/ai-proxy/go.sum @@ -13,8 +13,8 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw= -github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94= +github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index 9e0fafe179..7b19d03fc2 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -82,7 +82,8 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf providerConfig := pluginConfig.GetProviderConfig() if apiName == "" && !providerConfig.IsOriginal() { log.Debugf("[onHttpRequestHeader] unsupported path: %s", path.Path) - _ = util.SendResponse(404, "ai-proxy.unknown_api", util.MimeTypeTextPlain, "API not found: "+path.Path) + // _ = util.SendResponse(404, "ai-proxy.unknown_api", util.MimeTypeTextPlain, "API not found: "+path.Path) + log.Debugf("[onHttpRequestHeader] no send response") return types.ActionContinue } ctx.SetContext(ctxKeyApiName, apiName) diff --git a/plugins/wasm-go/extensions/ai-security-guard/main.go b/plugins/wasm-go/extensions/ai-security-guard/main.go index ca59f7f6a1..5b61589616 100644 --- a/plugins/wasm-go/extensions/ai-security-guard/main.go +++ b/plugins/wasm-go/extensions/ai-security-guard/main.go @@ -187,10 +187,6 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log log.Debugf("request checking is disabled") ctx.DontReadRequestBody() } - if !config.checkResponse { - log.Debugf("response checking is disabled") - ctx.DontReadResponseBody() - } return types.ActionContinue } @@ -199,7 +195,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body [] content := gjson.GetBytes(body, config.requestContentJsonPath).Raw model := gjson.GetBytes(body, "model").Raw ctx.SetContext("requestModel", model) - log.Debugf("Raw response content is: %s", content) + log.Debugf("Raw request content is: %s", content) if len(content) > 0 { timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z") randomID, _ := generateHexID(16) @@ -321,6 +317,11 @@ func reconvertHeaders(hs map[string][]string) [][2]string { } func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log wrapper.Log) types.Action { + if !config.checkResponse { + log.Debugf("response checking is disabled") + ctx.DontReadResponseBody() + return types.ActionContinue + } headers, err := proxywasm.GetHttpResponseHeaders() if err != nil { log.Warnf("failed to get response headers: %v", err) @@ -399,7 +400,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [ var jsonData []byte if config.protocolOriginal { jsonData = []byte(denyMessage) - } else if strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream") { + } else if isStreamingResponse { randomID := generateRandomID() jsonData = []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, denyMessage, randomID, model)) } else { diff --git a/plugins/wasm-go/extensions/request-block/Dockerfile b/plugins/wasm-go/extensions/request-block/Dockerfile new file mode 100644 index 0000000000..9b084e0596 --- /dev/null +++ b/plugins/wasm-go/extensions/request-block/Dockerfile @@ -0,0 +1,2 @@ +FROM scratch +COPY main.wasm plugin.wasm \ No newline at end of file diff --git a/plugins/wasm-go/extensions/request-block/Makefile b/plugins/wasm-go/extensions/request-block/Makefile new file mode 100644 index 0000000000..1210d6ec34 --- /dev/null +++ b/plugins/wasm-go/extensions/request-block/Makefile @@ -0,0 +1,4 @@ +.DEFAULT: +build: + tinygo build -o main.wasm -scheduler=none -target=wasi -gc=custom -tags='custommalloc nottinygc_finalizer' ./main.go + mv main.wasm ../../../../docker-compose-test/ \ No newline at end of file diff --git a/plugins/wasm-go/extensions/request-block/main.go b/plugins/wasm-go/extensions/request-block/main.go index 224d4b26d6..2a43b4df72 100644 --- a/plugins/wasm-go/extensions/request-block/main.go +++ b/plugins/wasm-go/extensions/request-block/main.go @@ -177,7 +177,9 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config RequestBlockConfig, lo } func onHttpRequestBody(ctx wrapper.HttpContext, config RequestBlockConfig, body []byte, log wrapper.Log) types.Action { + log.Infof("My request-block body: %s\n", string(body)) bodyStr := string(body) + if !config.caseSensitive { bodyStr = strings.ToLower(bodyStr) } diff --git a/plugins/wasm-go/extensions/traffic-tag/content.go b/plugins/wasm-go/extensions/traffic-tag/content.go index 07750f3314..4d62c02c99 100644 --- a/plugins/wasm-go/extensions/traffic-tag/content.go +++ b/plugins/wasm-go/extensions/traffic-tag/content.go @@ -26,10 +26,6 @@ import ( ) func onContentRequestHeaders(conditionGroups []ConditionGroup, log wrapper.Log) bool { - if len(conditionGroups) == 0 { - return false - } - for _, cg := range conditionGroups { if matchCondition(&cg, log) { addTagHeader(cg.HeaderName, cg.HeaderValue, log) diff --git a/plugins/wasm-go/extensions/traffic-tag/main.go b/plugins/wasm-go/extensions/traffic-tag/main.go index 6f69f23d64..0e58f95084 100644 --- a/plugins/wasm-go/extensions/traffic-tag/main.go +++ b/plugins/wasm-go/extensions/traffic-tag/main.go @@ -16,6 +16,7 @@ package main import ( "math/rand" + "strings" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" @@ -60,7 +61,6 @@ type TrafficTagConfig struct { WeightGroups []WeightGroup `json:"weightGroups,omitempty"` DefaultTagKey string `json:"defaultTagKey,omitempty"` DefaultTagVal string `json:"defaultTagVal,omitempty"` - randGen *rand.Rand } type ConditionGroup struct { @@ -93,8 +93,11 @@ func main() { } func parseConfig(json gjson.Result, config *TrafficTagConfig, log wrapper.Log) error { - if err := jsonValidate(json, log); err != nil { - return err + + jsonStr := strings.TrimSpace(json.Raw) + if jsonStr == "{}" || jsonStr == "" { + log.Error("plugin config is empty") + return nil } err := parseContentConfig(json, config, log) @@ -106,7 +109,17 @@ func parseConfig(json gjson.Result, config *TrafficTagConfig, log wrapper.Log) e } func onHttpRequestHeaders(ctx wrapper.HttpContext, config TrafficTagConfig, log wrapper.Log) types.Action { - if add := (onContentRequestHeaders(config.ConditionGroups, log) || onWeightRequestHeaders(config.WeightGroups, config.randGen, log)); !add { + + add := false + if len(config.ConditionGroups) != 0 { + add = add || onContentRequestHeaders(config.ConditionGroups, log) + } + + if !add && len(config.WeightGroups) != 0 { + add = add || onWeightRequestHeaders(config.WeightGroups, rand.Uint64(), log) + } + + if !add { setDefaultTag(config.DefaultTagKey, config.DefaultTagVal, log) } diff --git a/plugins/wasm-go/extensions/traffic-tag/parse.go b/plugins/wasm-go/extensions/traffic-tag/parse.go index 7878c3ba9b..c41890c1a7 100644 --- a/plugins/wasm-go/extensions/traffic-tag/parse.go +++ b/plugins/wasm-go/extensions/traffic-tag/parse.go @@ -17,10 +17,8 @@ package main import ( "errors" "fmt" - "math/rand" "strconv" "strings" - "time" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" regexp "github.com/wasilibs/go-re2" @@ -85,7 +83,6 @@ func parseWeightConfig(json gjson.Result, config *TrafficTagConfig, log wrapper. var parseError error var accumulatedWeight int64 config.WeightGroups = []WeightGroup{} - config.randGen = rand.New(rand.NewSource(time.Now().UnixNano())) // parse default tag key and value if k, v := json.Get(DefaultTagKey), json.Get(DefaultTagVal); k.Exists() && v.Exists() { diff --git a/plugins/wasm-go/extensions/traffic-tag/utils.go b/plugins/wasm-go/extensions/traffic-tag/utils.go index 870150dcc0..dd33efae84 100644 --- a/plugins/wasm-go/extensions/traffic-tag/utils.go +++ b/plugins/wasm-go/extensions/traffic-tag/utils.go @@ -15,14 +15,12 @@ package main import ( - "errors" "fmt" "net/url" "strings" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" - "github.com/tidwall/gjson" ) func setDefaultTag(k string, v string, log wrapper.Log) { @@ -75,23 +73,3 @@ func addTagHeader(key string, value string, log wrapper.Log) { } log.Infof("ADD HEADER: %s, value: %s", key, value) } - -func jsonValidate(json gjson.Result, log wrapper.Log) error { - if !json.Exists() { - log.Error("plugin config is missing in JSON") - return errors.New("plugin config is missing in JSON") - } - - jsonStr := strings.TrimSpace(json.Raw) - if jsonStr == "{}" || jsonStr == "" { - log.Error("plugin config is empty") - return errors.New("plugin config is empty") - } - - if !gjson.Valid(json.Raw) { - log.Error("plugin config is invalid JSON") - return errors.New("plugin config is invalid JSON") - } - - return nil -} diff --git a/plugins/wasm-go/extensions/traffic-tag/weight.go b/plugins/wasm-go/extensions/traffic-tag/weight.go index d6ea49f5d9..f825f1ec7c 100644 --- a/plugins/wasm-go/extensions/traffic-tag/weight.go +++ b/plugins/wasm-go/extensions/traffic-tag/weight.go @@ -15,16 +15,11 @@ package main import ( - "math/rand" - "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" ) -func onWeightRequestHeaders(weightGroups []WeightGroup, randGen *rand.Rand, log wrapper.Log) bool { - if len(weightGroups) == 0 { - return false - } - randomValue := randGen.Uint64() % TotalWeight +func onWeightRequestHeaders(weightGroups []WeightGroup, randomNum uint64, log wrapper.Log) bool { + randomValue := randomNum % TotalWeight log.Debugf("random value for weighted headers : %d", randomValue) // CDF for _, wg := range weightGroups { diff --git a/plugins/wasm-go/go.mod b/plugins/wasm-go/go.mod index 999721f3f6..2b21cb8360 100644 --- a/plugins/wasm-go/go.mod +++ b/plugins/wasm-go/go.mod @@ -4,15 +4,15 @@ go 1.19 require ( github.com/google/uuid v1.3.0 - github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 - github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f + github.com/higress-group/proxy-wasm-go-sdk v1.0.0 github.com/stretchr/testify v1.8.4 - github.com/tidwall/gjson v1.14.3 + github.com/tidwall/gjson v1.17.3 github.com/tidwall/resp v0.1.1 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect github.com/magefile/mage v1.14.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tidwall/match v1.1.1 // indirect diff --git a/plugins/wasm-go/go.sum b/plugins/wasm-go/go.sum index e726b100a5..679d810a90 100644 --- a/plugins/wasm-go/go.sum +++ b/plugins/wasm-go/go.sum @@ -4,16 +4,16 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew= -github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg= -github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= +github.com/higress-group/proxy-wasm-go-sdk v1.0.0 h1:BZRNf4R7jr9hwRivg/E29nkVaKEak5MWjBDhWjuHijU= +github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0= github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo= github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw= -github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94= +github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= diff --git a/plugins/wasm-go/pkg/wrapper/log_wrapper.go b/plugins/wasm-go/pkg/wrapper/log_wrapper.go index 6c6312f31a..65c0aa346c 100644 --- a/plugins/wasm-go/pkg/wrapper/log_wrapper.go +++ b/plugins/wasm-go/pkg/wrapper/log_wrapper.go @@ -36,7 +36,12 @@ type Log struct { } func (l Log) log(level LogLevel, msg string) { - msg = fmt.Sprintf("[%s] %s", l.pluginName, msg) + requestIDRaw, _ := proxywasm.GetProperty([]string{"x_request_id"}) + requestID := string(requestIDRaw) + if requestID == "" { + requestID = "nil" + } + msg = fmt.Sprintf("[%s] [%s] %s", l.pluginName, requestID, msg) switch level { case LogLevelTrace: proxywasm.LogTrace(msg) @@ -54,7 +59,12 @@ func (l Log) log(level LogLevel, msg string) { } func (l Log) logFormat(level LogLevel, format string, args ...interface{}) { - format = fmt.Sprintf("[%s] %s", l.pluginName, format) + requestIDRaw, _ := proxywasm.GetProperty([]string{"x_request_id"}) + requestID := string(requestIDRaw) + if requestID == "" { + requestID = "nil" + } + format = fmt.Sprintf("[%s] [%s] %s", l.pluginName, requestID, format) switch level { case LogLevelTrace: proxywasm.LogTracef(format, args...) diff --git a/plugins/wasm-go/pkg/wrapper/plugin_wrapper.go b/plugins/wasm-go/pkg/wrapper/plugin_wrapper.go index 19a2357aca..cce0456f8c 100644 --- a/plugins/wasm-go/pkg/wrapper/plugin_wrapper.go +++ b/plugins/wasm-go/pkg/wrapper/plugin_wrapper.go @@ -24,14 +24,8 @@ import ( "github.com/tidwall/gjson" "github.com/alibaba/higress/plugins/wasm-go/pkg/matcher" - _ "github.com/higress-group/nottinygc" ) -//export sched_yield -func sched_yield() int32 { - return 0 -} - type HttpContext interface { Scheme() string Host() string @@ -370,6 +364,8 @@ func (ctx *CommonHttpCtx[PluginConfig]) SetResponseBodyBufferLimit(size uint32) } func (ctx *CommonHttpCtx[PluginConfig]) OnHttpRequestHeaders(numHeaders int, endOfStream bool) types.Action { + requestID, _ := proxywasm.GetHttpRequestHeader("x-request-id") + _ = proxywasm.SetProperty([]string{"x_request_id"}, []byte(requestID)) config, err := ctx.plugin.GetMatchConfig() if err != nil { ctx.plugin.vm.log.Errorf("get match config failed, err:%v", err) diff --git a/plugins/wasm-go/pkg/wrapper/redis_wrapper.go b/plugins/wasm-go/pkg/wrapper/redis_wrapper.go index 10aa9020bd..c619c3e191 100644 --- a/plugins/wasm-go/pkg/wrapper/redis_wrapper.go +++ b/plugins/wasm-go/pkg/wrapper/redis_wrapper.go @@ -235,10 +235,11 @@ func (c RedisClusterClient[C]) Set(key string, value interface{}, callback Redis func (c RedisClusterClient[C]) SetEx(key string, value interface{}, ttl int, callback RedisResponseCallback) error { args := make([]interface{}, 0) - args = append(args, "setex") + args = append(args, "set") args = append(args, key) - args = append(args, ttl) args = append(args, value) + args = append(args, "ex") + args = append(args, ttl) return RedisCall(c.cluster, respString(args), callback) } diff --git a/plugins/wasm-rust/Cargo.lock b/plugins/wasm-rust/Cargo.lock index 9cf5140c77..63272f9a61 100644 --- a/plugins/wasm-rust/Cargo.lock +++ b/plugins/wasm-rust/Cargo.lock @@ -20,11 +20,23 @@ version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f" +[[package]] +name = "arc-swap" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" + +[[package]] +name = "autocfg" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" + [[package]] name = "bytes" -version = "1.7.2" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "428d9aa8fbc0670b7b8d6030a7fadd0f86151cae55e4dbbece15f3780a3dfaf3" +checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da" [[package]] name = "cfg-if" @@ -32,6 +44,16 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "combine" +version = "4.6.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd" +dependencies = [ + "bytes", + "memchr", +] + [[package]] name = "downcast-rs" version = "1.2.1" @@ -44,6 +66,15 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "form_urlencoded" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" +dependencies = [ + "percent-encoding", +] + [[package]] name = "getrandom" version = "0.2.15" @@ -74,6 +105,7 @@ dependencies = [ "lazy_static", "multimap", "proxy-wasm", + "redis", "serde", "serde_json", "uuid", @@ -90,6 +122,16 @@ dependencies = [ "itoa", ] +[[package]] +name = "idna" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" +dependencies = [ + "unicode-bidi", + "unicode-normalization", +] + [[package]] name = "itoa" version = "1.0.11" @@ -104,9 +146,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "libc" -version = "0.2.159" +version = "0.2.161" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "561d97a539a36e26a9a5fad1ea11a3039a67714694aaa379433e580854bc3dc5" +checksum = "8e9489c2807c139ffd9c1794f4af0ebe86a828db53ecdc7fea2111d0fed085d1" [[package]] name = "log" @@ -129,12 +171,46 @@ dependencies = [ "serde", ] +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + [[package]] name = "once_cell" version = "1.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" +[[package]] +name = "percent-encoding" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" + [[package]] name = "proc-macro2" version = "1.0.88" @@ -147,7 +223,7 @@ dependencies = [ [[package]] name = "proxy-wasm" version = "0.2.2" -source = "git+https://github.com/higress-group/proxy-wasm-rust-sdk?branch=main#6735737fad486c8a7cc324241f58df4a160e7887" +source = "git+https://github.com/higress-group/proxy-wasm-rust-sdk?branch=main#8c902102091698bec953471c850bdf9799bc344d" dependencies = [ "downcast-rs", "hashbrown", @@ -163,6 +239,21 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "redis" +version = "0.27.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81cccf17a692ce51b86564334614d72dcae1def0fd5ecebc9f02956da74352b5" +dependencies = [ + "arc-swap", + "combine", + "itoa", + "num-bigint", + "percent-encoding", + "ryu", + "url", +] + [[package]] name = "ryu" version = "1.0.18" @@ -171,18 +262,18 @@ checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" [[package]] name = "serde" -version = "1.0.210" +version = "1.0.211" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a" +checksum = "1ac55e59090389fb9f0dd9e0f3c09615afed1d19094284d0b200441f13550793" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.210" +version = "1.0.211" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f" +checksum = "54be4f245ce16bc58d57ef2716271d0d4519e0f6defa147f6e081005bcb278ff" dependencies = [ "proc-macro2", "quote", @@ -191,9 +282,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.128" +version = "1.0.132" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8" +checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03" dependencies = [ "itoa", "memchr", @@ -203,21 +294,62 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.79" +version = "2.0.82" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89132cd0bf050864e1d38dc3bbc07a0eb8e7530af26344d3d2bbbef83499f590" +checksum = "83540f837a8afc019423a8edb95b52a8effe46957ee402287f4292fae35be021" dependencies = [ "proc-macro2", "quote", "unicode-ident", ] +[[package]] +name = "tinyvec" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "445e881f4f6d382d5f27c034e25eb92edd7c784ceab92a0937db7f2e9471b938" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + +[[package]] +name = "unicode-bidi" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ab17db44d7388991a428b2ee655ce0c212e862eff1768a455c58f9aad6e7893" + [[package]] name = "unicode-ident" version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" +[[package]] +name = "unicode-normalization" +version = "0.1.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5033c97c4262335cded6d6fc3e5c18ab755e1a3dc96376350f3d8e9f009ad956" +dependencies = [ + "tinyvec", +] + +[[package]] +name = "url" +version = "2.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22784dbdf76fdde8af1aeda5622b546b422b6fc585325248a2bf9f5e41e94d6c" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", +] + [[package]] name = "uuid" version = "1.11.0" diff --git a/plugins/wasm-rust/Cargo.toml b/plugins/wasm-rust/Cargo.toml index e4fc274311..f58f2c8069 100644 --- a/plugins/wasm-rust/Cargo.toml +++ b/plugins/wasm-rust/Cargo.toml @@ -14,3 +14,4 @@ multimap = "0" http = "1" lazy_static = "1" downcast-rs="1" +redis={version = "0", default-features = false} diff --git a/plugins/wasm-rust/Makefile b/plugins/wasm-rust/Makefile index 25587e9698..5bc1809a38 100644 --- a/plugins/wasm-rust/Makefile +++ b/plugins/wasm-rust/Makefile @@ -6,12 +6,6 @@ IMAGE_TAG = $(if $(strip $(PLUGIN_VERSION)),${PLUGIN_VERSION},${BUILD_TIME}-${CO IMG ?= ${REGISTRY}${PLUGIN_NAME}:${IMAGE_TAG} .DEFAULT: -lint-base: - cargo fmt --all --check - cargo clippy --workspace --all-features --all-targets -lint: - cargo fmt --all --check --manifest-path extensions/${PLUGIN_NAME}/Cargo.toml - cargo clippy --workspace --all-features --all-targets --manifest-path extensions/${PLUGIN_NAME}/Cargo.toml build: DOCKER_BUILDKIT=1 docker build \ --build-arg PLUGIN_NAME=${PLUGIN_NAME} \ @@ -20,3 +14,10 @@ build: . @echo "" @echo "output wasm file: extensions/${PLUGIN_NAME}/plugin.wasm" + +lint-base: + cargo fmt --all --check + cargo clippy --workspace --all-features --all-targets +lint: + cargo fmt --all --check --manifest-path extensions/${PLUGIN_NAME}/Cargo.toml + cargo clippy --workspace --all-features --all-targets --manifest-path extensions/${PLUGIN_NAME}/Cargo.toml diff --git a/plugins/wasm-rust/example/sse-timing/Cargo.lock b/plugins/wasm-rust/example/sse-timing/Cargo.lock new file mode 100644 index 0000000000..9123a0e01c --- /dev/null +++ b/plugins/wasm-rust/example/sse-timing/Cargo.lock @@ -0,0 +1,270 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "ahash" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" +dependencies = [ + "cfg-if", + "once_cell", + "version_check", + "zerocopy", +] + +[[package]] +name = "allocator-api2" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f" + +[[package]] +name = "bytes" +version = "1.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "428d9aa8fbc0670b7b8d6030a7fadd0f86151cae55e4dbbece15f3780a3dfaf3" + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "downcast-rs" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75b325c5dbd37f80359721ad39aca5a29fb04c89279657cffdda8736d0c0b9d2" + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "getrandom" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +dependencies = [ + "ahash", + "allocator-api2", +] + +[[package]] +name = "higress-wasm-rust" +version = "0.1.0" +dependencies = [ + "downcast-rs", + "http", + "lazy_static", + "multimap", + "proxy-wasm", + "serde", + "serde_json", + "uuid", +] + +[[package]] +name = "http" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + +[[package]] +name = "itoa" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" + +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" + +[[package]] +name = "libc" +version = "0.2.161" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9489c2807c139ffd9c1794f4af0ebe86a828db53ecdc7fea2111d0fed085d1" + +[[package]] +name = "log" +version = "0.4.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" + +[[package]] +name = "memchr" +version = "2.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" + +[[package]] +name = "multimap" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "defc4c55412d89136f966bbb339008b474350e5e6e78d2714439c386b3137a03" +dependencies = [ + "serde", +] + +[[package]] +name = "once_cell" +version = "1.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" + +[[package]] +name = "proc-macro2" +version = "1.0.88" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c3a7fc5db1e57d5a779a352c8cdb57b29aa4c40cc69c3a68a7fedc815fbf2f9" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "proxy-wasm" +version = "0.2.2" +source = "git+https://github.com/higress-group/proxy-wasm-rust-sdk?branch=main#6735737fad486c8a7cc324241f58df4a160e7887" +dependencies = [ + "downcast-rs", + "hashbrown", + "log", +] + +[[package]] +name = "quote" +version = "1.0.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "ryu" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" + +[[package]] +name = "serde" +version = "1.0.210" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.210" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.132" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03" +dependencies = [ + "itoa", + "memchr", + "ryu", + "serde", +] + +[[package]] +name = "sse-timing" +version = "0.1.0" +dependencies = [ + "higress-wasm-rust", + "proxy-wasm", + "serde", + "serde_json", +] + +[[package]] +name = "syn" +version = "2.0.82" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83540f837a8afc019423a8edb95b52a8effe46957ee402287f4292fae35be021" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "unicode-ident" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" + +[[package]] +name = "uuid" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a" +dependencies = [ + "getrandom", +] + +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "zerocopy" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] diff --git a/plugins/wasm-rust/example/sse-timing/Cargo.toml b/plugins/wasm-rust/example/sse-timing/Cargo.toml new file mode 100644 index 0000000000..44c4c5dfd7 --- /dev/null +++ b/plugins/wasm-rust/example/sse-timing/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "sse-timing" +version = "0.1.0" +edition = "2021" +publish = false + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[lib] +crate-type = ["cdylib"] + +[dependencies] +higress-wasm-rust = { path = "../../", version = "0.1.0" } +proxy-wasm = { git="https://github.com/higress-group/proxy-wasm-rust-sdk", branch="main", version="0.2.2" } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" diff --git a/plugins/wasm-rust/example/sse-timing/Makefile b/plugins/wasm-rust/example/sse-timing/Makefile new file mode 100644 index 0000000000..22ec19d791 --- /dev/null +++ b/plugins/wasm-rust/example/sse-timing/Makefile @@ -0,0 +1,10 @@ +BUILD_OPTS="--release" + +.DEFAULT: +build: + cargo build --target wasm32-wasi ${BUILD_OPTS} + find target -name "*.wasm" -d 3 -exec cp "{}" plugin.wasm \; + +clean: + cargo clean + rm -f plugin.wasm diff --git a/plugins/wasm-rust/example/sse-timing/README.md b/plugins/wasm-rust/example/sse-timing/README.md new file mode 100644 index 0000000000..b7d8d4bfdc --- /dev/null +++ b/plugins/wasm-rust/example/sse-timing/README.md @@ -0,0 +1,26 @@ +## Proxy-Wasm plugin example: SSE Timing + +Proxy-Wasm plugin that traces Server-Side Event(SSE) duration from request start. + +### Building + +```sh +$ make +``` + +### Using in Envoy + +This example can be run with [`docker compose`](https://docs.docker.com/compose/install/) +and has a matching Envoy configuration. + +```sh +$ docker compose up +``` + +#### Access granted. + +Send HTTP request to `localhost:10000/`: + +```sh +$ curl localhost:10000/ +``` diff --git a/plugins/wasm-rust/example/sse-timing/docker-compose.yaml b/plugins/wasm-rust/example/sse-timing/docker-compose.yaml new file mode 100644 index 0000000000..78549a2ac7 --- /dev/null +++ b/plugins/wasm-rust/example/sse-timing/docker-compose.yaml @@ -0,0 +1,35 @@ +# Copyright (c) 2023 Alibaba Group Holding Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +services: + envoy: + image: higress-registry.cn-hangzhou.cr.aliyuncs.com/higress/all-in-one:latest + entrypoint: /usr/local/bin/envoy + command: -c /etc/envoy/envoy.yaml --component-log-level wasm:debug + depends_on: + - sse-server + hostname: envoy + ports: + - "10000:10000" + volumes: + - ./envoy.yaml:/etc/envoy/envoy.yaml + - ./target/wasm32-wasi/release:/etc/envoy/proxy-wasm-plugins + networks: + - envoymesh + sse-server: + build: sse-server + networks: + - envoymesh +networks: + envoymesh: {} \ No newline at end of file diff --git a/plugins/wasm-rust/example/sse-timing/envoy.yaml b/plugins/wasm-rust/example/sse-timing/envoy.yaml new file mode 100644 index 0000000000..6281aad0d9 --- /dev/null +++ b/plugins/wasm-rust/example/sse-timing/envoy.yaml @@ -0,0 +1,76 @@ +# Copyright (c) 2023 Alibaba Group Holding Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +static_resources: + listeners: + - name: listener_0 + address: + socket_address: + protocol: TCP + address: 0.0.0.0 + port_value: 10000 + filter_chains: + - filters: + - name: envoy.filters.network.http_connection_manager + typed_config: + "@type": type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager + stat_prefix: ingress_http + route_config: + name: local_route + virtual_hosts: + - name: local_service + domains: ["*"] + routes: + - match: + prefix: "/" + route: + cluster: sse-server + http_filters: + - name: envoy.filters.http.wasm + typed_config: + "@type": type.googleapis.com/udpa.type.v1.TypedStruct + type_url: type.googleapis.com/envoy.extensions.filters.http.wasm.v3.Wasm + value: + config: + name: "http_body" + configuration: + "@type": type.googleapis.com/google.protobuf.StringValue + value: |- + { + "name": "sse_timing", + "_rules_": [] + } + vm_config: + runtime: "envoy.wasm.runtime.v8" + code: + local: + filename: "/etc/envoy/proxy-wasm-plugins/sse_timing.wasm" + - name: envoy.filters.http.router + typed_config: + "@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router + clusters: + - name: sse-server + connect_timeout: 30s + type: LOGICAL_DNS +# dns_lookup_family: V4_ONLY + lb_policy: ROUND_ROBIN + load_assignment: + cluster_name: sse-server + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: sse-server + port_value: 8080 \ No newline at end of file diff --git a/plugins/wasm-rust/example/sse-timing/src/lib.rs b/plugins/wasm-rust/example/sse-timing/src/lib.rs new file mode 100644 index 0000000000..fbff5daf9e --- /dev/null +++ b/plugins/wasm-rust/example/sse-timing/src/lib.rs @@ -0,0 +1,198 @@ +// Copyright (c) 2023 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use higress_wasm_rust::event_stream::EventStream; +use higress_wasm_rust::log::Log; +use higress_wasm_rust::rule_matcher::{on_configure, RuleMatcher, SharedRuleMatcher}; +use proxy_wasm::traits::{Context, HttpContext, RootContext}; +use proxy_wasm::types::{ContextType, DataAction, HeaderAction, LogLevel}; +use serde::Deserialize; +use std::cell::RefCell; +use std::ops::DerefMut; +use std::rc::Rc; +use std::str::from_utf8; +use std::time::{Duration, SystemTime}; + +proxy_wasm::main! {{ + proxy_wasm::set_log_level(LogLevel::Trace); + proxy_wasm::set_root_context(|_|Box::new(SseTimingRoot::new())); +}} + +struct SseTimingRoot { + log: Rc, + rule_matcher: SharedRuleMatcher, +} + +struct SseTiming { + log: Rc, + rule_matcher: SharedRuleMatcher, + vendor: String, + is_event_stream: bool, + event_stream: EventStream, + start_time: SystemTime, +} + +#[derive(Default, Clone, Debug, Deserialize)] +struct SseTimingConfig { + vendor: Option, +} + +impl SseTimingRoot { + fn new() -> Self { + SseTimingRoot { + log: Rc::new(Log::new("sse_timing".to_string())), + rule_matcher: Rc::new(RefCell::new(RuleMatcher::default())), + } + } +} + +impl Context for SseTimingRoot {} + +impl RootContext for SseTimingRoot { + fn on_configure(&mut self, plugin_configuration_size: usize) -> bool { + on_configure( + self, + plugin_configuration_size, + self.rule_matcher.borrow_mut().deref_mut(), + &self.log, + ) + } + + fn create_http_context(&self, _context_id: u32) -> Option> { + Some(Box::new(SseTiming { + log: self.log.clone(), + rule_matcher: self.rule_matcher.clone(), + vendor: "higress".into(), + is_event_stream: false, + event_stream: EventStream::new(), + start_time: self.get_current_time(), + })) + } + + fn get_type(&self) -> Option { + Some(ContextType::HttpContext) + } +} + +impl Context for SseTiming {} + +impl HttpContext for SseTiming { + fn on_http_request_headers( + &mut self, + _num_headers: usize, + _end_of_stream: bool, + ) -> HeaderAction { + self.start_time = self.get_current_time(); + + let binding = self.rule_matcher.borrow(); + let config = match binding.get_match_config() { + None => { + return HeaderAction::Continue; + } + Some(config) => config.1, + }; + match config.vendor.clone() { + None => {} + Some(vendor) => self.vendor = vendor, + } + HeaderAction::Continue + } + + fn on_http_response_headers( + &mut self, + _num_headers: usize, + _end_of_stream: bool, + ) -> HeaderAction { + match self.get_http_response_header("Content-Type") { + None => self + .log + .warn("upstream response is not set Content-Type, skipped"), + Some(content_type) => { + if content_type.starts_with("text/event-stream") { + self.is_event_stream = true + } else { + self.log.warn(format!("upstream response Content-Type is not text/event-stream, but {}, skipped", content_type).as_str()) + } + } + } + HeaderAction::Continue + } + + fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> DataAction { + if !self.is_event_stream { + return DataAction::Continue; + } + + let body = self + .get_http_response_body(0, body_size) + .unwrap_or_default(); + self.event_stream.update(body); + self.process_event_stream(end_of_stream) + } +} + +impl SseTiming { + fn process_event_stream(&mut self, end_of_stream: bool) -> DataAction { + let mut modified_events = Vec::new(); + + loop { + match self.event_stream.next() { + None => break, + Some(raw_event) => { + if !raw_event.is_empty() { + // according to spec, event-stream must be utf-8 encoding + let event = from_utf8(raw_event.as_slice()).unwrap(); + let processed_event = self.process_event(event.to_string()); + modified_events.push(processed_event); + } + } + } + } + + if end_of_stream { + match self.event_stream.flush() { + None => {} + Some(raw_event) => { + if !raw_event.is_empty() { + // according to spec, event-stream must be utf-8 encoding + let event = from_utf8(raw_event.as_slice()).unwrap(); + let modified_event = self.process_event(event.into()); + modified_events.push(modified_event); + } + } + } + } + + if !modified_events.is_empty() { + let modified_body = modified_events.concat(); + self.set_http_response_body(0, modified_body.len(), modified_body.as_bytes()); + DataAction::Continue + } else { + DataAction::StopIterationNoBuffer + } + } + + fn process_event(&self, event: String) -> String { + let duration = self + .get_current_time() + .duration_since(self.start_time) + .unwrap_or(Duration::ZERO); + format!( + ": server-timing: {};dur={}\n{}\n\n", + self.vendor, + duration.as_millis(), + event + ) + } +} diff --git a/plugins/wasm-rust/example/sse-timing/sse-server/Dockerfile b/plugins/wasm-rust/example/sse-timing/sse-server/Dockerfile new file mode 100644 index 0000000000..7d251e48f9 --- /dev/null +++ b/plugins/wasm-rust/example/sse-timing/sse-server/Dockerfile @@ -0,0 +1,5 @@ +FROM golang:latest AS builder +WORKDIR /workspace +COPY . . +RUN GOOS=linux GOARCH=amd64 go build -o main . +CMD ./main \ No newline at end of file diff --git a/plugins/wasm-rust/example/sse-timing/sse-server/go.mod b/plugins/wasm-rust/example/sse-timing/sse-server/go.mod new file mode 100644 index 0000000000..63e8515e6c --- /dev/null +++ b/plugins/wasm-rust/example/sse-timing/sse-server/go.mod @@ -0,0 +1,3 @@ +module sse + +go 1.22 diff --git a/plugins/wasm-rust/example/sse-timing/sse-server/main.go b/plugins/wasm-rust/example/sse-timing/sse-server/main.go new file mode 100644 index 0000000000..bb4a3089fd --- /dev/null +++ b/plugins/wasm-rust/example/sse-timing/sse-server/main.go @@ -0,0 +1,42 @@ +package main + +import ( + "log" + "net/http" + "time" +) + +var events = []string{ + ": this is a test stream\n\n", + + "data: some text\n", + "data: another message\n", + "data: with two lines\n\n", + + "event: userconnect\n", + "data: {\"username\": \"bobby\", \"time\": \"02:33:48\"}\n\n", + + "event: usermessage\n", + "data: {\"username\": \"bobby\", \"time\": \"02:34:11\", \"text\": \"Hi everyone.\"}\n\n", + + "event: userdisconnect\n", + "data: {\"username\": \"bobby\", \"time\": \"02:34:23\"}\n\n", + + "event: usermessage\n", + "data: {\"username\": \"sean\", \"time\": \"02:34:36\", \"text\": \"Bye, bobby.\"}\n\n", +} + +func main() { + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + log.Println("receive request") + w.Header().Set("Content-Type", "text/event-stream") + for _, e := range events { + _, _ = w.Write([]byte(e)) + time.Sleep(1 * time.Second) + w.(http.Flusher).Flush() + } + }) + if err := http.ListenAndServe("0.0.0.0:8080", nil); err != nil { + panic(err) + } +} diff --git a/plugins/wasm-rust/extensions/ai-data-masking/Cargo.lock b/plugins/wasm-rust/extensions/ai-data-masking/Cargo.lock index 914d5b9689..e635ae454f 100644 --- a/plugins/wasm-rust/extensions/ai-data-masking/Cargo.lock +++ b/plugins/wasm-rust/extensions/ai-data-masking/Cargo.lock @@ -46,6 +46,18 @@ version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f" +[[package]] +name = "arc-swap" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" + +[[package]] +name = "autocfg" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" + [[package]] name = "bit-set" version = "0.5.3" @@ -84,15 +96,15 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.7.2" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "428d9aa8fbc0670b7b8d6030a7fadd0f86151cae55e4dbbece15f3780a3dfaf3" +checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da" [[package]] name = "cc" -version = "1.1.30" +version = "1.1.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b16803a61b81d9eabb7eae2588776c4c1e584b738ede45fdbb4c972cec1e9945" +checksum = "c2e7962b54006dcfcc61cb72735f4d89bb97061dd6a7ed882ec6b8ee53714c6f" dependencies = [ "shlex", ] @@ -112,6 +124,16 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "combine" +version = "4.6.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd" +dependencies = [ + "bytes", + "memchr", +] + [[package]] name = "cpufeatures" version = "0.2.14" @@ -230,6 +252,15 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "form_urlencoded" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" +dependencies = [ + "percent-encoding", +] + [[package]] name = "fxhash" version = "0.2.1" @@ -295,6 +326,7 @@ dependencies = [ "lazy_static", "multimap", "proxy-wasm", + "redis", "serde", "serde_json", "uuid", @@ -317,6 +349,16 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" +[[package]] +name = "idna" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" +dependencies = [ + "unicode-bidi", + "unicode-normalization", +] + [[package]] name = "itoa" version = "1.0.11" @@ -359,9 +401,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "libc" -version = "0.2.159" +version = "0.2.161" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "561d97a539a36e26a9a5fad1ea11a3039a67714694aaa379433e580854bc3dc5" +checksum = "8e9489c2807c139ffd9c1794f4af0ebe86a828db53ecdc7fea2111d0fed085d1" [[package]] name = "log" @@ -390,6 +432,34 @@ dependencies = [ "serde", ] +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + [[package]] name = "once_cell" version = "1.20.2" @@ -418,6 +488,12 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "percent-encoding" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" + [[package]] name = "pest" version = "2.7.14" @@ -519,7 +595,7 @@ dependencies = [ [[package]] name = "proxy-wasm" version = "0.2.2" -source = "git+https://github.com/higress-group/proxy-wasm-rust-sdk?branch=main#6735737fad486c8a7cc324241f58df4a160e7887" +source = "git+https://github.com/higress-group/proxy-wasm-rust-sdk?branch=main#8c902102091698bec953471c850bdf9799bc344d" dependencies = [ "downcast-rs", "hashbrown", @@ -550,6 +626,21 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +[[package]] +name = "redis" +version = "0.27.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81cccf17a692ce51b86564334614d72dcae1def0fd5ecebc9f02956da74352b5" +dependencies = [ + "arc-swap", + "combine", + "itoa", + "num-bigint", + "percent-encoding", + "ryu", + "url", +] + [[package]] name = "regex" version = "1.11.0" @@ -630,18 +721,18 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.210" +version = "1.0.211" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a" +checksum = "1ac55e59090389fb9f0dd9e0f3c09615afed1d19094284d0b200441f13550793" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.210" +version = "1.0.211" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f" +checksum = "54be4f245ce16bc58d57ef2716271d0d4519e0f6defa147f6e081005bcb278ff" dependencies = [ "proc-macro2", "quote", @@ -650,9 +741,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.128" +version = "1.0.132" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8" +checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03" dependencies = [ "itoa", "memchr", @@ -697,9 +788,9 @@ checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" [[package]] name = "syn" -version = "2.0.79" +version = "2.0.82" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89132cd0bf050864e1d38dc3bbc07a0eb8e7530af26344d3d2bbbef83499f590" +checksum = "83540f837a8afc019423a8edb95b52a8effe46957ee402287f4292fae35be021" dependencies = [ "proc-macro2", "quote", @@ -726,6 +817,21 @@ dependencies = [ "syn", ] +[[package]] +name = "tinyvec" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "445e881f4f6d382d5f27c034e25eb92edd7c784ceab92a0937db7f2e9471b938" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + [[package]] name = "typenum" version = "1.17.0" @@ -738,12 +844,38 @@ version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2896d95c02a80c6d6a5d6e953d479f5ddf2dfdb6a244441010e373ac0fb88971" +[[package]] +name = "unicode-bidi" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ab17db44d7388991a428b2ee655ce0c212e862eff1768a455c58f9aad6e7893" + [[package]] name = "unicode-ident" version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" +[[package]] +name = "unicode-normalization" +version = "0.1.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5033c97c4262335cded6d6fc3e5c18ab755e1a3dc96376350f3d8e9f009ad956" +dependencies = [ + "tinyvec", +] + +[[package]] +name = "url" +version = "2.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22784dbdf76fdde8af1aeda5622b546b422b6fc585325248a2bf9f5e41e94d6c" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", +] + [[package]] name = "uuid" version = "1.11.0" diff --git a/plugins/wasm-rust/extensions/ai-data-masking/Cargo.toml b/plugins/wasm-rust/extensions/ai-data-masking/Cargo.toml index b8bd6df5b2..aa7372fb20 100644 --- a/plugins/wasm-rust/extensions/ai-data-masking/Cargo.toml +++ b/plugins/wasm-rust/extensions/ai-data-masking/Cargo.toml @@ -18,5 +18,5 @@ md5 = "0" grok = "2" lazy_static = "1" jieba-rs = "0" -rust-embed="8.5.0" +rust-embed = "8.5.0" jsonpath-rust = "0" diff --git a/plugins/wasm-rust/extensions/ai-data-masking/README.md b/plugins/wasm-rust/extensions/ai-data-masking/README.md index b892b3dab8..c1e7f73369 100644 --- a/plugins/wasm-rust/extensions/ai-data-masking/README.md +++ b/plugins/wasm-rust/extensions/ai-data-masking/README.md @@ -148,4 +148,4 @@ curl -X POST \ - 流模式中,如果敏感词语被多个chunk拆分,可能会有敏感词的一部分返回给用户的情况 - grok 内置规则列表 https://help.aliyun.com/zh/sls/user-guide/grok-patterns - 内置敏感词库数据来源 https://github.com/houbb/sensitive-word/tree/master/src/main/resources - + - 由于敏感词列表是在文本分词后进行匹配的,所以请将 `deny_words` 设置为单个单词,英文多单词情况如 `hello word` 可能无法匹配 diff --git a/plugins/wasm-rust/extensions/ai-data-masking/README_EN.md b/plugins/wasm-rust/extensions/ai-data-masking/README_EN.md index 45e1622875..d7cb72378e 100644 --- a/plugins/wasm-rust/extensions/ai-data-masking/README_EN.md +++ b/plugins/wasm-rust/extensions/ai-data-masking/README_EN.md @@ -129,3 +129,4 @@ Please note that you need to replace `"key":"value"` with the actual data conten - In streaming mode, if sensitive words are split across multiple chunks, there may be cases where part of the sensitive word is returned to the user - Grok built-in rule list: https://help.aliyun.com/zh/sls/user-guide/grok-patterns - Built-in sensitive word library data source: https://github.com/houbb/sensitive-word/tree/master/src/main/resources + - Since the sensitive word list is matched after tokenizing the text, please set `deny_words` to single words. In the case of multiple words in English, such as `hello world`, the match may not be successful. diff --git a/plugins/wasm-rust/extensions/ai-data-masking/src/deny_word.rs b/plugins/wasm-rust/extensions/ai-data-masking/src/deny_word.rs new file mode 100644 index 0000000000..99ecab4d5b --- /dev/null +++ b/plugins/wasm-rust/extensions/ai-data-masking/src/deny_word.rs @@ -0,0 +1,54 @@ +use std::collections::HashSet; + +use jieba_rs::Jieba; + +use crate::Asset; + +#[derive(Default, Debug, Clone)] +pub(crate) struct DenyWord { + jieba: Jieba, + words: HashSet, +} + +impl DenyWord { + pub(crate) fn from_iter>>(words: T) -> Self { + let mut deny_word = DenyWord::default(); + + for word in words { + let word_s = word.into(); + let w = word_s.trim(); + if w.is_empty() { + continue; + } + deny_word.jieba.add_word(w, None, None); + deny_word.words.insert(w.to_string()); + } + + deny_word + } + + pub(crate) fn empty() -> Self { + DenyWord { + jieba: Jieba::empty(), + words: HashSet::new(), + } + } + + pub(crate) fn system() -> Self { + if let Some(file) = Asset::get("sensitive_word_dict.txt") { + if let Ok(data) = std::str::from_utf8(file.data.as_ref()) { + return DenyWord::from_iter(data.split('\n')); + } + } + Self::empty() + } + + pub(crate) fn check(&self, message: &str) -> Option { + for word in self.jieba.cut(message, true) { + if self.words.contains(word) { + return Some(word.to_string()); + } + } + None + } +} diff --git a/plugins/wasm-rust/extensions/ai-data-masking/src/lib.rs b/plugins/wasm-rust/extensions/ai-data-masking/src/lib.rs index 89c0da5219..ca2db3da42 100644 --- a/plugins/wasm-rust/extensions/ai-data-masking/src/lib.rs +++ b/plugins/wasm-rust/extensions/ai-data-masking/src/lib.rs @@ -12,13 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. +mod deny_word; + +use crate::deny_word::DenyWord; use fancy_regex::Regex; use grok::patterns; use higress_wasm_rust::log::Log; use higress_wasm_rust::plugin_wrapper::{HttpContextWrapper, RootContextWrapper}; use higress_wasm_rust::request_wrapper::has_request_body; use higress_wasm_rust::rule_matcher::{on_configure, RuleMatcher, SharedRuleMatcher}; -use jieba_rs::Jieba; use jsonpath_rust::{JsonPath, JsonPathValue}; use lazy_static::lazy_static; use proxy_wasm::traits::{Context, HttpContext, RootContext}; @@ -29,7 +31,7 @@ use serde::Deserialize; use serde::Deserializer; use serde_json::{json, Value}; use std::cell::RefCell; -use std::collections::{BTreeMap, HashMap, HashSet, VecDeque}; +use std::collections::{BTreeMap, HashMap, VecDeque}; use std::ops::DerefMut; use std::rc::Rc; use std::str::FromStr; @@ -47,11 +49,6 @@ const GROK_PATTERN: &str = r"%\{(?(?[A-z0-9]+)(?::(?[A-z0- #[folder = "res/"] struct Asset; -#[derive(Default, Debug, Clone)] -struct DenyWord { - jieba: Jieba, - words: HashSet, -} struct System { deny_word: DenyWord, grok_regex: Regex, @@ -97,13 +94,13 @@ where D: Deserializer<'de>, { let value: Value = Deserialize::deserialize(deserializer)?; - if let Some(_type) = value.as_str() { - if _type == "replace" { + if let Some(t) = value.as_str() { + if t == "replace" { Ok(Type::Replace) - } else if _type == "hash" { + } else if t == "hash" { Ok(Type::Hash) } else { - Err(Error::custom(format!("regexp error value {}", _type))) + Err(Error::custom(format!("regexp error value {}", t))) } } else { Err(Error::custom("type error not string".to_string())) @@ -227,52 +224,12 @@ static SYSTEM_PATTERNS: &[(&str, &str)] = &[ ("IDCARD", r#"\d{17}[0-9xX]|\d{15}"#), ]; -impl DenyWord { - fn empty() -> Self { - DenyWord { - jieba: Jieba::empty(), - words: HashSet::new(), - } - } - fn from_iter>>(words: T) -> Self { - let mut deny_word = DenyWord::empty(); - - for word in words { - let _w = word.into(); - let w = _w.trim(); - if w.is_empty() { - continue; - } - deny_word.jieba.add_word(w, None, None); - deny_word.words.insert(w.to_string()); - } - - deny_word - } - fn default() -> Self { - if let Some(file) = Asset::get("sensitive_word_dict.txt") { - if let Ok(data) = std::str::from_utf8(file.data.as_ref()) { - return DenyWord::from_iter(data.split('\n')); - } - } - DenyWord::empty() - } - - fn check(&self, message: &str) -> Option { - for word in self.jieba.cut(message, true) { - if self.words.contains(word) { - return Some(word.to_string()); - } - } - None - } -} impl System { fn new() -> Self { let grok_regex = Regex::new(GROK_PATTERN).unwrap(); let grok_patterns = BTreeMap::new(); let mut system = System { - deny_word: DenyWord::default(), + deny_word: DenyWord::system(), grok_regex, grok_patterns, }; @@ -314,12 +271,12 @@ impl System { fn grok_to_pattern(&self, pattern: &str) -> (String, bool) { let mut ok = true; let mut ret = pattern.to_string(); - for _c in self.grok_regex.captures_iter(pattern) { - if _c.is_err() { + for capture in self.grok_regex.captures_iter(pattern) { + if capture.is_err() { ok = false; continue; } - let c = _c.unwrap(); + let c = capture.unwrap(); if let (Some(full), Some(name)) = (c.get(0), c.name("pattern")) { if let Some(p) = self.grok_patterns.get(name.as_str()) { if let Some(alias) = c.name("alias") { @@ -335,6 +292,7 @@ impl System { (ret, ok) } } + impl AiDataMaskingRoot { fn new() -> Self { AiDataMaskingRoot { @@ -347,16 +305,16 @@ impl AiDataMaskingRoot { impl Context for AiDataMaskingRoot {} impl RootContext for AiDataMaskingRoot { - fn on_configure(&mut self, _plugin_configuration_size: usize) -> bool { + fn on_configure(&mut self, plugin_configuration_size: usize) -> bool { on_configure( self, - _plugin_configuration_size, + plugin_configuration_size, self.rule_matcher.borrow_mut().deref_mut(), &self.log, ) } - fn create_http_context(&self, _context_id: u32) -> Option> { - self.create_http_context_use_wrapper(_context_id) + fn create_http_context(&self, context_id: u32) -> Option> { + self.create_http_context_use_wrapper(context_id) } fn get_type(&self) -> Option { Some(ContextType::HttpContext) @@ -382,6 +340,7 @@ impl RootContextWrapper for AiDataMaskingRoot { })) } } + impl AiDataMasking { fn check_message(&self, message: &str) -> bool { if let Some(config) = &self.config { @@ -491,11 +450,11 @@ impl AiDataMasking { if rule.type_ == Type::Replace && !rule.restore { msg = rule.regex.replace_all(&msg, &rule.value).to_string(); } else { - for _m in rule.regex.find_iter(&msg) { - if _m.is_err() { + for mc in rule.regex.find_iter(&msg) { + if mc.is_err() { continue; } - let m = _m.unwrap(); + let m = mc.unwrap(); let from_word = m.as_str(); let to_word = match rule.type_ { @@ -532,6 +491,7 @@ impl AiDataMasking { } impl Context for AiDataMasking {} + impl HttpContext for AiDataMasking { fn on_http_request_headers( &mut self, @@ -607,6 +567,7 @@ impl HttpContext for AiDataMasking { DataAction::Continue } } + impl HttpContextWrapper for AiDataMasking { fn log(&self) -> &Log { &self.log diff --git a/plugins/wasm-rust/extensions/demo-wasm/Cargo.lock b/plugins/wasm-rust/extensions/demo-wasm/Cargo.lock index c1e197caf6..1c6c135617 100644 --- a/plugins/wasm-rust/extensions/demo-wasm/Cargo.lock +++ b/plugins/wasm-rust/extensions/demo-wasm/Cargo.lock @@ -20,11 +20,23 @@ version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f" +[[package]] +name = "arc-swap" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" + +[[package]] +name = "autocfg" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" + [[package]] name = "bytes" -version = "1.7.2" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "428d9aa8fbc0670b7b8d6030a7fadd0f86151cae55e4dbbece15f3780a3dfaf3" +checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da" [[package]] name = "cfg-if" @@ -32,6 +44,16 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "combine" +version = "4.6.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd" +dependencies = [ + "bytes", + "memchr", +] + [[package]] name = "demo-wasm" version = "0.1.0" @@ -40,7 +62,9 @@ dependencies = [ "http", "multimap", "proxy-wasm", + "redis", "serde", + "serde_json", ] [[package]] @@ -55,6 +79,15 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "form_urlencoded" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" +dependencies = [ + "percent-encoding", +] + [[package]] name = "getrandom" version = "0.2.15" @@ -85,6 +118,7 @@ dependencies = [ "lazy_static", "multimap", "proxy-wasm", + "redis", "serde", "serde_json", "uuid", @@ -101,6 +135,16 @@ dependencies = [ "itoa", ] +[[package]] +name = "idna" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" +dependencies = [ + "unicode-bidi", + "unicode-normalization", +] + [[package]] name = "itoa" version = "1.0.11" @@ -115,9 +159,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "libc" -version = "0.2.159" +version = "0.2.161" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "561d97a539a36e26a9a5fad1ea11a3039a67714694aaa379433e580854bc3dc5" +checksum = "8e9489c2807c139ffd9c1794f4af0ebe86a828db53ecdc7fea2111d0fed085d1" [[package]] name = "log" @@ -140,12 +184,46 @@ dependencies = [ "serde", ] +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + [[package]] name = "once_cell" version = "1.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" +[[package]] +name = "percent-encoding" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" + [[package]] name = "proc-macro2" version = "1.0.88" @@ -158,7 +236,7 @@ dependencies = [ [[package]] name = "proxy-wasm" version = "0.2.2" -source = "git+https://github.com/higress-group/proxy-wasm-rust-sdk?branch=main#6735737fad486c8a7cc324241f58df4a160e7887" +source = "git+https://github.com/higress-group/proxy-wasm-rust-sdk?branch=main#8c902102091698bec953471c850bdf9799bc344d" dependencies = [ "downcast-rs", "hashbrown", @@ -174,6 +252,21 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "redis" +version = "0.27.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81cccf17a692ce51b86564334614d72dcae1def0fd5ecebc9f02956da74352b5" +dependencies = [ + "arc-swap", + "combine", + "itoa", + "num-bigint", + "percent-encoding", + "ryu", + "url", +] + [[package]] name = "ryu" version = "1.0.18" @@ -182,18 +275,18 @@ checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" [[package]] name = "serde" -version = "1.0.210" +version = "1.0.211" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a" +checksum = "1ac55e59090389fb9f0dd9e0f3c09615afed1d19094284d0b200441f13550793" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.210" +version = "1.0.211" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f" +checksum = "54be4f245ce16bc58d57ef2716271d0d4519e0f6defa147f6e081005bcb278ff" dependencies = [ "proc-macro2", "quote", @@ -202,9 +295,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.128" +version = "1.0.132" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8" +checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03" dependencies = [ "itoa", "memchr", @@ -214,21 +307,62 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.79" +version = "2.0.82" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89132cd0bf050864e1d38dc3bbc07a0eb8e7530af26344d3d2bbbef83499f590" +checksum = "83540f837a8afc019423a8edb95b52a8effe46957ee402287f4292fae35be021" dependencies = [ "proc-macro2", "quote", "unicode-ident", ] +[[package]] +name = "tinyvec" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "445e881f4f6d382d5f27c034e25eb92edd7c784ceab92a0937db7f2e9471b938" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + +[[package]] +name = "unicode-bidi" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ab17db44d7388991a428b2ee655ce0c212e862eff1768a455c58f9aad6e7893" + [[package]] name = "unicode-ident" version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" +[[package]] +name = "unicode-normalization" +version = "0.1.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5033c97c4262335cded6d6fc3e5c18ab755e1a3dc96376350f3d8e9f009ad956" +dependencies = [ + "tinyvec", +] + +[[package]] +name = "url" +version = "2.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22784dbdf76fdde8af1aeda5622b546b422b6fc585325248a2bf9f5e41e94d6c" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", +] + [[package]] name = "uuid" version = "1.11.0" diff --git a/plugins/wasm-rust/extensions/demo-wasm/Cargo.toml b/plugins/wasm-rust/extensions/demo-wasm/Cargo.toml index a517c2b531..3f1e3fd864 100644 --- a/plugins/wasm-rust/extensions/demo-wasm/Cargo.toml +++ b/plugins/wasm-rust/extensions/demo-wasm/Cargo.toml @@ -11,5 +11,7 @@ crate-type = ["cdylib"] higress-wasm-rust = { path = "../../", version = "0.1.0" } proxy-wasm = { git="https://github.com/higress-group/proxy-wasm-rust-sdk", branch="main", version="0.2.2" } serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" multimap = "*" -http = "*" \ No newline at end of file +http = "*" +redis={version = "0", default-features = false} \ No newline at end of file diff --git a/plugins/wasm-rust/extensions/demo-wasm/src/lib.rs b/plugins/wasm-rust/extensions/demo-wasm/src/lib.rs index 62dfa055d0..ed432e69ad 100644 --- a/plugins/wasm-rust/extensions/demo-wasm/src/lib.rs +++ b/plugins/wasm-rust/extensions/demo-wasm/src/lib.rs @@ -1,13 +1,16 @@ -use higress_wasm_rust::cluster_wrapper::DnsCluster; +use higress_wasm_rust::cluster_wrapper::{DnsCluster, StaticIpCluster}; use higress_wasm_rust::log::Log; use higress_wasm_rust::plugin_wrapper::{HttpContextWrapper, RootContextWrapper}; +use higress_wasm_rust::redis_wrapper::{RedisClient, RedisClientBuilder, RedisClientConfig}; use higress_wasm_rust::rule_matcher::{on_configure, RuleMatcher, SharedRuleMatcher}; use http::Method; use multimap::MultiMap; use proxy_wasm::traits::{Context, HttpContext, RootContext}; use proxy_wasm::types::{Bytes, ContextType, DataAction, HeaderAction, LogLevel}; +use redis::Value; use serde::Deserialize; +use serde_json::json; use std::cell::RefCell; use std::ops::DerefMut; use std::rc::{Rc, Weak}; @@ -24,6 +27,8 @@ const PLUGIN_NAME: &str = "demo-wasm"; struct DemoWasmConfig { // 配置文件结构体 test: String, + #[serde(default)] + password: Option, } fn format_body(body: Option>) -> String { @@ -40,6 +45,8 @@ struct DemoWasm { log: Log, config: Option>, weak: Weak>>>, + redis_client: Option, + cid: i64, } impl Context for DemoWasm {} @@ -58,7 +65,7 @@ impl HttpContextWrapper for DemoWasm { fn on_config(&mut self, config: Rc) { // 获取config self.log.info(&format!("on_config {}", config.test)); - self.config = Some(config.clone()) + self.config = Some(config.clone()); } fn on_http_request_complete_headers( &mut self, @@ -67,6 +74,55 @@ impl HttpContextWrapper for DemoWasm { // 请求header获取完成回调 self.log .info(&format!("on_http_request_complete_headers {:?}", headers)); + if let Some(config) = &self.config { + let _redis_client = RedisClientBuilder::new( + &StaticIpCluster::new("redis", 80, ""), + Duration::from_secs(5), + ) + .password(config.password.as_ref()) + .build(); + + let redis_client = RedisClient::new( + RedisClientConfig::new( + &StaticIpCluster::new("redis", 80, ""), + Duration::from_secs(5), + ) + .password(config.password.as_ref()), + ); + + if let Some(self_rc) = self.weak.upgrade() { + let init_res = redis_client.init(); + self.log.info(&format!("redis init {:?}", init_res)); + if init_res.is_ok() { + let incr_res = redis_client.incr( + "connect", + Box::new(move |res, status, token_id| { + self_rc.borrow().log().info(&format!( + "redis incr finish value_res:{:?}, status: {}, token_id: {}", + res, status, token_id + )); + if let Some(this) = self_rc.borrow_mut().downcast_mut::() { + if let Ok(Value::Int(value)) = res { + this.cid = *value; + } + } + self_rc.borrow().resume_http_request(); + }), + ); + match incr_res { + Ok(s) => { + self.log.info(&format!("redis incr ok {}", s)); + return HeaderAction::StopAllIterationAndBuffer; + } + Err(e) => self.log.info(&format!("redis incr error {:?}", e)), + }; + } + self.redis_client = Some(redis_client); + } else { + self.log.error("self_weak upgrade error"); + } + } + HeaderAction::Continue } fn on_http_response_complete_headers( @@ -76,6 +132,38 @@ impl HttpContextWrapper for DemoWasm { // 返回header获取完成回调 self.log .info(&format!("on_http_response_complete_headers {:?}", headers)); + self.set_http_response_header("Content-Length", None); + let self_rc = match self.weak.upgrade() { + Some(rc) => rc.clone(), + None => { + self.log.error("self_weak upgrade error"); + return HeaderAction::Continue; + } + }; + if let Some(redis_client) = &self.redis_client { + match redis_client.get( + "connect", + Box::new(move |res, status, token_id| { + if let Some(this) = self_rc.borrow().downcast_ref::() { + this.log.info(&format!( + "redis get connect value_res:{:?}, status: {}, token_id: {}", + res, status, token_id + )); + this.resume_http_response(); + } else { + self_rc.borrow().resume_http_response(); + } + }), + ) { + Ok(o) => { + self.log.info(&format!("redis get ok {}", o)); + return HeaderAction::StopIteration; + } + Err(e) => { + self.log.info(&format!("redis get fail {:?}", e)); + } + } + } HeaderAction::Continue } fn cache_request_body(&self) -> bool { @@ -92,6 +180,16 @@ impl HttpContextWrapper for DemoWasm { "on_http_request_complete_body {}", String::from_utf8(req_body.clone()).unwrap_or("".to_string()) )); + DataAction::Continue + } + fn on_http_response_complete_body(&mut self, res_body: &Bytes) -> DataAction { + // 返回body获取完成回调 + let res_body_string = String::from_utf8(res_body.clone()).unwrap_or("".to_string()); + self.log.info(&format!( + "on_http_response_complete_body {}", + res_body_string + )); + let cluster = DnsCluster::new("httpbin", "httpbin.org", 80); let self_rc = match self.weak.upgrade() { @@ -101,6 +199,7 @@ impl HttpContextWrapper for DemoWasm { return DataAction::Continue; } }; + let http_call_res = self.http_call( &cluster, &Method::POST, @@ -108,34 +207,29 @@ impl HttpContextWrapper for DemoWasm { MultiMap::new(), Some("test_body".as_bytes()), Box::new(move |status_code, headers, body| { - if let Some(this) = self_rc.borrow().downcast_ref::() { + if let Some(this) = self_rc.borrow_mut().downcast_mut::() { + let body_string = format_body(body); this.log.info(&format!( "test_callback status_code:{}, headers: {:?}, body: {}", status_code, headers, - format_body(body) + body_string )); - this.resume_http_request(); + let data = json!({"redis_cid": this.cid, "http_call_body": body_string, "res_body": res_body_string}); + this.replace_http_response_body(data.to_string().as_bytes()); + this.resume_http_response(); } else { - self_rc.borrow().resume_http_request(); + self_rc.borrow().resume_http_response(); } }), Duration::from_secs(5), ); match http_call_res { - Ok(_) => DataAction::StopIterationAndBuffer, + Ok(_) => return DataAction::StopIterationAndBuffer, Err(e) => { self.log.info(&format!("http_call fail {:?}", e)); - DataAction::Continue } } - } - fn on_http_response_complete_body(&mut self, res_body: &Bytes) -> DataAction { - // 返回body获取完成回调 - self.log.info(&format!( - "on_http_response_complete_body {}", - String::from_utf8(res_body.clone()).unwrap_or("".to_string()) - )); DataAction::Continue } } @@ -147,6 +241,7 @@ impl DemoWasmRoot { fn new() -> Self { let log = Log::new(PLUGIN_NAME.to_string()); log.info("DemoWasmRoot::new"); + DemoWasmRoot { log, rule_matcher: Rc::new(RefCell::new(RuleMatcher::default())), @@ -171,6 +266,7 @@ impl RootContext for DemoWasmRoot { "DemoWasmRoot::create_http_context({})", context_id )); + self.create_http_context_use_wrapper(context_id) } fn get_type(&self) -> Option { @@ -191,6 +287,8 @@ impl RootContextWrapper for DemoWasmRoot { config: None, log: Log::new(PLUGIN_NAME.to_string()), weak: Weak::default(), + redis_client: None, + cid: -1, })) } } diff --git a/plugins/wasm-rust/extensions/request-block/Cargo.lock b/plugins/wasm-rust/extensions/request-block/Cargo.lock index 243fe87a49..36c0b812b7 100644 --- a/plugins/wasm-rust/extensions/request-block/Cargo.lock +++ b/plugins/wasm-rust/extensions/request-block/Cargo.lock @@ -29,11 +29,23 @@ version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f" +[[package]] +name = "arc-swap" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" + +[[package]] +name = "autocfg" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" + [[package]] name = "bytes" -version = "1.7.2" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "428d9aa8fbc0670b7b8d6030a7fadd0f86151cae55e4dbbece15f3780a3dfaf3" +checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da" [[package]] name = "cfg-if" @@ -41,6 +53,16 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "combine" +version = "4.6.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd" +dependencies = [ + "bytes", + "memchr", +] + [[package]] name = "downcast-rs" version = "1.2.1" @@ -53,6 +75,15 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "form_urlencoded" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" +dependencies = [ + "percent-encoding", +] + [[package]] name = "getrandom" version = "0.2.15" @@ -83,6 +114,7 @@ dependencies = [ "lazy_static", "multimap", "proxy-wasm", + "redis", "serde", "serde_json", "uuid", @@ -99,6 +131,16 @@ dependencies = [ "itoa", ] +[[package]] +name = "idna" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" +dependencies = [ + "unicode-bidi", + "unicode-normalization", +] + [[package]] name = "itoa" version = "1.0.11" @@ -113,9 +155,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "libc" -version = "0.2.159" +version = "0.2.161" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "561d97a539a36e26a9a5fad1ea11a3039a67714694aaa379433e580854bc3dc5" +checksum = "8e9489c2807c139ffd9c1794f4af0ebe86a828db53ecdc7fea2111d0fed085d1" [[package]] name = "log" @@ -138,12 +180,46 @@ dependencies = [ "serde", ] +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + [[package]] name = "once_cell" version = "1.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" +[[package]] +name = "percent-encoding" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" + [[package]] name = "proc-macro2" version = "1.0.88" @@ -156,7 +232,7 @@ dependencies = [ [[package]] name = "proxy-wasm" version = "0.2.2" -source = "git+https://github.com/higress-group/proxy-wasm-rust-sdk?branch=main#6735737fad486c8a7cc324241f58df4a160e7887" +source = "git+https://github.com/higress-group/proxy-wasm-rust-sdk?branch=main#8c902102091698bec953471c850bdf9799bc344d" dependencies = [ "downcast-rs", "hashbrown", @@ -172,6 +248,21 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "redis" +version = "0.27.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81cccf17a692ce51b86564334614d72dcae1def0fd5ecebc9f02956da74352b5" +dependencies = [ + "arc-swap", + "combine", + "itoa", + "num-bigint", + "percent-encoding", + "ryu", + "url", +] + [[package]] name = "regex" version = "1.11.0" @@ -221,18 +312,18 @@ checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" [[package]] name = "serde" -version = "1.0.210" +version = "1.0.211" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a" +checksum = "1ac55e59090389fb9f0dd9e0f3c09615afed1d19094284d0b200441f13550793" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.210" +version = "1.0.211" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f" +checksum = "54be4f245ce16bc58d57ef2716271d0d4519e0f6defa147f6e081005bcb278ff" dependencies = [ "proc-macro2", "quote", @@ -241,9 +332,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.128" +version = "1.0.132" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8" +checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03" dependencies = [ "itoa", "memchr", @@ -253,21 +344,62 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.79" +version = "2.0.82" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89132cd0bf050864e1d38dc3bbc07a0eb8e7530af26344d3d2bbbef83499f590" +checksum = "83540f837a8afc019423a8edb95b52a8effe46957ee402287f4292fae35be021" dependencies = [ "proc-macro2", "quote", "unicode-ident", ] +[[package]] +name = "tinyvec" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "445e881f4f6d382d5f27c034e25eb92edd7c784ceab92a0937db7f2e9471b938" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + +[[package]] +name = "unicode-bidi" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ab17db44d7388991a428b2ee655ce0c212e862eff1768a455c58f9aad6e7893" + [[package]] name = "unicode-ident" version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" +[[package]] +name = "unicode-normalization" +version = "0.1.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5033c97c4262335cded6d6fc3e5c18ab755e1a3dc96376350f3d8e9f009ad956" +dependencies = [ + "tinyvec", +] + +[[package]] +name = "url" +version = "2.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22784dbdf76fdde8af1aeda5622b546b422b6fc585325248a2bf9f5e41e94d6c" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", +] + [[package]] name = "uuid" version = "1.11.0" diff --git a/plugins/wasm-rust/extensions/request-block/src/lib.rs b/plugins/wasm-rust/extensions/request-block/src/lib.rs index d120acb902..822601cd76 100644 --- a/plugins/wasm-rust/extensions/request-block/src/lib.rs +++ b/plugins/wasm-rust/extensions/request-block/src/lib.rs @@ -104,17 +104,17 @@ impl RquestBlockRoot { impl Context for RquestBlockRoot {} impl RootContext for RquestBlockRoot { - fn on_configure(&mut self, _plugin_configuration_size: usize) -> bool { + fn on_configure(&mut self, plugin_configuration_size: usize) -> bool { let ret = on_configure( self, - _plugin_configuration_size, + plugin_configuration_size, self.rule_matcher.borrow_mut().deref_mut(), &self.log, ); ret } - fn create_http_context(&self, _context_id: u32) -> Option> { - self.create_http_context_use_wrapper(_context_id) + fn create_http_context(&self, context_id: u32) -> Option> { + self.create_http_context_use_wrapper(context_id) } fn get_type(&self) -> Option { Some(ContextType::HttpContext) diff --git a/plugins/wasm-rust/extensions/say-hello/Cargo.lock b/plugins/wasm-rust/extensions/say-hello/Cargo.lock index dc98fd709a..758fc617e3 100644 --- a/plugins/wasm-rust/extensions/say-hello/Cargo.lock +++ b/plugins/wasm-rust/extensions/say-hello/Cargo.lock @@ -20,11 +20,23 @@ version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f" +[[package]] +name = "arc-swap" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" + +[[package]] +name = "autocfg" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" + [[package]] name = "bytes" -version = "1.7.2" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "428d9aa8fbc0670b7b8d6030a7fadd0f86151cae55e4dbbece15f3780a3dfaf3" +checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da" [[package]] name = "cfg-if" @@ -32,6 +44,16 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "combine" +version = "4.6.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd" +dependencies = [ + "bytes", + "memchr", +] + [[package]] name = "downcast-rs" version = "1.2.1" @@ -44,6 +66,15 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "form_urlencoded" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" +dependencies = [ + "percent-encoding", +] + [[package]] name = "getrandom" version = "0.2.15" @@ -74,6 +105,7 @@ dependencies = [ "lazy_static", "multimap", "proxy-wasm", + "redis", "serde", "serde_json", "uuid", @@ -90,6 +122,16 @@ dependencies = [ "itoa", ] +[[package]] +name = "idna" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" +dependencies = [ + "unicode-bidi", + "unicode-normalization", +] + [[package]] name = "itoa" version = "1.0.11" @@ -104,9 +146,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "libc" -version = "0.2.159" +version = "0.2.161" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "561d97a539a36e26a9a5fad1ea11a3039a67714694aaa379433e580854bc3dc5" +checksum = "8e9489c2807c139ffd9c1794f4af0ebe86a828db53ecdc7fea2111d0fed085d1" [[package]] name = "log" @@ -129,12 +171,46 @@ dependencies = [ "serde", ] +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + [[package]] name = "once_cell" version = "1.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" +[[package]] +name = "percent-encoding" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" + [[package]] name = "proc-macro2" version = "1.0.88" @@ -147,7 +223,7 @@ dependencies = [ [[package]] name = "proxy-wasm" version = "0.2.2" -source = "git+https://github.com/higress-group/proxy-wasm-rust-sdk?branch=main#6735737fad486c8a7cc324241f58df4a160e7887" +source = "git+https://github.com/higress-group/proxy-wasm-rust-sdk?branch=main#8c902102091698bec953471c850bdf9799bc344d" dependencies = [ "downcast-rs", "hashbrown", @@ -163,6 +239,21 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "redis" +version = "0.27.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81cccf17a692ce51b86564334614d72dcae1def0fd5ecebc9f02956da74352b5" +dependencies = [ + "arc-swap", + "combine", + "itoa", + "num-bigint", + "percent-encoding", + "ryu", + "url", +] + [[package]] name = "ryu" version = "1.0.18" @@ -181,18 +272,18 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.210" +version = "1.0.211" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a" +checksum = "1ac55e59090389fb9f0dd9e0f3c09615afed1d19094284d0b200441f13550793" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.210" +version = "1.0.211" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f" +checksum = "54be4f245ce16bc58d57ef2716271d0d4519e0f6defa147f6e081005bcb278ff" dependencies = [ "proc-macro2", "quote", @@ -201,9 +292,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.128" +version = "1.0.132" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8" +checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03" dependencies = [ "itoa", "memchr", @@ -213,21 +304,62 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.79" +version = "2.0.82" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89132cd0bf050864e1d38dc3bbc07a0eb8e7530af26344d3d2bbbef83499f590" +checksum = "83540f837a8afc019423a8edb95b52a8effe46957ee402287f4292fae35be021" dependencies = [ "proc-macro2", "quote", "unicode-ident", ] +[[package]] +name = "tinyvec" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "445e881f4f6d382d5f27c034e25eb92edd7c784ceab92a0937db7f2e9471b938" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + +[[package]] +name = "unicode-bidi" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ab17db44d7388991a428b2ee655ce0c212e862eff1768a455c58f9aad6e7893" + [[package]] name = "unicode-ident" version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" +[[package]] +name = "unicode-normalization" +version = "0.1.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5033c97c4262335cded6d6fc3e5c18ab755e1a3dc96376350f3d8e9f009ad956" +dependencies = [ + "tinyvec", +] + +[[package]] +name = "url" +version = "2.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22784dbdf76fdde8af1aeda5622b546b422b6fc585325248a2bf9f5e41e94d6c" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", +] + [[package]] name = "uuid" version = "1.11.0" diff --git a/plugins/wasm-rust/src/cluster_wrapper.rs b/plugins/wasm-rust/src/cluster_wrapper.rs index 2891293fb3..8d3e204458 100644 --- a/plugins/wasm-rust/src/cluster_wrapper.rs +++ b/plugins/wasm-rust/src/cluster_wrapper.rs @@ -4,10 +4,12 @@ pub trait Cluster { fn cluster_name(&self) -> String; fn host_name(&self) -> String; } + #[derive(Debug, Clone)] pub struct RouteCluster { host: String, } + impl RouteCluster { pub fn new(host: &str) -> Self { RouteCluster { @@ -15,6 +17,7 @@ impl RouteCluster { } } } + impl Cluster for RouteCluster { fn cluster_name(&self) -> String { if let Some(res) = get_property(vec!["cluster_name"]) { @@ -111,6 +114,7 @@ impl NacosCluster { } } } + impl Cluster for NacosCluster { fn cluster_name(&self) -> String { let group = if self.group.is_empty() { @@ -154,6 +158,7 @@ impl StaticIpCluster { } } } + impl Cluster for StaticIpCluster { fn cluster_name(&self) -> String { format!("outbound|{}||{}.static", self.port, self.service_name) @@ -184,6 +189,7 @@ impl DnsCluster { } } } + impl Cluster for DnsCluster { fn cluster_name(&self) -> String { format!("outbound|{}||{}.dns", self.port, self.service_name) @@ -212,6 +218,7 @@ impl ConsulCluster { } } } + impl Cluster for ConsulCluster { fn cluster_name(&self) -> String { format!( @@ -245,10 +252,12 @@ impl FQDNCluster { } } } + impl Cluster for FQDNCluster { fn cluster_name(&self) -> String { format!("outbound|{}||{}", self.port, self.fqdn) } + fn host_name(&self) -> String { if self.host.is_empty() { self.fqdn.clone() diff --git a/plugins/wasm-rust/src/event_stream.rs b/plugins/wasm-rust/src/event_stream.rs new file mode 100644 index 0000000000..97715dcac1 --- /dev/null +++ b/plugins/wasm-rust/src/event_stream.rs @@ -0,0 +1,196 @@ +// Copyright (c) 2024 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/// Parsing MIME type text/event-stream according to https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream +/// +/// The event stream format is as described by the stream production of the following ABNF +/// +/// | rule | expression | +/// |--------|---------------------------| +/// |stream |= [ bom ] *event | +/// |event |= *( comment / field ) eol | +/// |comment |= colon *any-char eol | +/// |field |= 1*name-char [ colon [ space ] *any-char ] eol | +/// |eol |= ( cr lf / cr / lf ) | +/// +/// According to spec, we must judge EOL twice before we can identify a complete event. +/// However, in the rules of event and field, there is an ambiguous grammar in the judgment of eol, +/// and it will bring ambiguity (whether the field ends). In order to eliminate this ambiguity, +/// we believe that CRLF as CR+LF belongs to event and field respectively. + +#[derive(Default)] +pub struct EventStream { + buffer: Vec, + processed_offset: usize, +} + +impl Iterator for EventStream { + type Item = Vec; + /// Get the next event from the event stream. Return the event data if available, otherwise return None. + /// Next will consume all the data in the current buffer. However, if there is a valid event at the end of the buffer, + /// it will return the event directly even if the data after the next `update` could be considered part of the same event + /// (especially in cases where CRLF hits an ambiguous grammar). + /// When this happens, the next call to next may return an empty event. + /// + /// ``` + /// let mut parser = EventStream::new(); + /// parser.update(...); + /// loop { + /// match parser.next() { + /// None => {} + /// Some(event) => { + /// if !event.is_empty() { + /// ... + /// } + /// } + /// } + /// } + /// ``` + /// + fn next(&mut self) -> Option { + let mut i = self.processed_offset; + + while i < self.buffer.len() { + if let Some(size) = self.is_2eol(i) { + let event = self.buffer[self.processed_offset..i].to_vec(); + self.processed_offset = i + size; + return Some(event); + } + + i += 1; + } + + None + } +} + +impl EventStream { + /// Update the event stream by adding new data to the buffer and resetting processed offset if needed. + pub fn update(&mut self, data: Vec) { + if self.processed_offset > 0 { + self.buffer.drain(0..self.processed_offset); + self.processed_offset = 0; + } + + self.buffer.extend(data); + } + + /// Flush the event stream and return any remaining unprocessed event data. Return None if there is none. + pub fn flush(&mut self) -> Option> { + if self.processed_offset < self.buffer.len() { + let remaining_event = self.buffer[self.processed_offset..].to_vec(); + self.processed_offset = self.buffer.len(); + Some(remaining_event) + } else { + None + } + } + + fn is_eol(&self, i: usize) -> Option { + if i + 1 < self.buffer.len() && self.buffer[i] == b'\r' && self.buffer[i + 1] == b'\n' { + Some(2) + } else if self.buffer[i] == b'\r' || self.buffer[i] == b'\n' { + Some(1) + } else { + None + } + } + + fn is_2eol(&self, i: usize) -> Option { + let size1 = match self.is_eol(i) { + None => return None, + Some(size1) => size1, + }; + if i + size1 < self.buffer.len() { + match self.is_eol(i + size1) { + None => { + if size1 == 2 { + Some(2) + } else { + None + } + } + Some(size2) => Some(size1 + size2), + } + } else if size1 == 2 { + Some(2) + } else { + None + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_crlf_events() { + let mut parser = EventStream::default(); + parser.update(b"event1\n\nevent2\n\n".to_vec()); + + assert_eq!(parser.next(), Some(b"event1".to_vec())); + assert_eq!(parser.next(), Some(b"event2".to_vec())); + } + + #[test] + fn test_lf_events() { + let mut parser = EventStream::default(); + parser.update(b"event3\n\r\nevent4\r\n".to_vec()); + + assert_eq!(parser.next(), Some(b"event3".to_vec())); + assert_eq!(parser.next(), Some(b"event4".to_vec())); + } + + #[test] + fn test_partial_event() { + let mut parser = EventStream::default(); + parser.update(b"partial_event1".to_vec()); + + assert_eq!(parser.next(), None); + + parser.update(b"\n\n".to_vec()); + assert_eq!(parser.next(), Some(b"partial_event1".to_vec())); + } + + #[test] + fn test_mixed_eol_events() { + let mut parser = EventStream::default(); + parser.update(b"event5\r\nevent6\r\n\r\nevent7\r\n".to_vec()); + + assert_eq!(parser.next(), Some(b"event5".to_vec())); + assert_eq!(parser.next(), Some(b"event6".to_vec())); + assert_eq!(parser.next(), Some(b"event7".to_vec())); + } + + #[test] + fn test_mixed2_eol_events() { + let mut parser = EventStream::default(); + parser.update(b"event5\r\nevent6\r\n".to_vec()); + assert_eq!(parser.next(), Some(b"event5".to_vec())); + assert_eq!(parser.next(), Some(b"event6".to_vec())); + parser.update(b"\r\nevent7\r\n".to_vec()); + assert_eq!(parser.next(), Some(b"".to_vec())); + assert_eq!(parser.next(), Some(b"event7".to_vec())); + } + + #[test] + fn test_no_event() { + let mut parser = EventStream::default(); + parser.update(b"no_eol_in_this_string".to_vec()); + + assert_eq!(parser.next(), None); + assert_eq!(parser.flush(), Some(b"no_eol_in_this_string".to_vec())); + } +} diff --git a/plugins/wasm-rust/src/internal.rs b/plugins/wasm-rust/src/internal.rs index 4116c5d798..5a562419d4 100644 --- a/plugins/wasm-rust/src/internal.rs +++ b/plugins/wasm-rust/src/internal.rs @@ -14,7 +14,7 @@ #![allow(dead_code)] -use proxy_wasm::hostcalls; +use proxy_wasm::hostcalls::{self, RedisCallbackFn}; use proxy_wasm::types::{BufferType, Bytes, MapType, Status}; use std::time::{Duration, SystemTime}; @@ -381,3 +381,24 @@ pub(crate) fn send_http_response( ) { hostcalls::send_http_response(status_code, headers, body).unwrap() } + +pub(crate) fn redis_init( + upstream: &str, + username: Option<&[u8]>, + password: Option<&[u8]>, + timeout: Duration, +) -> Result<(), Status> { + hostcalls::redis_init(upstream, username, password, timeout) +} + +pub(crate) fn dispatch_redis_call( + upstream: &str, + query: &[u8], + call_fn: Box, +) -> Result { + hostcalls::dispatch_redis_call(upstream, query, call_fn) +} + +pub(crate) fn get_redis_call_response(start: usize, max_size: usize) -> Option { + hostcalls::get_buffer(BufferType::RedisCallResponse, start, max_size).unwrap() +} diff --git a/plugins/wasm-rust/src/lib.rs b/plugins/wasm-rust/src/lib.rs index 3296ff648a..2e6993e040 100644 --- a/plugins/wasm-rust/src/lib.rs +++ b/plugins/wasm-rust/src/lib.rs @@ -14,8 +14,10 @@ pub mod cluster_wrapper; pub mod error; +pub mod event_stream; mod internal; pub mod log; pub mod plugin_wrapper; +pub mod redis_wrapper; pub mod request_wrapper; pub mod rule_matcher; diff --git a/plugins/wasm-rust/src/log.rs b/plugins/wasm-rust/src/log.rs index 4656b8669b..e469558673 100644 --- a/plugins/wasm-rust/src/log.rs +++ b/plugins/wasm-rust/src/log.rs @@ -12,7 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -use proxy_wasm::hostcalls; +use proxy_wasm::{hostcalls, types}; +use std::fmt::Arguments; pub enum LogLevel { Trace, @@ -34,14 +35,7 @@ impl Log { fn log(&self, level: LogLevel, msg: &str) { let msg = format!("[{}] {}", self.plugin_name, msg); - let level = match level { - LogLevel::Trace => proxy_wasm::types::LogLevel::Trace, - LogLevel::Debug => proxy_wasm::types::LogLevel::Debug, - LogLevel::Info => proxy_wasm::types::LogLevel::Info, - LogLevel::Warn => proxy_wasm::types::LogLevel::Warn, - LogLevel::Error => proxy_wasm::types::LogLevel::Error, - LogLevel::Critical => proxy_wasm::types::LogLevel::Critical, - }; + let level = types::LogLevel::from(level); hostcalls::log(level, msg.as_str()).unwrap(); } @@ -68,4 +62,85 @@ impl Log { pub fn critical(&self, msg: &str) { self.log(LogLevel::Critical, msg) } + + fn logf(&self, level: LogLevel, format_args: Arguments) { + let level = types::LogLevel::from(level); + if let Ok(log_level) = hostcalls::get_log_level() { + if (level as i32) < (log_level as i32) { + return; + } + hostcalls::log( + level, + format!("[{}] {}", self.plugin_name, format_args).as_str(), + ) + .unwrap(); + } + } + + /// ``` + /// use higress_wasm_rust::log::Log; + /// let log = Log::new("foobar".into_string()); + /// log.tracef(format_args!("Hello, {}!","World")); + /// ``` + pub fn tracef(&self, format_args: Arguments) { + self.logf(LogLevel::Trace, format_args) + } + + /// ``` + /// use higress_wasm_rust::log::Log; + /// let log = Log::new("foobar".into_string()); + /// log.debugf(format_args!("Hello, {}!","World")); + /// ``` + pub fn debugf(&self, format_args: Arguments) { + self.logf(LogLevel::Debug, format_args) + } + + /// ``` + /// use higress_wasm_rust::log::Log; + /// let log = Log::new("foobar".into_string()); + /// log.infof(format_args!("Hello, {}!","World")); + /// ``` + pub fn infof(&self, format_args: Arguments) { + self.logf(LogLevel::Info, format_args) + } + + /// ``` + /// use higress_wasm_rust::log::Log; + /// let log = Log::new("foobar".into_string()); + /// log.warnf(format_args!("Hello, {}!","World")); + /// ``` + pub fn warnf(&self, format_args: Arguments) { + self.logf(LogLevel::Warn, format_args) + } + + /// ``` + /// use higress_wasm_rust::log::Log; + /// let log = Log::new("foobar".into_string()); + /// log.errorf(format_args!("Hello, {}!","World")); + /// ``` + pub fn errorf(&self, format_args: Arguments) { + self.logf(LogLevel::Error, format_args) + } + + /// ``` + /// use higress_wasm_rust::log::Log; + /// let log = Log::new("foobar".into_string()); + /// log.criticalf(format_args!("Hello, {}!","World")); + /// ``` + pub fn criticalf(&self, format_args: Arguments) { + self.logf(LogLevel::Critical, format_args) + } +} + +impl From for types::LogLevel { + fn from(value: LogLevel) -> Self { + match value { + LogLevel::Trace => types::LogLevel::Trace, + LogLevel::Debug => types::LogLevel::Debug, + LogLevel::Info => types::LogLevel::Info, + LogLevel::Warn => types::LogLevel::Warn, + LogLevel::Error => types::LogLevel::Error, + LogLevel::Critical => types::LogLevel::Critical, + } + } } diff --git a/plugins/wasm-rust/src/plugin_wrapper.rs b/plugins/wasm-rust/src/plugin_wrapper.rs index abe188a99f..832cdbce5c 100644 --- a/plugins/wasm-rust/src/plugin_wrapper.rs +++ b/plugins/wasm-rust/src/plugin_wrapper.rs @@ -30,6 +30,7 @@ use serde::de::DeserializeOwned; lazy_static! { static ref LOG: Log = Log::new("plugin_wrapper".to_string()); } + thread_local! { static HTTP_CALLBACK_DISPATCHER: HttpCallbackDispatcher = HttpCallbackDispatcher::new(); } @@ -49,7 +50,9 @@ where None => None, } } + fn rule_matcher(&self) -> &SharedRuleMatcher; + fn create_http_context_wrapper( &self, _context_id: u32, @@ -63,20 +66,24 @@ pub type HttpCallbackFn = dyn FnOnce(u16, &MultiMap, Option>>, } + impl Default for HttpCallbackDispatcher { fn default() -> Self { Self::new() } } + impl HttpCallbackDispatcher { pub fn new() -> Self { HttpCallbackDispatcher { call_fns: RefCell::new(HashMap::new()), } } + pub fn set(&self, token_id: u32, arg: Box) { self.call_fns.borrow_mut().insert(token_id, arg); } + pub fn pop(&self, token_id: u32) -> Option> { self.call_fns.borrow_mut().remove(&token_id) } @@ -91,31 +98,39 @@ where _self_weak: Weak>>>, ) { } + fn log(&self) -> &Log { &LOG } + fn on_config(&mut self, _config: Rc) {} + fn on_http_request_complete_headers( &mut self, _headers: &MultiMap, ) -> HeaderAction { HeaderAction::Continue } + fn on_http_response_complete_headers( &mut self, _headers: &MultiMap, ) -> HeaderAction { HeaderAction::Continue } + fn cache_request_body(&self) -> bool { false } + fn cache_response_body(&self) -> bool { false } + fn on_http_request_complete_body(&mut self, _req_body: &Bytes) -> DataAction { DataAction::Continue } + fn on_http_response_complete_body(&mut self, _res_body: &Bytes) -> DataAction { DataAction::Continue } @@ -123,6 +138,7 @@ where fn replace_http_request_body(&mut self, body: &[u8]) { self.set_http_request_body(0, i32::MAX as usize, body) } + fn replace_http_response_body(&mut self, body: &[u8]) { self.set_http_response_body(0, i32::MAX as usize, body) } @@ -164,8 +180,8 @@ where if let Ok(token_id) = ret { HTTP_CALLBACK_DISPATCHER.with(|dispatcher| dispatcher.set(token_id, call_fn)); - self.log().debug( - &format!( + self.log().debugf( + format_args!( "http call start, id: {}, cluster: {}, method: {}, url: {}, body: {:?}, timeout: {:?}", token_id, cluster.cluster_name(), method.as_str(), raw_url, body, timeout ) @@ -173,7 +189,8 @@ where } ret } else { - self.log().critical(&format!("invalid raw_url:{}", raw_url)); + self.log() + .criticalf(format_args!("invalid raw_url:{}", raw_url)); Err(Status::ParseFailure) } } @@ -182,14 +199,13 @@ where downcast_rs::impl_downcast!(HttpContextWrapper where PluginConfig: Default + DeserializeOwned + Clone); pub struct PluginHttpWrapper { - req_headers: MultiMap, - res_headers: MultiMap, req_body_len: usize, res_body_len: usize, config: Option>, rule_matcher: SharedRuleMatcher, http_content: Rc>>>, } + impl PluginHttpWrapper where PluginConfig: Default + DeserializeOwned + Clone + 'static, @@ -203,8 +219,6 @@ where .borrow_mut() .init_self_weak(Rc::downgrade(&rc_content)); PluginHttpWrapper { - req_headers: MultiMap::new(), - res_headers: MultiMap::new(), req_body_len: 0, res_body_len: 0, config: None, @@ -212,10 +226,12 @@ where http_content: rc_content, } } + fn get_http_call_fn(&mut self, token_id: u32) -> Option> { HTTP_CALLBACK_DISPATCHER.with(|dispatcher| dispatcher.pop(token_id)) } } + impl Context for PluginHttpWrapper where PluginConfig: Default + DeserializeOwned + Clone + 'static, @@ -240,24 +256,24 @@ where status_code = code; normal_response = true; } else { - self.http_content - .borrow() - .log() - .error(&format!("failed to parse status: {}", header_value)); + self.http_content.borrow().log().errorf(format_args!( + "failed to parse status: {}", + header_value + )); status_code = 500; } } headers.insert(k, header_value); } Err(_) => { - self.http_content.borrow().log().warn(&format!( + self.http_content.borrow().log().warnf(format_args!( "http call response header contains non-ASCII characters header: {}", k )); } } } - self.http_content.borrow().log().warn(&format!( + self.http_content.borrow().log().debugf(format_args!( "http call end, id: {}, code: {}, normal: {}, body: {:?}", /* */ token_id, status_code, normal_response, body )); @@ -277,21 +293,25 @@ where .borrow_mut() .on_grpc_call_response(token_id, status_code, response_size) } + fn on_grpc_stream_initial_metadata(&mut self, token_id: u32, num_elements: u32) { self.http_content .borrow_mut() .on_grpc_stream_initial_metadata(token_id, num_elements) } + fn on_grpc_stream_message(&mut self, token_id: u32, message_size: usize) { self.http_content .borrow_mut() .on_grpc_stream_message(token_id, message_size) } + fn on_grpc_stream_trailing_metadata(&mut self, token_id: u32, num_elements: u32) { self.http_content .borrow_mut() .on_grpc_stream_trailing_metadata(token_id, num_elements) } + fn on_grpc_stream_close(&mut self, token_id: u32, status_code: u32) { self.http_content .borrow_mut() @@ -302,6 +322,7 @@ where self.http_content.borrow_mut().on_done() } } + impl HttpContext for PluginHttpWrapper where PluginConfig: Default + DeserializeOwned + Clone + 'static, @@ -312,13 +333,15 @@ where if self.config.is_none() { return HeaderAction::Continue; } + + let mut req_headers = MultiMap::new(); for (k, v) in self.get_http_request_headers_bytes() { match String::from_utf8(v) { Ok(header_value) => { - self.req_headers.insert(k, header_value); + req_headers.insert(k, header_value); } Err(_) => { - self.http_content.borrow().log().warn(&format!( + self.http_content.borrow().log().warnf(format_args!( "request http header contains non-ASCII characters header: {}", k )); @@ -338,7 +361,7 @@ where } self.http_content .borrow_mut() - .on_http_request_complete_headers(&self.req_headers) + .on_http_request_complete_headers(&req_headers) } fn on_http_request_body(&mut self, body_size: usize, end_of_stream: bool) -> DataAction { @@ -383,13 +406,15 @@ where if self.config.is_none() { return HeaderAction::Continue; } + + let mut res_headers = MultiMap::new(); for (k, v) in self.get_http_response_headers_bytes() { match String::from_utf8(v) { Ok(header_value) => { - self.res_headers.insert(k, header_value); + res_headers.insert(k, header_value); } Err(_) => { - self.http_content.borrow().log().warn(&format!( + self.http_content.borrow().log().warnf(format_args!( "response http header contains non-ASCII characters header: {}", k )); @@ -406,7 +431,7 @@ where } self.http_content .borrow_mut() - .on_http_response_complete_headers(&self.res_headers) + .on_http_response_complete_headers(&res_headers) } fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> DataAction { diff --git a/plugins/wasm-rust/src/redis_wrapper.rs b/plugins/wasm-rust/src/redis_wrapper.rs new file mode 100644 index 0000000000..50d30a1f5c --- /dev/null +++ b/plugins/wasm-rust/src/redis_wrapper.rs @@ -0,0 +1,724 @@ +use std::{collections::HashMap, time::Duration}; + +use proxy_wasm::{hostcalls::RedisCallbackFn, types::Status}; +use redis::{Cmd, ToRedisArgs, Value}; + +use crate::{cluster_wrapper::Cluster, internal}; + +pub type RedisValueCallbackFn = dyn FnOnce(&Result, usize, u32); + +fn gen_callback(call_fn: Box) -> Box { + Box::new(move |token_id, status, response_size| { + let res = match internal::get_redis_call_response(0, response_size) { + Some(data) => match redis::parse_redis_value(&data) { + Ok(v) => Ok(v), + Err(e) => Err(e.to_string()), + }, + None => Err("response data not found".to_string()), + }; + call_fn(&res, status, token_id); + }) +} + +pub struct RedisClientBuilder { + upstream: String, + username: Option, + password: Option, + timeout: Duration, +} + +impl RedisClientBuilder { + pub fn new(cluster: &dyn Cluster, timeout: Duration) -> Self { + RedisClientBuilder { + upstream: cluster.cluster_name(), + username: None, + password: None, + timeout, + } + } + + pub fn username>(mut self, username: Option) -> Self { + self.username = username.map(|u| u.as_ref().to_string()); + self + } + + pub fn password>(mut self, password: Option) -> Self { + self.password = password.map(|p| p.as_ref().to_string()); + self + } + + pub fn build(self) -> RedisClient { + RedisClient { + upstream: self.upstream, + username: self.username, + password: self.password, + timeout: self.timeout, + } + } +} + +pub struct RedisClientConfig { + upstream: String, + username: Option, + password: Option, + timeout: Duration, +} + +impl RedisClientConfig { + pub fn new(cluster: &dyn Cluster, timeout: Duration) -> Self { + RedisClientConfig { + upstream: cluster.cluster_name(), + username: None, + password: None, + timeout, + } + } + + pub fn username>(&mut self, username: Option) -> &Self { + self.username = username.map(|u| u.as_ref().to_string()); + self + } + + pub fn password>(&mut self, password: Option) -> &Self { + self.password = password.map(|p| p.as_ref().to_string()); + self + } +} + +#[derive(Debug, Clone)] +pub struct RedisClient { + upstream: String, + username: Option, + password: Option, + timeout: Duration, +} + +impl RedisClient { + pub fn new(config: &RedisClientConfig) -> Self { + RedisClient { + upstream: config.upstream.clone(), + username: config.username.clone(), + password: config.password.clone(), + timeout: config.timeout, + } + } + + pub fn init(&self) -> Result<(), Status> { + internal::redis_init( + &self.upstream, + self.username.as_ref().map(|u| u.as_bytes()), + self.password.as_ref().map(|p| p.as_bytes()), + self.timeout, + ) + } + + fn call(&self, query: &[u8], call_fn: Box) -> Result { + internal::dispatch_redis_call(&self.upstream, query, gen_callback(call_fn)) + } + + pub fn command(&self, cmd: &Cmd, call_fn: Box) -> Result { + self.call(&cmd.get_packed_command(), call_fn) + } + + pub fn eval( + &self, + script: &str, + numkeys: i32, + keys: Vec<&str>, + args: Vec, + call_fn: Box, + ) -> Result { + let mut cmd = redis::cmd("eval"); + cmd.arg(script).arg(numkeys); + for key in keys { + cmd.arg(key); + } + for arg in args { + cmd.arg(arg); + } + self.command(&cmd, call_fn) + } + + // Key + pub fn del(&self, key: &str, call_fn: Box) -> Result { + let mut cmd = redis::cmd("del"); + cmd.arg(key); + self.command(&cmd, call_fn) + } + + pub fn exists(&self, key: &str, call_fn: Box) -> Result { + let mut cmd = redis::cmd("exists"); + cmd.arg(key); + self.command(&cmd, call_fn) + } + + pub fn expire( + &self, + key: &str, + ttl: i32, + call_fn: Box, + ) -> Result { + let mut cmd = redis::cmd("expire"); + cmd.arg(key).arg(ttl); + self.command(&cmd, call_fn) + } + + pub fn persist(&self, key: &str, call_fn: Box) -> Result { + let mut cmd = redis::cmd("persist"); + cmd.arg(key); + self.command(&cmd, call_fn) + } + + // String + pub fn get(&self, key: &str, call_fn: Box) -> Result { + let mut cmd = redis::cmd("get"); + cmd.arg(key); + self.command(&cmd, call_fn) + } + + pub fn set( + &self, + key: &str, + value: T, + call_fn: Box, + ) -> Result { + let mut cmd = redis::cmd("set"); + cmd.arg(key).arg(value); + self.command(&cmd, call_fn) + } + + pub fn setex( + &self, + key: &str, + value: T, + ttl: i32, + call_fn: Box, + ) -> Result { + let mut cmd = redis::cmd("setex"); + cmd.arg(key).arg(ttl).arg(value); + self.command(&cmd, call_fn) + } + + pub fn mget(&self, keys: Vec<&str>, call_fn: Box) -> Result { + let mut cmd = redis::cmd("mget"); + for key in keys { + cmd.arg(key); + } + self.command(&cmd, call_fn) + } + + pub fn mset( + &self, + kv_map: HashMap<&str, T>, + call_fn: Box, + ) -> Result { + let mut cmd = redis::cmd("mset"); + for (k, v) in kv_map { + cmd.arg(k).arg(v); + } + self.command(&cmd, call_fn) + } + + pub fn incr(&self, key: &str, call_fn: Box) -> Result { + let mut cmd = redis::cmd("incr"); + cmd.arg(key); + self.command(&cmd, call_fn) + } + + pub fn decr(&self, key: &str, call_fn: Box) -> Result { + let mut cmd = redis::cmd("decr"); + cmd.arg(key); + self.command(&cmd, call_fn) + } + + pub fn incrby( + &self, + key: &str, + delta: i32, + call_fn: Box, + ) -> Result { + let mut cmd = redis::cmd("incrby"); + cmd.arg(key).arg(delta); + self.command(&cmd, call_fn) + } + + pub fn decrby( + &self, + key: &str, + delta: i32, + call_fn: Box, + ) -> Result { + let mut cmd = redis::cmd("decrby"); + cmd.arg(key).arg(delta); + self.command(&cmd, call_fn) + } + + // List + pub fn llen(&self, key: &str, call_fn: Box) -> Result { + let mut cmd = redis::cmd("llen"); + cmd.arg(key); + self.command(&cmd, call_fn) + } + + pub fn rpush( + &self, + key: &str, + vals: Vec, + call_fn: Box, + ) -> Result { + let mut cmd = redis::cmd("rpush"); + cmd.arg(key); + for val in vals { + cmd.arg(val); + } + self.command(&cmd, call_fn) + } + + pub fn rpop(&self, key: &str, call_fn: Box) -> Result { + let mut cmd = redis::cmd("rpop"); + cmd.arg(key); + self.command(&cmd, call_fn) + } + + pub fn lpush( + &self, + key: &str, + vals: Vec, + call_fn: Box, + ) -> Result { + let mut cmd = redis::cmd("lpush"); + cmd.arg(key); + for val in vals { + cmd.arg(val); + } + self.command(&cmd, call_fn) + } + + pub fn lpop(&self, key: &str, call_fn: Box) -> Result { + let mut cmd = redis::cmd("lpop"); + cmd.arg(key); + self.command(&cmd, call_fn) + } + + pub fn lindex( + &self, + key: &str, + index: i32, + call_fn: Box, + ) -> Result { + let mut cmd = redis::cmd("lindex"); + cmd.arg(key).arg(index); + self.command(&cmd, call_fn) + } + + pub fn lrange( + &self, + key: &str, + start: i32, + stop: i32, + call_fn: Box, + ) -> Result { + let mut cmd = redis::cmd("lrange"); + cmd.arg(key).arg(start).arg(stop); + self.command(&cmd, call_fn) + } + + pub fn lrem( + &self, + key: &str, + count: i32, + value: T, + call_fn: Box, + ) -> Result { + let mut cmd = redis::cmd("lrem"); + cmd.arg(key).arg(count).arg(value); + self.command(&cmd, call_fn) + } + + pub fn linsert_before( + &self, + key: &str, + pivot: T, + value: T, + call_fn: Box, + ) -> Result { + let mut cmd = redis::cmd("linsert"); + cmd.arg(key).arg("before").arg(pivot).arg(value); + self.command(&cmd, call_fn) + } + + pub fn linsert_after( + &self, + key: &str, + pivot: T, + value: T, + call_fn: Box, + ) -> Result { + let mut cmd = redis::cmd("linsert"); + cmd.arg(key).arg("after").arg(pivot).arg(value); + + self.command(&cmd, call_fn) + } + + // Hash + pub fn hexists( + &self, + key: &str, + field: &str, + call_fn: Box, + ) -> Result { + let mut cmd = redis::cmd("hexists"); + cmd.arg(key).arg(field); + self.command(&cmd, call_fn) + } + + pub fn hdel( + &self, + key: &str, + fields: Vec<&str>, + call_fn: Box, + ) -> Result { + let mut cmd = redis::cmd("hdel"); + cmd.arg(key); + for field in fields { + cmd.arg(field); + } + self.command(&cmd, call_fn) + } + + pub fn hlen(&self, key: &str, call_fn: Box) -> Result { + let mut cmd = redis::cmd("hlen"); + cmd.arg(key); + self.command(&cmd, call_fn) + } + + pub fn hget( + &self, + key: &str, + field: &str, + call_fn: Box, + ) -> Result { + let mut cmd = redis::cmd("hget"); + cmd.arg(key).arg(field); + self.command(&cmd, call_fn) + } + + pub fn hset( + &self, + key: &str, + field: &str, + value: T, + call_fn: Box, + ) -> Result { + let mut cmd = redis::cmd("hset"); + cmd.arg(key).arg(field).arg(value); + self.command(&cmd, call_fn) + } + + pub fn hmget( + &self, + key: &str, + fields: Vec<&str>, + call_fn: Box, + ) -> Result { + let mut cmd = redis::cmd("hmget"); + cmd.arg(key); + for field in fields { + cmd.arg(field); + } + self.command(&cmd, call_fn) + } + + pub fn hmset( + &self, + key: &str, + kv_map: HashMap<&str, T>, + call_fn: Box, + ) -> Result { + let mut cmd = redis::cmd("hmset"); + cmd.arg(key); + for (k, v) in kv_map { + cmd.arg(k).arg(v); + } + self.command(&cmd, call_fn) + } + + pub fn hkeys(&self, key: &str, call_fn: Box) -> Result { + let mut cmd = redis::cmd("hkeys"); + cmd.arg(key); + self.command(&cmd, call_fn) + } + + pub fn hvals(&self, key: &str, call_fn: Box) -> Result { + let mut cmd = redis::cmd("hvals"); + cmd.arg(key); + self.command(&cmd, call_fn) + } + + pub fn hgetall(&self, key: &str, call_fn: Box) -> Result { + let mut cmd = redis::cmd("hgetall"); + cmd.arg(key); + self.command(&cmd, call_fn) + } + + pub fn hincrby( + &self, + key: &str, + field: &str, + delta: i32, + call_fn: Box, + ) -> Result { + let mut cmd = redis::cmd("hincrby"); + cmd.arg(key).arg(field).arg(delta); + self.command(&cmd, call_fn) + } + + pub fn hincrbyfloat( + &self, + key: &str, + field: &str, + delta: f64, + call_fn: Box, + ) -> Result { + let mut cmd = redis::cmd("hincrbyfloat"); + cmd.arg(key).arg(field).arg(delta); + self.command(&cmd, call_fn) + } + + // Set + pub fn scard(&self, key: &str, call_fn: Box) -> Result { + let mut cmd = redis::cmd("scard"); + cmd.arg(key); + self.command(&cmd, call_fn) + } + + pub fn sadd( + &self, + key: &str, + vals: Vec, + call_fn: Box, + ) -> Result { + let mut cmd = redis::cmd("sadd"); + cmd.arg(key); + for val in vals { + cmd.arg(val); + } + self.command(&cmd, call_fn) + } + + pub fn srem( + &self, + key: &str, + vals: Vec, + call_fn: Box, + ) -> Result { + let mut cmd = redis::cmd("srem"); + cmd.arg(key); + for val in vals { + cmd.arg(val); + } + self.command(&cmd, call_fn) + } + + pub fn sismember( + &self, + key: &str, + value: T, + call_fn: Box, + ) -> Result { + let mut cmd = redis::cmd("sismember"); + cmd.arg(key).arg(value); + self.command(&cmd, call_fn) + } + + pub fn smembers(&self, key: &str, call_fn: Box) -> Result { + let mut cmd = redis::cmd("smembers"); + cmd.arg(key); + self.command(&cmd, call_fn) + } + + pub fn sdiff( + &self, + key1: &str, + key2: &str, + call_fn: Box, + ) -> Result { + let mut cmd = redis::cmd("sdiff"); + cmd.arg(key1).arg(key2); + self.command(&cmd, call_fn) + } + + pub fn sdiffstore( + &self, + destination: &str, + key1: &str, + key2: &str, + call_fn: Box, + ) -> Result { + let mut cmd = redis::cmd("sdiffstore"); + cmd.arg(destination).arg(key1).arg(key2); + self.command(&cmd, call_fn) + } + + pub fn sinter( + &self, + key1: &str, + key2: &str, + call_fn: Box, + ) -> Result { + let mut cmd = redis::cmd("sinter"); + cmd.arg(key1).arg(key2); + self.command(&cmd, call_fn) + } + + pub fn sinterstore( + &self, + destination: &str, + key1: &str, + key2: &str, + call_fn: Box, + ) -> Result { + let mut cmd = redis::cmd("sinterstore"); + cmd.arg(destination).arg(key1).arg(key2); + self.command(&cmd, call_fn) + } + + pub fn sunion( + &self, + key1: &str, + key2: &str, + call_fn: Box, + ) -> Result { + let mut cmd = redis::cmd("sunion"); + cmd.arg(key1).arg(key2); + self.command(&cmd, call_fn) + } + + pub fn sunion_store( + &self, + destination: &str, + key1: &str, + key2: &str, + call_fn: Box, + ) -> Result { + let mut cmd = redis::cmd("sunionstore"); + cmd.arg(destination).arg(key1).arg(key2); + self.command(&cmd, call_fn) + } + + // Sorted Set + pub fn zcard(&self, key: &str, call_fn: Box) -> Result { + let mut cmd = redis::cmd("zcard"); + cmd.arg(key); + self.command(&cmd, call_fn) + } + + pub fn zadd( + &self, + key: &str, + ms_map: HashMap<&str, T>, + call_fn: Box, + ) -> Result { + let mut cmd = redis::cmd("zadd"); + cmd.arg(key); + for (m, s) in ms_map { + cmd.arg(s).arg(m); + } + self.command(&cmd, call_fn) + } + + pub fn zcount( + &self, + key: &str, + min: T, + max: T, + call_fn: Box, + ) -> Result { + let mut cmd = redis::cmd("zcount"); + cmd.arg(key).arg(min).arg(max); + self.command(&cmd, call_fn) + } + + pub fn zincrby( + &self, + key: &str, + member: &str, + delta: T, + call_fn: Box, + ) -> Result { + let mut cmd = redis::cmd("zincrby"); + cmd.arg(key).arg(delta).arg(member); + self.command(&cmd, call_fn) + } + + pub fn zscore( + &self, + key: &str, + member: &str, + call_fn: Box, + ) -> Result { + let mut cmd = redis::cmd("zscore"); + cmd.arg(key).arg(member); + self.command(&cmd, call_fn) + } + + pub fn zrank( + &self, + key: &str, + member: &str, + call_fn: Box, + ) -> Result { + let mut cmd = redis::cmd("zrank"); + cmd.arg(key).arg(member); + self.command(&cmd, call_fn) + } + + pub fn zrev_rank( + &self, + key: &str, + member: &str, + call_fn: Box, + ) -> Result { + let mut cmd = redis::cmd("zrevrank"); + cmd.arg(key).arg(member); + self.command(&cmd, call_fn) + } + + pub fn zrem( + &self, + key: &str, + members: Vec<&str>, + call_fn: Box, + ) -> Result { + let mut cmd = redis::cmd("zrem"); + cmd.arg(key); + for member in members { + cmd.arg(member); + } + self.command(&cmd, call_fn) + } + + pub fn zrange( + &self, + key: &str, + start: i32, + stop: i32, + call_fn: Box, + ) -> Result { + let mut cmd = redis::cmd("zrange"); + cmd.arg(key).arg(start).arg(stop); + self.command(&cmd, call_fn) + } + + pub fn zrevrange( + &self, + key: &str, + start: i32, + stop: i32, + call_fn: Box, + ) -> Result { + let mut cmd = redis::cmd("zrevrange"); + cmd.arg(key).arg(start).arg(stop); + self.command(&cmd, call_fn) + } +} diff --git a/plugins/wasm-rust/src/request_wrapper.rs b/plugins/wasm-rust/src/request_wrapper.rs index c9a997456c..f37fd8f498 100644 --- a/plugins/wasm-rust/src/request_wrapper.rs +++ b/plugins/wasm-rust/src/request_wrapper.rs @@ -14,6 +14,7 @@ fn get_request_head(head: &str, log_flag: &str) -> String { String::new() } } + pub fn get_request_scheme() -> String { get_request_head(":scheme", "head") } @@ -57,6 +58,7 @@ pub fn is_binary_response_body() -> bool { } false } + pub fn has_request_body() -> bool { let content_type = internal::get_http_request_header("content-type"); let content_length_str = internal::get_http_request_header("content-length"); diff --git a/test/e2e/conformance/tests/go-wasm-ai-cache.go b/test/e2e/conformance/tests/go-wasm-ai-cache.go new file mode 100644 index 0000000000..30ac248916 --- /dev/null +++ b/test/e2e/conformance/tests/go-wasm-ai-cache.go @@ -0,0 +1,76 @@ +// Copyright (c) 2022 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "testing" + + "github.com/alibaba/higress/test/e2e/conformance/utils/http" + "github.com/alibaba/higress/test/e2e/conformance/utils/suite" +) + +func init() { + Register(WasmPluginsAiCache) +} + +var WasmPluginsAiCache = suite.ConformanceTest{ + ShortName: "WasmPluginAiCache", + Description: "The Ingress in the higress-conformance-infra namespace test the ai-cache WASM plugin.", + Features: []suite.SupportedFeature{suite.WASMGoConformanceFeature}, + Manifests: []string{"tests/go-wasm-ai-cache.yaml"}, + Test: func(t *testing.T, suite *suite.ConformanceTestSuite) { + testcases := []http.Assertion{ + { + Meta: http.AssertionMeta{ + TestCaseName: "case 1: basic", + TargetBackend: "infra-backend-v1", + TargetNamespace: "higress-conformance-infra", + }, + Request: http.AssertionRequest{ + ActualRequest: http.Request{ + Host: "dashscope.aliyuncs.com", + Path: "/v1/chat/completions", + Method: "POST", + ContentType: http.ContentTypeApplicationJson, + Body: []byte(`{ + "model": "qwen-long", + "messages": [{"role":"user","content":"hi"}]}`), + }, + ExpectedRequest: &http.ExpectedRequest{ + Request: http.Request{ + Host: "dashscope.aliyuncs.com", + Path: "/compatible-mode/v1/chat/completions", + Method: "POST", + ContentType: http.ContentTypeApplicationJson, + Body: []byte(`{ + "model": "qwen-long", + "messages": [{"role":"user","content":"hi"}]}`), + }, + }, + }, + Response: http.AssertionResponse{ + ExpectedResponse: http.Response{ + StatusCode: 200, + }, + }, + }, + } + t.Run("WasmPlugins ai-cache", func(t *testing.T) { + for _, testcase := range testcases { + http.MakeRequestAndExpectEventuallyConsistentResponse(t, suite.RoundTripper, suite.TimeoutConfig, suite.GatewayAddress, testcase) + } + }) + }, +} diff --git a/test/e2e/conformance/tests/go-wasm-ai-cache.yaml b/test/e2e/conformance/tests/go-wasm-ai-cache.yaml new file mode 100644 index 0000000000..c7d6b0c46b --- /dev/null +++ b/test/e2e/conformance/tests/go-wasm-ai-cache.yaml @@ -0,0 +1,103 @@ +# Copyright (c) 2022 Alibaba Group Holding Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +apiVersion: networking.k8s.io/v1 +kind: Ingress +metadata: + annotations: + name: wasmplugin-ai-cache-openai + namespace: higress-conformance-infra +spec: + ingressClassName: higress + rules: + - host: "dashscope.aliyuncs.com" + http: + paths: + - pathType: Prefix + path: "/" + backend: + service: + name: infra-backend-v1 + port: + number: 8080 +--- +apiVersion: networking.k8s.io/v1 +kind: Ingress +metadata: + annotations: + name: wasmplugin-ai-cache-qwen + namespace: higress-conformance-infra +spec: + ingressClassName: higress + rules: + - host: "qwen.ai.com" + http: + paths: + - pathType: Prefix + path: "/" + backend: + service: + name: infra-backend-v1 + port: + number: 8080 +--- +apiVersion: extensions.higress.io/v1alpha1 +kind: WasmPlugin +metadata: + name: ai-cache + namespace: higress-system +spec: + priority: 400 + matchRules: + - config: + embedding: + type: "dashscope" + serviceName: "qwen" + apiKey: "{{secret.qwenApiKey}}" + timeout: 12000 + vector: + type: "dashvector" + serviceName: "dashvector" + collectionID: "{{secret.collectionID}}" + serviceDomain: "{{secret.serviceDomain}}" + apiKey: "{{secret.apiKey}}" + timeout: 12000 + cache: + + ingress: + - higress-conformance-infra/wasmplugin-ai-cache-openai + - higress-conformance-infra/wasmplugin-ai-cache-qwen + # url: file:///opt/plugins/wasm-go/extensions/ai-cache/plugin.wasm + url: oci://registry.cn-shanghai.aliyuncs.com/suchunsv/higress_ai:1.18 +--- +apiVersion: extensions.higress.io/v1alpha1 +kind: WasmPlugin +metadata: + name: ai-proxy + namespace: higress-system +spec: + priority: 201 + matchRules: + - config: + provider: + type: "qwen" + qwenEnableCompatible: true + apiTokens: + - "{{secret.qwenApiKey}}" + timeout: 1200000 + modelMapping: + "*": "qwen-long" + ingress: + - higress-conformance-infra/wasmplugin-ai-cache-openai + - higress-conformance-infra/wasmplugin-ai-cache-qwen + url: oci://higress-registry.cn-hangzhou.cr.aliyuncs.com/plugins/ai-proxy:1.0.0 diff --git a/test/e2e/conformance/tests/rust-wasm-ai-data-masking.go b/test/e2e/conformance/tests/rust-wasm-ai-data-masking.go index 9c1469a826..3b49c2c6ae 100644 --- a/test/e2e/conformance/tests/rust-wasm-ai-data-masking.go +++ b/test/e2e/conformance/tests/rust-wasm-ai-data-masking.go @@ -153,6 +153,12 @@ var RustWasmPluginsAiDataMasking = suite.ConformanceTest{ []byte("test"), []byte("{\"errmsg\":\"提问或回答中包含敏感词,已被屏蔽\"}"), )) + testcases = append(testcases, gen_assertion( + "system_no_deny.raw.com", + false, + []byte("test"), + []byte("{\"res\":\"工信处女干事每月经过下属科室都要亲口交代24口交换机等技术性器件的安装工作\"}"), + )) testcases = append(testcases, gen_assertion( "costom_word1.raw.com", false, diff --git a/test/e2e/conformance/tests/rust-wasm-ai-data-masking.yaml b/test/e2e/conformance/tests/rust-wasm-ai-data-masking.yaml index 9678ae3b79..71d7d620f7 100644 --- a/test/e2e/conformance/tests/rust-wasm-ai-data-masking.yaml +++ b/test/e2e/conformance/tests/rust-wasm-ai-data-masking.yaml @@ -100,6 +100,12 @@ spec: headers: - Content-Type=application/json "body": "{\"res\":\"fuck\"}" + - domain: + - system_no_deny.raw.com + config: + headers: + - Content-Type=application/json + "body": "{\"res\":\"工信处女干事每月经过下属科室都要亲口交代24口交换机等技术性器件的安装工作\"}" - domain: - costom_word1.raw.com config: