Skip to content

Commit

Permalink
feat: support MFA via SMS
Browse files Browse the repository at this point in the history
  • Loading branch information
jonas-jonas committed Jan 6, 2024
1 parent eb8d1b9 commit 0fdfcb5
Show file tree
Hide file tree
Showing 27 changed files with 203 additions and 166 deletions.
25 changes: 24 additions & 1 deletion contrib/quickstart/kratos/phone-password/identity.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,26 @@
"traits": {
"type": "object",
"properties": {
"email": {
"type": "string",
"format": "email",
"title": "E-mail",
"minLength": 3,
"ory.sh/kratos": {
"credentials": {
"password": {
"identifier": true
},
"code": {
"identifier": true,
"via": "email"
}
},
"verification": {
"via": "email"
}
}
},
"phone": {
"type": "string",
"format": "tel",
Expand All @@ -16,6 +36,9 @@
"credentials": {
"password": {
"identifier": true
},
"code": {
"identifier": true
}
},
"verification": {
Expand All @@ -24,7 +47,7 @@
}
}
},
"required": ["phone"],
"required": ["email", "phone"],
"additionalProperties": false
}
}
Expand Down
6 changes: 6 additions & 0 deletions courier/sms_templates.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ func NewSMSTemplateFromMessage(d template.Dependencies, m Message) (SMSTemplate,
return nil, err
}
return sms.NewTestStub(d, &t), nil
case template.TypeLoginCodeValid:
var t sms.LoginCodeValidModel
if err := json.Unmarshal(m.TemplateData, &t); err != nil {
return nil, err

Check warning on line 40 in courier/sms_templates.go

View check run for this annotation

Codecov / codecov/patch

courier/sms_templates.go#L37-L40

Added lines #L37 - L40 were not covered by tests
}
return sms.NewLoginCodeValid(d, &t), nil

Check warning on line 42 in courier/sms_templates.go

View check run for this annotation

Codecov / codecov/patch

courier/sms_templates.go#L42

Added line #L42 was not covered by tests

default:
return nil, errors.Errorf("received unexpected message template type: %s", m.TemplateType)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Your login code is: {{ .LoginCode }}
52 changes: 52 additions & 0 deletions courier/template/sms/login_code_valid.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0

package sms

import (
"context"
"encoding/json"
"os"

"github.com/ory/kratos/courier/template"
)

type (
LoginCodeValid struct {
deps template.Dependencies
model *LoginCodeValidModel
}
LoginCodeValidModel struct {
To string
LoginCode string
Identity map[string]interface{}
}
)

func NewLoginCodeValid(d template.Dependencies, m *LoginCodeValidModel) *LoginCodeValid {
return &LoginCodeValid{deps: d, model: m}

Check warning on line 27 in courier/template/sms/login_code_valid.go

View check run for this annotation

Codecov / codecov/patch

courier/template/sms/login_code_valid.go#L26-L27

Added lines #L26 - L27 were not covered by tests
}

func (t *LoginCodeValid) PhoneNumber() (string, error) {
return t.model.To, nil

Check warning on line 31 in courier/template/sms/login_code_valid.go

View check run for this annotation

Codecov / codecov/patch

courier/template/sms/login_code_valid.go#L30-L31

Added lines #L30 - L31 were not covered by tests
}

func (t *LoginCodeValid) SMSBody(ctx context.Context) (string, error) {
return template.LoadText(
ctx,
t.deps,
os.DirFS(t.deps.CourierConfig().CourierTemplatesRoot(ctx)),
"login_code/valid/sms.body.gotmpl", // TODO:!!!
"login_code/valid/sms.body*",
t.model,
"",
)

Check warning on line 43 in courier/template/sms/login_code_valid.go

View check run for this annotation

Codecov / codecov/patch

courier/template/sms/login_code_valid.go#L34-L43

Added lines #L34 - L43 were not covered by tests
}

func (t *LoginCodeValid) MarshalJSON() ([]byte, error) {
return json.Marshal(t.model)

Check warning on line 47 in courier/template/sms/login_code_valid.go

View check run for this annotation

Codecov / codecov/patch

courier/template/sms/login_code_valid.go#L46-L47

Added lines #L46 - L47 were not covered by tests
}

func (t *LoginCodeValid) TemplateType() template.TemplateType {
return template.TypeLoginCodeValid

Check warning on line 51 in courier/template/sms/login_code_valid.go

View check run for this annotation

Codecov / codecov/patch

courier/template/sms/login_code_valid.go#L50-L51

Added lines #L50 - L51 were not covered by tests
}
7 changes: 1 addition & 6 deletions driver/registry_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,12 +353,7 @@ func (m *RegistryDefault) strategyRegistrationEnabled(ctx context.Context, id st
}

func (m *RegistryDefault) strategyLoginEnabled(ctx context.Context, id string) bool {
switch id {
case identity.CredentialsTypeCodeAuth.String():
return m.Config().SelfServiceCodeStrategy(ctx).PasswordlessEnabled
default:
return m.Config().SelfServiceStrategy(ctx, id).Enabled
}
return m.Config().SelfServiceStrategy(ctx, id).Enabled
}

func (m *RegistryDefault) RegistrationStrategies(ctx context.Context, filters ...registration.StrategyFilter) (registrationStrategies registration.Strategies) {
Expand Down
12 changes: 11 additions & 1 deletion driver/registry_default_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -674,39 +674,49 @@ func TestDriverDefault_Strategies(t *testing.T) {
t.Run("case=login", func(t *testing.T) {
t.Parallel()
for k, tc := range []struct {
name string
prep func(conf *config.Config)
expect []string
}{
{
name: "no strategies",
prep: func(conf *config.Config) {
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".password.enabled", false)
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.enabled", false)
},
},
{
name: "only password",
prep: func(conf *config.Config) {
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".password.enabled", true)
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.enabled", false)
},
expect: []string{"password"},
},
{
name: "oidc and password",
prep: func(conf *config.Config) {
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".oidc.enabled", true)
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".password.enabled", true)
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.enabled", false)
},
expect: []string{"password", "oidc"},
},
{
name: "oidc, password and totp",
prep: func(conf *config.Config) {
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".oidc.enabled", true)
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".password.enabled", true)
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".totp.enabled", true)
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.enabled", false)
},
expect: []string{"password", "oidc", "totp"},
},
{
name: "password and code",
prep: func(conf *config.Config) {
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".password.enabled", true)
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.passwordless_enabled", true)
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.enabled", true)
},
expect: []string{"password", "code"},
},
Expand Down
2 changes: 1 addition & 1 deletion embedx/identity_extension.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
},
"via": {
"type": "string",
"enum": ["email"]
"enum": ["email", "sms"]
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion identity/credentials_code.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"database/sql"
)

type CodeAddressType string
type CodeAddressType = string

const (
CodeAddressTypeEmail CodeAddressType = AddressTypeEmail
Expand Down
2 changes: 2 additions & 0 deletions identity/extension_credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ func (r *SchemaExtensionCredentials) Run(ctx jsonschema.ValidationContext, s sch
// }

// r.setIdentifier(CredentialsTypeCodeAuth, value, CredentialsIdentifierAddressTypePhone)
case f.AddCase(""):

Check warning on line 73 in identity/extension_credentials.go

View check run for this annotation

Codecov / codecov/patch

identity/extension_credentials.go#L73

Added line #L73 was not covered by tests
// continue
default:
return ctx.Error("", "credentials.code.via has unknown value %q", s.Credentials.Code.Via)
}
Expand Down
4 changes: 2 additions & 2 deletions internal/testhelpers/selfservice_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,12 @@ func InitializeLoginFlowViaBrowser(t *testing.T, client *http.Client, ts *httpte
flowID = gjson.GetBytes(body, "id").String()
}

rs, _, err := publicClient.FrontendApi.GetLoginFlow(context.Background()).Id(flowID).Execute()
rs, r, err := publicClient.FrontendApi.GetLoginFlow(context.Background()).Id(flowID).Execute()
if expectGetError {
require.Error(t, err)
require.Nil(t, rs)
} else {
require.NoError(t, err)
require.NoError(t, err, "%s", ioutilx.MustReadAll(r.Body))
assert.Empty(t, rs.Active)
}

Expand Down
2 changes: 1 addition & 1 deletion selfservice/flow/login/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,7 @@ continueLogin:
sess = session.NewInactiveSession()
}

method := ss.CompletedAuthenticationMethod(r.Context())
method := ss.CompletedAuthenticationMethod(r.Context(), sess.AMR)
sess.CompletedLoginForMethod(method)
i = interim
break
Expand Down
2 changes: 1 addition & 1 deletion selfservice/flow/login/hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ func (e *HookExecutor) maybeLinkCredentials(ctx context.Context, sess *session.S
return err
}

method := strategy.CompletedAuthenticationMethod(ctx)
method := strategy.CompletedAuthenticationMethod(ctx, sess.AMR)
sess.CompletedLoginForMethod(method)

return nil
Expand Down
2 changes: 1 addition & 1 deletion selfservice/flow/login/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ type Strategy interface {
RegisterLoginRoutes(*x.RouterPublic)
PopulateLoginMethod(r *http.Request, requestedAAL identity.AuthenticatorAssuranceLevel, sr *Flow) error
Login(w http.ResponseWriter, r *http.Request, f *Flow, identityID uuid.UUID) (i *identity.Identity, err error)
CompletedAuthenticationMethod(ctx context.Context) session.AuthenticationMethod
CompletedAuthenticationMethod(ctx context.Context, methods session.AuthenticationMethods) session.AuthenticationMethod
}

type Strategies []Strategy
Expand Down
3 changes: 2 additions & 1 deletion selfservice/flow/registration/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (

"github.com/ory/kratos/corpx"
"github.com/ory/kratos/hydra"
"github.com/ory/x/ioutilx"
"github.com/ory/x/urlx"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -479,7 +480,7 @@ func TestOIDCStrategyOrder(t *testing.T) {
resp, err := client.Do(req)
require.NoError(t, err)

require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, http.StatusOK, resp.StatusCode, "%s", ioutilx.MustReadAll(resp.Body))

verifiableAddress, err := reg.PrivilegedIdentityPool().FindVerifiableAddressByValue(ctx, identity.VerifiableAddressTypeEmail, email)
require.NoError(t, err)
Expand Down
15 changes: 1 addition & 14 deletions selfservice/flow/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"strings"

"github.com/ory/kratos/driver/config"
"github.com/ory/kratos/identity"
"github.com/ory/kratos/selfservice/strategy"
"github.com/ory/x/decoderx"

Expand Down Expand Up @@ -106,19 +105,7 @@ func MethodEnabledAndAllowed(ctx context.Context, flowName FlowName, expected, a
return errors.WithStack(ErrStrategyNotResponsible)
}

var ok bool
if strings.EqualFold(actual, identity.CredentialsTypeCodeAuth.String()) {
switch flowName {
case RegistrationFlow, LoginFlow:
ok = d.Config().SelfServiceCodeStrategy(ctx).PasswordlessEnabled
case VerificationFlow, RecoveryFlow:
ok = d.Config().SelfServiceCodeStrategy(ctx).Enabled
}
} else {
ok = d.Config().SelfServiceStrategy(ctx, expected).Enabled
}

if !ok {
if !d.Config().SelfServiceStrategy(ctx, expected).Enabled {
return errors.WithStack(herodot.ErrNotFound.WithReason(strategy.EndpointDisabledMessage))
}

Expand Down
71 changes: 0 additions & 71 deletions selfservice/flow/request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,74 +101,3 @@ func TestMethodEnabledAndAllowed(t *testing.T) {
assert.Contains(t, string(body), "The requested resource could not be found")
})
}

func TestMethodCodeEnabledAndAllowed(t *testing.T) {
ctx := context.Background()
conf, d := internal.NewFastRegistryWithMocks(t)

currentFlow := make(chan flow.FlowName, 1)

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
f := <-currentFlow
if err := flow.MethodEnabledAndAllowedFromRequest(r, f, "code", d); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusNoContent)
}))

t.Run("login code allowed", func(t *testing.T) {
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.passwordless_enabled", true)
currentFlow <- flow.LoginFlow
res, err := ts.Client().PostForm(ts.URL, url.Values{"method": {"code"}})
require.NoError(t, err)
require.NoError(t, res.Body.Close())
assert.Equal(t, http.StatusNoContent, res.StatusCode)
})

t.Run("login code not allowed", func(t *testing.T) {
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.passwordless_enabled", false)
currentFlow <- flow.LoginFlow
res, err := ts.Client().PostForm(ts.URL, url.Values{"method": {"code"}})
require.NoError(t, err)
body, err := io.ReadAll(res.Body)
require.NoError(t, err)
require.NoError(t, res.Body.Close())
assert.Equal(t, http.StatusInternalServerError, res.StatusCode)
assert.Contains(t, string(body), "The requested resource could not be found")
})

t.Run("registration code allowed", func(t *testing.T) {
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.passwordless_enabled", true)
currentFlow <- flow.RegistrationFlow
res, err := ts.Client().PostForm(ts.URL, url.Values{"method": {"code"}})
require.NoError(t, err)
require.NoError(t, res.Body.Close())
assert.Equal(t, http.StatusNoContent, res.StatusCode)
})

t.Run("registration code not allowed", func(t *testing.T) {
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.passwordless_enabled", false)
currentFlow <- flow.RegistrationFlow
res, err := ts.Client().PostForm(ts.URL, url.Values{"method": {"code"}})
require.NoError(t, err)
body, err := io.ReadAll(res.Body)
require.NoError(t, err)
require.NoError(t, res.Body.Close())
assert.Equal(t, http.StatusInternalServerError, res.StatusCode)
assert.Contains(t, string(body), "The requested resource could not be found")
})

t.Run("recovery and verification should still be allowed if registration and login is disabled", func(t *testing.T) {
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.passwordless_enabled", false)
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.enabled", true)

for _, f := range []flow.FlowName{flow.RecoveryFlow, flow.VerificationFlow} {
currentFlow <- f
res, err := ts.Client().PostForm(ts.URL, url.Values{"method": {"code"}})
require.NoError(t, err)
require.NoError(t, res.Body.Close())
assert.Equal(t, http.StatusNoContent, res.StatusCode)
}
})
}
Loading

0 comments on commit 0fdfcb5

Please sign in to comment.