Skip to content

Commit

Permalink
feat: link oidc credentials when login (#3563)
Browse files Browse the repository at this point in the history
When user tries to login with OIDC for the first time but has already registered before with email/password a credentials identifier conflict may be detected by Kratos. In this case user needs to login with email/password first and then link OIDC credentials on a settings screen.
This PR simplifies UX and allows user to link OIDC credentials to existing account right in the login flow, without
switching to settings flow.

Closes #2727
Closes #3222
  • Loading branch information
hperl authored Nov 8, 2023
1 parent 3b75f37 commit b784949
Show file tree
Hide file tree
Showing 33 changed files with 905 additions and 176 deletions.
5 changes: 5 additions & 0 deletions cmd/clidoc/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,13 @@ func init() {
"NewInfoLoginTOTPLabel": text.NewInfoLoginTOTPLabel(),
"NewInfoLoginLookupLabel": text.NewInfoLoginLookupLabel(),
"NewInfoLogin": text.NewInfoLogin(),
"NewInfoLoginAndLink": text.NewInfoLoginAndLink(),
"NewInfoLoginLinkMessage": text.NewInfoLoginLinkMessage("{duplicteIdentifier}", "{provider}", "{newLoginUrl}"),
"NewInfoLoginTOTP": text.NewInfoLoginTOTP(),
"NewInfoLoginLookup": text.NewInfoLoginLookup(),
"NewInfoLoginVerify": text.NewInfoLoginVerify(),
"NewInfoLoginWith": text.NewInfoLoginWith("{provider}"),
"NewInfoLoginWithAndLink": text.NewInfoLoginWithAndLink("{provider}"),
"NewErrorValidationLoginFlowExpired": text.NewErrorValidationLoginFlowExpired(aSecondAgo),
"NewErrorValidationLoginNoStrategyFound": text.NewErrorValidationLoginNoStrategyFound(),
"NewErrorValidationRegistrationNoStrategyFound": text.NewErrorValidationRegistrationNoStrategyFound(),
Expand All @@ -144,6 +147,7 @@ func init() {
"NewErrorValidationRecoveryStateFailure": text.NewErrorValidationRecoveryStateFailure(),
"NewInfoNodeInputEmail": text.NewInfoNodeInputEmail(),
"NewInfoNodeResendOTP": text.NewInfoNodeResendOTP(),
"NewInfoNodeLoginAndLinkCredential": text.NewInfoNodeLoginAndLinkCredential(),
"NewInfoNodeLabelContinue": text.NewInfoNodeLabelContinue(),
"NewInfoSelfServiceSettingsRegisterWebAuthn": text.NewInfoSelfServiceSettingsRegisterWebAuthn(),
"NewInfoLoginWebAuthnPasswordless": text.NewInfoLoginWebAuthnPasswordless(),
Expand All @@ -163,6 +167,7 @@ func init() {
"NewInfoSelfServiceLoginCode": text.NewInfoSelfServiceLoginCode(),
"NewErrorValidationRegistrationRetrySuccessful": text.NewErrorValidationRegistrationRetrySuccessful(),
"NewInfoSelfServiceRegistrationRegisterCode": text.NewInfoSelfServiceRegistrationRegisterCode(),
"NewErrorValidationLoginLinkedCredentialsDoNotMatch": text.NewErrorValidationLoginLinkedCredentialsDoNotMatch(),
}
}

Expand Down
79 changes: 43 additions & 36 deletions identity/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,7 @@ func (m *Manager) Create(ctx context.Context, i *Identity, opts ...ManagerOption
return nil
}

func (m *Manager) findExistingAuthMethod(ctx context.Context, e error, i *Identity) (err error) {
if !m.r.Config().SelfServiceFlowRegistrationLoginHints(ctx) {
return &ErrDuplicateCredentials{error: e}
}
// First we try to find the conflict in the identifiers table. This is most likely to have a conflict.
var found *Identity
func (m *Manager) ConflictingIdentity(ctx context.Context, i *Identity) (found *Identity, foundConflictAddress string, err error) {
for ct, cred := range i.Credentials {
for _, id := range cred.Identifiers {
found, _, err = m.r.PrivilegedIdentityPool().FindByCredentialsIdentifier(ctx, ct, id)
Expand All @@ -125,53 +120,65 @@ func (m *Manager) findExistingAuthMethod(ctx context.Context, e error, i *Identi

// FindByCredentialsIdentifier does not expand identity credentials.
if err = m.r.PrivilegedIdentityPool().HydrateIdentityAssociations(ctx, found, ExpandCredentials); err != nil {
return err
return nil, "", err
}

return found, id, nil
}
}

// If the conflict is not in the identifiers table, it is coming from the verifiable or recovery address.
var foundConflictAddress string
if found == nil {
for _, va := range i.VerifiableAddresses {
conflictingAddress, err := m.r.PrivilegedIdentityPool().FindVerifiableAddressByValue(ctx, va.Via, va.Value)
if errors.Is(err, sqlcon.ErrNoRows) {
continue
} else if err != nil {
return err
}
for _, va := range i.VerifiableAddresses {
conflictingAddress, err := m.r.PrivilegedIdentityPool().FindVerifiableAddressByValue(ctx, va.Via, va.Value)
if errors.Is(err, sqlcon.ErrNoRows) {
continue
} else if err != nil {
return nil, "", err
}

foundConflictAddress = conflictingAddress.Value
found, err = m.r.PrivilegedIdentityPool().GetIdentity(ctx, conflictingAddress.IdentityID, ExpandCredentials)
if err != nil {
return err
}
foundConflictAddress = conflictingAddress.Value
found, err = m.r.PrivilegedIdentityPool().GetIdentity(ctx, conflictingAddress.IdentityID, ExpandCredentials)
if err != nil {
return nil, "", err
}

return found, foundConflictAddress, nil
}

// Last option: check the recovery address
if found == nil {
for _, va := range i.RecoveryAddresses {
conflictingAddress, err := m.r.PrivilegedIdentityPool().FindRecoveryAddressByValue(ctx, va.Via, va.Value)
if errors.Is(err, sqlcon.ErrNoRows) {
continue
} else if err != nil {
return err
}
for _, va := range i.RecoveryAddresses {
conflictingAddress, err := m.r.PrivilegedIdentityPool().FindRecoveryAddressByValue(ctx, va.Via, va.Value)
if errors.Is(err, sqlcon.ErrNoRows) {
continue
} else if err != nil {
return nil, "", err
}

foundConflictAddress = conflictingAddress.Value
found, err = m.r.PrivilegedIdentityPool().GetIdentity(ctx, conflictingAddress.IdentityID, ExpandCredentials)
if err != nil {
return err
}
foundConflictAddress = conflictingAddress.Value
found, err = m.r.PrivilegedIdentityPool().GetIdentity(ctx, conflictingAddress.IdentityID, ExpandCredentials)
if err != nil {
return nil, "", err
}

return found, foundConflictAddress, nil
}

// Still not found? Return generic error.
if found == nil {
return nil, "", sqlcon.ErrNoRows
}

func (m *Manager) findExistingAuthMethod(ctx context.Context, e error, i *Identity) (err error) {
if !m.r.Config().SelfServiceFlowRegistrationLoginHints(ctx) {
return &ErrDuplicateCredentials{error: e}
}

found, foundConflictAddress, err := m.ConflictingIdentity(ctx, i)
if err != nil {
if errors.Is(err, sqlcon.ErrNoRows) {
return &ErrDuplicateCredentials{error: e}
}
return err
}

// We need to sort the credentials for the error message to be deterministic.
var creds []Credentials
for _, cred := range found.Credentials {
Expand Down
63 changes: 63 additions & 0 deletions identity/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"time"

"github.com/ory/x/pointerx"
"github.com/ory/x/sqlcon"

"github.com/gofrs/uuid"

Expand Down Expand Up @@ -556,6 +557,68 @@ func TestManager(t *testing.T) {
checkExtensionFields(fromStore, "[email protected]")(t)
})
})

t.Run("method=ConflictingIdentity", func(t *testing.T) {
ctx := context.Background()

conflicOnIdentifier := identity.NewIdentity(config.DefaultIdentityTraitsSchemaID)
conflicOnIdentifier.Traits = identity.Traits(`{"email":"[email protected]"}`)
require.NoError(t, reg.IdentityManager().Create(ctx, conflicOnIdentifier))

conflicOnVerifiableAddress := identity.NewIdentity(config.DefaultIdentityTraitsSchemaID)
conflicOnVerifiableAddress.Traits = identity.Traits(`{"email":"[email protected]", "email_verify":"[email protected]"}`)
require.NoError(t, reg.IdentityManager().Create(ctx, conflicOnVerifiableAddress))

conflicOnRecoveryAddress := identity.NewIdentity(config.DefaultIdentityTraitsSchemaID)
conflicOnRecoveryAddress.Traits = identity.Traits(`{"email":"[email protected]", "email_recovery":"[email protected]"}`)
require.NoError(t, reg.IdentityManager().Create(ctx, conflicOnRecoveryAddress))

t.Run("case=returns not found if no conflict", func(t *testing.T) {
found, foundConflictAddress, err := reg.IdentityManager().ConflictingIdentity(ctx, &identity.Identity{
Credentials: map[identity.CredentialsType]identity.Credentials{
identity.CredentialsTypePassword: {Identifiers: []string{"[email protected]"}},
},
})
assert.ErrorIs(t, err, sqlcon.ErrNoRows)
assert.Nil(t, found)
assert.Empty(t, foundConflictAddress)
})

t.Run("case=conflict on identifier", func(t *testing.T) {
found, foundConflictAddress, err := reg.IdentityManager().ConflictingIdentity(ctx, &identity.Identity{
Credentials: map[identity.CredentialsType]identity.Credentials{
identity.CredentialsTypePassword: {Identifiers: []string{"[email protected]"}},
},
})
require.NoError(t, err)
assert.Equal(t, conflicOnIdentifier.ID, found.ID)
assert.Equal(t, "[email protected]", foundConflictAddress)
})

t.Run("case=conflict on verifiable address", func(t *testing.T) {
found, foundConflictAddress, err := reg.IdentityManager().ConflictingIdentity(ctx, &identity.Identity{
VerifiableAddresses: []identity.VerifiableAddress{{
Value: "[email protected]",
Via: "email",
}},
})
require.NoError(t, err)
assert.Equal(t, conflicOnVerifiableAddress.ID, found.ID)
assert.Equal(t, "[email protected]", foundConflictAddress)
})

t.Run("case=conflict on recovery address", func(t *testing.T) {
found, foundConflictAddress, err := reg.IdentityManager().ConflictingIdentity(ctx, &identity.Identity{
RecoveryAddresses: []identity.RecoveryAddress{{
Value: "[email protected]",
Via: "email",
}},
})
require.NoError(t, err)
assert.Equal(t, conflicOnRecoveryAddress.ID, found.ID)
assert.Equal(t, "[email protected]", foundConflictAddress)
})
})
}

func TestManagerNoDefaultNamedSchema(t *testing.T) {
Expand Down
10 changes: 10 additions & 0 deletions schema/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,3 +349,13 @@ func NewLoginCodeInvalid() error {
Messages: new(text.Messages).Add(text.NewErrorValidationLoginCodeInvalidOrAlreadyUsed()),
})
}

func NewLinkedCredentialsDoNotMatch() error {
return errors.WithStack(&ValidationError{
ValidationError: &jsonschema.ValidationError{
Message: `linked credentials do not match; please start a new flow`,
InstancePtr: "#/",
},
Messages: new(text.Messages).Add(text.NewErrorValidationLoginLinkedCredentialsDoNotMatch()),
})
}
3 changes: 1 addition & 2 deletions selfservice/flow/continue_with.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@ type ContinueWith any
// swagger:enum ContinueWithActionSetOrySessionToken
type ContinueWithActionSetOrySessionToken string

// #nosec G101 -- only a key constant
const (
ContinueWithActionSetOrySessionTokenString ContinueWithActionSetOrySessionToken = "set_ory_session_token"
ContinueWithActionSetOrySessionTokenString ContinueWithActionSetOrySessionToken = "set_ory_session_token" // #nosec G101 -- only a key constant
)

var _ ContinueWith = new(ContinueWithSetOrySessionToken)
Expand Down
61 changes: 61 additions & 0 deletions selfservice/flow/duplicate_credentials.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0

package flow

import (
"encoding/json"

"github.com/tidwall/gjson"
"github.com/tidwall/sjson"

"github.com/ory/kratos/identity"
"github.com/ory/x/sqlxx"
)

const internalContextDuplicateCredentialsPath = "registration_duplicate_credentials"

type DuplicateCredentialsData struct {
CredentialsType identity.CredentialsType
CredentialsConfig sqlxx.JSONRawMessage
DuplicateIdentifier string
}

type InternalContexter interface {
EnsureInternalContext()
GetInternalContext() sqlxx.JSONRawMessage
SetInternalContext(sqlxx.JSONRawMessage)
}

// SetDuplicateCredentials sets the duplicate credentials data in the flow's internal context.
func SetDuplicateCredentials(flow InternalContexter, creds DuplicateCredentialsData) error {
if flow.GetInternalContext() == nil {
flow.EnsureInternalContext()
}
bytes, err := sjson.SetBytes(
flow.GetInternalContext(),
internalContextDuplicateCredentialsPath,
creds,
)
if err != nil {
return err
}
flow.SetInternalContext(bytes)

return nil
}

// DuplicateCredentials returns the duplicate credentials data from the flow's internal context.
func DuplicateCredentials(flow InternalContexter) (*DuplicateCredentialsData, error) {
if flow.GetInternalContext() == nil {
flow.EnsureInternalContext()
}
raw := gjson.GetBytes(flow.GetInternalContext(), internalContextDuplicateCredentialsPath)
if !raw.IsObject() {
return nil, nil
}
var creds DuplicateCredentialsData
err := json.Unmarshal([]byte(raw.Raw), &creds)

return &creds, err
}
4 changes: 1 addition & 3 deletions selfservice/flow/flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,13 @@ import (
"net/http"
"net/url"

"github.com/gofrs/uuid"
"github.com/pkg/errors"

"github.com/ory/herodot"
"github.com/ory/kratos/driver/config"
"github.com/ory/kratos/ui/container"
"github.com/ory/kratos/x"

"github.com/gofrs/uuid"

"github.com/ory/x/urlx"
)

Expand Down
8 changes: 8 additions & 0 deletions selfservice/flow/login/flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,14 @@ func (f *Flow) EnsureInternalContext() {
}
}

func (f *Flow) GetInternalContext() sqlxx.JSONRawMessage {
return f.InternalContext
}

func (f *Flow) SetInternalContext(bytes sqlxx.JSONRawMessage) {
f.InternalContext = bytes
}

func (f Flow) MarshalJSON() ([]byte, error) {
type local Flow
f.SetReturnTo()
Expand Down
Loading

0 comments on commit b784949

Please sign in to comment.