From 087dc6924f1b2120857669ce5003918290c6cbe8 Mon Sep 17 00:00:00 2001 From: Arne Luenser Date: Thu, 18 Jan 2024 14:33:03 +0100 Subject: [PATCH] feat: jsonnet caching for OIDC claims mapper, webhooks, JWT session tokenizer --- courier/http_channel.go | 3 +- courier/template/load_template.go | 9 +--- driver/registry_default.go | 2 +- go.mod | 6 +-- go.sum | 12 ++--- request/builder.go | 25 ++++++----- request/builder_test.go | 4 +- selfservice/hook/web_hook.go | 9 +++- .../strategy/oidc/strategy_registration.go | 43 +++++++++--------- session/tokenizer.go | 45 +++++++------------ x/fetcher.go | 4 +- 11 files changed, 74 insertions(+), 88 deletions(-) diff --git a/courier/http_channel.go b/courier/http_channel.go index 41315f33e067..e9015dca13fd 100644 --- a/courier/http_channel.go +++ b/courier/http_channel.go @@ -59,7 +59,7 @@ func (c *httpChannel) Dispatch(ctx context.Context, msg Message) (err error) { ctx, span := c.d.Tracer(ctx).Tracer().Start(ctx, "courier.httpChannel.Dispatch") defer otelx.End(span, &err) - builder, err := request.NewBuilder(ctx, c.requestConfig, c.d) + builder, err := request.NewBuilder(ctx, c.requestConfig, c.d, nil) if err != nil { return errors.WithStack(err) } @@ -82,6 +82,7 @@ func (c *httpChannel) Dispatch(ctx context.Context, msg Message) (err error) { if err != nil { return errors.WithStack(err) } + req = req.WithContext(ctx) res, err := c.d.HTTPClient(ctx).Do(req) if err != nil { diff --git a/courier/template/load_template.go b/courier/template/load_template.go index 58b46f898b9d..47341b8f35af 100644 --- a/courier/template/load_template.go +++ b/courier/template/load_template.go @@ -78,24 +78,19 @@ func loadBuiltInTemplate(filesystem fs.FS, name string, html bool) (Template, er return tpl, nil } -func loadRemoteTemplate(ctx context.Context, d templateDependencies, url string, html bool) (Template, error) { +func loadRemoteTemplate(ctx context.Context, d templateDependencies, url string, html bool) (t Template, err error) { var b []byte - var err error - - // instead of creating a new request always we always cache the bytes.Buffer using the url as the key if t, found := Cache.Get(url); found { b = t.([]byte) } else { f := fetcher.NewFetcher(fetcher.WithClient(d.HTTPClient(ctx))) - bb, err := f.FetchContext(ctx, url) + b, err = f.FetchContext(ctx, url) if err != nil { return nil, errors.WithStack(err) } - b = bb.Bytes() _ = Cache.Add(url, b) } - var t Template if html { t, err = htemplate.New(url).Funcs(sprig.HermeticHtmlFuncMap()).Parse(string(b)) if err != nil { diff --git a/driver/registry_default.go b/driver/registry_default.go index f9c47bb30739..a267d9cb637d 100644 --- a/driver/registry_default.go +++ b/driver/registry_default.go @@ -875,7 +875,7 @@ func (m *RegistryDefault) Contextualizer() contextx.Contextualizer { return m.ctxer } -func (m *RegistryDefault) Fetcher() *jwksx.FetcherNext { +func (m *RegistryDefault) JWKSFetcher() *jwksx.FetcherNext { if m.jwkFetcher == nil { maxItems := int64(10000000) cache, _ := ristretto.NewCache(&ristretto.Config{ diff --git a/go.mod b/go.mod index 6e64290855d1..42b2bad25942 100644 --- a/go.mod +++ b/go.mod @@ -75,7 +75,7 @@ require ( github.com/ory/jsonschema/v3 v3.0.8 github.com/ory/mail/v3 v3.0.0 github.com/ory/nosurf v1.2.7 - github.com/ory/x v0.0.607 + github.com/ory/x v0.0.610 github.com/peterhellberg/link v1.2.0 github.com/phayes/freeport v0.0.0-20180830031419-95f893ade6f2 github.com/pkg/errors v0.9.1 @@ -297,8 +297,8 @@ require ( github.com/xtgo/uuid v0.0.0-20140804021211-a0b114877d4c // indirect go.mongodb.org/mongo-driver v1.11.3 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/httptrace/otelhttptrace v0.46.1 // indirect - go.opentelemetry.io/contrib/propagators/b3 v1.20.0 // indirect - go.opentelemetry.io/contrib/propagators/jaeger v1.20.0 // indirect + go.opentelemetry.io/contrib/propagators/b3 v1.21.0 // indirect + go.opentelemetry.io/contrib/propagators/jaeger v1.21.1 // indirect go.opentelemetry.io/contrib/samplers/jaegerremote v0.15.1 // indirect go.opentelemetry.io/otel/exporters/jaeger v1.17.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.21.0 // indirect; / indirect diff --git a/go.sum b/go.sum index bc6051c55224..b3b463884b74 100644 --- a/go.sum +++ b/go.sum @@ -828,8 +828,8 @@ github.com/ory/nosurf v1.2.7 h1:YrHrbSensQyU6r6HT/V5+HPdVEgrOTMJiLoJABSBOp4= github.com/ory/nosurf v1.2.7/go.mod h1:d4L3ZBa7Amv55bqxCBtCs63wSlyaiCkWVl4vKf3OUxA= github.com/ory/sessions v1.2.2-0.20220110165800-b09c17334dc2 h1:zm6sDvHy/U9XrGpixwHiuAwpp0Ock6khSVHkrv6lQQU= github.com/ory/sessions v1.2.2-0.20220110165800-b09c17334dc2/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= -github.com/ory/x v0.0.607 h1:qNP1gU6RWVtsEB04rPht+1rV2DqQhvOAN2sF+4eqVWo= -github.com/ory/x v0.0.607/go.mod h1:fCYvVVHo8wYrCwLyU8+9hFY3IRo4EZM3KI30ysDsDYY= +github.com/ory/x v0.0.610 h1:SY1X5jfJVq45zeZEIc1XWqlE+lna+rMjeLxMoRdx0bY= +github.com/ory/x v0.0.610/go.mod h1:FkkwV9h7U9VqDe6Xkdl7KKLCfDC1sQcO+q2iCBAA0Ho= github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= github.com/pascaldekloe/goe v0.1.0 h1:cBOtyMzM9HTpWjXfbbunk26uA6nG3a8n06Wieeh0MwY= github.com/pascaldekloe/goe v0.1.0/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= @@ -1061,10 +1061,10 @@ go.opentelemetry.io/contrib/instrumentation/net/http/httptrace/otelhttptrace v0. go.opentelemetry.io/contrib/instrumentation/net/http/httptrace/otelhttptrace v0.46.1/go.mod h1:GnOaBaFQ2we3b9AGWJpsBa7v1S5RlQzlC3O7dRMxZhM= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.46.1 h1:aFJWCqJMNjENlcleuuOkGAPH82y0yULBScfXcIEdS24= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.46.1/go.mod h1:sEGXWArGqc3tVa+ekntsN65DmVbVeW+7lTKTjZF3/Fo= -go.opentelemetry.io/contrib/propagators/b3 v1.20.0 h1:Yty9Vs4F3D6/liF1o6FNt0PvN85h/BJJ6DQKJ3nrcM0= -go.opentelemetry.io/contrib/propagators/b3 v1.20.0/go.mod h1:On4VgbkqYL18kbJlWsa18+cMNe6rYpBnPi1ARI/BrsU= -go.opentelemetry.io/contrib/propagators/jaeger v1.20.0 h1:iVhNKkMIpzyZqxk8jkDU2n4DFTD+FbpGacvooxEvyyc= -go.opentelemetry.io/contrib/propagators/jaeger v1.20.0/go.mod h1:cpSABr0cm/AH/HhbJjn+AudBVUMgZWdfN3Gb+ZqxSZc= +go.opentelemetry.io/contrib/propagators/b3 v1.21.0 h1:uGdgDPNzwQWRwCXJgw/7h29JaRqcq9B87Iv4hJDKAZw= +go.opentelemetry.io/contrib/propagators/b3 v1.21.0/go.mod h1:D9GQXvVGT2pzyTfp1QBOnD1rzKEWzKjjwu5q2mslCUI= +go.opentelemetry.io/contrib/propagators/jaeger v1.21.1 h1:f4beMGDKiVzg9IcX7/VuWVy+oGdjx3dNJ72YehmtY5k= +go.opentelemetry.io/contrib/propagators/jaeger v1.21.1/go.mod h1:U9jhkEl8d1LL+QXY7q3kneJWJugiN3kZJV2OWz3hkBY= go.opentelemetry.io/contrib/samplers/jaegerremote v0.15.1 h1:Qb+5A+JbIjXwO7l4HkRUhgIn4Bzz0GNS2q+qdmSx+0c= go.opentelemetry.io/contrib/samplers/jaegerremote v0.15.1/go.mod h1:G4vNCm7fRk0kjZ6pGNLo5SpLxAUvOfSrcaegnT8TPck= go.opentelemetry.io/otel v1.21.0 h1:hzLeKBZEL7Okw2mGzZ0cc4k/A7Fta0uoPgaJCr8fsFc= diff --git a/request/builder.go b/request/builder.go index f930bc089aae..f407c8c5c763 100644 --- a/request/builder.go +++ b/request/builder.go @@ -12,18 +12,18 @@ import ( "net/url" "reflect" "strings" + "time" - "go.opentelemetry.io/otel/attribute" - - "github.com/ory/x/otelx" - + "github.com/dgraph-io/ristretto" "github.com/google/go-jsonnet" "github.com/hashicorp/go-retryablehttp" "github.com/pkg/errors" + "go.opentelemetry.io/otel/attribute" "github.com/ory/kratos/x" "github.com/ory/x/fetcher" "github.com/ory/x/jsonnetsecure" + "github.com/ory/x/otelx" ) var ErrCancel = errors.New("request cancel by JsonNet") @@ -44,10 +44,11 @@ type ( r *retryablehttp.Request Config *Config deps Dependencies + cache *ristretto.Cache } ) -func NewBuilder(ctx context.Context, config json.RawMessage, deps Dependencies) (_ *Builder, err error) { +func NewBuilder(ctx context.Context, config json.RawMessage, deps Dependencies, jsonnetCache *ristretto.Cache) (_ *Builder, err error) { _, span := deps.Tracer(ctx).Tracer().Start(ctx, "request.NewBuilder") defer otelx.End(span, &err) @@ -67,6 +68,7 @@ func NewBuilder(ctx context.Context, config json.RawMessage, deps Dependencies) r: r, Config: c, deps: deps, + cache: jsonnetCache, }, nil } @@ -118,7 +120,7 @@ func (b *Builder) addBody(ctx context.Context, body interface{}) error { return nil } -func (b *Builder) addJSONBody(ctx context.Context, template *bytes.Buffer, body interface{}) error { +func (b *Builder) addJSONBody(ctx context.Context, jsonnetSnippet []byte, body interface{}) error { buf := new(bytes.Buffer) enc := json.NewEncoder(buf) enc.SetEscapeHTML(false) @@ -136,7 +138,7 @@ func (b *Builder) addJSONBody(ctx context.Context, template *bytes.Buffer, body res, err := vm.EvaluateAnonymousSnippet( b.Config.TemplateURI, - template.String(), + string(jsonnetSnippet), ) if err != nil { // Unfortunately we can not use errors.As / errors.Is, see: @@ -156,7 +158,7 @@ func (b *Builder) addJSONBody(ctx context.Context, template *bytes.Buffer, body return nil } -func (b *Builder) addURLEncodedBody(ctx context.Context, template *bytes.Buffer, body interface{}) error { +func (b *Builder) addURLEncodedBody(ctx context.Context, jsonnetSnippet []byte, body interface{}) error { buf := new(bytes.Buffer) enc := json.NewEncoder(buf) enc.SetEscapeHTML(false) @@ -172,7 +174,7 @@ func (b *Builder) addURLEncodedBody(ctx context.Context, template *bytes.Buffer, } vm.TLACode("ctx", buf.String()) - res, err := vm.EvaluateAnonymousSnippet(b.Config.TemplateURI, template.String()) + res, err := vm.EvaluateAnonymousSnippet(b.Config.TemplateURI, string(jsonnetSnippet)) if err != nil { return errors.WithStack(err) } @@ -213,15 +215,14 @@ func (b *Builder) BuildRequest(ctx context.Context, body interface{}) (*retryabl return b.r, nil } -func (b *Builder) readTemplate(ctx context.Context) (*bytes.Buffer, error) { +func (b *Builder) readTemplate(ctx context.Context) ([]byte, error) { templateURI := b.Config.TemplateURI if templateURI == "" { return nil, nil } - f := fetcher.NewFetcher(fetcher.WithClient(b.deps.HTTPClient(ctx))) - + f := fetcher.NewFetcher(fetcher.WithClient(b.deps.HTTPClient(ctx)), fetcher.WithCache(b.cache, 60*time.Minute)) tpl, err := f.FetchContext(ctx, templateURI) if errors.Is(err, fetcher.ErrUnknownScheme) { // legacy filepath diff --git a/request/builder_test.go b/request/builder_test.go index cffa032cc2f9..8269b5133346 100644 --- a/request/builder_test.go +++ b/request/builder_test.go @@ -245,7 +245,7 @@ func TestBuildRequest(t *testing.T) { } { t.Run( "request-type="+tc.name, func(t *testing.T) { - rb, err := NewBuilder(context.Background(), json.RawMessage(tc.rawConfig), newTestDependencyProvider(t)) + rb, err := NewBuilder(context.Background(), json.RawMessage(tc.rawConfig), newTestDependencyProvider(t), nil) require.NoError(t, err) assert.Equal(t, tc.bodyTemplateURI, rb.Config.TemplateURI) @@ -279,7 +279,7 @@ func TestBuildRequest(t *testing.T) { "method": "POST", "body": "file://./stub/cancel_body.jsonnet" }`, - ), newTestDependencyProvider(t)) + ), newTestDependencyProvider(t), nil) require.NoError(t, err) _, err = rb.BuildRequest(context.Background(), json.RawMessage(`{}`)) diff --git a/selfservice/hook/web_hook.go b/selfservice/hook/web_hook.go index c06d47a216ba..5c9aa31e92fa 100644 --- a/selfservice/hook/web_hook.go +++ b/selfservice/hook/web_hook.go @@ -12,6 +12,7 @@ import ( "net/http/httputil" "time" + "github.com/dgraph-io/ristretto" "github.com/gofrs/uuid" "github.com/hashicorp/go-retryablehttp" "github.com/pkg/errors" @@ -60,6 +61,12 @@ var _ interface { settings.PostHookPostPersistExecutor } = (*WebHook)(nil) +var jsonnetCache, _ = ristretto.NewCache(&ristretto.Config{ + MaxCost: 100 << 20, // 100MB, + NumCounters: 1_000_000, // 1kB per snippet -> 100k snippets -> 1M counters + BufferItems: 64, +}) + type ( webHookDependencies interface { x.LoggingProvider @@ -334,7 +341,7 @@ func (e *WebHook) execute(ctx context.Context, data *templateContext) error { } }(time.Now()) - builder, err := request.NewBuilder(ctx, e.conf, e.deps) + builder, err := request.NewBuilder(ctx, e.conf, e.deps, jsonnetCache) if err != nil { return err } diff --git a/selfservice/strategy/oidc/strategy_registration.go b/selfservice/strategy/oidc/strategy_registration.go index d3f3b217f760..27b4a74eb9f5 100644 --- a/selfservice/strategy/oidc/strategy_registration.go +++ b/selfservice/strategy/oidc/strategy_registration.go @@ -10,39 +10,36 @@ import ( "strings" "time" + "github.com/dgraph-io/ristretto" "github.com/gofrs/uuid" "github.com/julienschmidt/httprouter" - - "github.com/ory/x/otelx" - "github.com/ory/x/sqlxx" - - "github.com/ory/herodot" - - "github.com/ory/x/fetcher" - + "github.com/pkg/errors" "github.com/tidwall/gjson" "github.com/tidwall/sjson" - - "github.com/ory/x/decoderx" - "golang.org/x/oauth2" - "github.com/ory/kratos/selfservice/flow/login" - - "github.com/ory/kratos/text" - - "github.com/pkg/errors" - + "github.com/ory/herodot" "github.com/ory/kratos/continuity" - "github.com/ory/kratos/identity" "github.com/ory/kratos/selfservice/flow" + "github.com/ory/kratos/selfservice/flow/login" "github.com/ory/kratos/selfservice/flow/registration" + "github.com/ory/kratos/text" "github.com/ory/kratos/x" + "github.com/ory/x/decoderx" + "github.com/ory/x/fetcher" + "github.com/ory/x/otelx" + "github.com/ory/x/sqlxx" ) var _ registration.Strategy = new(Strategy) +var jsonnetCache, _ = ristretto.NewCache(&ristretto.Config{ + MaxCost: 100 << 20, // 100MB, + NumCounters: 1_000_000, // 1kB per snippet -> 100k snippets -> 1M counters + BufferItems: 64, +}) + type MetadataType string type VerifiedAddress struct { @@ -308,13 +305,13 @@ func (s *Strategy) processRegistration(w http.ResponseWriter, r *http.Request, r return nil, nil } - fetch := fetcher.NewFetcher(fetcher.WithClient(s.d.HTTPClient(r.Context()))) - jn, err := fetch.FetchContext(r.Context(), provider.Config().Mapper) + fetch := fetcher.NewFetcher(fetcher.WithClient(s.d.HTTPClient(r.Context())), fetcher.WithCache(jsonnetCache, 60*time.Minute)) + jsonnetMapperSnippet, err := fetch.FetchContext(r.Context(), provider.Config().Mapper) if err != nil { return nil, s.handleError(w, r, rf, provider.Config().ID, nil, err) } - i, va, err := s.createIdentity(w, r, rf, claims, provider, container, jn) + i, va, err := s.createIdentity(w, r, rf, claims, provider, container, jsonnetMapperSnippet) if err != nil { return nil, s.handleError(w, r, rf, provider.Config().ID, nil, err) } @@ -369,7 +366,7 @@ func (s *Strategy) processRegistration(w http.ResponseWriter, r *http.Request, r return nil, nil } -func (s *Strategy) createIdentity(w http.ResponseWriter, r *http.Request, a *registration.Flow, claims *Claims, provider Provider, container *AuthCodeContainer, jn *bytes.Buffer) (*identity.Identity, []VerifiedAddress, error) { +func (s *Strategy) createIdentity(w http.ResponseWriter, r *http.Request, a *registration.Flow, claims *Claims, provider Provider, container *AuthCodeContainer, jsonnetSnippet []byte) (*identity.Identity, []VerifiedAddress, error) { var jsonClaims bytes.Buffer if err := json.NewEncoder(&jsonClaims).Encode(claims); err != nil { return nil, nil, s.handleError(w, r, a, provider.Config().ID, nil, err) @@ -381,7 +378,7 @@ func (s *Strategy) createIdentity(w http.ResponseWriter, r *http.Request, a *reg } vm.ExtCode("claims", jsonClaims.String()) - evaluated, err := vm.EvaluateAnonymousSnippet(provider.Config().Mapper, jn.String()) + evaluated, err := vm.EvaluateAnonymousSnippet(provider.Config().Mapper, string(jsonnetSnippet)) if err != nil { return nil, nil, s.handleError(w, r, a, provider.Config().ID, nil, err) } diff --git a/session/tokenizer.go b/session/tokenizer.go index 2d1decb352e5..f3c47ef22e0e 100644 --- a/session/tokenizer.go +++ b/session/tokenizer.go @@ -5,7 +5,6 @@ package session import ( "context" - "crypto/sha256" "encoding/json" "time" @@ -32,11 +31,12 @@ type ( x.TracingProvider x.HTTPClientProvider config.Provider - x.JWKFetchProvider + x.JWKSFetchProvider } Tokenizer struct { r tokenizerDependencies nowFunc func() time.Time + cache *ristretto.Cache } TokenizerProvider interface { SessionTokenizer() *Tokenizer @@ -44,23 +44,18 @@ type ( ) func NewTokenizer(r tokenizerDependencies) *Tokenizer { - return &Tokenizer{r: r, nowFunc: time.Now} + cache, _ := ristretto.NewCache(&ristretto.Config{ + MaxCost: 50 << 20, // 50MB, + NumCounters: 500_000, // 1kB per snippet -> 50k snippets -> 500k counters + BufferItems: 64, + }) + return &Tokenizer{r: r, nowFunc: time.Now, cache: cache} } func (s *Tokenizer) SetNowFunc(t func() time.Time) { s.nowFunc = t } -var cache, _ = ristretto.NewCache(&ristretto.Config{ - NumCounters: 100000000, - MaxCost: 10000000, - BufferItems: 64, - IgnoreInternalCost: true, - Cost: func(value interface{}) int64 { - return 1 - }, -}) - func (s *Tokenizer) TokenizeSession(ctx context.Context, template string, session *Session) (err error) { ctx, span := s.r.Tracer(ctx).Tracer().Start(ctx, "sessions.ManagerHTTP.TokenizeSession") defer otelx.End(span, &err) @@ -71,7 +66,7 @@ func (s *Tokenizer) TokenizeSession(ctx context.Context, template string, sessio } httpClient := s.r.HTTPClient(ctx) - key, err := s.r.Fetcher().ResolveKey( + key, err := s.r.JWKSFetcher().ResolveKey( ctx, tpl.JWKSURL, jwksx.WithCacheEnabled(), @@ -94,8 +89,6 @@ func (s *Tokenizer) TokenizeSession(ctx context.Context, template string, sessio return err } - fetch := fetcher.NewFetcher(fetcher.WithClient(httpClient)) - now := s.nowFunc() token := jwt.New(alg) token.Header["kid"] = key.KeyID() @@ -110,19 +103,6 @@ func (s *Tokenizer) TokenizeSession(ctx context.Context, template string, sessio } if mapper := tpl.ClaimsMapperURL; len(mapper) > 0 { - var jsonnet string - cacheKey := sha256.Sum256([]byte(mapper)) - if result, found := cache.Get(cacheKey[:]); found { - jsonnet = result.(string) - } else { - jn, err := fetch.FetchContext(ctx, mapper) - if err != nil { - return err - } - jsonnet = jn.String() - cache.SetWithTTL(cacheKey[:], jsonnet, 1, time.Hour) - } - sessionRaw, err := json.Marshal(session) if err != nil { return errors.WithStack(herodot.ErrInternalServerError.WithWrap(err).WithReasonf("Unable to encode session to JSON.")) @@ -136,7 +116,12 @@ func (s *Tokenizer) TokenizeSession(ctx context.Context, template string, sessio vm.ExtCode("session", string(sessionRaw)) vm.ExtCode("claims", string(claimsRaw)) - evaluated, err := vm.EvaluateAnonymousSnippet(tpl.ClaimsMapperURL, jsonnet) + fetcher := fetcher.NewFetcher(fetcher.WithClient(httpClient), fetcher.WithCache(s.cache, 60*time.Minute)) + jsonnet, err := fetcher.FetchContext(ctx, mapper) + if err != nil { + return err + } + evaluated, err := vm.EvaluateAnonymousSnippet(tpl.ClaimsMapperURL, string(jsonnet)) if err != nil { return errors.WithStack(herodot.ErrBadRequest.WithWrap(err).WithDebug(err.Error()).WithReasonf("Unable to execute tokenizer JsonNet.")) } diff --git a/x/fetcher.go b/x/fetcher.go index 77f3dbf31099..bb6aeaeaa3fb 100644 --- a/x/fetcher.go +++ b/x/fetcher.go @@ -5,6 +5,6 @@ package x import "github.com/ory/x/jwksx" -type JWKFetchProvider interface { - Fetcher() *jwksx.FetcherNext +type JWKSFetchProvider interface { + JWKSFetcher() *jwksx.FetcherNext }