Skip to content

Commit

Permalink
fix: return credentials in FindByCredentialsIdentifier (#4068)
Browse files Browse the repository at this point in the history
Instead of re-fetching the credentials later (expensive), we load them only once.
  • Loading branch information
aeneasr authored Aug 30, 2024
1 parent dbf7274 commit f949173
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 12 deletions.
18 changes: 13 additions & 5 deletions identity/test/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -865,8 +865,12 @@ func TestPool(ctx context.Context, p persistence.Persister, m *identity.Manager,
// assert.EqualValues(t, expected.Credentials[CredentialsTypePassword].CreatedAt.Unix(), creds.CreatedAt.Unix())
// assert.EqualValues(t, expected.Credentials[CredentialsTypePassword].UpdatedAt.Unix(), creds.UpdatedAt.Unix())

expected.Credentials = nil
assertEqual(t, expected, actual)
require.Equal(t, expected.Traits, actual.Traits)
require.Equal(t, expected.ID, actual.ID)
require.NotNil(t, actual.Credentials[identity.CredentialsTypePassword])
assert.EqualValues(t, expected.Credentials[identity.CredentialsTypePassword].ID, actual.Credentials[identity.CredentialsTypePassword].ID)
assert.EqualValues(t, expected.Credentials[identity.CredentialsTypePassword].Identifiers, actual.Credentials[identity.CredentialsTypePassword].Identifiers)
assert.JSONEq(t, string(expected.Credentials[identity.CredentialsTypePassword].Config), string(actual.Credentials[identity.CredentialsTypePassword].Config))

t.Run("not if on another network", func(t *testing.T) {
_, p := testhelpers.NewNetwork(t, ctx, p)
Expand Down Expand Up @@ -1030,8 +1034,12 @@ func TestPool(ctx context.Context, p persistence.Persister, m *identity.Manager,
assert.EqualValues(t, []string{strings.ToLower(identifier)}, creds.Identifiers)
assert.JSONEq(t, string(expected.Credentials[identity.CredentialsTypePassword].Config), string(creds.Config))

expected.Credentials = nil
assertEqual(t, expected, actual)
require.Equal(t, expected.Traits, actual.Traits)
require.Equal(t, expected.ID, actual.ID)
require.NotNil(t, actual.Credentials[identity.CredentialsTypePassword])
assert.EqualValues(t, expected.Credentials[identity.CredentialsTypePassword].ID, actual.Credentials[identity.CredentialsTypePassword].ID)
assert.EqualValues(t, []string{strings.ToLower(identifier)}, actual.Credentials[identity.CredentialsTypePassword].Identifiers)
assert.JSONEq(t, string(expected.Credentials[identity.CredentialsTypePassword].Config), string(actual.Credentials[identity.CredentialsTypePassword].Config))

t.Run("not if on another network", func(t *testing.T) {
_, p := testhelpers.NewNetwork(t, ctx, p)
Expand Down Expand Up @@ -1354,7 +1362,7 @@ func TestPool(ctx context.Context, p persistence.Persister, m *identity.Manager,
i, c, err := p.FindByCredentialsIdentifier(ctx, m[0].Name, "nid1")
require.NoError(t, err)
assert.Equal(t, "nid1", c.Identifiers[0])
require.Len(t, i.Credentials, 0)
require.Len(t, i.Credentials, 1)

_, _, err = p.FindByCredentialsIdentifier(ctx, m[0].Name, "nid2")
require.ErrorIs(t, err, sqlcon.ErrNoRows)
Expand Down
2 changes: 1 addition & 1 deletion persistence/sql/identity/persister_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ func (p *IdentityPersister) FindByCredentialsIdentifier(ctx context.Context, ct
return nil, nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("The SQL adapter failed to return the appropriate credentials_type \"%s\". This is a bug in the code.", ct))
}

return i.CopyWithoutCredentials(), creds, nil
return i, creds, nil
}

func (p *IdentityPersister) FindIdentityByWebauthnUserHandle(ctx context.Context, userHandle []byte) (_ *identity.Identity, err error) {
Expand Down
4 changes: 2 additions & 2 deletions selfservice/flow/login/hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ func (e *HookExecutor) PostLoginHook(
s.Token = ""

// If we detect that whoami would require a higher AAL, we redirect!
if _, err := e.requiresAAL2(r, s, f); err != nil {
if _, err := e.requiresAAL2(r, classified, f); err != nil {
if aalErr := new(session.ErrAALNotSatisfied); errors.As(err, &aalErr) {
span.SetAttributes(attribute.String("return_to", aalErr.RedirectTo), attribute.String("redirect_reason", "requires aal2"))
e.d.Writer().WriteError(w, r, flow.NewBrowserLocationChangeRequiredError(aalErr.RedirectTo))
Expand Down Expand Up @@ -303,7 +303,7 @@ func (e *HookExecutor) PostLoginHook(
}

// If we detect that whoami would require a higher AAL, we redirect!
if _, err := e.requiresAAL2(r, s, f); err != nil {
if _, err := e.requiresAAL2(r, classified, f); err != nil {
if aalErr := new(session.ErrAALNotSatisfied); errors.As(err, &aalErr) {
http.Redirect(w, r, aalErr.RedirectTo, http.StatusSeeOther)
return nil
Expand Down
8 changes: 4 additions & 4 deletions session/manager_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,10 @@ func (s *ManagerHTTP) ActivateSession(r *http.Request, session *Session, i *iden
return errors.WithStack(ErrIdentityDisabled.WithDetail("identity_id", i.ID))
}

if err := s.r.IdentityManager().RefreshAvailableAAL(ctx, i); err != nil {
return err
}

session.Identity = i
session.IdentityID = i.ID

Expand All @@ -454,10 +458,6 @@ func (s *ManagerHTTP) ActivateSession(r *http.Request, session *Session, i *iden
session.SetSessionDeviceInformation(r.WithContext(ctx))
session.SetAuthenticatorAssuranceLevel()

if err := s.r.IdentityManager().RefreshAvailableAAL(ctx, session.Identity); err != nil {
return err
}

span.SetAttributes(
attribute.String("identity.available_aal", session.Identity.InternalAvailableAAL.String),
)
Expand Down

0 comments on commit f949173

Please sign in to comment.