Skip to content

Commit

Permalink
chore: cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
jonas-jonas committed Jan 11, 2024
1 parent 0fdfcb5 commit 434d08c
Show file tree
Hide file tree
Showing 18 changed files with 333 additions and 232 deletions.
2 changes: 2 additions & 0 deletions driver/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ type (
SelfServiceStrategyCode struct {
*SelfServiceStrategy
PasswordlessEnabled bool `json:"passwordless_enabled"`
MFAEnabled bool `json:"mfa_enabled"`
}
Schema struct {
ID string `json:"id" koanf:"id"`
Expand Down Expand Up @@ -782,6 +783,7 @@ func (p *Config) SelfServiceCodeStrategy(ctx context.Context) *SelfServiceStrate
Config: config,
},
PasswordlessEnabled: pp.BoolF(basePath+".passwordless_enabled", false),
MFAEnabled: pp.BoolF(basePath+".mfa_enabled", false),
}
}

Expand Down
7 changes: 1 addition & 6 deletions driver/registry_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,12 +344,7 @@ func (m *RegistryDefault) selfServiceStrategies() []any {
}

func (m *RegistryDefault) strategyRegistrationEnabled(ctx context.Context, id string) bool {
switch id {
case identity.CredentialsTypeCodeAuth.String():
return m.Config().SelfServiceCodeStrategy(ctx).PasswordlessEnabled
default:
return m.Config().SelfServiceStrategy(ctx, id).Enabled
}
return m.Config().SelfServiceStrategy(ctx, id).Enabled
}

func (m *RegistryDefault) strategyLoginEnabled(ctx context.Context, id string) bool {
Expand Down
16 changes: 13 additions & 3 deletions driver/registry_default_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -621,39 +621,49 @@ func TestDriverDefault_Strategies(t *testing.T) {
t.Run("case=registration", func(t *testing.T) {
t.Parallel()
for k, tc := range []struct {
name string
prep func(conf *config.Config)
expect []string
}{
{
name: "no strategies",
prep: func(conf *config.Config) {
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".password.enabled", false)
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.enabled", false)
},
},
{
name: "only password",
prep: func(conf *config.Config) {
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".password.enabled", true)
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.enabled", false)
},
expect: []string{"password"},
},
{
name: "oidc and password",
prep: func(conf *config.Config) {
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".oidc.enabled", true)
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".password.enabled", true)
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.enabled", false)
},
expect: []string{"password", "oidc"},
},
{
name: "oidc, password and totp",
prep: func(conf *config.Config) {
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".oidc.enabled", true)
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".password.enabled", true)
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".totp.enabled", true)
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.enabled", false)
},
expect: []string{"password", "oidc"},
},
{
name: "password and code",
prep: func(conf *config.Config) {
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".password.enabled", true)
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.passwordless_enabled", true)
conf.MustSet(ctx, config.ViperKeySelfServiceStrategyConfig+".code.enabled", true)
},
expect: []string{"password", "code"},
},
Expand All @@ -673,7 +683,7 @@ func TestDriverDefault_Strategies(t *testing.T) {

t.Run("case=login", func(t *testing.T) {
t.Parallel()
for k, tc := range []struct {
for _, tc := range []struct {
name string
prep func(conf *config.Config)
expect []string
Expand Down Expand Up @@ -721,7 +731,7 @@ func TestDriverDefault_Strategies(t *testing.T) {
expect: []string{"password", "code"},
},
} {
t.Run(fmt.Sprintf("run=%d", k), func(t *testing.T) {
t.Run(fmt.Sprintf("run=%s", tc.name), func(t *testing.T) {
conf, reg := internal.NewVeryFastRegistryWithoutDB(t)
tc.prep(conf)

Expand Down
5 changes: 5 additions & 0 deletions embedx/config.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -1416,6 +1416,11 @@
"title": "Enables Login and Registration with the Code Method",
"default": false
},
"mfa_enabled": {
"type": "boolean",
"title": "Enables Login flows Code Method to fulfil MFA requests",
"default": false
},
"passwordless_login_fallback_enabled": {
"type": "boolean",
"title": "Passwordless Login Fallback Enabled",
Expand Down
16 changes: 8 additions & 8 deletions internal/registrationhelpers/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func AssertSchemDoesNotExist(t *testing.T, reg *driver.RegistryDefault, flows []
reset()

t.Run("case=should fail because schema does not exist", func(t *testing.T) {
var check = func(t *testing.T, actual string) {
check := func(t *testing.T, actual string) {
assert.Equal(t, int64(http.StatusInternalServerError), gjson.Get(actual, "code").Int(), "%s", actual)
assert.Equal(t, "Internal Server Error", gjson.Get(actual, "status").String(), "%s", actual)
assert.Contains(t, gjson.Get(actual, "reason").String(), "no such file or directory", "%s", actual)
Expand Down Expand Up @@ -164,7 +164,7 @@ func AssertCSRFFailures(t *testing.T, reg *driver.RegistryDefault, flows []strin
apiClient := testhelpers.NewDebugClient(t)
_ = testhelpers.NewErrorTestServer(t, reg)

var values = url.Values{
values := url.Values{
"csrf_token": {"invalid_token"},
"traits.username": {testhelpers.RandomEmail()},
"traits.foobar": {"bar"},
Expand Down Expand Up @@ -253,15 +253,15 @@ func AssertRegistrationRespectsValidation(t *testing.T, reg *driver.RegistryDefa

t.Run("case=should return an error because not passing validation", func(t *testing.T) {
email := testhelpers.RandomEmail()
var check = func(t *testing.T, actual string) {
check := func(t *testing.T, actual string) {
assert.NotEmpty(t, gjson.Get(actual, "id").String(), "%s", actual)
assert.Contains(t, gjson.Get(actual, "ui.action").String(), publicTS.URL+registration.RouteSubmitFlow, "%s", actual)
CheckFormContent(t, []byte(actual), "password", "csrf_token", "traits.username", "traits.foobar")
assert.Contains(t, gjson.Get(actual, "ui.nodes.#(attributes.name==traits.foobar).messages.0").String(), `Property foobar is missing`, "%s", actual)
assert.Equal(t, email, gjson.Get(actual, "ui.nodes.#(attributes.name==traits.username).attributes.value").String(), "%s", actual)
}

var values = func(v url.Values) {
values := func(v url.Values) {
v.Set("traits.username", email)
v.Del("traits.foobar")
payload(v)
Expand Down Expand Up @@ -411,7 +411,7 @@ func AssertCommonErrorCases(t *testing.T, flows []string) {
})

t.Run("case=should show the error ui because the request id is missing", func(t *testing.T) {
var check = func(t *testing.T, actual string) {
check := func(t *testing.T, actual string) {
assert.Equal(t, int64(http.StatusNotFound), gjson.Get(actual, "code").Int(), "%s", actual)
assert.Equal(t, "Not Found", gjson.Get(actual, "status").String(), "%s", actual)
assert.Contains(t, gjson.Get(actual, "message").String(), "Unable to locate the resource", "%s", actual)
Expand Down Expand Up @@ -481,14 +481,14 @@ func AssertCommonErrorCases(t *testing.T, flows []string) {
})
})

t.Run("case=should fail because the return_to url is not allowed", func(t *testing.T) {
t.Run("case=should fail because the password was used in databreaches", func(t *testing.T) {
testhelpers.SetDefaultIdentitySchemaFromRaw(conf, multifieldSchema)
t.Cleanup(func() {
testhelpers.SetDefaultIdentitySchemaFromRaw(conf, basicSchema)
})

email := testhelpers.RandomEmail()
var check = func(t *testing.T, actual string) {
check := func(t *testing.T, actual string) {
assert.NotEmpty(t, gjson.Get(actual, "id").String(), "%s", actual)
assert.Contains(t, gjson.Get(actual, "ui.action").String(), publicTS.URL+registration.RouteSubmitFlow, "%s", actual)
CheckFormContent(t, []byte(actual), "password", "csrf_token", "traits.username", "traits.foobar")
Expand All @@ -498,7 +498,7 @@ func AssertCommonErrorCases(t *testing.T, flows []string) {
assert.Equal(t, "password", gjson.Get(actual, "ui.nodes.#(attributes.name==method).attributes.value").String(), "%s", actual)
}

var values = func(v url.Values) {
values := func(v url.Values) {
v.Set("traits.username", email)
v.Set("password", "password")
v.Set("traits.foobar", "bar")
Expand Down
2 changes: 1 addition & 1 deletion internal/testhelpers/courier.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func CourierExpectMessage(ctx context.Context, t *testing.T, reg interface {
}
}

require.Failf(t, "could not find courier messages with recipient %s and subject %s", recipient, subject)
require.Failf(t, "could not find courier messages", "could not find courier messages with recipient %s and subject %s", recipient, subject)
return nil
}

Expand Down
3 changes: 2 additions & 1 deletion internal/testhelpers/selfservice_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ func InitializeLoginFlowViaBrowser(t *testing.T, client *http.Client, ts *httpte
require.NoError(t, err)
body := x.MustReadAll(res.Body)
require.NoError(t, res.Body.Close())
require.Equal(t, 200, res.StatusCode, "%s", body)
if expectInitError {
require.Equal(t, 200, res.StatusCode)
require.NotNil(t, res.Request.URL)
require.Contains(t, res.Request.URL.String(), "error-ts")
}
Expand All @@ -142,6 +142,7 @@ func InitializeLoginFlowViaBrowser(t *testing.T, client *http.Client, ts *httpte
if isSPA {
flowID = gjson.GetBytes(body, "id").String()
}
require.NotEmpty(t, flowID)

rs, r, err := publicClient.FrontendApi.GetLoginFlow(context.Background()).Id(flowID).Execute()
if expectGetError {
Expand Down
14 changes: 14 additions & 0 deletions internal/testhelpers/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,20 @@ func NewHTTPClientWithIdentitySessionCookie(t *testing.T, reg *driver.RegistryDe
return NewHTTPClientWithSessionCookie(t, reg, s)
}

func NewHTTPClientWithIdentitySessionCookieLocalhost(t *testing.T, reg *driver.RegistryDefault, id *identity.Identity) *http.Client {
req := NewTestHTTPRequest(t, "GET", "/sessions/whoami", nil)
s, err := session.NewActiveSession(req,
id,
NewSessionLifespanProvider(time.Hour),
time.Now(),
identity.CredentialsTypePassword,
identity.AuthenticatorAssuranceLevel1,
)
require.NoError(t, err, "Could not initialize session from identity.")

return NewHTTPClientWithSessionCookieLocalhost(t, reg, s)
}

func NewHTTPClientWithIdentitySessionToken(t *testing.T, reg *driver.RegistryDefault, id *identity.Identity) *http.Client {
req := NewTestHTTPRequest(t, "GET", "/sessions/whoami", nil)
s, err := session.NewActiveSession(req,
Expand Down
18 changes: 9 additions & 9 deletions selfservice/flow/login/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,14 +186,6 @@ func (h *Handler) NewLoginFlow(w http.ResponseWriter, r *http.Request, ft flow.T
}

preLoginHook:
if f.Refresh {
f.UI.Messages.Set(text.NewInfoLoginReAuth())
}

if sess != nil && f.RequestedAAL > sess.AuthenticatorAssuranceLevel && f.RequestedAAL > identity.AuthenticatorAssuranceLevel1 {
f.UI.Messages.Add(text.NewInfoLoginMFA())
}

var strategyFilters []StrategyFilter
orgID := uuid.NullUUID{
Valid: false,
Expand Down Expand Up @@ -222,6 +214,14 @@ preLoginHook:
}
}

if f.Refresh {
f.UI.Messages.Set(text.NewInfoLoginReAuth())
}

if sess != nil && f.RequestedAAL > sess.AuthenticatorAssuranceLevel && f.RequestedAAL > identity.AuthenticatorAssuranceLevel1 {
f.UI.Messages.Add(text.NewInfoLoginMFA())
}

if err := sortNodes(r.Context(), f.UI.Nodes); err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -781,7 +781,7 @@ continueLogin:
var i *identity.Identity
var group node.UiNodeGroup
for _, ss := range h.d.AllLoginStrategies() {
interim, err := ss.Login(w, r, f, sess.IdentityID)
interim, err := ss.Login(w, r, f, sess)
group = ss.NodeGroup()
if errors.Is(err, flow.ErrStrategyNotResponsible) {
continue
Expand Down
3 changes: 1 addition & 2 deletions selfservice/flow/login/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"context"
"net/http"

"github.com/gofrs/uuid"
"github.com/pkg/errors"

"github.com/ory/kratos/identity"
Expand All @@ -22,7 +21,7 @@ type Strategy interface {
NodeGroup() node.UiNodeGroup
RegisterLoginRoutes(*x.RouterPublic)
PopulateLoginMethod(r *http.Request, requestedAAL identity.AuthenticatorAssuranceLevel, sr *Flow) error
Login(w http.ResponseWriter, r *http.Request, f *Flow, identityID uuid.UUID) (i *identity.Identity, err error)
Login(w http.ResponseWriter, r *http.Request, f *Flow, sess *session.Session) (i *identity.Identity, err error)
CompletedAuthenticationMethod(ctx context.Context, methods session.AuthenticationMethods) session.AuthenticationMethod
}

Expand Down
Loading

0 comments on commit 434d08c

Please sign in to comment.