Skip to content

Commit

Permalink
fix: don't show oidc subject in login hints (#4264)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonas-jonas authored Jan 10, 2025
1 parent 23f3232 commit b95fd3f
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 39 deletions.
62 changes: 29 additions & 33 deletions identity/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"reflect"
"slices"
"sort"
"strings"

"github.com/ory/kratos/schema"
"github.com/ory/x/sqlcon"
Expand Down Expand Up @@ -102,7 +103,7 @@ func (m *Manager) Create(ctx context.Context, i *Identity, opts ...ManagerOption
return nil
}

func (m *Manager) ConflictingIdentity(ctx context.Context, i *Identity) (found *Identity, foundConflictAddress string, err error) {
func (m *Manager) ConflictingIdentity(ctx context.Context, i *Identity) (found *Identity, foundConflictAddress string, conflictAddressType string, err error) {
for ct, cred := range i.Credentials {
for _, id := range cred.Identifiers {
found, _, err = m.r.PrivilegedIdentityPool().FindByCredentialsIdentifier(ctx, ct, id)
Expand All @@ -112,10 +113,10 @@ func (m *Manager) ConflictingIdentity(ctx context.Context, i *Identity) (found *

// FindByCredentialsIdentifier does not expand identity credentials.
if err = m.r.PrivilegedIdentityPool().HydrateIdentityAssociations(ctx, found, ExpandCredentials); err != nil {
return nil, "", err
return nil, "", "", err
}

return found, id, nil
return found, id, ct.String(), nil
}
}

Expand All @@ -125,16 +126,16 @@ func (m *Manager) ConflictingIdentity(ctx context.Context, i *Identity) (found *
if errors.Is(err, sqlcon.ErrNoRows) {
continue
} else if err != nil {
return nil, "", err
return nil, "", "", err
}

foundConflictAddress = conflictingAddress.Value
found, err = m.r.PrivilegedIdentityPool().GetIdentity(ctx, conflictingAddress.IdentityID, ExpandCredentials)
if err != nil {
return nil, "", err
return nil, "", "", err
}

return found, foundConflictAddress, nil
return found, foundConflictAddress, va.Via, nil
}

// Last option: check the recovery address
Expand All @@ -143,27 +144,27 @@ func (m *Manager) ConflictingIdentity(ctx context.Context, i *Identity) (found *
if errors.Is(err, sqlcon.ErrNoRows) {
continue
} else if err != nil {
return nil, "", err
return nil, "", "", err
}

foundConflictAddress = conflictingAddress.Value
found, err = m.r.PrivilegedIdentityPool().GetIdentity(ctx, conflictingAddress.IdentityID, ExpandCredentials)
if err != nil {
return nil, "", err
return nil, "", "", err
}

return found, foundConflictAddress, nil
return found, foundConflictAddress, string(va.Via), nil
}

return nil, "", sqlcon.ErrNoRows
return nil, "", "", sqlcon.ErrNoRows
}

func (m *Manager) findExistingAuthMethod(ctx context.Context, e error, i *Identity) (err error) {
if !m.r.Config().SelfServiceFlowRegistrationLoginHints(ctx) {
return &ErrDuplicateCredentials{error: e}
}

found, foundConflictAddress, err := m.ConflictingIdentity(ctx, i)
found, foundConflictAddress, conflictingAddressType, err := m.ConflictingIdentity(ctx, i)
if err != nil {
if errors.Is(err, sqlcon.ErrNoRows) {
return &ErrDuplicateCredentials{error: e}
Expand All @@ -181,6 +182,11 @@ func (m *Manager) findExistingAuthMethod(ctx context.Context, e error, i *Identi
})

duplicateCredErr := &ErrDuplicateCredentials{error: e}
// OIDC credentials are not email addresses but the sub claim from the OIDC provider.
// This is useless for the user, so in that case, we don't set the identifier hint.
if conflictingAddressType != CredentialsTypeOIDC.String() {
duplicateCredErr.SetIdentifierHint(strings.Trim(foundConflictAddress, " "))
}

for _, cred := range creds {
if cred.Config == nil {
Expand All @@ -192,11 +198,9 @@ func (m *Manager) findExistingAuthMethod(ctx context.Context, e error, i *Identi
// in to the first factor (obviously).
switch cred.Type {
case CredentialsTypePassword:
identifierHint := foundConflictAddress
if len(cred.Identifiers) > 0 {
identifierHint = cred.Identifiers[0]
if duplicateCredErr.IdentifierHint() == "" && len(cred.Identifiers) == 1 {
duplicateCredErr.SetIdentifierHint(cred.Identifiers[0])
}
duplicateCredErr.SetIdentifierHint(identifierHint)

var cfg CredentialsPassword
if err := json.Unmarshal(cred.Config, &cfg); err != nil {
Expand All @@ -209,14 +213,7 @@ func (m *Manager) findExistingAuthMethod(ctx context.Context, e error, i *Identi
}

duplicateCredErr.AddCredentialsType(cred.Type)

case CredentialsTypeCodeAuth:
identifierHint := foundConflictAddress
if len(cred.Identifiers) > 0 {
identifierHint = cred.Identifiers[0]
}

duplicateCredErr.SetIdentifierHint(identifierHint)
duplicateCredErr.AddCredentialsType(cred.Type)
case CredentialsTypeOIDC:
var cfg CredentialsOIDC
Expand All @@ -230,23 +227,19 @@ func (m *Manager) findExistingAuthMethod(ctx context.Context, e error, i *Identi
}

duplicateCredErr.AddCredentialsType(cred.Type)
duplicateCredErr.SetIdentifierHint(foundConflictAddress)
duplicateCredErr.availableOIDCProviders = available
case CredentialsTypeWebAuthn:
var cfg CredentialsWebAuthnConfig
if err := json.Unmarshal(cred.Config, &cfg); err != nil {
return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Unable to JSON decode identity credentials %s for identity %s.", cred.Type, found.ID))
}

identifierHint := foundConflictAddress
if len(cred.Identifiers) > 0 {
identifierHint = cred.Identifiers[0]
if duplicateCredErr.IdentifierHint() == "" && len(cred.Identifiers) == 1 {
duplicateCredErr.SetIdentifierHint(cred.Identifiers[0])
}

for _, webauthn := range cfg.Credentials {
if webauthn.IsPasswordless {
duplicateCredErr.AddCredentialsType(cred.Type)
duplicateCredErr.SetIdentifierHint(identifierHint)
break
}
}
Expand All @@ -256,15 +249,12 @@ func (m *Manager) findExistingAuthMethod(ctx context.Context, e error, i *Identi
return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Unable to JSON decode identity credentials %s for identity %s.", cred.Type, found.ID))
}

identifierHint := foundConflictAddress
if len(cred.Identifiers) > 0 {
identifierHint = cred.Identifiers[0]
if duplicateCredErr.IdentifierHint() == "" && len(cred.Identifiers) == 1 {
duplicateCredErr.SetIdentifierHint(cred.Identifiers[0])
}

for _, webauthn := range cfg.Credentials {
if webauthn.IsPasswordless {
duplicateCredErr.AddCredentialsType(cred.Type)
duplicateCredErr.SetIdentifierHint(identifierHint)
break
}
}
Expand Down Expand Up @@ -343,6 +333,7 @@ func (e *CreateIdentitiesError) Error() string {
e.init()
return fmt.Sprintf("create identities error: %d identities failed", len(e.failedIdentities))
}

func (e *CreateIdentitiesError) Unwrap() []error {
e.init()
var errs []error
Expand All @@ -356,17 +347,20 @@ func (e *CreateIdentitiesError) AddFailedIdentity(ident *Identity, err *herodot.
e.init()
e.failedIdentities[ident] = err
}

func (e *CreateIdentitiesError) Merge(other *CreateIdentitiesError) {
e.init()
for k, v := range other.failedIdentities {
e.failedIdentities[k] = v
}
}

func (e *CreateIdentitiesError) Contains(ident *Identity) bool {
e.init()
_, found := e.failedIdentities[ident]
return found
}

func (e *CreateIdentitiesError) Find(ident *Identity) *FailedIdentity {
e.init()
if err, found := e.failedIdentities[ident]; found {
Expand All @@ -375,12 +369,14 @@ func (e *CreateIdentitiesError) Find(ident *Identity) *FailedIdentity {

return nil
}

func (e *CreateIdentitiesError) ErrOrNil() error {
if e == nil || len(e.failedIdentities) == 0 {
return nil
}
return e
}

func (e *CreateIdentitiesError) init() {
if e.failedIdentities == nil {
e.failedIdentities = map[*Identity]*herodot.DefaultError{}
Expand Down
14 changes: 9 additions & 5 deletions identity/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ func TestManager(t *testing.T) {
assert.ErrorAs(t, err, &verr)
assert.ElementsMatch(t, []string{"oidc"}, verr.AvailableCredentials())
assert.ElementsMatch(t, []string{"google", "github"}, verr.AvailableOIDCProviders())
// The conflicting identifier is the oidc subject, which is not useful for the user
assert.Equal(t, email, verr.IdentifierHint())
})

Expand Down Expand Up @@ -756,29 +757,31 @@ func TestManager(t *testing.T) {
require.NoError(t, reg.IdentityManager().Create(ctx, conflicOnRecoveryAddress))

t.Run("case=returns not found if no conflict", func(t *testing.T) {
found, foundConflictAddress, err := reg.IdentityManager().ConflictingIdentity(ctx, &identity.Identity{
found, foundConflictAddress, addressType, err := reg.IdentityManager().ConflictingIdentity(ctx, &identity.Identity{
Credentials: map[identity.CredentialsType]identity.Credentials{
identity.CredentialsTypePassword: {Identifiers: []string{"[email protected]"}},
},
})
assert.ErrorIs(t, err, sqlcon.ErrNoRows)
assert.Nil(t, found)
assert.Empty(t, foundConflictAddress)
assert.Empty(t, addressType)
})

t.Run("case=conflict on identifier", func(t *testing.T) {
found, foundConflictAddress, err := reg.IdentityManager().ConflictingIdentity(ctx, &identity.Identity{
found, foundConflictAddress, addressType, err := reg.IdentityManager().ConflictingIdentity(ctx, &identity.Identity{
Credentials: map[identity.CredentialsType]identity.Credentials{
identity.CredentialsTypePassword: {Identifiers: []string{"[email protected]"}},
},
})
require.NoError(t, err)
assert.Equal(t, conflicOnIdentifier.ID, found.ID)
assert.Equal(t, "[email protected]", foundConflictAddress)
assert.EqualValues(t, string(identity.CredentialsTypePassword), addressType)
})

t.Run("case=conflict on verifiable address", func(t *testing.T) {
found, foundConflictAddress, err := reg.IdentityManager().ConflictingIdentity(ctx, &identity.Identity{
found, foundConflictAddress, addressType, err := reg.IdentityManager().ConflictingIdentity(ctx, &identity.Identity{
VerifiableAddresses: []identity.VerifiableAddress{{
Value: "[email protected]",
Via: "email",
Expand All @@ -787,10 +790,10 @@ func TestManager(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, conflicOnVerifiableAddress.ID, found.ID)
assert.Equal(t, "[email protected]", foundConflictAddress)
assert.Equal(t, "email", addressType)
})

t.Run("case=conflict on recovery address", func(t *testing.T) {
found, foundConflictAddress, err := reg.IdentityManager().ConflictingIdentity(ctx, &identity.Identity{
found, foundConflictAddress, addressType, err := reg.IdentityManager().ConflictingIdentity(ctx, &identity.Identity{
RecoveryAddresses: []identity.RecoveryAddress{{
Value: "[email protected]",
Via: "email",
Expand All @@ -799,6 +802,7 @@ func TestManager(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, conflicOnRecoveryAddress.ID, found.ID)
assert.Equal(t, "[email protected]", foundConflictAddress)
assert.Equal(t, "email", addressType)
})
})
}
Expand Down
2 changes: 1 addition & 1 deletion selfservice/flow/registration/hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ func (e *HookExecutor) PostRegistrationHook(w http.ResponseWriter, r *http.Reque
}

func (e *HookExecutor) getDuplicateIdentifier(ctx context.Context, i *identity.Identity) (string, error) {
_, id, err := e.d.IdentityManager().ConflictingIdentity(ctx, i)
_, id, _, err := e.d.IdentityManager().ConflictingIdentity(ctx, i)
if err != nil {
return "", err
}
Expand Down

0 comments on commit b95fd3f

Please sign in to comment.