Skip to content

Commit

Permalink
fix: account linking with 2FA (#4188)
Browse files Browse the repository at this point in the history
This fixes some edge cases with OIDC account linking for accounts with 2FA enabled.
  • Loading branch information
hperl authored Nov 7, 2024
1 parent 215af57 commit 4a870a6
Show file tree
Hide file tree
Showing 11 changed files with 170 additions and 36 deletions.
1 change: 1 addition & 0 deletions internal/client-go/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e h1:bRhVy7zSSasaqNksaRZiA5EEI+Ei4I1nO5Jh72wfHlg=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
Expand Down
2 changes: 0 additions & 2 deletions selfservice/flow/duplicate_credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package flow
import (
"encoding/json"

"github.com/gofrs/uuid"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"

Expand All @@ -20,7 +19,6 @@ type DuplicateCredentialsData struct {
CredentialsType identity.CredentialsType
CredentialsConfig sqlxx.JSONRawMessage
DuplicateIdentifier string
OrganizationID uuid.UUID
}

type InternalContexter interface {
Expand Down
7 changes: 2 additions & 5 deletions selfservice/flow/login/flow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,10 @@ import (

"github.com/tidwall/gjson"

"github.com/ory/x/jsonx"
"github.com/ory/x/sqlxx"
"github.com/ory/x/uuidx"

"github.com/ory/kratos/driver/config"
"github.com/ory/kratos/identity"
"github.com/ory/x/jsonx"
"github.com/ory/x/sqlxx"

"github.com/ory/kratos/internal"

Expand Down Expand Up @@ -225,7 +223,6 @@ func TestDuplicateCredentials(t *testing.T) {
CredentialsType: "foo",
CredentialsConfig: sqlxx.JSONRawMessage(`{"bar":"baz"}`),
DuplicateIdentifier: "bar",
OrganizationID: uuidx.NewV4(),
}

require.NoError(t, flow.SetDuplicateCredentials(f, dc))
Expand Down
20 changes: 19 additions & 1 deletion selfservice/flow/login/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,18 @@ func WithFlowReturnTo(returnTo string) FlowOption {
}
}

func WithOrganizationID(organizationID uuid.NullUUID) FlowOption {
return func(f *Flow) {
f.OrganizationID = organizationID
}
}

func WithRequestedAAL(aal identity.AuthenticatorAssuranceLevel) FlowOption {
return func(f *Flow) {
f.RequestedAAL = aal
}
}

func WithInternalContext(internalContext []byte) FlowOption {
return func(f *Flow) {
f.InternalContext = internalContext
Expand Down Expand Up @@ -217,7 +229,13 @@ preLoginHook:

if orgID.Valid {
f.OrganizationID = orgID
strategyFilters = []StrategyFilter{func(s Strategy) bool { return s.ID() == identity.CredentialsTypeOIDC }}
if f.RequestedAAL == identity.AuthenticatorAssuranceLevel1 {
// We only apply the filter on AAL1, because the OIDC strategy can only satsify
// AAL1.
strategyFilters = []StrategyFilter{func(s Strategy) bool {
return s.ID() == identity.CredentialsTypeOIDC
}}
}
}

for _, s := range h.d.LoginStrategies(r.Context(), strategyFilters...) {
Expand Down
55 changes: 49 additions & 6 deletions selfservice/flow/login/hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"context"
"fmt"
"net/http"
"net/url"
"time"

"github.com/gofrs/uuid"
Expand Down Expand Up @@ -55,6 +56,7 @@ type (
x.LoggingProvider
x.TracingProvider
sessiontokenexchange.PersistenceProvider
HandlerProvider

FlowPersistenceProvider
HooksProvider
Expand Down Expand Up @@ -273,8 +275,28 @@ func (e *HookExecutor) PostLoginHook(
// If we detect that whoami would require a higher AAL, we redirect!
if err := e.checkAAL(ctx, classified, f); err != nil {
if aalErr := new(session.ErrAALNotSatisfied); errors.As(err, &aalErr) {
span.SetAttributes(attribute.String("return_to", aalErr.RedirectTo), attribute.String("redirect_reason", "requires aal2"))
e.d.Writer().WriteError(w, r, flow.NewBrowserLocationChangeRequiredError(aalErr.RedirectTo))
if data, _ := flow.DuplicateCredentials(f); data == nil {
span.SetAttributes(attribute.String("return_to", aalErr.RedirectTo), attribute.String("redirect_reason", "requires aal2"))
e.d.Writer().WriteError(w, r, flow.NewBrowserLocationChangeRequiredError(aalErr.RedirectTo))
return nil
}

// Special case: If we are in a flow that wants to link credentials, we create a
// new login flow here that asks for the require AAL, but also copies over the
// internal context and the organization ID.
r.URL, err = url.Parse(aalErr.RedirectTo)
if err != nil {
return errors.WithStack(err)
}
newFlow, _, err := e.d.LoginHandler().NewLoginFlow(w, r, flow.TypeBrowser,
WithInternalContext(f.InternalContext),
WithOrganizationID(f.OrganizationID),
)
if err != nil {
return errors.WithStack(err)
}

x.AcceptToRedirectOrJSON(w, r, e.d.Writer(), newFlow, newFlow.AppendTo(e.d.Config().SelfServiceFlowLoginUI(ctx)).String())
return nil
}
return err
Expand Down Expand Up @@ -309,7 +331,27 @@ func (e *HookExecutor) PostLoginHook(
// If we detect that whoami would require a higher AAL, we redirect!
if err := e.checkAAL(ctx, classified, f); err != nil {
if aalErr := new(session.ErrAALNotSatisfied); errors.As(err, &aalErr) {
http.Redirect(w, r, aalErr.RedirectTo, http.StatusSeeOther)
if data, _ := flow.DuplicateCredentials(f); data == nil {
http.Redirect(w, r, aalErr.RedirectTo, http.StatusSeeOther)
return nil
}

// Special case: If we are in a flow that wants to link credentials, we create a
// new login flow here that asks for the require AAL, but also copies over the
// internal context and the organization ID.
r.URL, err = url.Parse(aalErr.RedirectTo)
if err != nil {
return errors.WithStack(err)
}
newFlow, _, err := e.d.LoginHandler().NewLoginFlow(w, r, flow.TypeBrowser,
WithInternalContext(f.InternalContext),
WithOrganizationID(f.OrganizationID),
)
if err != nil {
return errors.WithStack(err)
}

x.AcceptToRedirectOrJSON(w, r, e.d.Writer(), newFlow, newFlow.AppendTo(e.d.Config().SelfServiceFlowLoginUI(ctx)).String())
return nil
}
return errors.WithStack(err)
Expand Down Expand Up @@ -362,7 +404,7 @@ func (e *HookExecutor) maybeLinkCredentials(ctx context.Context, sess *session.S
return nil
}

if err := e.checkDuplicateCredentialsIdentifierMatch(ctx, ident.ID, lc.DuplicateIdentifier); err != nil {
if err = e.checkDuplicateCredentialsIdentifierMatch(ctx, ident.ID, lc.DuplicateIdentifier); err != nil {
return err
}
strategy, err := e.d.AllLoginStrategies().Strategy(lc.CredentialsType)
Expand All @@ -380,8 +422,9 @@ func (e *HookExecutor) maybeLinkCredentials(ctx context.Context, sess *session.S
return err
}

method := strategy.CompletedAuthenticationMethod(ctx)
sess.CompletedLoginForMethod(method)
if err = linkableStrategy.CompletedLogin(sess, lc); err != nil {
return err
}

return nil
}
Expand Down
48 changes: 39 additions & 9 deletions selfservice/flow/login/hook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"testing"
"time"

"github.com/gofrs/uuid"
"github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand All @@ -27,6 +28,7 @@ import (
"github.com/ory/kratos/selfservice/flow"
"github.com/ory/kratos/selfservice/flow/login"
"github.com/ory/kratos/session"
"github.com/ory/kratos/ui/node"
"github.com/ory/kratos/x"
)

Expand All @@ -42,6 +44,7 @@ func TestLoginExecutor(t *testing.T) {
reg.WithHydra(hydra.NewFake())
testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/login.schema.json")
conf.MustSet(ctx, config.ViperKeySelfServiceBrowserDefaultReturnTo, "https://www.ory.sh/")
_ = testhelpers.NewLoginUIFlowEchoServer(t, reg)

newServer := func(t *testing.T, ft flow.Type, useIdentity *identity.Identity, flowCallback ...func(*login.Flow)) *httptest.Server {
router := httprouter.New()
Expand Down Expand Up @@ -222,7 +225,6 @@ func TestLoginExecutor(t *testing.T) {

t.Run("case=work normally if AAL is satisfied", func(t *testing.T) {
conf.MustSet(ctx, config.ViperKeySessionWhoAmIAAL, "aal1")
_ = testhelpers.NewLoginUIFlowEchoServer(t, reg)
t.Cleanup(testhelpers.SelfServiceHookConfigReset(t, conf))

useIdentity := &identity.Identity{Credentials: map[identity.CredentialsType]identity.Credentials{
Expand Down Expand Up @@ -255,7 +257,6 @@ func TestLoginExecutor(t *testing.T) {

t.Run("case=redirect to login if AAL is too low", func(t *testing.T) {
conf.MustSet(ctx, config.ViperKeySessionWhoAmIAAL, "highest_available")
_ = testhelpers.NewLoginUIFlowEchoServer(t, reg)
t.Cleanup(func() {
conf.MustSet(ctx, config.ViperKeySessionWhoAmIAAL, "aal1")
})
Expand Down Expand Up @@ -320,6 +321,7 @@ func TestLoginExecutor(t *testing.T) {
t.Run("case=maybe links credential", func(t *testing.T) {
t.Cleanup(testhelpers.SelfServiceHookConfigReset(t, conf))
conf.MustSet(ctx, config.ViperKeySessionWhoAmIAAL, config.HighestAvailableAAL)
conf.MustSet(ctx, "selfservice.methods.totp.enabled", true)

email1, email2 := testhelpers.RandomEmail(), testhelpers.RandomEmail()
passwordOnlyIdentity := &identity.Identity{Credentials: map[identity.CredentialsType]identity.Credentials{
Expand Down Expand Up @@ -360,15 +362,43 @@ func TestLoginExecutor(t *testing.T) {
require.NoError(t, err)

t.Run("sub-case=does not link after first factor when second factor is available", func(t *testing.T) {
duplicateCredentialsData := flow.DuplicateCredentialsData{
CredentialsType: identity.CredentialsTypeOIDC,
CredentialsConfig: credsOIDC2FA.Config,
DuplicateIdentifier: email2,
}
ts := newServer(t, flow.TypeBrowser, twoFAIdentitiy, func(l *login.Flow) {
require.NoError(t, flow.SetDuplicateCredentials(l, flow.DuplicateCredentialsData{
CredentialsType: identity.CredentialsTypeOIDC,
CredentialsConfig: credsOIDC2FA.Config,
DuplicateIdentifier: email2,
}))
require.NoError(t, flow.SetDuplicateCredentials(l, duplicateCredentialsData))
})
res, body := makeRequestPost(t, ts, false, url.Values{})
assert.Equal(t, res.Request.URL.String(), ts.URL+login.RouteInitBrowserFlow+"?aal=aal2", "%s", body)
res, _ := makeRequestPost(t, ts, false, url.Values{})

assert.Equal(t, reg.Config().SelfServiceFlowLoginUI(ctx).Host, res.Request.URL.Host)
assert.Equal(t, reg.Config().SelfServiceFlowLoginUI(ctx).Path, res.Request.URL.Path)
newFlowID := res.Request.URL.Query().Get("flow")
assert.NotEmpty(t, newFlowID)

newFlow, err := reg.LoginFlowPersister().GetLoginFlow(ctx, uuid.Must(uuid.FromString(newFlowID)))
require.NoError(t, err)
newFlowDuplicateCredentialsData, err := flow.DuplicateCredentials(newFlow)
require.NoError(t, err)

// Duplicate credentials data should have been copied over
assert.Equal(t, duplicateCredentialsData.CredentialsType, newFlowDuplicateCredentialsData.CredentialsType)
assert.Equal(t, duplicateCredentialsData.DuplicateIdentifier, newFlowDuplicateCredentialsData.DuplicateIdentifier)
assert.JSONEq(t, string(duplicateCredentialsData.CredentialsConfig), string(newFlowDuplicateCredentialsData.CredentialsConfig))

// AAL should be AAL2
assert.Equal(t, identity.AuthenticatorAssuranceLevel2, newFlow.RequestedAAL)

// TOTP nodes should be present
found := false
for _, n := range newFlow.UI.Nodes {
if n.Group == node.TOTPGroup {
found = true
break
}
}
assert.True(t, found, "could not find TOTP nodes in %+v", newFlow.UI.Nodes)

ident, err := reg.Persister().GetIdentity(ctx, twoFAIdentitiy.ID, identity.ExpandCredentials)
require.NoError(t, err)
Expand Down
3 changes: 3 additions & 0 deletions selfservice/flow/login/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/pkg/errors"

"github.com/ory/kratos/identity"
"github.com/ory/kratos/selfservice/flow"
"github.com/ory/kratos/session"
"github.com/ory/kratos/ui/node"
"github.com/ory/kratos/x"
Expand All @@ -28,6 +29,8 @@ type Strategies []Strategy

type LinkableStrategy interface {
Link(ctx context.Context, i *identity.Identity, credentials sqlxx.JSONRawMessage) error
CompletedLogin(sess *session.Session, data *flow.DuplicateCredentialsData) error
SetDuplicateCredentials(f flow.InternalContexter, duplicateIdentifier string, credentials identity.Credentials, provider string) error
}

func (s Strategies) Strategy(id identity.CredentialsType) (Strategy, error) {
Expand Down
17 changes: 7 additions & 10 deletions selfservice/flow/registration/hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,21 +161,18 @@ func (e *HookExecutor) PostRegistrationHook(w http.ResponseWriter, r *http.Reque
return err
}

if _, ok := strategy.(login.LinkableStrategy); ok {
if strategy, ok := strategy.(login.LinkableStrategy); ok {
duplicateIdentifier, err := e.getDuplicateIdentifier(ctx, i)
if err != nil {
return err
}
registrationDuplicateCredentials := flow.DuplicateCredentialsData{
CredentialsType: ct,
CredentialsConfig: i.Credentials[ct].Config,
DuplicateIdentifier: duplicateIdentifier,
}
if registrationFlow.OrganizationID.Valid {
registrationDuplicateCredentials.OrganizationID = registrationFlow.OrganizationID.UUID
}

if err := flow.SetDuplicateCredentials(registrationFlow, registrationDuplicateCredentials); err != nil {
if err := strategy.SetDuplicateCredentials(
registrationFlow,
duplicateIdentifier,
i.Credentials[ct],
provider,
); err != nil {
return err
}
}
Expand Down
1 change: 1 addition & 0 deletions selfservice/strategy/oidc/strategy_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ func (s *Strategy) processLogin(ctx context.Context, w http.ResponseWriter, r *h

sess := session.NewInactiveSession()
sess.CompletedLoginForWithProvider(s.ID(), identity.AuthenticatorAssuranceLevel1, provider.Config().ID, provider.Config().OrganizationID)

for _, c := range oidcCredentials.Providers {
if c.Subject == claims.Subject && c.Provider == provider.Config().ID {
if err = s.d.LoginHookExecutor().PostLoginHook(w, r, node.OpenIDConnectGroup, loginFlow, i, sess, provider.Config().ID); err != nil {
Expand Down
47 changes: 46 additions & 1 deletion selfservice/strategy/oidc/strategy_settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"context"
_ "embed"
"encoding/json"
"fmt"
"net/http"
"time"

Expand Down Expand Up @@ -527,7 +528,7 @@ func (s *Strategy) Link(ctx context.Context, i *identity.Identity, credentialsCo
return err
}
if len(credentialsOIDCConfig.Providers) != 1 {
return errors.New("No oidc provider was set")
return errors.New("no oidc provider was set")
}
credentialsOIDCProvider := credentialsOIDCConfig.Providers[0]

Expand All @@ -550,3 +551,47 @@ func (s *Strategy) Link(ctx context.Context, i *identity.Identity, credentialsCo

return nil
}

func (s *Strategy) CompletedLogin(sess *session.Session, data *flow.DuplicateCredentialsData) error {
var credentialsOIDCConfig identity.CredentialsOIDC
if err := json.Unmarshal(data.CredentialsConfig, &credentialsOIDCConfig); err != nil {
return err
}
if len(credentialsOIDCConfig.Providers) != 1 {
return errors.New("no oidc provider was set")
}
credentialsOIDCProvider := credentialsOIDCConfig.Providers[0]

sess.CompletedLoginForWithProvider(
s.ID(),
identity.AuthenticatorAssuranceLevel1,
credentialsOIDCProvider.Provider,
credentialsOIDCProvider.Organization,
)

return nil
}

func (s *Strategy) SetDuplicateCredentials(f flow.InternalContexter, duplicateIdentifier string, credentials identity.Credentials, provider string) error {
var credentialsOIDCConfig identity.CredentialsOIDC
if err := json.Unmarshal(credentials.Config, &credentialsOIDCConfig); err != nil {
return err
}

// We want to only set the provider in the credentials config that was used to authenticate the user.
for _, p := range credentialsOIDCConfig.Providers {
if p.Provider == provider {
credentialsOIDCConfig.Providers = []identity.CredentialsOIDCProvider{p}
config, err := json.Marshal(credentialsOIDCConfig)
if err != nil {
return err
}
return flow.SetDuplicateCredentials(f, flow.DuplicateCredentialsData{
CredentialsType: s.ID(),
CredentialsConfig: config,
DuplicateIdentifier: duplicateIdentifier,
})
}
}
return fmt.Errorf("provider %q not found in credentials", provider)
}
Loading

0 comments on commit 4a870a6

Please sign in to comment.