Skip to content

Commit

Permalink
feat: fast add credential type lookups (#4177)
Browse files Browse the repository at this point in the history
  • Loading branch information
aeneasr authored Oct 29, 2024
1 parent 8e29b68 commit eeb1355
Show file tree
Hide file tree
Showing 5 changed files with 255 additions and 39 deletions.
6 changes: 6 additions & 0 deletions identity/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@ type (

// FindIdentityByWebauthnUserHandle returns an identity matching a webauthn user handle.
FindIdentityByWebauthnUserHandle(ctx context.Context, userHandle []byte) (*Identity, error)

// FindIdentityCredentialsTypeByID returns the credentials type by its id.
FindIdentityCredentialsTypeByID(ctx context.Context, id uuid.UUID) (CredentialsType, error)

// FindIdentityCredentialsTypeByName returns the credentials type by its name.
FindIdentityCredentialsTypeByName(ctx context.Context, ct CredentialsType) (uuid.UUID, error)
}
)

Expand Down
21 changes: 21 additions & 0 deletions identity/test/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -1211,6 +1211,27 @@ func TestPool(ctx context.Context, p persistence.Persister, m *identity.Manager,
})
})

t.Run("suite=credential-types", func(t *testing.T) {
for _, ct := range identity.AllCredentialTypes {
t.Run("type="+ct.String(), func(t *testing.T) {
id, err := p.FindIdentityCredentialsTypeByName(ctx, ct)
require.NoError(t, err)

require.NotEqual(t, uuid.Nil, id)
name, err := p.FindIdentityCredentialsTypeByID(ctx, id)
require.NoError(t, err)

assert.Equal(t, ct, name)
})
}

_, err := p.FindIdentityCredentialsTypeByName(ctx, "unknown")
require.Error(t, err)

_, err = p.FindIdentityCredentialsTypeByID(ctx, x.NewUUID())
require.Error(t, err)
})

t.Run("suite=recovery-address", func(t *testing.T) {
createIdentityWithAddresses := func(t *testing.T, email string) *identity.Identity {
var i identity.Identity
Expand Down
103 changes: 64 additions & 39 deletions persistence/sql/identity/persister_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"fmt"
"sort"
"strings"
"sync"
"time"

"github.com/ory/kratos/x/events"
Expand Down Expand Up @@ -61,12 +60,17 @@ type IdentityPersister struct {
r dependencies
c *pop.Connection
nid uuid.UUID

credentialTypesID *x.SyncMap[uuid.UUID, identity.CredentialsType]
credentialTypesName *x.SyncMap[identity.CredentialsType, uuid.UUID]
}

func NewPersister(r dependencies, c *pop.Connection) *IdentityPersister {
return &IdentityPersister{
c: c,
r: r,
c: c,
r: r,
credentialTypesID: x.NewSyncMap[uuid.UUID, identity.CredentialsType](),
credentialTypesName: x.NewSyncMap[identity.CredentialsType, uuid.UUID](),
}
}

Expand Down Expand Up @@ -282,36 +286,6 @@ LIMIT 1`, jsonPath, jsonPath),
return &id, nil
}

var credentialsTypes = struct {
sync.RWMutex
m map[identity.CredentialsType]*identity.CredentialsTypeTable
}{
m: map[identity.CredentialsType]*identity.CredentialsTypeTable{},
}

func (p *IdentityPersister) findIdentityCredentialsType(ctx context.Context, ct identity.CredentialsType) (_ *identity.CredentialsTypeTable, err error) {
credentialsTypes.RLock()
v, ok := credentialsTypes.m[ct]
credentialsTypes.RUnlock()

if ok && v != nil {
return v, nil
}

ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.findIdentityCredentialsType")
defer otelx.End(span, &err)

var m identity.CredentialsTypeTable
if err := p.GetConnection(ctx).Where("name = ?", ct).First(&m); err != nil {
return nil, sqlcon.HandleError(err)
}
credentialsTypes.Lock()
credentialsTypes.m[ct] = &m
credentialsTypes.Unlock()

return &m, nil
}

func (p *IdentityPersister) createIdentityCredentials(ctx context.Context, conn *pop.Connection, identities ...*identity.Identity) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.createIdentityCredentials",
trace.WithAttributes(
Expand Down Expand Up @@ -339,7 +313,7 @@ func (p *IdentityPersister) createIdentityCredentials(ctx context.Context, conn
cred.Config = sqlxx.JSONRawMessage("{}")
}

ct, err := p.findIdentityCredentialsType(ctx, cred.Type)
ct, err := p.FindIdentityCredentialsTypeByName(ctx, cred.Type)
if err != nil {
return err
}
Expand All @@ -350,7 +324,7 @@ func (p *IdentityPersister) createIdentityCredentials(ctx context.Context, conn
}
cred.IdentityID = ident.ID
cred.NID = nid
cred.IdentityCredentialTypeID = ct.ID
cred.IdentityCredentialTypeID = ct
credentials = append(credentials, &cred)

ident.Credentials[k] = cred
Expand All @@ -370,15 +344,15 @@ func (p *IdentityPersister) createIdentityCredentials(ctx context.Context, conn
"Unable to create identity credentials with missing or empty identifier."))
}

ct, err := p.findIdentityCredentialsType(ctx, cred.Type)
ct, err := p.FindIdentityCredentialsTypeByName(ctx, cred.Type)
if err != nil {
return err
}

identifiers = append(identifiers, &identity.CredentialIdentifier{
Identifier: identifier,
IdentityCredentialsID: cred.ID,
IdentityCredentialsTypeID: ct.ID,
IdentityCredentialsTypeID: ct,
NID: p.NetworkID(ctx),
})
}
Expand Down Expand Up @@ -883,11 +857,11 @@ func (p *IdentityPersister) getCredentialTypeIDs(ctx context.Context, credential
result := map[identity.CredentialsType]uuid.UUID{}

for _, ct := range credentialTypes {
typeID, err := p.findIdentityCredentialsType(ctx, ct)
typeID, err := p.FindIdentityCredentialsTypeByName(ctx, ct)
if err != nil {
return nil, err
}
result[ct] = typeID.ID
result[ct] = typeID
}

return result, nil
Expand Down Expand Up @@ -1340,3 +1314,54 @@ func (p *IdentityPersister) InjectTraitsSchemaURL(ctx context.Context, i *identi
i.SchemaURL = s.SchemaURL(p.r.Config().SelfPublicURL(ctx)).String()
return nil
}

func (p *IdentityPersister) FindIdentityCredentialsTypeByID(ctx context.Context, id uuid.UUID) (identity.CredentialsType, error) {
result, found := p.credentialTypesID.Load(id)
if !found {
if err := p.loadCredentialTypes(ctx); err != nil {
return "", err
}

result, found = p.credentialTypesID.Load(id)
}

if !found {
return "", errors.WithStack(herodot.ErrInternalServerError.WithReasonf("The SQL adapter failed to return the appropriate credentials_type for id %q. This is a bug in the code.", id))
}

return result, nil
}

func (p *IdentityPersister) FindIdentityCredentialsTypeByName(ctx context.Context, ct identity.CredentialsType) (uuid.UUID, error) {
result, found := p.credentialTypesName.Load(ct)
if !found {
if err := p.loadCredentialTypes(ctx); err != nil {
return uuid.Nil, err
}

result, found = p.credentialTypesName.Load(ct)
}

if !found {
return uuid.Nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("The SQL adapter failed to return the appropriate credentials_type for nane %s. This is a bug in the code.", ct))
}

return result, nil
}

func (p *IdentityPersister) loadCredentialTypes(ctx context.Context) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.identity.loadCredentialTypes")
defer otelx.End(span, &err)

var tt []identity.CredentialsTypeTable
if err := p.GetConnection(ctx).All(&tt); err != nil {
return sqlcon.HandleError(err)
}

for _, t := range tt {
p.credentialTypesID.Store(t.ID, t.Name)
p.credentialTypesName.Store(t.Name, t.ID)
}

return nil
}
66 changes: 66 additions & 0 deletions x/sync_map.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright © 2024 Ory Corp
// SPDX-License-Identifier: Apache-2.0

package x

import (
"sync"
)

// SyncMap provides a thread-safe map with generic keys and values
type SyncMap[K comparable, V any] struct {
mu sync.RWMutex
data map[K]V
}

// NewSyncMap initializes a new SyncMap instance
func NewSyncMap[K comparable, V any]() *SyncMap[K, V] {
return &SyncMap[K, V]{
data: make(map[K]V),
}
}

// Load retrieves a value for a key. It returns the value and a boolean indicating if the key exists.
func (m *SyncMap[K, V]) Load(key K) (V, bool) {
m.mu.RLock()
defer m.mu.RUnlock()
val, ok := m.data[key]
return val, ok
}

// Store sets a value for a key, replacing any existing value.
func (m *SyncMap[K, V]) Store(key K, value V) {
m.mu.Lock()
defer m.mu.Unlock()
m.data[key] = value
}

// LoadOrStore retrieves the existing value for a key if it exists, or stores and returns a given value if it doesn't.
func (m *SyncMap[K, V]) LoadOrStore(key K, value V) (V, bool) {
m.mu.Lock()
defer m.mu.Unlock()
if existing, ok := m.data[key]; ok {
return existing, true
}
m.data[key] = value
return value, false
}

// Delete removes a key-value pair from the map.
func (m *SyncMap[K, V]) Delete(key K) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.data, key)
}

// Range iterates over all entries in the map, calling the provided function for each key-value pair.
// If the function returns false, the iteration stops.
func (m *SyncMap[K, V]) Range(f func(key K, value V) bool) {
m.mu.RLock()
defer m.mu.RUnlock()
for k, v := range m.data {
if !f(k, v) {
break
}
}
}
98 changes: 98 additions & 0 deletions x/sync_map_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// Copyright © 2024 Ory Corp
// SPDX-License-Identifier: Apache-2.0

package x

import (
"sync"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestSyncMapStoreAndLoad(t *testing.T) {
m := NewSyncMap[int, string]()

m.Store(1, "one")

// Test Load for an existing key
val, ok := m.Load(1)
require.True(t, ok, "Expected key 1 to exist")
assert.Equal(t, "one", val, "Expected value 'one' for key 1")

// Test Load for a non-existing key
_, ok = m.Load(2)
assert.False(t, ok, "Expected key 2 to be absent")
}

func TestSyncMapLoadOrStore(t *testing.T) {
m := NewSyncMap[int, string]()

// Store a new key-value pair
val, loaded := m.LoadOrStore(1, "one")
require.False(t, loaded, "Expected key 1 to be newly stored")
assert.Equal(t, "one", val, "Expected value 'one' for key 1 after LoadOrStore")

// Attempt to store a new value for an existing key
val, loaded = m.LoadOrStore(1, "uno")
require.True(t, loaded, "Expected key 1 to already exist")
assert.Equal(t, "one", val, "Expected existing value 'one' for key 1")
}

func TestSyncMapDelete(t *testing.T) {
m := NewSyncMap[int, string]()

m.Store(1, "one")
m.Delete(1)

_, ok := m.Load(1)
assert.False(t, ok, "Expected key 1 to be deleted")
}

func TestSyncMapRange(t *testing.T) {
m := NewSyncMap[int, string]()

m.Store(1, "one")
m.Store(2, "two")
m.Store(3, "three")

expected := map[int]string{
1: "one",
2: "two",
3: "three",
}

m.Range(func(key int, value string) bool {
expectedVal, exists := expected[key]
require.True(t, exists, "Unexpected key found in map")
assert.Equal(t, expectedVal, value, "Unexpected value for key %d", key)
delete(expected, key)
return true
})

assert.Empty(t, expected, "Not all entries were iterated over")
}

func TestSyncMapConcurrentAccess(t *testing.T) {
m := NewSyncMap[int, int]()
var wg sync.WaitGroup

// Run multiple goroutines to test concurrent access
for i := 0; i < 100; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
m.Store(i, i)
}(i)
}

wg.Wait()

// Verify all stored values
for i := 0; i < 100; i++ {
val, ok := m.Load(i)
require.True(t, ok, "Expected key %d to exist", i)
assert.Equal(t, i, val, "Expected value %d for key %d", i, i)
}
}

0 comments on commit eeb1355

Please sign in to comment.