Skip to content

Commit

Permalink
test: enable server-side config from context (#3954)
Browse files Browse the repository at this point in the history
  • Loading branch information
zepatrik authored Jun 18, 2024
1 parent bac030b commit e0001b0
Show file tree
Hide file tree
Showing 12 changed files with 234 additions and 108 deletions.
8 changes: 5 additions & 3 deletions cipher/cipher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"fmt"
"testing"

confighelpers "github.com/ory/kratos/driver/config/testhelpers"

"github.com/ory/x/configx"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -44,7 +46,7 @@ func TestCipher(t *testing.T) {
t.Run("case=encryption_failed", func(t *testing.T) {
t.Parallel()

ctx := config.WithConfigValue(ctx, config.ViperKeySecretsCipher, []string{""})
ctx := confighelpers.WithConfigValue(ctx, config.ViperKeySecretsCipher, []string{""})

// secret have to be set
_, err := c.Encrypt(ctx, []byte("not-empty"))
Expand All @@ -53,7 +55,7 @@ func TestCipher(t *testing.T) {
require.ErrorAs(t, err, &hErr)
assert.Equal(t, "Unable to encrypt message because no cipher secrets were configured.", hErr.Reason())

ctx = config.WithConfigValue(ctx, config.ViperKeySecretsCipher, []string{"bad-length"})
ctx = confighelpers.WithConfigValue(ctx, config.ViperKeySecretsCipher, []string{"bad-length"})

// bad secret length
_, err = c.Encrypt(ctx, []byte("not-empty"))
Expand All @@ -70,7 +72,7 @@ func TestCipher(t *testing.T) {
_, err = c.Decrypt(ctx, "not-empty")
require.Error(t, err)

_, err = c.Decrypt(config.WithConfigValue(ctx, config.ViperKeySecretsCipher, []string{""}), "not-empty")
_, err = c.Decrypt(confighelpers.WithConfigValue(ctx, config.ViperKeySecretsCipher, []string{""}), "not-empty")
require.Error(t, err)
})
})
Expand Down
40 changes: 30 additions & 10 deletions driver/config/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ package config_test
import (
"context"
"io"
"net/http/httptest"
"testing"

confighelpers "github.com/ory/kratos/driver/config/testhelpers"

"github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand All @@ -17,35 +18,54 @@ import (
"github.com/ory/kratos/internal"
)

type configProvider struct {
cfg *config.Config
}

func (c *configProvider) Config() *config.Config {
return c.cfg
}

func TestNewConfigHashHandler(t *testing.T) {
ctx := context.Background()
conf, reg := internal.NewFastRegistryWithMocks(t)
cfg := internal.NewConfigurationWithDefaults(t)
router := httprouter.New()
config.NewConfigHashHandler(reg, router)
ts := httptest.NewServer(router)
config.NewConfigHashHandler(&configProvider{cfg: cfg}, router)
ts := confighelpers.NewConfigurableTestServer(router)
t.Cleanup(ts.Close)
res, err := ts.Client().Get(ts.URL + "/health/config")

// first request, get baseline hash
res, err := ts.Client(ctx).Get(ts.URL + "/health/config")
require.NoError(t, err)
defer res.Body.Close()
require.Equal(t, 200, res.StatusCode)
first, err := io.ReadAll(res.Body)
require.NoError(t, err)

res, err = ts.Client().Get(ts.URL + "/health/config")
// second request, no config change
res, err = ts.Client(ctx).Get(ts.URL + "/health/config")
require.NoError(t, err)
defer res.Body.Close()
require.Equal(t, 200, res.StatusCode)
second, err := io.ReadAll(res.Body)
require.NoError(t, err)
assert.Equal(t, first, second)

require.NoError(t, conf.Set(ctx, config.ViperKeySessionDomain, "foobar"))
// third request, with config change
res, err = ts.Client(confighelpers.WithConfigValue(ctx, config.ViperKeySessionDomain, "foobar")).Get(ts.URL + "/health/config")
require.NoError(t, err)
defer res.Body.Close()
require.Equal(t, 200, res.StatusCode)
third, err := io.ReadAll(res.Body)
require.NoError(t, err)
assert.NotEqual(t, first, third)

res, err = ts.Client().Get(ts.URL + "/health/config")
// fourth request, no config change
res, err = ts.Client(ctx).Get(ts.URL + "/health/config")
require.NoError(t, err)
defer res.Body.Close()
require.Equal(t, 200, res.StatusCode)
second, err = io.ReadAll(res.Body)
fourth, err := io.ReadAll(res.Body)
require.NoError(t, err)
assert.NotEqual(t, first, second)
assert.Equal(t, first, fourth)
}
64 changes: 0 additions & 64 deletions driver/config/test_config.go

This file was deleted.

152 changes: 152 additions & 0 deletions driver/config/testhelpers/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
// Copyright © 2024 Ory Corp
// SPDX-License-Identifier: Apache-2.0

package testhelpers

import (
"context"
"net/http"
"net/http/httptest"

"github.com/gofrs/uuid"

"github.com/ory/kratos/embedx"
"github.com/ory/x/configx"
"github.com/ory/x/contextx"
)

type (
TestConfigProvider struct {
contextx.Contextualizer
Options []configx.OptionModifier
}
contextKey int
)

func (t *TestConfigProvider) NewProvider(ctx context.Context, opts ...configx.OptionModifier) (*configx.Provider, error) {
return configx.New(ctx, []byte(embedx.ConfigSchema), append(t.Options, opts...)...)
}

func (t *TestConfigProvider) Config(ctx context.Context, config *configx.Provider) *configx.Provider {
config = t.Contextualizer.Config(ctx, config)
values, ok := ctx.Value(contextConfigKey).([]map[string]any)
if !ok {
return config
}
opts := make([]configx.OptionModifier, 0, len(values))
for _, v := range values {
opts = append(opts, configx.WithValues(v))
}
config, err := t.NewProvider(ctx, opts...)
if err != nil {
// This is not production code. The provider is only used in tests.
panic(err)
}
return config
}

const contextConfigKey contextKey = 1

var (
_ contextx.Contextualizer = (*TestConfigProvider)(nil)
)

func WithConfigValue(ctx context.Context, key string, value any) context.Context {
return WithConfigValues(ctx, map[string]any{key: value})
}

func WithConfigValues(ctx context.Context, setValues ...map[string]any) context.Context {
values, ok := ctx.Value(contextConfigKey).([]map[string]any)
if !ok {
values = make([]map[string]any, 0)
}
newValues := make([]map[string]any, len(values), len(values)+len(setValues))
copy(newValues, values)
newValues = append(newValues, setValues...)

return context.WithValue(ctx, contextConfigKey, newValues)
}

type ConfigurableTestHandler struct {
configs map[uuid.UUID][]map[string]any
handler http.Handler
}

func NewConfigurableTestHandler(h http.Handler) *ConfigurableTestHandler {
return &ConfigurableTestHandler{
configs: make(map[uuid.UUID][]map[string]any),
handler: h,
}
}

func (t *ConfigurableTestHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
cID := r.Header.Get("Test-Config-Id")
if config, ok := t.configs[uuid.FromStringOrNil(cID)]; ok {
r = r.WithContext(WithConfigValues(r.Context(), config...))
}
t.handler.ServeHTTP(w, r)
}

func (t *ConfigurableTestHandler) RegisterConfig(config ...map[string]any) uuid.UUID {
id := uuid.Must(uuid.NewV4())
t.configs[id] = config
return id
}

func (t *ConfigurableTestHandler) UseConfig(r *http.Request, id uuid.UUID) *http.Request {
r.Header.Set("Test-Config-Id", id.String())
return r
}

func (t *ConfigurableTestHandler) UseConfigValues(r *http.Request, values ...map[string]any) *http.Request {
return t.UseConfig(r, t.RegisterConfig(values...))
}

type ConfigurableTestServer struct {
*httptest.Server
handler *ConfigurableTestHandler
transport http.RoundTripper
}

func NewConfigurableTestServer(h http.Handler) *ConfigurableTestServer {
handler := NewConfigurableTestHandler(h)
server := httptest.NewServer(handler)

t := server.Client().Transport
cts := &ConfigurableTestServer{
handler: handler,
Server: server,
transport: t,
}
server.Client().Transport = cts
return cts
}

func (t *ConfigurableTestServer) RoundTrip(r *http.Request) (*http.Response, error) {
config, ok := r.Context().Value(contextConfigKey).([]map[string]any)
if ok && config != nil {
r = t.handler.UseConfigValues(r, config...)
}
return t.transport.RoundTrip(r)
}

type AutoContextClient struct {
*http.Client
transport http.RoundTripper
ctx context.Context
}

func (t *ConfigurableTestServer) Client(ctx context.Context) *AutoContextClient {
baseClient := *t.Server.Client()
autoClient := &AutoContextClient{
Client: &baseClient,
transport: t,
ctx: ctx,
}
baseClient.Transport = autoClient
return autoClient
}

func (c *AutoContextClient) RoundTrip(r *http.Request) (*http.Response, error) {
return c.transport.RoundTrip(r.WithContext(c.ctx))
}
Loading

0 comments on commit e0001b0

Please sign in to comment.