Skip to content

Commit

Permalink
feat: jsonnet caching for OIDC claims mapper, webhooks, JWT session t…
Browse files Browse the repository at this point in the history
…okenizer
  • Loading branch information
alnr committed Jan 19, 2024
1 parent d93570d commit 087dc69
Show file tree
Hide file tree
Showing 11 changed files with 74 additions and 88 deletions.
3 changes: 2 additions & 1 deletion courier/http_channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func (c *httpChannel) Dispatch(ctx context.Context, msg Message) (err error) {
ctx, span := c.d.Tracer(ctx).Tracer().Start(ctx, "courier.httpChannel.Dispatch")
defer otelx.End(span, &err)

builder, err := request.NewBuilder(ctx, c.requestConfig, c.d)
builder, err := request.NewBuilder(ctx, c.requestConfig, c.d, nil)
if err != nil {
return errors.WithStack(err)
}
Expand All @@ -82,6 +82,7 @@ func (c *httpChannel) Dispatch(ctx context.Context, msg Message) (err error) {
if err != nil {
return errors.WithStack(err)
}
req = req.WithContext(ctx)

res, err := c.d.HTTPClient(ctx).Do(req)
if err != nil {
Expand Down
9 changes: 2 additions & 7 deletions courier/template/load_template.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,24 +78,19 @@ func loadBuiltInTemplate(filesystem fs.FS, name string, html bool) (Template, er
return tpl, nil
}

func loadRemoteTemplate(ctx context.Context, d templateDependencies, url string, html bool) (Template, error) {
func loadRemoteTemplate(ctx context.Context, d templateDependencies, url string, html bool) (t Template, err error) {
var b []byte
var err error

// instead of creating a new request always we always cache the bytes.Buffer using the url as the key
if t, found := Cache.Get(url); found {
b = t.([]byte)
} else {
f := fetcher.NewFetcher(fetcher.WithClient(d.HTTPClient(ctx)))
bb, err := f.FetchContext(ctx, url)
b, err = f.FetchContext(ctx, url)
if err != nil {
return nil, errors.WithStack(err)
}
b = bb.Bytes()
_ = Cache.Add(url, b)
}

var t Template
if html {
t, err = htemplate.New(url).Funcs(sprig.HermeticHtmlFuncMap()).Parse(string(b))
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion driver/registry_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -875,7 +875,7 @@ func (m *RegistryDefault) Contextualizer() contextx.Contextualizer {
return m.ctxer
}

func (m *RegistryDefault) Fetcher() *jwksx.FetcherNext {
func (m *RegistryDefault) JWKSFetcher() *jwksx.FetcherNext {
if m.jwkFetcher == nil {
maxItems := int64(10000000)
cache, _ := ristretto.NewCache(&ristretto.Config{
Expand Down
6 changes: 3 additions & 3 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ require (
github.com/ory/jsonschema/v3 v3.0.8
github.com/ory/mail/v3 v3.0.0
github.com/ory/nosurf v1.2.7
github.com/ory/x v0.0.607
github.com/ory/x v0.0.610
github.com/peterhellberg/link v1.2.0
github.com/phayes/freeport v0.0.0-20180830031419-95f893ade6f2
github.com/pkg/errors v0.9.1
Expand Down Expand Up @@ -297,8 +297,8 @@ require (
github.com/xtgo/uuid v0.0.0-20140804021211-a0b114877d4c // indirect
go.mongodb.org/mongo-driver v1.11.3 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/httptrace/otelhttptrace v0.46.1 // indirect
go.opentelemetry.io/contrib/propagators/b3 v1.20.0 // indirect
go.opentelemetry.io/contrib/propagators/jaeger v1.20.0 // indirect
go.opentelemetry.io/contrib/propagators/b3 v1.21.0 // indirect
go.opentelemetry.io/contrib/propagators/jaeger v1.21.1 // indirect
go.opentelemetry.io/contrib/samplers/jaegerremote v0.15.1 // indirect
go.opentelemetry.io/otel/exporters/jaeger v1.17.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.21.0 // indirect; / indirect
Expand Down
12 changes: 6 additions & 6 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -828,8 +828,8 @@ github.com/ory/nosurf v1.2.7 h1:YrHrbSensQyU6r6HT/V5+HPdVEgrOTMJiLoJABSBOp4=
github.com/ory/nosurf v1.2.7/go.mod h1:d4L3ZBa7Amv55bqxCBtCs63wSlyaiCkWVl4vKf3OUxA=
github.com/ory/sessions v1.2.2-0.20220110165800-b09c17334dc2 h1:zm6sDvHy/U9XrGpixwHiuAwpp0Ock6khSVHkrv6lQQU=
github.com/ory/sessions v1.2.2-0.20220110165800-b09c17334dc2/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
github.com/ory/x v0.0.607 h1:qNP1gU6RWVtsEB04rPht+1rV2DqQhvOAN2sF+4eqVWo=
github.com/ory/x v0.0.607/go.mod h1:fCYvVVHo8wYrCwLyU8+9hFY3IRo4EZM3KI30ysDsDYY=
github.com/ory/x v0.0.610 h1:SY1X5jfJVq45zeZEIc1XWqlE+lna+rMjeLxMoRdx0bY=
github.com/ory/x v0.0.610/go.mod h1:FkkwV9h7U9VqDe6Xkdl7KKLCfDC1sQcO+q2iCBAA0Ho=
github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc=
github.com/pascaldekloe/goe v0.1.0 h1:cBOtyMzM9HTpWjXfbbunk26uA6nG3a8n06Wieeh0MwY=
github.com/pascaldekloe/goe v0.1.0/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc=
Expand Down Expand Up @@ -1061,10 +1061,10 @@ go.opentelemetry.io/contrib/instrumentation/net/http/httptrace/otelhttptrace v0.
go.opentelemetry.io/contrib/instrumentation/net/http/httptrace/otelhttptrace v0.46.1/go.mod h1:GnOaBaFQ2we3b9AGWJpsBa7v1S5RlQzlC3O7dRMxZhM=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.46.1 h1:aFJWCqJMNjENlcleuuOkGAPH82y0yULBScfXcIEdS24=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.46.1/go.mod h1:sEGXWArGqc3tVa+ekntsN65DmVbVeW+7lTKTjZF3/Fo=
go.opentelemetry.io/contrib/propagators/b3 v1.20.0 h1:Yty9Vs4F3D6/liF1o6FNt0PvN85h/BJJ6DQKJ3nrcM0=
go.opentelemetry.io/contrib/propagators/b3 v1.20.0/go.mod h1:On4VgbkqYL18kbJlWsa18+cMNe6rYpBnPi1ARI/BrsU=
go.opentelemetry.io/contrib/propagators/jaeger v1.20.0 h1:iVhNKkMIpzyZqxk8jkDU2n4DFTD+FbpGacvooxEvyyc=
go.opentelemetry.io/contrib/propagators/jaeger v1.20.0/go.mod h1:cpSABr0cm/AH/HhbJjn+AudBVUMgZWdfN3Gb+ZqxSZc=
go.opentelemetry.io/contrib/propagators/b3 v1.21.0 h1:uGdgDPNzwQWRwCXJgw/7h29JaRqcq9B87Iv4hJDKAZw=
go.opentelemetry.io/contrib/propagators/b3 v1.21.0/go.mod h1:D9GQXvVGT2pzyTfp1QBOnD1rzKEWzKjjwu5q2mslCUI=
go.opentelemetry.io/contrib/propagators/jaeger v1.21.1 h1:f4beMGDKiVzg9IcX7/VuWVy+oGdjx3dNJ72YehmtY5k=
go.opentelemetry.io/contrib/propagators/jaeger v1.21.1/go.mod h1:U9jhkEl8d1LL+QXY7q3kneJWJugiN3kZJV2OWz3hkBY=
go.opentelemetry.io/contrib/samplers/jaegerremote v0.15.1 h1:Qb+5A+JbIjXwO7l4HkRUhgIn4Bzz0GNS2q+qdmSx+0c=
go.opentelemetry.io/contrib/samplers/jaegerremote v0.15.1/go.mod h1:G4vNCm7fRk0kjZ6pGNLo5SpLxAUvOfSrcaegnT8TPck=
go.opentelemetry.io/otel v1.21.0 h1:hzLeKBZEL7Okw2mGzZ0cc4k/A7Fta0uoPgaJCr8fsFc=
Expand Down
25 changes: 13 additions & 12 deletions request/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,18 @@ import (
"net/url"
"reflect"
"strings"
"time"

"go.opentelemetry.io/otel/attribute"

"github.com/ory/x/otelx"

"github.com/dgraph-io/ristretto"
"github.com/google/go-jsonnet"
"github.com/hashicorp/go-retryablehttp"
"github.com/pkg/errors"
"go.opentelemetry.io/otel/attribute"

"github.com/ory/kratos/x"
"github.com/ory/x/fetcher"
"github.com/ory/x/jsonnetsecure"
"github.com/ory/x/otelx"
)

var ErrCancel = errors.New("request cancel by JsonNet")
Expand All @@ -44,10 +44,11 @@ type (
r *retryablehttp.Request
Config *Config
deps Dependencies
cache *ristretto.Cache
}
)

func NewBuilder(ctx context.Context, config json.RawMessage, deps Dependencies) (_ *Builder, err error) {
func NewBuilder(ctx context.Context, config json.RawMessage, deps Dependencies, jsonnetCache *ristretto.Cache) (_ *Builder, err error) {
_, span := deps.Tracer(ctx).Tracer().Start(ctx, "request.NewBuilder")
defer otelx.End(span, &err)

Expand All @@ -67,6 +68,7 @@ func NewBuilder(ctx context.Context, config json.RawMessage, deps Dependencies)
r: r,
Config: c,
deps: deps,
cache: jsonnetCache,
}, nil
}

Expand Down Expand Up @@ -118,7 +120,7 @@ func (b *Builder) addBody(ctx context.Context, body interface{}) error {
return nil
}

func (b *Builder) addJSONBody(ctx context.Context, template *bytes.Buffer, body interface{}) error {
func (b *Builder) addJSONBody(ctx context.Context, jsonnetSnippet []byte, body interface{}) error {
buf := new(bytes.Buffer)
enc := json.NewEncoder(buf)
enc.SetEscapeHTML(false)
Expand All @@ -136,7 +138,7 @@ func (b *Builder) addJSONBody(ctx context.Context, template *bytes.Buffer, body

res, err := vm.EvaluateAnonymousSnippet(
b.Config.TemplateURI,
template.String(),
string(jsonnetSnippet),
)
if err != nil {
// Unfortunately we can not use errors.As / errors.Is, see:
Expand All @@ -156,7 +158,7 @@ func (b *Builder) addJSONBody(ctx context.Context, template *bytes.Buffer, body
return nil
}

func (b *Builder) addURLEncodedBody(ctx context.Context, template *bytes.Buffer, body interface{}) error {
func (b *Builder) addURLEncodedBody(ctx context.Context, jsonnetSnippet []byte, body interface{}) error {
buf := new(bytes.Buffer)
enc := json.NewEncoder(buf)
enc.SetEscapeHTML(false)
Expand All @@ -172,7 +174,7 @@ func (b *Builder) addURLEncodedBody(ctx context.Context, template *bytes.Buffer,
}
vm.TLACode("ctx", buf.String())

res, err := vm.EvaluateAnonymousSnippet(b.Config.TemplateURI, template.String())
res, err := vm.EvaluateAnonymousSnippet(b.Config.TemplateURI, string(jsonnetSnippet))
if err != nil {
return errors.WithStack(err)
}
Expand Down Expand Up @@ -213,15 +215,14 @@ func (b *Builder) BuildRequest(ctx context.Context, body interface{}) (*retryabl
return b.r, nil
}

func (b *Builder) readTemplate(ctx context.Context) (*bytes.Buffer, error) {
func (b *Builder) readTemplate(ctx context.Context) ([]byte, error) {
templateURI := b.Config.TemplateURI

if templateURI == "" {
return nil, nil
}

f := fetcher.NewFetcher(fetcher.WithClient(b.deps.HTTPClient(ctx)))

f := fetcher.NewFetcher(fetcher.WithClient(b.deps.HTTPClient(ctx)), fetcher.WithCache(b.cache, 60*time.Minute))
tpl, err := f.FetchContext(ctx, templateURI)
if errors.Is(err, fetcher.ErrUnknownScheme) {
// legacy filepath
Expand Down
4 changes: 2 additions & 2 deletions request/builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ func TestBuildRequest(t *testing.T) {
} {
t.Run(
"request-type="+tc.name, func(t *testing.T) {
rb, err := NewBuilder(context.Background(), json.RawMessage(tc.rawConfig), newTestDependencyProvider(t))
rb, err := NewBuilder(context.Background(), json.RawMessage(tc.rawConfig), newTestDependencyProvider(t), nil)
require.NoError(t, err)

assert.Equal(t, tc.bodyTemplateURI, rb.Config.TemplateURI)
Expand Down Expand Up @@ -279,7 +279,7 @@ func TestBuildRequest(t *testing.T) {
"method": "POST",
"body": "file://./stub/cancel_body.jsonnet"
}`,
), newTestDependencyProvider(t))
), newTestDependencyProvider(t), nil)
require.NoError(t, err)

_, err = rb.BuildRequest(context.Background(), json.RawMessage(`{}`))
Expand Down
9 changes: 8 additions & 1 deletion selfservice/hook/web_hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"net/http/httputil"
"time"

"github.com/dgraph-io/ristretto"
"github.com/gofrs/uuid"
"github.com/hashicorp/go-retryablehttp"
"github.com/pkg/errors"
Expand Down Expand Up @@ -60,6 +61,12 @@ var _ interface {
settings.PostHookPostPersistExecutor
} = (*WebHook)(nil)

var jsonnetCache, _ = ristretto.NewCache(&ristretto.Config{
MaxCost: 100 << 20, // 100MB,
NumCounters: 1_000_000, // 1kB per snippet -> 100k snippets -> 1M counters
BufferItems: 64,
})

type (
webHookDependencies interface {
x.LoggingProvider
Expand Down Expand Up @@ -334,7 +341,7 @@ func (e *WebHook) execute(ctx context.Context, data *templateContext) error {
}
}(time.Now())

builder, err := request.NewBuilder(ctx, e.conf, e.deps)
builder, err := request.NewBuilder(ctx, e.conf, e.deps, jsonnetCache)
if err != nil {
return err
}
Expand Down
43 changes: 20 additions & 23 deletions selfservice/strategy/oidc/strategy_registration.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,39 +10,36 @@ import (
"strings"
"time"

"github.com/dgraph-io/ristretto"
"github.com/gofrs/uuid"
"github.com/julienschmidt/httprouter"

"github.com/ory/x/otelx"
"github.com/ory/x/sqlxx"

"github.com/ory/herodot"

"github.com/ory/x/fetcher"

"github.com/pkg/errors"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"

"github.com/ory/x/decoderx"

"golang.org/x/oauth2"

"github.com/ory/kratos/selfservice/flow/login"

"github.com/ory/kratos/text"

"github.com/pkg/errors"

"github.com/ory/herodot"
"github.com/ory/kratos/continuity"

"github.com/ory/kratos/identity"
"github.com/ory/kratos/selfservice/flow"
"github.com/ory/kratos/selfservice/flow/login"
"github.com/ory/kratos/selfservice/flow/registration"
"github.com/ory/kratos/text"
"github.com/ory/kratos/x"
"github.com/ory/x/decoderx"
"github.com/ory/x/fetcher"
"github.com/ory/x/otelx"
"github.com/ory/x/sqlxx"
)

var _ registration.Strategy = new(Strategy)

var jsonnetCache, _ = ristretto.NewCache(&ristretto.Config{
MaxCost: 100 << 20, // 100MB,
NumCounters: 1_000_000, // 1kB per snippet -> 100k snippets -> 1M counters
BufferItems: 64,
})

type MetadataType string

type VerifiedAddress struct {
Expand Down Expand Up @@ -308,13 +305,13 @@ func (s *Strategy) processRegistration(w http.ResponseWriter, r *http.Request, r
return nil, nil
}

fetch := fetcher.NewFetcher(fetcher.WithClient(s.d.HTTPClient(r.Context())))
jn, err := fetch.FetchContext(r.Context(), provider.Config().Mapper)
fetch := fetcher.NewFetcher(fetcher.WithClient(s.d.HTTPClient(r.Context())), fetcher.WithCache(jsonnetCache, 60*time.Minute))
jsonnetMapperSnippet, err := fetch.FetchContext(r.Context(), provider.Config().Mapper)
if err != nil {
return nil, s.handleError(w, r, rf, provider.Config().ID, nil, err)
}

i, va, err := s.createIdentity(w, r, rf, claims, provider, container, jn)
i, va, err := s.createIdentity(w, r, rf, claims, provider, container, jsonnetMapperSnippet)
if err != nil {
return nil, s.handleError(w, r, rf, provider.Config().ID, nil, err)
}
Expand Down Expand Up @@ -369,7 +366,7 @@ func (s *Strategy) processRegistration(w http.ResponseWriter, r *http.Request, r
return nil, nil
}

func (s *Strategy) createIdentity(w http.ResponseWriter, r *http.Request, a *registration.Flow, claims *Claims, provider Provider, container *AuthCodeContainer, jn *bytes.Buffer) (*identity.Identity, []VerifiedAddress, error) {
func (s *Strategy) createIdentity(w http.ResponseWriter, r *http.Request, a *registration.Flow, claims *Claims, provider Provider, container *AuthCodeContainer, jsonnetSnippet []byte) (*identity.Identity, []VerifiedAddress, error) {
var jsonClaims bytes.Buffer
if err := json.NewEncoder(&jsonClaims).Encode(claims); err != nil {
return nil, nil, s.handleError(w, r, a, provider.Config().ID, nil, err)
Expand All @@ -381,7 +378,7 @@ func (s *Strategy) createIdentity(w http.ResponseWriter, r *http.Request, a *reg
}

vm.ExtCode("claims", jsonClaims.String())
evaluated, err := vm.EvaluateAnonymousSnippet(provider.Config().Mapper, jn.String())
evaluated, err := vm.EvaluateAnonymousSnippet(provider.Config().Mapper, string(jsonnetSnippet))
if err != nil {
return nil, nil, s.handleError(w, r, a, provider.Config().ID, nil, err)
}
Expand Down
Loading

0 comments on commit 087dc69

Please sign in to comment.