From cee3d92138a610e92559a6772b4d69f870b9e3cc Mon Sep 17 00:00:00 2001 From: Thamara Lakshan Date: Fri, 14 Feb 2025 14:16:25 +0530 Subject: [PATCH] sementic cache init --- gateway/enforcer/cmd/main.go | 5 +- gateway/enforcer/go.mod | 15 +- gateway/enforcer/go.sum | 59 ++------ gateway/enforcer/internal/cache/core.go | 98 +++++++++++++ gateway/enforcer/internal/cache/main.go | 69 ++++++++++ .../internal/datastore/cache_store.go | 129 ++++++++++++++++++ gateway/enforcer/internal/dto/ai_cache.go | 54 ++++++++ .../external_processing_envoy_attributes.go | 1 + .../dto/external_processing_envoy_metadata.go | 1 + gateway/enforcer/internal/extproc/ext_proc.go | 37 ++++- gateway/enforcer/internal/util/json.go | 33 ++++- 11 files changed, 434 insertions(+), 67 deletions(-) create mode 100644 gateway/enforcer/internal/cache/core.go create mode 100644 gateway/enforcer/internal/cache/main.go create mode 100644 gateway/enforcer/internal/datastore/cache_store.go create mode 100644 gateway/enforcer/internal/dto/ai_cache.go diff --git a/gateway/enforcer/cmd/main.go b/gateway/enforcer/cmd/main.go index cc81b43a1e..b2dcdbf5a8 100644 --- a/gateway/enforcer/cmd/main.go +++ b/gateway/enforcer/cmd/main.go @@ -41,8 +41,11 @@ func main() { apiStore, configStore, jwtIssuerDatastore, modelBasedRoundRobinTracker := xds.CreateXDSClients(cfg) // NewJWTTransformer creates a new instance of JWTTransformer. jwtTransformer := transformer.NewJWTTransformer(jwtIssuerDatastore) + // Create new cache store and incomingstorecachekeystore + cacheStore := datastore.NewRedisCache() + incomingRequestCacheKeyStore := datastore.NewIncomingRequestCacheKeyStore() // Start the external processing server - go extproc.StartExternalProcessingServer(cfg, apiStore, subAppDatastore, jwtTransformer, modelBasedRoundRobinTracker) + go extproc.StartExternalProcessingServer(cfg, apiStore, subAppDatastore, cacheStore, incomingRequestCacheKeyStore, jwtTransformer, modelBasedRoundRobinTracker) // Wait for the config to be loaded cfg.Logger.Info("Waiting for the config to be loaded") diff --git a/gateway/enforcer/go.mod b/gateway/enforcer/go.mod index fb8ca0c6d9..f5428a13fc 100644 --- a/gateway/enforcer/go.mod +++ b/gateway/enforcer/go.mod @@ -12,6 +12,7 @@ require ( github.com/google/uuid v1.6.0 github.com/kelseyhightower/envconfig v1.4.0 github.com/prometheus/client_golang v1.20.5 + github.com/redis/go-redis/v9 v9.7.0 github.com/stretchr/testify v1.10.0 github.com/vektah/gqlparser/v2 v2.5.17 github.com/wso2/apk/adapter v0.0.0-20241016075419-fc842057860d @@ -23,23 +24,14 @@ require ( ) require ( - github.com/Azure/azure-sdk-for-go v68.0.0+incompatible // indirect - github.com/Azure/go-amqp v1.3.0 // indirect - github.com/Azure/go-autorest v14.2.0+incompatible // indirect - github.com/Azure/go-autorest/autorest v0.11.30 // indirect - github.com/Azure/go-autorest/autorest/adal v0.9.24 // indirect - github.com/Azure/go-autorest/autorest/date v0.3.1 // indirect - github.com/Azure/go-autorest/autorest/to v0.4.1 // indirect - github.com/Azure/go-autorest/autorest/validation v0.3.2 // indirect - github.com/Azure/go-autorest/logger v0.2.2 // indirect - github.com/Azure/go-autorest/tracing v0.6.1 // indirect - github.com/agnivade/levenshtein v1.1.1 // indirect github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 // indirect github.com/Azure/go-amqp v1.3.0 // indirect + github.com/agnivade/levenshtein v1.1.1 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cncf/xds/go v0.0.0-20240905190251-b4127c9b8d78 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/envoyproxy/protoc-gen-validate v1.1.0 // indirect github.com/evanphx/json-patch v5.9.0+incompatible // indirect github.com/fxamacker/cbor/v2 v2.7.0 // indirect @@ -51,7 +43,6 @@ require ( github.com/golang/protobuf v1.5.4 // indirect github.com/google/gnostic-models v0.6.8 // indirect github.com/google/gofuzz v1.2.0 // indirect - github.com/joho/godotenv v1.5.1 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/compress v1.17.9 // indirect diff --git a/gateway/enforcer/go.sum b/gateway/enforcer/go.sum index f5c3f8c769..d778ca3d06 100644 --- a/gateway/enforcer/go.sum +++ b/gateway/enforcer/go.sum @@ -1,44 +1,3 @@ -github.com/Azure/azure-amqp-common-go/v4 v4.2.0 h1:q/jLx1KJ8xeI8XGfkOWMN9XrXzAfVTkyvCxPvHCjd2I= -github.com/Azure/azure-amqp-common-go/v4 v4.2.0/go.mod h1:GD3m/WPPma+621UaU6KNjKEo5Hl09z86viKwQjTpV0Q= -github.com/Azure/azure-event-hubs-go/v3 v3.6.2 h1:7rNj1/iqS/i3mUKokA2n2eMYO72TB7lO7OmpbKoakKY= -github.com/Azure/azure-event-hubs-go/v3 v3.6.2/go.mod h1:n+ocYr9j2JCLYqUqz9eI+lx/TEAtL/g6rZzyTFSuIpc= -github.com/Azure/azure-sdk-for-go v68.0.0+incompatible h1:fcYLmCpyNYRnvJbPerq7U0hS+6+I79yEDJBqVNcqUzU= -github.com/Azure/azure-sdk-for-go v68.0.0+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc= -github.com/Azure/go-amqp v1.3.0 h1://1rikYhoIQNXJFXyoO/Rlb4+4EkHYfJceNtLlys2/4= -github.com/Azure/go-amqp v1.3.0/go.mod h1:vZAogwdrkbyK3Mla8m/CxSc/aKdnTZ4IbPxl51Y5WZE= -github.com/Azure/go-autorest v14.2.0+incompatible h1:V5VMDjClD3GiElqLWO7mz2MxNAK/vTfRHdAubSIPRgs= -github.com/Azure/go-autorest v14.2.0+incompatible/go.mod h1:r+4oMnoxhatjLLJ6zxSWATqVooLgysK6ZNox3g/xq24= -github.com/Azure/go-autorest/autorest v0.11.30 h1:iaZ1RGz/ALZtN5eq4Nr1SOFSlf2E4pDI3Tcsl+dZPVE= -github.com/Azure/go-autorest/autorest v0.11.30/go.mod h1:t1kpPIOpIVX7annvothKvb0stsrXa37i7b+xpmBW8Fs= -github.com/Azure/go-autorest/autorest/adal v0.9.22/go.mod h1:XuAbAEUv2Tta//+voMI038TrJBqjKam0me7qR+L8Cmk= -github.com/Azure/go-autorest/autorest/adal v0.9.24 h1:BHZfgGsGwdkHDyZdtQRQk1WeUdW0m2WPAwuHZwUi5i4= -github.com/Azure/go-autorest/autorest/adal v0.9.24/go.mod h1:7T1+g0PYFmACYW5LlG2fcoPiPlFHjClyRGL7dRlP5c8= -github.com/Azure/go-autorest/autorest/azure/auth v0.4.2 h1:iM6UAvjR97ZIeR93qTcwpKNMpV+/FTWjwEbuPD495Tk= -github.com/Azure/go-autorest/autorest/azure/auth v0.4.2/go.mod h1:90gmfKdlmKgfjUpnCEpOJzsUEjrWDSLwHIG73tSXddM= -github.com/Azure/go-autorest/autorest/azure/cli v0.3.1 h1:LXl088ZQlP0SBppGFsRZonW6hSvwgL5gRByMbvUbx8U= -github.com/Azure/go-autorest/autorest/azure/cli v0.3.1/go.mod h1:ZG5p860J94/0kI9mNJVoIoLgXcirM2gF5i2kWloofxw= -github.com/Azure/go-autorest/autorest/date v0.3.0/go.mod h1:BI0uouVdmngYNUzGWeSYnokU+TrmwEsOqdt8Y6sso74= -github.com/Azure/go-autorest/autorest/date v0.3.1 h1:o9Z8Jyt+VJJTCZ/UORishuHOusBwolhjokt9s5k8I4w= -github.com/Azure/go-autorest/autorest/date v0.3.1/go.mod h1:Dz/RDmXlfiFFS/eW+b/xMUSFs1tboPVy6UjgADToWDM= -github.com/Azure/go-autorest/autorest/mocks v0.4.1/go.mod h1:LTp+uSrOhSkaKrUy935gNZuuIPPVsHlr9DSOxSayd+k= -github.com/Azure/go-autorest/autorest/mocks v0.4.2 h1:PGN4EDXnuQbojHbU0UWoNvmu9AGVwYHG9/fkDYhtAfw= -github.com/Azure/go-autorest/autorest/mocks v0.4.2/go.mod h1:Vy7OitM9Kei0i1Oj+LvyAWMXJHeKH1MVlzFugfVrmyU= -github.com/Azure/go-autorest/autorest/to v0.4.1 h1:CxNHBqdzTr7rLtdrtb5CMjJcDut+WNGCVv7OmS5+lTc= -github.com/Azure/go-autorest/autorest/to v0.4.1/go.mod h1:EtaofgU4zmtvn1zT2ARsjRFdq9vXx0YWtmElwL+GZ9M= -github.com/Azure/go-autorest/autorest/validation v0.3.2 h1:myD3tcvs+Fk1bkJ1Xx7xidop4z4FWvWADiMGMXeVd2E= -github.com/Azure/go-autorest/autorest/validation v0.3.2/go.mod h1:4z7eU88lSINAB5XL8mhfPumiUdoAQo/c7qXwbsM8Zhc= -github.com/Azure/go-autorest/logger v0.2.1/go.mod h1:T9E3cAhj2VqvPOtCYAvby9aBXkZmbF5NWuPV8+WeEW8= -github.com/Azure/go-autorest/logger v0.2.2 h1:hYqBsEBywrrOSW24kkOCXRcKfKhK76OzLTfF+MYDE2o= -github.com/Azure/go-autorest/logger v0.2.2/go.mod h1:I5fg9K52o+iuydlWfa9T5K6WFos9XYr9dYTFzpqgibw= -github.com/Azure/go-autorest/tracing v0.6.0/go.mod h1:+vhtPC754Xsa23ID7GlGsrdKBpUA79WCAKPPZVC2DeU= -github.com/Azure/go-autorest/tracing v0.6.1 h1:YUMSrC/CeD1ZnnXcNYU4a/fzsO35u2Fsful9L/2nyR0= -github.com/Azure/go-autorest/tracing v0.6.1/go.mod h1:/3EgjbsjraOqiicERAeu3m7/z0x1TzjQGAwDrJrXGkc= -github.com/agnivade/levenshtein v1.1.1 h1:QY8M92nrzkmr798gCo3kmMyqXFzdQVpxLlGPRBij0P8= -github.com/agnivade/levenshtein v1.1.1/go.mod h1:veldBMzWxcCG2ZvUTKD2kJNRdCk5hVbJomOvKkmgYbo= -github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883 h1:bvNMNQO63//z+xNgfBlViaCIJKLlCJ6/fmUseuG0wVQ= -github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883/go.mod h1:rCTlJbsFo29Kk6CurOXKm700vrz8f0KW0JNfpkRJY/8= -github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q= -github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE= github.com/Azure/azure-sdk-for-go/sdk/azcore v1.13.0 h1:GJHeeA2N7xrG3q30L2UXDyuWRzDM900/65j70wcM4Ww= github.com/Azure/azure-sdk-for-go/sdk/azcore v1.13.0/go.mod h1:l38EPgmsp71HHLq9j7De57JcKOWPyhrsW1Awm1JS6K0= github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0 h1:tfLQ34V6F7tVSwoTf/4lH5sE0o6eCJuNDTmH09nDpbc= @@ -55,8 +14,18 @@ github.com/Azure/go-amqp v1.3.0 h1://1rikYhoIQNXJFXyoO/Rlb4+4EkHYfJceNtLlys2/4= github.com/Azure/go-amqp v1.3.0/go.mod h1:vZAogwdrkbyK3Mla8m/CxSc/aKdnTZ4IbPxl51Y5WZE= github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 h1:XHOnouVk1mxXfQidrMEnLlPk9UMeRtyBTnEFtxkV0kU= github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI= +github.com/agnivade/levenshtein v1.1.1 h1:QY8M92nrzkmr798gCo3kmMyqXFzdQVpxLlGPRBij0P8= +github.com/agnivade/levenshtein v1.1.1/go.mod h1:veldBMzWxcCG2ZvUTKD2kJNRdCk5hVbJomOvKkmgYbo= +github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883 h1:bvNMNQO63//z+xNgfBlViaCIJKLlCJ6/fmUseuG0wVQ= +github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883/go.mod h1:rCTlJbsFo29Kk6CurOXKm700vrz8f0KW0JNfpkRJY/8= +github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q= +github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cncf/xds/go v0.0.0-20240905190251-b4127c9b8d78 h1:QVw89YDxXxEe+l8gU8ETbOasdwEV+avkR75ZzsVV9WI= @@ -65,12 +34,10 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/devigned/tab v0.1.1 h1:3mD6Kb1mUOYeLpJvTVSDwSg5ZsfSxfvxGRTxRsJsITA= -github.com/devigned/tab v0.1.1/go.mod h1:XG9mPq0dFghrYvoBF3xdRrJzSTX1b7IQrvaL9mzjeJY= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48 h1:fRzb/w+pyskVMQ+UbP35JkH8yB7MYb4q/qhBarqZE6g= github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48/go.mod h1:if7Fbed8SFyPtHLHbg49SI7NAdJiC5WIA09pe59rfAA= -github.com/dimchansky/utfbom v1.1.0 h1:FcM3g+nofKgUteL8dm/UpdRXNC9KmADgTpLKsu0TRo4= -github.com/dimchansky/utfbom v1.1.0/go.mod h1:rO41eb7gLfo8SF1jd9F8HplJm1Fewwi4mQvIirEdv+8= github.com/envoyproxy/gateway v1.2.3 h1:Qne11MOjNPmawTCFi35iuYvwA3kTqmBTFE7wDZkIgmo= github.com/envoyproxy/gateway v1.2.3/go.mod h1:JkrLVKpgdd3D6Umr6uw1Hu98lCCpxU2pzK32qeM67U0= github.com/envoyproxy/go-control-plane v0.13.1 h1:vPfJZCkob6yTMEgS+0TwfTUfbHjfy/6vOJ8hUWX/uXE= @@ -173,6 +140,8 @@ github.com/prometheus/common v0.60.1 h1:FUas6GcOw66yB/73KC+BOZoFJmbo/1pojoILArPA github.com/prometheus/common v0.60.1/go.mod h1:h0LYf1R1deLSKtD4Vdg8gy4RuOvENW2J/h19V5NADQw= github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= +github.com/redis/go-redis/v9 v9.7.0 h1:HhLSs+B6O021gwzl+locl0zEDnyNkxMtf/Z3NNBMa9E= +github.com/redis/go-redis/v9 v9.7.0/go.mod h1:f6zhXITC7JUJIlPEiBOTXxJgPLdZcA93GewI7inzyWw= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/sergi/go-diff v1.3.1 h1:xkr+Oxo4BOQKmkn/B9eMK0g5Kg/983T9DqqPHwYqD+8= diff --git a/gateway/enforcer/internal/cache/core.go b/gateway/enforcer/internal/cache/core.go new file mode 100644 index 0000000000..8a3cd83a72 --- /dev/null +++ b/gateway/enforcer/internal/cache/core.go @@ -0,0 +1,98 @@ +package cache + +import ( + "encoding/json" + "fmt" + "time" + + corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + envoy_service_proc_v3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + v32 "github.com/envoyproxy/go-control-plane/envoy/type/v3" + "github.com/wso2/apk/gateway/enforcer/internal/datastore" + "github.com/wso2/apk/gateway/enforcer/internal/dto" +) + +// CheckCacheForKey checks if the key is in the cache +func CheckCacheForKey(key string, cacheStore datastore.CacheStore) (string, error) { + + return cacheStore.Get(key) + // TODO: check vector similarity search if redis cache miss + +} + +// Caches the response value +func cacheResponse(key string, value string, cacheStore datastore.CacheStore) { + err := cacheStore.Set(key, value) + if err != nil { + fmt.Printf("[AI-CACHE] cache set failed, key: %s, error: %v", key, err) + return + } + fmt.Printf("[AI-CACHE] cache set success, key: %s, length of value: %d", key, len(value)) + +} + +// SendCachedHTTPResponse makes ext_proc response for cached value +func SendCachedHTTPResponse(cachedResponse string, resp *envoy_service_proc_v3.ProcessingResponse) { + + llmResponse := dto.LLMResponse{ + ID: "chatcmpl-123", // You may want to generate a unique ID + Object: "chat.completion", + Created: time.Now().Unix(), + Model: "gpt-3.5-turbo", + Usage: dto.Usage{ + PromptTokens: 0, + CompletionTokens: 0, + TotalTokens: 0, + }, + Choices: []dto.Choice{ + { + Index: 0, + Message: dto.Message{ + Role: "assistant", + Content: cachedResponse, + }, + Delta: []any{nil}, + FinishReason: "stop", + }, + }, + } + + httpBody, _ := json.Marshal(llmResponse) + httpBodyLength := len(httpBody) + + headers := &envoy_service_proc_v3.HeaderMutation{ + SetHeaders: []*corev3.HeaderValueOption{ + { + Header: &corev3.HeaderValue{ + Key: "Content-Length", + RawValue: []byte(fmt.Sprintf("%d", httpBodyLength)), + }, + }, + { + Header: &corev3.HeaderValue{ + Key: "Content-Type", + RawValue: []byte("application/json"), + }, + }, + { + Header: &corev3.HeaderValue{ + Key: "X-Cache-Status", + RawValue: []byte("HIT"), + }, + }, + }, + } + + rbq := &envoy_service_proc_v3.ImmediateResponse{ + Status: &v32.HttpStatus{ + Code: v32.StatusCode_OK, + }, + Headers: headers, + Body: httpBody, + } + + resp.Response = &envoy_service_proc_v3.ProcessingResponse_ImmediateResponse{ + ImmediateResponse: rbq, + } + +} diff --git a/gateway/enforcer/internal/cache/main.go b/gateway/enforcer/internal/cache/main.go new file mode 100644 index 0000000000..0dcf2357d3 --- /dev/null +++ b/gateway/enforcer/internal/cache/main.go @@ -0,0 +1,69 @@ +package cache + +import ( + "encoding/json" + "fmt" + + envoy_service_proc_v3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + "github.com/wso2/apk/gateway/enforcer/internal/datastore" + "github.com/wso2/apk/gateway/enforcer/internal/dto" + "github.com/wso2/apk/gateway/enforcer/internal/util" +) + +// HandleHTTPRequestBody handles http request body +func HandleHTTPRequestBody(requestID string, cacheStore datastore.CacheStore, keyStore *datastore.IncomingRequestCacheKeyStore, req *envoy_service_proc_v3.ProcessingRequest, resp *envoy_service_proc_v3.ProcessingResponse) { + httpBody := req.GetRequestBody().Body + + var llmRequest dto.LLMRequest + if err := json.Unmarshal(httpBody, &llmRequest); err != nil { + fmt.Printf("[AI-CACHE] Error unmarshaling JSON Reuqest Body. %v", err) + return + } + + key, has := llmRequest.GetKey() + if !has { + fmt.Printf("[AI-CACHE] cache key not found in request body.") + return + } + + cachedResponse, err := CheckCacheForKey(key, cacheStore) + if err != nil { + fmt.Printf("[AI-CACHE] error retrieving key: %s from cache, error: %v", key, err) + keyStore.Set(requestID, key) // TODO: perform only if cache miss + return + } + + SendCachedHTTPResponse(cachedResponse, resp) +} + +// HandleHTTPResponseBody handles http response body +func HandleHTTPResponseBody(requestID string, cacheStore datastore.CacheStore, keyStore *datastore.IncomingRequestCacheKeyStore, req *envoy_service_proc_v3.ProcessingRequest, resp *envoy_service_proc_v3.ProcessingResponse) { + httpBody := req.GetResponseBody().Body + + var llmResponse dto.LLMResponse + + uncompressedBody, err := util.DecompressIfGzip(httpBody) + if err != nil { + fmt.Printf("[AI-CACHE] Error decompressing response body, error: %v", err) + return + } + + if err := json.Unmarshal(uncompressedBody, &llmResponse); err != nil { + fmt.Printf("[AI-CACHE] Error unmarshaling JSON Response Body, error: %v", err) + return + } + + key, hasKey := keyStore.Pop(requestID) + if !hasKey { + fmt.Printf("[AI-CACHE] cache key not found for request ID: %s", requestID) + return + } + + responseValue, hasValue := llmResponse.GetValue() + if !hasValue { + fmt.Printf("[AI-CACHE] cached value for key %s is missing or empty", key) + return + } + + cacheResponse(key, responseValue, cacheStore) +} diff --git a/gateway/enforcer/internal/datastore/cache_store.go b/gateway/enforcer/internal/datastore/cache_store.go new file mode 100644 index 0000000000..fdf8a7222b --- /dev/null +++ b/gateway/enforcer/internal/datastore/cache_store.go @@ -0,0 +1,129 @@ +package datastore + +import ( + "context" + "errors" + "fmt" + "sync" + + "github.com/redis/go-redis/v9" +) + +// CacheStore defines an interface for a simple key-value cache. +type CacheStore interface { + Set(key string, value string) error + Get(key string) (string, error) +} + +// MockCache is an in-memory implementation of CacheStore for testing purposes. +type MockCache struct { + store map[string]string + mu sync.RWMutex +} + +// Set stores a key-value pair in the mock cache. +func (s *MockCache) Set(key string, value string) error { + s.mu.Lock() + defer s.mu.Unlock() + s.store[key] = value + fmt.Printf("[MockCache] SET key=%q -> value=%q\n", key, value) + return nil +} + +// Get retrieves a value from the mock cache. +// If the key does not exist, it returns an error. +func (s *MockCache) Get(key string) (string, error) { + s.mu.RLock() + defer s.mu.RUnlock() + val, exists := s.store[key] + fmt.Printf("[MockCache] GET key=%q -> value=%q, exists=%v\n", key, val, exists) + if !exists { + return "", errors.New("key doesn't exist") + } + return val, nil +} + +// NewMockCache initializes and returns a new instance of MockCache. +func NewMockCache() *MockCache { + return &MockCache{ + store: make(map[string]string), + } +} + +// RedisCache implements CacheStore using Redis. +type RedisCache struct { + client *redis.Client + ctx context.Context +} + +// Set stores a key-value pair in Redis. +func (r *RedisCache) Set(key string, value string) error { + err := r.client.Set(r.ctx, key, value, 0).Err() + if err != nil { + return fmt.Errorf("failed to set key: %w", err) + } + fmt.Printf("[RedisCache] SET key=%q -> value=%q\n", key, value) + return nil +} + +// Get retrieves a value from Redis. +func (r *RedisCache) Get(key string) (string, error) { + val, err := r.client.Get(r.ctx, key).Result() + if err == redis.Nil { + fmt.Printf("[RedisCache] GET key=%q -> Not found\n", key) + return "", errors.New("key doesn't exist") + } else if err != nil { + return "", fmt.Errorf("failed to get key: %w", err) + } + fmt.Printf("[RedisCache] GET key=%q -> value=%q\n", key, val) + return val, nil +} + +// NewRedisCache initializes and returns a new RedisCache instance. +func NewRedisCache() *RedisCache { + client := redis.NewClient(&redis.Options{ + Addr: "host.docker.internal:6379", + Password: "", // No password by default + DB: 0, // Use default DB + }) + + return &RedisCache{ + client: client, + ctx: context.Background(), + } +} + +// IncomingRequestCacheKeyStore defines store cache key store +type IncomingRequestCacheKeyStore struct { + keys map[string]string + mu sync.RWMutex +} + +// NewIncomingRequestCacheKeyStore Initiate new cache key store +func NewIncomingRequestCacheKeyStore() *IncomingRequestCacheKeyStore { + return &IncomingRequestCacheKeyStore{ + keys: make(map[string]string), + } +} + +// Set Incoming request cache key +func (s *IncomingRequestCacheKeyStore) Set(requestID string, key string) bool { + s.mu.Lock() + defer s.mu.Unlock() + s.keys[requestID] = key + fmt.Printf("[IncomingRequestCacheKeyStore] SET requestID=%q -> key=%q\n", requestID, key) + return true +} + +// Pop the request cache key. +func (s *IncomingRequestCacheKeyStore) Pop(requestID string) (string, bool) { + s.mu.Lock() + defer s.mu.Unlock() + key, has := s.keys[requestID] + fmt.Printf("[IncomingRequestCacheKeyStore] POP requestID=%q -> key=%q, exists=%v\n", requestID, key, has) + if !has { + return "", false + } + delete(s.keys, requestID) + return key, true +} diff --git a/gateway/enforcer/internal/dto/ai_cache.go b/gateway/enforcer/internal/dto/ai_cache.go new file mode 100644 index 0000000000..bea40a3315 --- /dev/null +++ b/gateway/enforcer/internal/dto/ai_cache.go @@ -0,0 +1,54 @@ +package dto + +// LLMRequest defines the OpenAI request structure +type LLMRequest struct { + Model string `json:"model"` + Messages []Message `json:"messages"` +} + +// Message represents a single message in the OpenAI request +type Message struct { + Role string `json:"role"` + Content string `json:"content"` +} + +// LLMResponse defines the OpenAI response structure +type LLMResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Usage Usage `json:"usage"` + Choices []Choice `json:"choices"` +} + +// Usage represents token usage details +type Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +// Choice represents a completion choice from OpenAI +type Choice struct { + Index int `json:"index"` + Message Message `json:"message"` + Delta []any `json:"delta"` + FinishReason string `json:"finish_reason"` +} + +// GetKey extracts the last message's content (key) from the request +func (r *LLMRequest) GetKey() (string, bool) { + if len(r.Messages) == 0 || r.Messages[len(r.Messages)-1].Content == "" { + return "", false + } + return r.Messages[len(r.Messages)-1].Content, true +} + +// GetValue extracts the assistant's response content (value) from the response +func (r *LLMResponse) GetValue() (string, bool) { + if len(r.Choices) == 0 || r.Choices[0].Message.Content == "" { + return "", false + } + return r.Choices[0].Message.Content, true +} diff --git a/gateway/enforcer/internal/dto/external_processing_envoy_attributes.go b/gateway/enforcer/internal/dto/external_processing_envoy_attributes.go index 4aba3cd651..61d0656500 100644 --- a/gateway/enforcer/internal/dto/external_processing_envoy_attributes.go +++ b/gateway/enforcer/internal/dto/external_processing_envoy_attributes.go @@ -30,6 +30,7 @@ type ExternalProcessingEnvoyAttributes struct { APIName string `json:"apiNameAttribute"` ClusterName string `json:"clusterNameAttribute"` RequestMethod string `json:"requestMethodAttribute"` + RequestID string `json:"requestIdAttribute"` Organization string `json:"organizationAttribute"` ApplicationID string `json:"applicationIdAttribute"` CorrelationID string `json:"correlationIdAttribute"` diff --git a/gateway/enforcer/internal/dto/external_processing_envoy_metadata.go b/gateway/enforcer/internal/dto/external_processing_envoy_metadata.go index cd4d365718..f13d0a2216 100644 --- a/gateway/enforcer/internal/dto/external_processing_envoy_metadata.go +++ b/gateway/enforcer/internal/dto/external_processing_envoy_metadata.go @@ -23,6 +23,7 @@ type ExternalProcessingEnvoyMetadata struct { MatchedResourceIdentifier string `json:"matchedResourceIdentifier"` MatchedSubscriptionIdentifier string `json:"matchedSubscriptionIdentifier"` MatchedApplicationIdentifier string `json:"matchedApplicationIdentifier"` + RequestIdentifier string `json:"requestIdentifier"` } // JwtAuthenticationData represents the JWT authentication data. diff --git a/gateway/enforcer/internal/extproc/ext_proc.go b/gateway/enforcer/internal/extproc/ext_proc.go index d8b130638c..def8f26c6f 100644 --- a/gateway/enforcer/internal/extproc/ext_proc.go +++ b/gateway/enforcer/internal/extproc/ext_proc.go @@ -31,6 +31,7 @@ import ( "github.com/wso2/apk/common-go-libs/loggers" "github.com/wso2/apk/gateway/enforcer/internal/analytics" "github.com/wso2/apk/gateway/enforcer/internal/authorization" + "github.com/wso2/apk/gateway/enforcer/internal/cache" "github.com/wso2/apk/gateway/enforcer/internal/config" "github.com/wso2/apk/gateway/enforcer/internal/datastore" "github.com/wso2/apk/gateway/enforcer/internal/dto" @@ -60,6 +61,8 @@ type ExternalProcessingServer struct { log logging.Logger apiStore *datastore.APIStore subscriptionApplicationDatastore *datastore.SubscriptionApplicationDataStore + cacheStore datastore.CacheStore + incomingRequestCacheKeyStore *datastore.IncomingRequestCacheKeyStore ratelimitHelper *ratelimit.AIRatelimitHelper requestConfigHolder *requestconfig.Holder cfg *config.Server @@ -92,6 +95,7 @@ const ( matchedResourceMetadataKey string = "request:matchedresource" matchedSubscriptionMetadataKey string = "request:matchedsubscription" matchedApplicationMetadataKey string = "request:matchedapplication" + requestIDMetadataKey string = "request:requestid" modelMetadataKey string = "aitoken:model" ) @@ -107,7 +111,7 @@ var httpHandler requesthandler.HTTP = requesthandler.HTTP{} // public and private keys, and a logger instance. // // If there is an error during the creation of the gRPC server, the function will panic. -func StartExternalProcessingServer(cfg *config.Server, apiStore *datastore.APIStore, subAppDatastore *datastore.SubscriptionApplicationDataStore, jwtTransformer *transformer.JWTTransformer, modelBasedRoundRobinTracker *datastore.ModelBasedRoundRobinTracker) { +func StartExternalProcessingServer(cfg *config.Server, apiStore *datastore.APIStore, subAppDatastore *datastore.SubscriptionApplicationDataStore, cacheStore datastore.CacheStore, incomingRequestCacheKeyStore *datastore.IncomingRequestCacheKeyStore, jwtTransformer *transformer.JWTTransformer, modelBasedRoundRobinTracker *datastore.ModelBasedRoundRobinTracker) { kaParams := keepalive.ServerParameters{ Time: time.Duration(cfg.ExternalProcessingKeepAliveTime) * time.Hour, // Ping the client if it is idle for 2 hours Timeout: 20 * time.Second, @@ -122,7 +126,7 @@ func StartExternalProcessingServer(cfg *config.Server, apiStore *datastore.APISt } ratelimitHelper := ratelimit.NewAIRatelimitHelper(cfg) - envoy_service_proc_v3.RegisterExternalProcessorServer(server, &ExternalProcessingServer{cfg.Logger, apiStore, subAppDatastore, ratelimitHelper, nil, cfg, jwtTransformer, modelBasedRoundRobinTracker}) + envoy_service_proc_v3.RegisterExternalProcessorServer(server, &ExternalProcessingServer{cfg.Logger, apiStore, subAppDatastore, cacheStore, incomingRequestCacheKeyStore, ratelimitHelper, nil, cfg, jwtTransformer, modelBasedRoundRobinTracker}) listener, err := net.Listen("tcp", fmt.Sprintf(":%s", cfg.ExternalProcessingPort)) if err != nil { cfg.Logger.Error(err, fmt.Sprintf("Failed to listen on port: %s", cfg.ExternalProcessingPort)) @@ -190,6 +194,7 @@ func (s *ExternalProcessingServer) Process(srv envoy_service_proc_v3.ExternalPro } break } + dynamicMetadataKeyValuePairs[requestIDMetadataKey] = attributes.RequestID rhq := &envoy_service_proc_v3.HeadersResponse{ Response: &envoy_service_proc_v3.CommonResponse{ HeaderMutation: &envoy_service_proc_v3.HeaderMutation{ @@ -222,7 +227,6 @@ func (s *ExternalProcessingServer) Process(srv envoy_service_proc_v3.ExternalPro dynamicMetadataKeyValuePairs[analytics.APIOrganizationIDKey] = requestConfigHolder.MatchedAPI.OrganizationID dynamicMetadataKeyValuePairs[analytics.APICreatorTenantDomainKey] = requestConfigHolder.MatchedAPI.OrganizationID - requestConfigHolder.ExternalProcessingEnvoyAttributes = attributes if requestConfigHolder.MatchedAPI != nil && requestConfigHolder.MatchedAPI.APIDefinitionPath != "" { definitionPath := requestConfigHolder.MatchedAPI.APIDefinitionPath @@ -272,15 +276,15 @@ func (s *ExternalProcessingServer) Process(srv envoy_service_proc_v3.ExternalPro } } s.cfg.Logger.Info(fmt.Sprintf("Metadata context : %+v", req.GetMetadataContext())) - + requestConfigHolder.MatchedResource = httpHandler.GetMatchedResource(requestConfigHolder.MatchedAPI, *requestConfigHolder.ExternalProcessingEnvoyAttributes) if requestConfigHolder.MatchedResource != nil { requestConfigHolder.MatchedResource.RouteMetadataAttributes = attributes dynamicMetadataKeyValuePairs[matchedResourceMetadataKey] = requestConfigHolder.MatchedResource.GetResourceIdentifier() dynamicMetadataKeyValuePairs[analytics.APIResourceTemplateKey] = requestConfigHolder.MatchedResource.Path s.log.Info(fmt.Sprintf("Matched Resource Endpoints: %+v", requestConfigHolder.MatchedResource.Endpoints)) - if requestConfigHolder.MatchedResource.Endpoints!= nil && len(requestConfigHolder.MatchedResource.Endpoints.URLs) > 0 { - dynamicMetadataKeyValuePairs[analytics.DestinationKey] = requestConfigHolder.MatchedResource.Endpoints.URLs[0] + if requestConfigHolder.MatchedResource.Endpoints != nil && len(requestConfigHolder.MatchedResource.Endpoints.URLs) > 0 { + dynamicMetadataKeyValuePairs[analytics.DestinationKey] = requestConfigHolder.MatchedResource.Endpoints.URLs[0] } } metadata, err := extractExternalProcessingMetadata(req.GetMetadataContext()) @@ -290,7 +294,7 @@ func (s *ExternalProcessingServer) Process(srv envoy_service_proc_v3.ExternalPro break } requestConfigHolder.ExternalProcessingEnvoyMetadata = metadata - + // s.log.Info(fmt.Sprintf("Matched api bjc: %v", requestConfigHolder.MatchedAPI.BackendJwtConfiguration)) // s.log.Info(fmt.Sprintf("Matched Resource: %v", requestConfigHolder.MatchedResource)) // s.log.Info(fmt.Sprintf("req holderrr: %+v\n s: %+v", &requestConfigHolder, &s)) @@ -625,6 +629,11 @@ func (s *ExternalProcessingServer) Process(srv envoy_service_proc_v3.ExternalPro } } + // HANDLE CACHE + // TODO: add cacheStore and incomingRequestCacheKeyStore in server ext_proc_server + // TODO: make sure RequestIdentifier exists + cache.HandleHTTPRequestBody(metadata.RequestIdentifier, s.cacheStore, s.incomingRequestCacheKeyStore, req, resp) + case *envoy_service_proc_v3.ProcessingRequest_ResponseHeaders: s.log.Info(fmt.Sprintf("response header %+v, ", v.ResponseHeaders)) rhq := &envoy_service_proc_v3.HeadersResponse{ @@ -877,6 +886,12 @@ func (s *ExternalProcessingServer) Process(srv envoy_service_proc_v3.ExternalPro duration := matchedResource.AIModelBasedRoundRobin.OnQuotaExceedSuspendDuration s.modelBasedRoundRobinTracker.SuspendModel(matchedAPI.UUID, matchedResource.Path, model, time.Duration(time.Duration(duration*1000*1000*1000))) } + + // HANDLE CACHE + // TODO: add cacheStore and incomingRequestCacheKeyStore in server ext_proc_server + // TODO: make sure RequestIdentifier exists + cache.HandleHTTPResponseBody(metadata.RequestIdentifier, s.cacheStore, s.incomingRequestCacheKeyStore, req, resp) + default: s.log.Info(fmt.Sprintf("Unknown Request type %v\n", v)) } @@ -951,6 +966,9 @@ func extractExternalProcessingMetadata(data *corev3.Metadata) (*dto.ExternalProc if matchedSubscriptionKey, exists := extProcMetadata.Fields[matchedSubscriptionMetadataKey]; exists { externalProcessingEnvoyMetadata.MatchedSubscriptionIdentifier = matchedSubscriptionKey.GetStringValue() } + if requestID, exists := extProcMetadata.Fields[requestIDMetadataKey]; exists { + externalProcessingEnvoyMetadata.RequestIdentifier = requestID.GetStringValue() + } } return externalProcessingEnvoyMetadata, nil @@ -999,6 +1017,11 @@ func extractExternalProcessingXDSRouteMetadataAttributes(data map[string]*struct attributes.RequestMethod = method } + if field, ok := fields["request.id"]; ok { + id := field.GetStringValue() + attributes.RequestID = id + } + // We need to navigate through the nested fields to get the actual values if field, ok := fields["xds.route_metadata"]; ok { diff --git a/gateway/enforcer/internal/util/json.go b/gateway/enforcer/internal/util/json.go index bf89a417ba..e14a92cdb8 100644 --- a/gateway/enforcer/internal/util/json.go +++ b/gateway/enforcer/internal/util/json.go @@ -1,8 +1,13 @@ package util import ( + "bytes" + "compress/gzip" "encoding/json" + "fmt" + "io" ) + // ToJSONString converts any object to a JSON string func ToJSONString(obj interface{}) (string, error) { jsonData, err := json.Marshal(obj) @@ -14,5 +19,29 @@ func ToJSONString(obj interface{}) (string, error) { // IsValidJSON checks if a string is a valid JSON func IsValidJSON(s string) bool { - return json.Valid([]byte(s)) -} \ No newline at end of file + return json.Valid([]byte(s)) +} + +// DecompressIfGzip Decompress GZIP if the response is compressed +func DecompressIfGzip(data []byte) ([]byte, error) { + if len(data) < 2 { + return data, nil // Not GZIP + } + + // GZIP magic number check + if data[0] == 0x1f && data[1] == 0x8b { + reader, err := gzip.NewReader(bytes.NewReader(data)) + if err != nil { + return nil, fmt.Errorf("failed to create GZIP reader: %w", err) + } + defer reader.Close() + + uncompressedData, err := io.ReadAll(reader) + if err != nil { + return nil, fmt.Errorf("failed to decompress GZIP: %w", err) + } + return uncompressedData, nil + } + + return data, nil // Not compressed, return as is +}