From 0fdfcb52eac1213484af389e03f73bb9e34172cf Mon Sep 17 00:00:00 2001 From: Jonas Hungershausen Date: Sat, 6 Jan 2024 13:59:46 +0100 Subject: [PATCH] feat: support MFA via SMS --- .../phone-password/identity.schema.json | 25 ++++++- courier/sms_templates.go | 6 ++ .../login_code/valid/sms.body.gotmpl | 1 + courier/template/sms/login_code_valid.go | 52 ++++++++++++++ driver/registry_default.go | 7 +- driver/registry_default_test.go | 12 +++- embedx/identity_extension.schema.json | 2 +- identity/credentials_code.go | 2 +- identity/extension_credentials.go | 2 + internal/testhelpers/selfservice_login.go | 4 +- selfservice/flow/login/handler.go | 2 +- selfservice/flow/login/hook.go | 2 +- selfservice/flow/login/strategy.go | 2 +- selfservice/flow/registration/handler_test.go | 3 +- selfservice/flow/request.go | 15 +--- selfservice/flow/request_test.go | 71 ------------------- selfservice/strategy/code/code_sender.go | 28 +++++--- selfservice/strategy/code/strategy.go | 26 +++++-- selfservice/strategy/code/strategy_login.go | 39 ++++++---- .../strategy/code/strategy_login_test.go | 27 +++---- .../code/strategy_registration_test.go | 5 +- selfservice/strategy/lookup/strategy.go | 8 ++- selfservice/strategy/oidc/strategy.go | 2 +- selfservice/strategy/password/strategy.go | 2 +- selfservice/strategy/totp/strategy.go | 10 +-- selfservice/strategy/webauthn/strategy.go | 10 +-- .../strategy/webauthn/strategy_test.go | 4 +- 27 files changed, 203 insertions(+), 166 deletions(-) create mode 100644 courier/template/courier/builtin/templates/login_code/valid/sms.body.gotmpl create mode 100644 courier/template/sms/login_code_valid.go diff --git a/contrib/quickstart/kratos/phone-password/identity.schema.json b/contrib/quickstart/kratos/phone-password/identity.schema.json index 0d986aeb672e..7f757b0a1690 100644 --- a/contrib/quickstart/kratos/phone-password/identity.schema.json +++ b/contrib/quickstart/kratos/phone-password/identity.schema.json @@ -7,6 +7,26 @@ "traits": { "type": "object", "properties": { + "email": { + "type": "string", + "format": "email", + "title": "E-mail", + "minLength": 3, + "ory.sh/kratos": { + "credentials": { + "password": { + "identifier": true + }, + "code": { + "identifier": true, + "via": "email" + } + }, + "verification": { + "via": "email" + } + } + }, "phone": { "type": "string", "format": "tel", @@ -16,6 +36,9 @@ "credentials": { "password": { "identifier": true + }, + "code": { + "identifier": true } }, "verification": { @@ -24,7 +47,7 @@ } } }, - "required": ["phone"], + "required": ["email", "phone"], "additionalProperties": false } } diff --git a/courier/sms_templates.go b/courier/sms_templates.go index 683ba7d98ca3..b560542d53c9 100644 --- a/courier/sms_templates.go +++ b/courier/sms_templates.go @@ -34,6 +34,12 @@ func NewSMSTemplateFromMessage(d template.Dependencies, m Message) (SMSTemplate, return nil, err } return sms.NewTestStub(d, &t), nil + case template.TypeLoginCodeValid: + var t sms.LoginCodeValidModel + if err := json.Unmarshal(m.TemplateData, &t); err != nil { + return nil, err + } + return sms.NewLoginCodeValid(d, &t), nil default: return nil, errors.Errorf("received unexpected message template type: %s", m.TemplateType) diff --git a/courier/template/courier/builtin/templates/login_code/valid/sms.body.gotmpl b/courier/template/courier/builtin/templates/login_code/valid/sms.body.gotmpl new file mode 100644 index 000000000000..5b88dde9a382 --- /dev/null +++ b/courier/template/courier/builtin/templates/login_code/valid/sms.body.gotmpl @@ -0,0 +1 @@ +Your login code is: {{ .LoginCode }} diff --git a/courier/template/sms/login_code_valid.go b/courier/template/sms/login_code_valid.go new file mode 100644 index 000000000000..c72189bc5f3d --- /dev/null +++ b/courier/template/sms/login_code_valid.go @@ -0,0 +1,52 @@ +// Copyright © 2023 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package sms + +import ( + "context" + "encoding/json" + "os" + + "github.com/ory/kratos/courier/template" +) + +type ( + LoginCodeValid struct { + deps template.Dependencies + model *LoginCodeValidModel + } + LoginCodeValidModel struct { + To string + LoginCode string + Identity map[string]interface{} + } +) + +func NewLoginCodeValid(d template.Dependencies, m *LoginCodeValidModel) *LoginCodeValid { + return &LoginCodeValid{deps: d, model: m} +} + +func (t *LoginCodeValid) PhoneNumber() (string, error) { + return t.model.To, nil +} + +func (t *LoginCodeValid) SMSBody(ctx context.Context) (string, error) { + return template.LoadText( + ctx, + t.deps, + os.DirFS(t.deps.CourierConfig().CourierTemplatesRoot(ctx)), + "login_code/valid/sms.body.gotmpl", // TODO:!!! + "login_code/valid/sms.body*", + t.model, + "", + ) +} + +func (t *LoginCodeValid) MarshalJSON() ([]byte, error) { + return json.Marshal(t.model) +} + +func (t *LoginCodeValid) TemplateType() template.TemplateType { + return template.TypeLoginCodeValid +} diff --git a/driver/registry_default.go b/driver/registry_default.go index f9c47bb30739..9e5568b3ee1b 100644 --- a/driver/registry_default.go +++ b/driver/registry_default.go @@ -353,12 +353,7 @@ func (m *RegistryDefault) strategyRegistrationEnabled(ctx context.Context, id st } func (m *RegistryDefault) strategyLoginEnabled(ctx context.Context, id string) bool { - switch id { - case identity.CredentialsTypeCodeAuth.String(): - return m.Config().SelfServiceCodeStrategy(ctx).PasswordlessEnabled - default: - return m.Config().SelfServiceStrategy(ctx, id).Enabled - } + return m.Config().SelfServiceStrategy(ctx, id).Enabled } func (m *RegistryDefault) RegistrationStrategies(ctx context.Context, filters ...registration.StrategyFilter) (registrationStrategies registration.Strategies) { diff --git a/driver/registry_default_test.go b/driver/registry_default_test.go index 533cdd621a9a..729ee41879e1 100644 --- a/driver/registry_default_test.go +++ b/driver/registry_default_test.go @@ -674,39 +674,49 @@ func TestDriverDefault_Strategies(t *testing.T) { t.Run("case=login", func(t *testing.T) { t.Parallel() for k, tc := range []struct { + name string prep func(conf *config.Config) expect []string }{ { + name: "no strategies", prep: func(conf *config.Config) { conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".password.enabled", false) + conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.enabled", false) }, }, { + name: "only password", prep: func(conf *config.Config) { conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".password.enabled", true) + conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.enabled", false) }, expect: []string{"password"}, }, { + name: "oidc and password", prep: func(conf *config.Config) { conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".oidc.enabled", true) conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".password.enabled", true) + conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.enabled", false) }, expect: []string{"password", "oidc"}, }, { + name: "oidc, password and totp", prep: func(conf *config.Config) { conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".oidc.enabled", true) conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".password.enabled", true) conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".totp.enabled", true) + conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.enabled", false) }, expect: []string{"password", "oidc", "totp"}, }, { + name: "password and code", prep: func(conf *config.Config) { conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".password.enabled", true) - conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.passwordless_enabled", true) + conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.enabled", true) }, expect: []string{"password", "code"}, }, diff --git a/embedx/identity_extension.schema.json b/embedx/identity_extension.schema.json index c105cbe42f95..6ad7765ef839 100644 --- a/embedx/identity_extension.schema.json +++ b/embedx/identity_extension.schema.json @@ -48,7 +48,7 @@ }, "via": { "type": "string", - "enum": ["email"] + "enum": ["email", "sms"] } } } diff --git a/identity/credentials_code.go b/identity/credentials_code.go index abc9decbbf46..b7f828ae09a1 100644 --- a/identity/credentials_code.go +++ b/identity/credentials_code.go @@ -7,7 +7,7 @@ import ( "database/sql" ) -type CodeAddressType string +type CodeAddressType = string const ( CodeAddressTypeEmail CodeAddressType = AddressTypeEmail diff --git a/identity/extension_credentials.go b/identity/extension_credentials.go index 3baa826b2e9c..709a47cb769c 100644 --- a/identity/extension_credentials.go +++ b/identity/extension_credentials.go @@ -70,6 +70,8 @@ func (r *SchemaExtensionCredentials) Run(ctx jsonschema.ValidationContext, s sch // } // r.setIdentifier(CredentialsTypeCodeAuth, value, CredentialsIdentifierAddressTypePhone) + case f.AddCase(""): + // continue default: return ctx.Error("", "credentials.code.via has unknown value %q", s.Credentials.Code.Via) } diff --git a/internal/testhelpers/selfservice_login.go b/internal/testhelpers/selfservice_login.go index bedba03fee16..ec8f14a53bf5 100644 --- a/internal/testhelpers/selfservice_login.go +++ b/internal/testhelpers/selfservice_login.go @@ -143,12 +143,12 @@ func InitializeLoginFlowViaBrowser(t *testing.T, client *http.Client, ts *httpte flowID = gjson.GetBytes(body, "id").String() } - rs, _, err := publicClient.FrontendApi.GetLoginFlow(context.Background()).Id(flowID).Execute() + rs, r, err := publicClient.FrontendApi.GetLoginFlow(context.Background()).Id(flowID).Execute() if expectGetError { require.Error(t, err) require.Nil(t, rs) } else { - require.NoError(t, err) + require.NoError(t, err, "%s", ioutilx.MustReadAll(r.Body)) assert.Empty(t, rs.Active) } diff --git a/selfservice/flow/login/handler.go b/selfservice/flow/login/handler.go index 096a657c0869..95fe1e321b30 100644 --- a/selfservice/flow/login/handler.go +++ b/selfservice/flow/login/handler.go @@ -798,7 +798,7 @@ continueLogin: sess = session.NewInactiveSession() } - method := ss.CompletedAuthenticationMethod(r.Context()) + method := ss.CompletedAuthenticationMethod(r.Context(), sess.AMR) sess.CompletedLoginForMethod(method) i = interim break diff --git a/selfservice/flow/login/hook.go b/selfservice/flow/login/hook.go index 2af0c3daf1fc..e34ad32c2433 100644 --- a/selfservice/flow/login/hook.go +++ b/selfservice/flow/login/hook.go @@ -356,7 +356,7 @@ func (e *HookExecutor) maybeLinkCredentials(ctx context.Context, sess *session.S return err } - method := strategy.CompletedAuthenticationMethod(ctx) + method := strategy.CompletedAuthenticationMethod(ctx, sess.AMR) sess.CompletedLoginForMethod(method) return nil diff --git a/selfservice/flow/login/strategy.go b/selfservice/flow/login/strategy.go index c8ad84986a55..4b802a0ae0cf 100644 --- a/selfservice/flow/login/strategy.go +++ b/selfservice/flow/login/strategy.go @@ -23,7 +23,7 @@ type Strategy interface { RegisterLoginRoutes(*x.RouterPublic) PopulateLoginMethod(r *http.Request, requestedAAL identity.AuthenticatorAssuranceLevel, sr *Flow) error Login(w http.ResponseWriter, r *http.Request, f *Flow, identityID uuid.UUID) (i *identity.Identity, err error) - CompletedAuthenticationMethod(ctx context.Context) session.AuthenticationMethod + CompletedAuthenticationMethod(ctx context.Context, methods session.AuthenticationMethods) session.AuthenticationMethod } type Strategies []Strategy diff --git a/selfservice/flow/registration/handler_test.go b/selfservice/flow/registration/handler_test.go index eae66fb720f8..a9e7b842718c 100644 --- a/selfservice/flow/registration/handler_test.go +++ b/selfservice/flow/registration/handler_test.go @@ -21,6 +21,7 @@ import ( "github.com/ory/kratos/corpx" "github.com/ory/kratos/hydra" + "github.com/ory/x/ioutilx" "github.com/ory/x/urlx" "github.com/stretchr/testify/assert" @@ -479,7 +480,7 @@ func TestOIDCStrategyOrder(t *testing.T) { resp, err := client.Do(req) require.NoError(t, err) - require.Equal(t, http.StatusOK, resp.StatusCode) + require.Equal(t, http.StatusOK, resp.StatusCode, "%s", ioutilx.MustReadAll(resp.Body)) verifiableAddress, err := reg.PrivilegedIdentityPool().FindVerifiableAddressByValue(ctx, identity.VerifiableAddressTypeEmail, email) require.NoError(t, err) diff --git a/selfservice/flow/request.go b/selfservice/flow/request.go index 6db872f43199..bb140eae24a2 100644 --- a/selfservice/flow/request.go +++ b/selfservice/flow/request.go @@ -10,7 +10,6 @@ import ( "strings" "github.com/ory/kratos/driver/config" - "github.com/ory/kratos/identity" "github.com/ory/kratos/selfservice/strategy" "github.com/ory/x/decoderx" @@ -106,19 +105,7 @@ func MethodEnabledAndAllowed(ctx context.Context, flowName FlowName, expected, a return errors.WithStack(ErrStrategyNotResponsible) } - var ok bool - if strings.EqualFold(actual, identity.CredentialsTypeCodeAuth.String()) { - switch flowName { - case RegistrationFlow, LoginFlow: - ok = d.Config().SelfServiceCodeStrategy(ctx).PasswordlessEnabled - case VerificationFlow, RecoveryFlow: - ok = d.Config().SelfServiceCodeStrategy(ctx).Enabled - } - } else { - ok = d.Config().SelfServiceStrategy(ctx, expected).Enabled - } - - if !ok { + if !d.Config().SelfServiceStrategy(ctx, expected).Enabled { return errors.WithStack(herodot.ErrNotFound.WithReason(strategy.EndpointDisabledMessage)) } diff --git a/selfservice/flow/request_test.go b/selfservice/flow/request_test.go index adc47f6149c9..3728748f8044 100644 --- a/selfservice/flow/request_test.go +++ b/selfservice/flow/request_test.go @@ -101,74 +101,3 @@ func TestMethodEnabledAndAllowed(t *testing.T) { assert.Contains(t, string(body), "The requested resource could not be found") }) } - -func TestMethodCodeEnabledAndAllowed(t *testing.T) { - ctx := context.Background() - conf, d := internal.NewFastRegistryWithMocks(t) - - currentFlow := make(chan flow.FlowName, 1) - - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - f := <-currentFlow - if err := flow.MethodEnabledAndAllowedFromRequest(r, f, "code", d); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - w.WriteHeader(http.StatusNoContent) - })) - - t.Run("login code allowed", func(t *testing.T) { - conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.passwordless_enabled", true) - currentFlow <- flow.LoginFlow - res, err := ts.Client().PostForm(ts.URL, url.Values{"method": {"code"}}) - require.NoError(t, err) - require.NoError(t, res.Body.Close()) - assert.Equal(t, http.StatusNoContent, res.StatusCode) - }) - - t.Run("login code not allowed", func(t *testing.T) { - conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.passwordless_enabled", false) - currentFlow <- flow.LoginFlow - res, err := ts.Client().PostForm(ts.URL, url.Values{"method": {"code"}}) - require.NoError(t, err) - body, err := io.ReadAll(res.Body) - require.NoError(t, err) - require.NoError(t, res.Body.Close()) - assert.Equal(t, http.StatusInternalServerError, res.StatusCode) - assert.Contains(t, string(body), "The requested resource could not be found") - }) - - t.Run("registration code allowed", func(t *testing.T) { - conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.passwordless_enabled", true) - currentFlow <- flow.RegistrationFlow - res, err := ts.Client().PostForm(ts.URL, url.Values{"method": {"code"}}) - require.NoError(t, err) - require.NoError(t, res.Body.Close()) - assert.Equal(t, http.StatusNoContent, res.StatusCode) - }) - - t.Run("registration code not allowed", func(t *testing.T) { - conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.passwordless_enabled", false) - currentFlow <- flow.RegistrationFlow - res, err := ts.Client().PostForm(ts.URL, url.Values{"method": {"code"}}) - require.NoError(t, err) - body, err := io.ReadAll(res.Body) - require.NoError(t, err) - require.NoError(t, res.Body.Close()) - assert.Equal(t, http.StatusInternalServerError, res.StatusCode) - assert.Contains(t, string(body), "The requested resource could not be found") - }) - - t.Run("recovery and verification should still be allowed if registration and login is disabled", func(t *testing.T) { - conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.passwordless_enabled", false) - conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.enabled", true) - - for _, f := range []flow.FlowName{flow.RecoveryFlow, flow.VerificationFlow} { - currentFlow <- f - res, err := ts.Client().PostForm(ts.URL, url.Values{"method": {"code"}}) - require.NoError(t, err) - require.NoError(t, res.Body.Close()) - assert.Equal(t, http.StatusNoContent, res.StatusCode) - } - }) -} diff --git a/selfservice/strategy/code/code_sender.go b/selfservice/strategy/code/code_sender.go index d41beef463ed..ccc3330f9ba8 100644 --- a/selfservice/strategy/code/code_sender.go +++ b/selfservice/strategy/code/code_sender.go @@ -84,7 +84,7 @@ func (s *Sender) SendCode(ctx context.Context, f flow.Flow, id *identity.Identit code, err := s.deps. RegistrationCodePersister(). CreateRegistrationCode(ctx, &CreateRegistrationCodeParams{ - AddressType: address.Via, + AddressType: identity.CodeAddressType(address.Via), RawCode: rawCode, ExpiresIn: s.deps.Config().SelfServiceCodeMethodLifespan(ctx), FlowID: f.GetID(), @@ -118,7 +118,7 @@ func (s *Sender) SendCode(ctx context.Context, f flow.Flow, id *identity.Identit code, err := s.deps. LoginCodePersister(). CreateLoginCode(ctx, &CreateLoginCodeParams{ - AddressType: address.Via, + AddressType: identity.CodeAddressType(address.Via), Address: address.To, RawCode: rawCode, ExpiresIn: s.deps.Config().SelfServiceCodeMethodLifespan(ctx), @@ -133,19 +133,29 @@ func (s *Sender) SendCode(ctx context.Context, f flow.Flow, id *identity.Identit if err != nil { return err } - - emailModel := email.LoginCodeValidModel{ - To: address.To, - LoginCode: rawCode, - Identity: model, - } s.deps.Audit(). WithField("login_flow_id", code.FlowID). WithField("login_code_id", code.ID). WithSensitiveField("login_code", rawCode). Info("Sending out login email with code.") - if err := s.send(ctx, string(address.Via), email.NewLoginCodeValid(s.deps, &emailModel)); err != nil { + var t courier.Template + switch address.Via { + case identity.ChannelTypeEmail: + t = email.NewLoginCodeValid(s.deps, &email.LoginCodeValidModel{ + To: address.To, + LoginCode: rawCode, + Identity: model, + }) + case identity.ChannelTypeSMS: + t = sms.NewLoginCodeValid(s.deps, &sms.LoginCodeValidModel{ + To: address.To, + LoginCode: rawCode, + Identity: model, + }) + } + + if err := s.send(ctx, string(address.Via), t); err != nil { return errors.WithStack(err) } diff --git a/selfservice/strategy/code/strategy.go b/selfservice/strategy/code/strategy.go index a5d0c83703ec..de6cf64fb290 100644 --- a/selfservice/strategy/code/strategy.go +++ b/selfservice/strategy/code/strategy.go @@ -132,6 +132,15 @@ func (s *Strategy) NodeGroup() node.UiNodeGroup { } func (s *Strategy) PopulateMethod(r *http.Request, f flow.Flow) error { + switch f.GetFlowName() { + case flow.LoginFlow: + fallthrough + case flow.RegistrationFlow: + if !s.deps.Config().SelfServiceCodeStrategy(r.Context()).PasswordlessEnabled { + return nil + } + } + if string(f.GetState()) == "" { f.SetState(flow.StateChooseMethod) } @@ -153,13 +162,20 @@ func (s *Strategy) PopulateMethod(r *http.Request, f flow.Flow) error { if err != nil { return err } - - identifierLabel, err := login.GetIdentifierLabelFromSchema(r.Context(), ds.String()) - if err != nil { - return err + lf, ok := f.(*login.Flow) + if !ok { + return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Expected login.Flow but got: %T", f)) } + if lf.RequestedAAL == identity.AuthenticatorAssuranceLevel2 { + nodes.Upsert(node.NewInputField("identifier", "", node.DefaultGroup, node.InputAttributeTypeText, node.WithRequiredInputAttribute).WithMetaLabel(text.NewInfoNodeLabelID())) + } else { + identifierLabel, err := login.GetIdentifierLabelFromSchema(r.Context(), ds.String()) + if err != nil { + return err + } - nodes.Upsert(node.NewInputField("identifier", "", node.DefaultGroup, node.InputAttributeTypeText, node.WithRequiredInputAttribute).WithMetaLabel(identifierLabel)) + nodes.Upsert(node.NewInputField("identifier", "", node.DefaultGroup, node.InputAttributeTypeText, node.WithRequiredInputAttribute).WithMetaLabel(identifierLabel)) + } } else if f.GetFlowName() == flow.RegistrationFlow { ds, err := s.deps.Config().DefaultIdentityTraitsSchemaURL(r.Context()) if err != nil { diff --git a/selfservice/strategy/code/strategy_login.go b/selfservice/strategy/code/strategy_login.go index 1a0170d36648..3bf5f1890c37 100644 --- a/selfservice/strategy/code/strategy_login.go +++ b/selfservice/strategy/code/strategy_login.go @@ -17,6 +17,8 @@ import ( "github.com/ory/herodot" "github.com/ory/x/otelx" + "github.com/samber/lo" + "github.com/ory/kratos/identity" "github.com/ory/kratos/schema" "github.com/ory/kratos/selfservice/flow" @@ -60,7 +62,16 @@ type updateLoginFlowWithCodeMethod struct { func (s *Strategy) RegisterLoginRoutes(*x.RouterPublic) {} -func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context) session.AuthenticationMethod { +func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, amr session.AuthenticationMethods) session.AuthenticationMethod { + aal1Satisfied := lo.ContainsBy(amr, func(am session.AuthenticationMethod) bool { + return am.Method != identity.CredentialsTypeCodeAuth && am.AAL == identity.AuthenticatorAssuranceLevel1 + }) + if aal1Satisfied { + return session.AuthenticationMethod{ + Method: identity.CredentialsTypeCodeAuth, + AAL: identity.AuthenticatorAssuranceLevel2, + } + } return session.AuthenticationMethod{ Method: identity.CredentialsTypeCodeAuth, AAL: identity.AuthenticatorAssuranceLevel1, @@ -97,9 +108,6 @@ func (s *Strategy) HandleLoginError(r *http.Request, f *login.Flow, body *update } func (s *Strategy) PopulateLoginMethod(r *http.Request, requestedAAL identity.AuthenticatorAssuranceLevel, lf *login.Flow) error { - if requestedAAL > identity.AuthenticatorAssuranceLevel1 { - return nil - } return s.PopulateMethod(r, lf) } @@ -144,10 +152,6 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, return nil, err } - if err := login.CheckAAL(f, identity.AuthenticatorAssuranceLevel1); err != nil { - return nil, err - } - var p updateLoginFlowWithCodeMethod if err := s.dx.Decode(r, &p, decoderx.HTTPDecoderSetValidatePayloads(true), @@ -167,7 +171,7 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, switch f.GetState() { case flow.StateChooseMethod: - if err := s.loginSendEmail(ctx, w, r, f, &p); err != nil { + if err := s.loginSendCode(ctx, w, r, f, &p); err != nil { return nil, s.HandleLoginError(r, f, &p, err) } return nil, nil @@ -184,8 +188,8 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, return nil, s.HandleLoginError(r, f, &p, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Unexpected flow state: %s", f.GetState()))) } -func (s *Strategy) loginSendEmail(ctx context.Context, w http.ResponseWriter, r *http.Request, f *login.Flow, p *updateLoginFlowWithCodeMethod) (err error) { - ctx, span := s.deps.Tracer(ctx).Tracer().Start(ctx, "selfservice.strategy.code.strategy.loginSendEmail") +func (s *Strategy) loginSendCode(ctx context.Context, w http.ResponseWriter, r *http.Request, f *login.Flow, p *updateLoginFlowWithCodeMethod) (err error) { + ctx, span := s.deps.Tracer(ctx).Tracer().Start(ctx, "selfservice.strategy.code.strategy.loginSendCode") defer otelx.End(span, &err) if len(p.Identifier) == 0 { @@ -205,10 +209,15 @@ func (s *Strategy) loginSendEmail(ctx context.Context, w http.ResponseWriter, r return errors.WithStack(err) } - addresses := []Address{{ - To: p.Identifier, - Via: identity.CodeAddressType(identity.AddressTypeEmail), - }} + matches := lo.Filter(i.VerifiableAddresses, func(va identity.VerifiableAddress, _ int) bool { + return va.Value == maybeNormalizeEmail(p.Identifier) + }) + addresses := lo.Map(matches, func(va identity.VerifiableAddress, _ int) Address { + return Address{ + To: va.Value, + Via: va.Via, + } + }) // kratos only supports `email` identifiers at the moment with the code method // this is validated in the identity validation step above diff --git a/selfservice/strategy/code/strategy_login_test.go b/selfservice/strategy/code/strategy_login_test.go index cda8ef3c05a9..093f12f1271c 100644 --- a/selfservice/strategy/code/strategy_login_test.go +++ b/selfservice/strategy/code/strategy_login_test.go @@ -12,6 +12,7 @@ import ( "net/url" "testing" + "github.com/ory/x/ioutilx" "github.com/ory/x/sqlcon" "github.com/ory/x/stringsx" @@ -33,7 +34,7 @@ func TestLoginCodeStrategy(t *testing.T) { ctx := context.Background() conf, reg := internal.NewFastRegistryWithMocks(t) testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/code.identity.schema.json") - conf.MustSet(ctx, fmt.Sprintf("%s.%s.enabled", config.ViperKeySelfServiceStrategyConfig, identity.CredentialsTypeCodeAuth.String()), false) + conf.MustSet(ctx, fmt.Sprintf("%s.%s.enabled", config.ViperKeySelfServiceStrategyConfig, identity.CredentialsTypeCodeAuth.String()), true) conf.MustSet(ctx, fmt.Sprintf("%s.%s.passwordless_enabled", config.ViperKeySelfServiceStrategyConfig, identity.CredentialsTypeCodeAuth.String()), true) conf.MustSet(ctx, config.ViperKeySelfServiceBrowserDefaultReturnTo, "https://www.ory.sh") conf.MustSet(ctx, config.ViperKeyURLsAllowedReturnToDomains, []string{"https://www.ory.sh"}) @@ -173,13 +174,14 @@ func TestLoginCodeStrategy(t *testing.T) { resp, err = s.client.Do(req) require.NoError(t, err) require.EqualValues(t, http.StatusOK, resp.StatusCode) + body = string(ioutilx.MustReadAll(resp.Body)) } else { // SPAs need to be informed that the login has not yet completed using status 400. // Browser clients will redirect back to the login URL. if apiType == ApiTypeBrowser { - require.EqualValues(t, http.StatusOK, resp.StatusCode) + require.EqualValues(t, http.StatusOK, resp.StatusCode, "%s", body) } else { - require.EqualValues(t, http.StatusBadRequest, resp.StatusCode) + require.EqualValues(t, http.StatusBadRequest, resp.StatusCode, "%s", body) } } @@ -570,22 +572,13 @@ func TestLoginCodeStrategy(t *testing.T) { require.True(t, va.Verified) }) - t.Run("case=should not populate on 2FA request", func(t *testing.T) { + t.Run("case=should populate on 2FA request", func(t *testing.T) { if tc.apiType == ApiTypeNative { t.Skip("skipping test since it is not applicable to native clients") } ctx := context.Background() - // enable webauthn 2FA method - conf.MustSet(ctx, fmt.Sprintf("%s.%s.enabled", config.ViperKeySelfServiceStrategyConfig, "webauthn"), true) - conf.MustSet(ctx, config.ViperKeySessionWhoAmIAAL, config.HighestAvailableAAL) - - t.Cleanup(func() { - conf.MustSet(ctx, fmt.Sprintf("%s.%s.enabled", config.ViperKeySelfServiceStrategyConfig, "webauthn"), false) - conf.MustSet(ctx, config.ViperKeySessionWhoAmIAAL, "aal1") - }) - s := createLoginFlow(ctx, t, public, tc.apiType, false) // submit email @@ -603,17 +596,13 @@ func TestLoginCodeStrategy(t *testing.T) { s = submitLogin(ctx, t, s, tc.apiType, func(v *url.Values) { v.Set("code", loginCode) }, false, func(t *testing.T, s *state, body string, res *http.Response) { - if tc.apiType == ApiTypeSPA { - require.EqualValues(t, http.StatusOK, res.StatusCode) - } else { - require.EqualValues(t, http.StatusOK, res.StatusCode) - } + require.EqualValues(t, http.StatusOK, res.StatusCode) }) clientInit := testhelpers.InitializeLoginFlowViaBrowser(t, s.client, public, false, tc.apiType == ApiTypeSPA, false, false, testhelpers.InitFlowWithAAL("aal2")) body, err := json.Marshal(clientInit) require.NoError(t, err) - require.Len(t, gjson.GetBytes(body, "ui.nodes.#(group==code)").Array(), 0, "should not populate code field on 2fa request") + require.Len(t, gjson.GetBytes(body, "ui.nodes.#(group==code)").Array(), 1) }) }) } diff --git a/selfservice/strategy/code/strategy_registration_test.go b/selfservice/strategy/code/strategy_registration_test.go index e591e627f5ef..1ca150b5f0ba 100644 --- a/selfservice/strategy/code/strategy_registration_test.go +++ b/selfservice/strategy/code/strategy_registration_test.go @@ -96,7 +96,7 @@ func TestRegistrationCodeStrategy(t *testing.T) { conf, reg := internal.NewFastRegistryWithMocks(t) testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/code.identity.schema.json") conf.MustSet(ctx, fmt.Sprintf("%s.%s.enabled", config.ViperKeySelfServiceStrategyConfig, identity.CredentialsTypePassword.String()), false) - conf.MustSet(ctx, fmt.Sprintf("%s.%s.enabled", config.ViperKeySelfServiceStrategyConfig, identity.CredentialsTypeCodeAuth.String()), false) + conf.MustSet(ctx, fmt.Sprintf("%s.%s.enabled", config.ViperKeySelfServiceStrategyConfig, identity.CredentialsTypeCodeAuth.String()), true) conf.MustSet(ctx, fmt.Sprintf("%s.%s.passwordless_enabled", config.ViperKeySelfServiceStrategyConfig, identity.CredentialsTypeCodeAuth), true) conf.MustSet(ctx, config.ViperKeySelfServiceBrowserDefaultReturnTo, "https://www.ory.sh") conf.MustSet(ctx, config.ViperKeyURLsAllowedReturnToDomains, []string{"https://www.ory.sh"}) @@ -176,6 +176,7 @@ func TestRegistrationCodeStrategy(t *testing.T) { submitAssertion(ctx, t, s, body, resp) return s } + t.Logf("%v", body) if apiType == ApiTypeBrowser { require.EqualValues(t, http.StatusOK, resp.StatusCode) @@ -532,7 +533,7 @@ func TestRegistrationCodeStrategy(t *testing.T) { require.Contains(t, gjson.GetBytes(body, "ui.messages").String(), "Could not find any login identifiers") } else { - require.Equal(t, http.StatusBadRequest, resp.StatusCode) + require.Equal(t, http.StatusBadRequest, resp.StatusCode, "%v", body) require.Contains(t, gjson.Get(body, "ui.messages").String(), "Could not find any login identifiers") } }) diff --git a/selfservice/strategy/lookup/strategy.go b/selfservice/strategy/lookup/strategy.go index 13c233913f18..e8f12cac9948 100644 --- a/selfservice/strategy/lookup/strategy.go +++ b/selfservice/strategy/lookup/strategy.go @@ -24,8 +24,10 @@ import ( ) // var _ login.Strategy = new(Strategy) -var _ settings.Strategy = new(Strategy) -var _ identity.ActiveCredentialsCounter = new(Strategy) +var ( + _ settings.Strategy = new(Strategy) + _ identity.ActiveCredentialsCounter = new(Strategy) +) type lookupStrategyDependencies interface { x.LoggingProvider @@ -104,7 +106,7 @@ func (s *Strategy) NodeGroup() node.UiNodeGroup { return node.LookupGroup } -func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context) session.AuthenticationMethod { +func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, _ session.AuthenticationMethods) session.AuthenticationMethod { return session.AuthenticationMethod{ Method: s.ID(), AAL: identity.AuthenticatorAssuranceLevel2, diff --git a/selfservice/strategy/oidc/strategy.go b/selfservice/strategy/oidc/strategy.go index e67a5826b57e..4f5848e8a431 100644 --- a/selfservice/strategy/oidc/strategy.go +++ b/selfservice/strategy/oidc/strategy.go @@ -663,7 +663,7 @@ func (s *Strategy) NodeGroup() node.UiNodeGroup { return node.OpenIDConnectGroup } -func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context) session.AuthenticationMethod { +func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, _ session.AuthenticationMethods) session.AuthenticationMethod { return session.AuthenticationMethod{ Method: s.ID(), AAL: identity.AuthenticatorAssuranceLevel1, diff --git a/selfservice/strategy/password/strategy.go b/selfservice/strategy/password/strategy.go index bb750bc9ef80..911ad619cd15 100644 --- a/selfservice/strategy/password/strategy.go +++ b/selfservice/strategy/password/strategy.go @@ -109,7 +109,7 @@ func (s *Strategy) ID() identity.CredentialsType { return identity.CredentialsTypePassword } -func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context) session.AuthenticationMethod { +func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, _ session.AuthenticationMethods) session.AuthenticationMethod { return session.AuthenticationMethod{ Method: s.ID(), AAL: identity.AuthenticatorAssuranceLevel1, diff --git a/selfservice/strategy/totp/strategy.go b/selfservice/strategy/totp/strategy.go index bcaf09069053..6c3205abd9ac 100644 --- a/selfservice/strategy/totp/strategy.go +++ b/selfservice/strategy/totp/strategy.go @@ -24,9 +24,11 @@ import ( "github.com/ory/x/decoderx" ) -var _ login.Strategy = new(Strategy) -var _ settings.Strategy = new(Strategy) -var _ identity.ActiveCredentialsCounter = new(Strategy) +var ( + _ login.Strategy = new(Strategy) + _ settings.Strategy = new(Strategy) + _ identity.ActiveCredentialsCounter = new(Strategy) +) type totpStrategyDependencies interface { x.LoggingProvider @@ -107,7 +109,7 @@ func (s *Strategy) NodeGroup() node.UiNodeGroup { return node.TOTPGroup } -func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context) session.AuthenticationMethod { +func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, _ session.AuthenticationMethods) session.AuthenticationMethod { return session.AuthenticationMethod{ Method: s.ID(), AAL: identity.AuthenticatorAssuranceLevel2, diff --git a/selfservice/strategy/webauthn/strategy.go b/selfservice/strategy/webauthn/strategy.go index aa4529cda620..998490055996 100644 --- a/selfservice/strategy/webauthn/strategy.go +++ b/selfservice/strategy/webauthn/strategy.go @@ -23,9 +23,11 @@ import ( "github.com/ory/x/decoderx" ) -var _ login.Strategy = new(Strategy) -var _ settings.Strategy = new(Strategy) -var _ identity.ActiveCredentialsCounter = new(Strategy) +var ( + _ login.Strategy = new(Strategy) + _ settings.Strategy = new(Strategy) + _ identity.ActiveCredentialsCounter = new(Strategy) +) type webauthnStrategyDependencies interface { x.LoggingProvider @@ -112,7 +114,7 @@ func (s *Strategy) NodeGroup() node.UiNodeGroup { return node.WebAuthnGroup } -func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context) session.AuthenticationMethod { +func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, _ session.AuthenticationMethods) session.AuthenticationMethod { aal := identity.AuthenticatorAssuranceLevel1 if !s.d.Config().WebAuthnForPasswordless(ctx) { aal = identity.AuthenticatorAssuranceLevel2 diff --git a/selfservice/strategy/webauthn/strategy_test.go b/selfservice/strategy/webauthn/strategy_test.go index 699a36ab6cbb..cc5f6fafb475 100644 --- a/selfservice/strategy/webauthn/strategy_test.go +++ b/selfservice/strategy/webauthn/strategy_test.go @@ -26,13 +26,13 @@ func TestCompletedAuthenticationMethod(t *testing.T) { assert.Equal(t, session.AuthenticationMethod{ Method: strategy.ID(), AAL: identity.AuthenticatorAssuranceLevel2, - }, strategy.CompletedAuthenticationMethod(context.Background())) + }, strategy.CompletedAuthenticationMethod(context.Background(), session.AuthenticationMethods{})) conf.MustSet(ctx, config.ViperKeyWebAuthnPasswordless, true) assert.Equal(t, session.AuthenticationMethod{ Method: strategy.ID(), AAL: identity.AuthenticatorAssuranceLevel1, - }, strategy.CompletedAuthenticationMethod(context.Background())) + }, strategy.CompletedAuthenticationMethod(context.Background(), session.AuthenticationMethods{})) } func TestCountActiveFirstFactorCredentials(t *testing.T) {