From eeb13552118504f17b48f2c7e002e777f5ee73f4 Mon Sep 17 00:00:00 2001 From: hackerman <3372410+aeneasr@users.noreply.github.com> Date: Tue, 29 Oct 2024 11:24:53 +0100 Subject: [PATCH] feat: fast add credential type lookups (#4177) --- identity/pool.go | 6 + identity/test/pool.go | 21 ++++ .../sql/identity/persister_identity.go | 103 +++++++++++------- x/sync_map.go | 66 +++++++++++ x/sync_map_test.go | 98 +++++++++++++++++ 5 files changed, 255 insertions(+), 39 deletions(-) create mode 100644 x/sync_map.go create mode 100644 x/sync_map_test.go diff --git a/identity/pool.go b/identity/pool.go index f57cd87ca475..340b73a9b98d 100644 --- a/identity/pool.go +++ b/identity/pool.go @@ -115,6 +115,12 @@ type ( // FindIdentityByWebauthnUserHandle returns an identity matching a webauthn user handle. FindIdentityByWebauthnUserHandle(ctx context.Context, userHandle []byte) (*Identity, error) + + // FindIdentityCredentialsTypeByID returns the credentials type by its id. + FindIdentityCredentialsTypeByID(ctx context.Context, id uuid.UUID) (CredentialsType, error) + + // FindIdentityCredentialsTypeByName returns the credentials type by its name. + FindIdentityCredentialsTypeByName(ctx context.Context, ct CredentialsType) (uuid.UUID, error) } ) diff --git a/identity/test/pool.go b/identity/test/pool.go index 4d9f4c440910..14b41f4bf4d8 100644 --- a/identity/test/pool.go +++ b/identity/test/pool.go @@ -1211,6 +1211,27 @@ func TestPool(ctx context.Context, p persistence.Persister, m *identity.Manager, }) }) + t.Run("suite=credential-types", func(t *testing.T) { + for _, ct := range identity.AllCredentialTypes { + t.Run("type="+ct.String(), func(t *testing.T) { + id, err := p.FindIdentityCredentialsTypeByName(ctx, ct) + require.NoError(t, err) + + require.NotEqual(t, uuid.Nil, id) + name, err := p.FindIdentityCredentialsTypeByID(ctx, id) + require.NoError(t, err) + + assert.Equal(t, ct, name) + }) + } + + _, err := p.FindIdentityCredentialsTypeByName(ctx, "unknown") + require.Error(t, err) + + _, err = p.FindIdentityCredentialsTypeByID(ctx, x.NewUUID()) + require.Error(t, err) + }) + t.Run("suite=recovery-address", func(t *testing.T) { createIdentityWithAddresses := func(t *testing.T, email string) *identity.Identity { var i identity.Identity diff --git a/persistence/sql/identity/persister_identity.go b/persistence/sql/identity/persister_identity.go index ed46af2924ea..026c40687946 100644 --- a/persistence/sql/identity/persister_identity.go +++ b/persistence/sql/identity/persister_identity.go @@ -10,7 +10,6 @@ import ( "fmt" "sort" "strings" - "sync" "time" "github.com/ory/kratos/x/events" @@ -61,12 +60,17 @@ type IdentityPersister struct { r dependencies c *pop.Connection nid uuid.UUID + + credentialTypesID *x.SyncMap[uuid.UUID, identity.CredentialsType] + credentialTypesName *x.SyncMap[identity.CredentialsType, uuid.UUID] } func NewPersister(r dependencies, c *pop.Connection) *IdentityPersister { return &IdentityPersister{ - c: c, - r: r, + c: c, + r: r, + credentialTypesID: x.NewSyncMap[uuid.UUID, identity.CredentialsType](), + credentialTypesName: x.NewSyncMap[identity.CredentialsType, uuid.UUID](), } } @@ -282,36 +286,6 @@ LIMIT 1`, jsonPath, jsonPath), return &id, nil } -var credentialsTypes = struct { - sync.RWMutex - m map[identity.CredentialsType]*identity.CredentialsTypeTable -}{ - m: map[identity.CredentialsType]*identity.CredentialsTypeTable{}, -} - -func (p *IdentityPersister) findIdentityCredentialsType(ctx context.Context, ct identity.CredentialsType) (_ *identity.CredentialsTypeTable, err error) { - credentialsTypes.RLock() - v, ok := credentialsTypes.m[ct] - credentialsTypes.RUnlock() - - if ok && v != nil { - return v, nil - } - - ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.findIdentityCredentialsType") - defer otelx.End(span, &err) - - var m identity.CredentialsTypeTable - if err := p.GetConnection(ctx).Where("name = ?", ct).First(&m); err != nil { - return nil, sqlcon.HandleError(err) - } - credentialsTypes.Lock() - credentialsTypes.m[ct] = &m - credentialsTypes.Unlock() - - return &m, nil -} - func (p *IdentityPersister) createIdentityCredentials(ctx context.Context, conn *pop.Connection, identities ...*identity.Identity) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.createIdentityCredentials", trace.WithAttributes( @@ -339,7 +313,7 @@ func (p *IdentityPersister) createIdentityCredentials(ctx context.Context, conn cred.Config = sqlxx.JSONRawMessage("{}") } - ct, err := p.findIdentityCredentialsType(ctx, cred.Type) + ct, err := p.FindIdentityCredentialsTypeByName(ctx, cred.Type) if err != nil { return err } @@ -350,7 +324,7 @@ func (p *IdentityPersister) createIdentityCredentials(ctx context.Context, conn } cred.IdentityID = ident.ID cred.NID = nid - cred.IdentityCredentialTypeID = ct.ID + cred.IdentityCredentialTypeID = ct credentials = append(credentials, &cred) ident.Credentials[k] = cred @@ -370,7 +344,7 @@ func (p *IdentityPersister) createIdentityCredentials(ctx context.Context, conn "Unable to create identity credentials with missing or empty identifier.")) } - ct, err := p.findIdentityCredentialsType(ctx, cred.Type) + ct, err := p.FindIdentityCredentialsTypeByName(ctx, cred.Type) if err != nil { return err } @@ -378,7 +352,7 @@ func (p *IdentityPersister) createIdentityCredentials(ctx context.Context, conn identifiers = append(identifiers, &identity.CredentialIdentifier{ Identifier: identifier, IdentityCredentialsID: cred.ID, - IdentityCredentialsTypeID: ct.ID, + IdentityCredentialsTypeID: ct, NID: p.NetworkID(ctx), }) } @@ -883,11 +857,11 @@ func (p *IdentityPersister) getCredentialTypeIDs(ctx context.Context, credential result := map[identity.CredentialsType]uuid.UUID{} for _, ct := range credentialTypes { - typeID, err := p.findIdentityCredentialsType(ctx, ct) + typeID, err := p.FindIdentityCredentialsTypeByName(ctx, ct) if err != nil { return nil, err } - result[ct] = typeID.ID + result[ct] = typeID } return result, nil @@ -1340,3 +1314,54 @@ func (p *IdentityPersister) InjectTraitsSchemaURL(ctx context.Context, i *identi i.SchemaURL = s.SchemaURL(p.r.Config().SelfPublicURL(ctx)).String() return nil } + +func (p *IdentityPersister) FindIdentityCredentialsTypeByID(ctx context.Context, id uuid.UUID) (identity.CredentialsType, error) { + result, found := p.credentialTypesID.Load(id) + if !found { + if err := p.loadCredentialTypes(ctx); err != nil { + return "", err + } + + result, found = p.credentialTypesID.Load(id) + } + + if !found { + return "", errors.WithStack(herodot.ErrInternalServerError.WithReasonf("The SQL adapter failed to return the appropriate credentials_type for id %q. This is a bug in the code.", id)) + } + + return result, nil +} + +func (p *IdentityPersister) FindIdentityCredentialsTypeByName(ctx context.Context, ct identity.CredentialsType) (uuid.UUID, error) { + result, found := p.credentialTypesName.Load(ct) + if !found { + if err := p.loadCredentialTypes(ctx); err != nil { + return uuid.Nil, err + } + + result, found = p.credentialTypesName.Load(ct) + } + + if !found { + return uuid.Nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("The SQL adapter failed to return the appropriate credentials_type for nane %s. This is a bug in the code.", ct)) + } + + return result, nil +} + +func (p *IdentityPersister) loadCredentialTypes(ctx context.Context) (err error) { + ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.identity.loadCredentialTypes") + defer otelx.End(span, &err) + + var tt []identity.CredentialsTypeTable + if err := p.GetConnection(ctx).All(&tt); err != nil { + return sqlcon.HandleError(err) + } + + for _, t := range tt { + p.credentialTypesID.Store(t.ID, t.Name) + p.credentialTypesName.Store(t.Name, t.ID) + } + + return nil +} diff --git a/x/sync_map.go b/x/sync_map.go new file mode 100644 index 000000000000..4a343e932eb6 --- /dev/null +++ b/x/sync_map.go @@ -0,0 +1,66 @@ +// Copyright © 2024 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package x + +import ( + "sync" +) + +// SyncMap provides a thread-safe map with generic keys and values +type SyncMap[K comparable, V any] struct { + mu sync.RWMutex + data map[K]V +} + +// NewSyncMap initializes a new SyncMap instance +func NewSyncMap[K comparable, V any]() *SyncMap[K, V] { + return &SyncMap[K, V]{ + data: make(map[K]V), + } +} + +// Load retrieves a value for a key. It returns the value and a boolean indicating if the key exists. +func (m *SyncMap[K, V]) Load(key K) (V, bool) { + m.mu.RLock() + defer m.mu.RUnlock() + val, ok := m.data[key] + return val, ok +} + +// Store sets a value for a key, replacing any existing value. +func (m *SyncMap[K, V]) Store(key K, value V) { + m.mu.Lock() + defer m.mu.Unlock() + m.data[key] = value +} + +// LoadOrStore retrieves the existing value for a key if it exists, or stores and returns a given value if it doesn't. +func (m *SyncMap[K, V]) LoadOrStore(key K, value V) (V, bool) { + m.mu.Lock() + defer m.mu.Unlock() + if existing, ok := m.data[key]; ok { + return existing, true + } + m.data[key] = value + return value, false +} + +// Delete removes a key-value pair from the map. +func (m *SyncMap[K, V]) Delete(key K) { + m.mu.Lock() + defer m.mu.Unlock() + delete(m.data, key) +} + +// Range iterates over all entries in the map, calling the provided function for each key-value pair. +// If the function returns false, the iteration stops. +func (m *SyncMap[K, V]) Range(f func(key K, value V) bool) { + m.mu.RLock() + defer m.mu.RUnlock() + for k, v := range m.data { + if !f(k, v) { + break + } + } +} diff --git a/x/sync_map_test.go b/x/sync_map_test.go new file mode 100644 index 000000000000..6da97c63c801 --- /dev/null +++ b/x/sync_map_test.go @@ -0,0 +1,98 @@ +// Copyright © 2024 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package x + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSyncMapStoreAndLoad(t *testing.T) { + m := NewSyncMap[int, string]() + + m.Store(1, "one") + + // Test Load for an existing key + val, ok := m.Load(1) + require.True(t, ok, "Expected key 1 to exist") + assert.Equal(t, "one", val, "Expected value 'one' for key 1") + + // Test Load for a non-existing key + _, ok = m.Load(2) + assert.False(t, ok, "Expected key 2 to be absent") +} + +func TestSyncMapLoadOrStore(t *testing.T) { + m := NewSyncMap[int, string]() + + // Store a new key-value pair + val, loaded := m.LoadOrStore(1, "one") + require.False(t, loaded, "Expected key 1 to be newly stored") + assert.Equal(t, "one", val, "Expected value 'one' for key 1 after LoadOrStore") + + // Attempt to store a new value for an existing key + val, loaded = m.LoadOrStore(1, "uno") + require.True(t, loaded, "Expected key 1 to already exist") + assert.Equal(t, "one", val, "Expected existing value 'one' for key 1") +} + +func TestSyncMapDelete(t *testing.T) { + m := NewSyncMap[int, string]() + + m.Store(1, "one") + m.Delete(1) + + _, ok := m.Load(1) + assert.False(t, ok, "Expected key 1 to be deleted") +} + +func TestSyncMapRange(t *testing.T) { + m := NewSyncMap[int, string]() + + m.Store(1, "one") + m.Store(2, "two") + m.Store(3, "three") + + expected := map[int]string{ + 1: "one", + 2: "two", + 3: "three", + } + + m.Range(func(key int, value string) bool { + expectedVal, exists := expected[key] + require.True(t, exists, "Unexpected key found in map") + assert.Equal(t, expectedVal, value, "Unexpected value for key %d", key) + delete(expected, key) + return true + }) + + assert.Empty(t, expected, "Not all entries were iterated over") +} + +func TestSyncMapConcurrentAccess(t *testing.T) { + m := NewSyncMap[int, int]() + var wg sync.WaitGroup + + // Run multiple goroutines to test concurrent access + for i := 0; i < 100; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + m.Store(i, i) + }(i) + } + + wg.Wait() + + // Verify all stored values + for i := 0; i < 100; i++ { + val, ok := m.Load(i) + require.True(t, ok, "Expected key %d to exist", i) + assert.Equal(t, i, val, "Expected value %d for key %d", i, i) + } +}