Skip to content

Commit

Permalink
feat: extract identifier label for login from the default identity sc…
Browse files Browse the repository at this point in the history
…hema
  • Loading branch information
zepatrik committed Nov 28, 2023
1 parent c25ddff commit e30425a
Show file tree
Hide file tree
Showing 12 changed files with 287 additions and 26 deletions.
2 changes: 1 addition & 1 deletion cmd/clidoc/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func init() {
"NewInfoNodeLabelGenerated": text.NewInfoNodeLabelGenerated("{title}"),
"NewInfoNodeLabelSave": text.NewInfoNodeLabelSave(),
"NewInfoNodeLabelSubmit": text.NewInfoNodeLabelSubmit(),
"NewInfoNodeLabelID": text.NewInfoNodeLabelID(),
"NewInfoNodeLabelID": text.NewInfoNodeLabelID(""),
"NewErrorValidationSettingsFlowExpired": text.NewErrorValidationSettingsFlowExpired(aSecondAgo),
"NewInfoSelfServiceSettingsTOTPQRCode": text.NewInfoSelfServiceSettingsTOTPQRCode(),
"NewInfoSelfServiceSettingsTOTPSecret": text.NewInfoSelfServiceSettingsTOTPSecret("{secret}"),
Expand Down
4 changes: 2 additions & 2 deletions identity/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ func NewValidator(d validatorDependencies) *Validator {
return &Validator{v: schema.NewValidator(), d: d}
}

func (v *Validator) ValidateWithRunner(ctx context.Context, i *Identity, runners ...schema.Extension) error {
runner, err := schema.NewExtensionRunner(ctx, runners...)
func (v *Validator) ValidateWithRunner(ctx context.Context, i *Identity, runners ...schema.ValidateExtension) error {
runner, err := schema.NewExtensionRunner(ctx, schema.WithValidateRunners(runners...))
if err != nil {
return err
}
Expand Down
50 changes: 35 additions & 15 deletions schema/extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,30 +41,41 @@ type (
Recovery struct {
Via string `json:"via"`
} `json:"recovery"`
Mappings struct {
Identity struct {
Traits []struct {
Path string `json:"path"`
} `json:"traits"`
} `json:"identity"`
} `json:"mappings"`
}

Extension interface {
ValidateExtension interface {
Run(ctx jsonschema.ValidationContext, config ExtensionConfig, value interface{}) error
Finish() error
}
CompileExtension interface {
Run(ctx jsonschema.CompilerContext, config ExtensionConfig, rawSchema map[string]interface{}) error
}

ExtensionRunner struct {
meta *jsonschema.Schema
compile func(ctx jsonschema.CompilerContext, m map[string]interface{}) (interface{}, error)
validate func(ctx jsonschema.ValidationContext, s interface{}, v interface{}) error

runners []Extension
validateRunners []ValidateExtension
compileRunners []CompileExtension
}

ExtensionRunnerOption func(*ExtensionRunner)
)

func NewExtensionRunner(ctx context.Context, runners ...Extension) (*ExtensionRunner, error) {
func WithValidateRunners(runners ...ValidateExtension) ExtensionRunnerOption {
return func(r *ExtensionRunner) {
r.validateRunners = append(r.validateRunners, runners...)
}
}

func WithCompileRunners(runners ...CompileExtension) ExtensionRunnerOption {
return func(r *ExtensionRunner) {
r.compileRunners = append(r.compileRunners, runners...)
}
}

func NewExtensionRunner(ctx context.Context, opts ...ExtensionRunnerOption) (*ExtensionRunner, error) {
var err error
r := new(ExtensionRunner)
c := jsonschema.NewCompiler()
Expand All @@ -90,6 +101,12 @@ func NewExtensionRunner(ctx context.Context, runners ...Extension) (*ExtensionRu
return nil, errors.WithStack(err)
}

for _, runner := range r.compileRunners {
if err := runner.Run(ctx, e, m); err != nil {
return nil, err

Check warning on line 106 in schema/extension.go

View check run for this annotation

Codecov / codecov/patch

schema/extension.go#L106

Added line #L106 was not covered by tests
}
}

return &e, nil
}
return nil, nil
Expand All @@ -101,15 +118,18 @@ func NewExtensionRunner(ctx context.Context, runners ...Extension) (*ExtensionRu
return nil
}

for _, runner := range r.runners {
for _, runner := range r.validateRunners {
if err := runner.Run(ctx, *c, v); err != nil {
return err
}
}
return nil
}

r.runners = runners
for _, opt := range opts {
opt(r)
}

return r, nil
}

Expand All @@ -126,13 +146,13 @@ func (r *ExtensionRunner) Extension() jsonschema.Extension {
}
}

func (r *ExtensionRunner) AddRunner(run Extension) *ExtensionRunner {
r.runners = append(r.runners, run)
func (r *ExtensionRunner) AddRunner(run ValidateExtension) *ExtensionRunner {
r.validateRunners = append(r.validateRunners, run)
return r
}

func (r *ExtensionRunner) Finish() error {
for _, runner := range r.runners {
for _, runner := range r.validateRunners {
if err := runner.Finish(); err != nil {
return err
}
Expand Down
2 changes: 2 additions & 0 deletions selfservice/flow/login/error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ func TestHandleError(t *testing.T) {
conf, reg := internal.NewFastRegistryWithMocks(t)
public, _ := testhelpers.NewKratosServer(t, reg)

testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/password.schema.json")

router := httprouter.New()
ts := httptest.NewServer(router)
t.Cleanup(ts.Close)
Expand Down
64 changes: 64 additions & 0 deletions selfservice/flow/login/extension_identifier_label.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0

package login

import (
"context"
"sort"

"github.com/ory/jsonschema/v3"
"github.com/ory/kratos/schema"
)

type identifierLabelExtension struct {
identifierLabelCandidates []string
}

var _ schema.CompileExtension = new(identifierLabelExtension)

func GetIdentifierLabelFromSchema(ctx context.Context, schemaURL string) (string, error) {
ext := &identifierLabelExtension{}

runner, err := schema.NewExtensionRunner(ctx, schema.WithCompileRunners(ext))
if err != nil {
return "", err

Check warning on line 25 in selfservice/flow/login/extension_identifier_label.go

View check run for this annotation

Codecov / codecov/patch

selfservice/flow/login/extension_identifier_label.go#L25

Added line #L25 was not covered by tests
}
c := jsonschema.NewCompiler()
runner.Register(c)

_, err = c.Compile(ctx, schemaURL)
if err != nil {
return "", err

Check warning on line 32 in selfservice/flow/login/extension_identifier_label.go

View check run for this annotation

Codecov / codecov/patch

selfservice/flow/login/extension_identifier_label.go#L32

Added line #L32 was not covered by tests
}
return ext.getLabel(), nil
}

func (i *identifierLabelExtension) Run(_ jsonschema.CompilerContext, config schema.ExtensionConfig, rawSchema map[string]interface{}) error {
if config.Credentials.Password.Identifier ||
config.Credentials.WebAuthn.Identifier ||
config.Credentials.TOTP.AccountName ||
config.Credentials.Code.Identifier {
if title, ok := rawSchema["title"]; ok {
// The jsonschema compiler validates the title to be a string, so this should always work.
switch t := title.(type) {
case string:
if t != "" {
i.identifierLabelCandidates = append(i.identifierLabelCandidates, t)
}
}
}
}
return nil
}

func (i *identifierLabelExtension) getLabel() string {
if len(i.identifierLabelCandidates) == 0 {
// sane default is set elsewhere
return ""
}
// sort the candidates to get a deterministic result
sort.Strings(i.identifierLabelCandidates)
// just take the first, no good way to decide which one is the best
return i.identifierLabelCandidates[0]
}
137 changes: 137 additions & 0 deletions selfservice/flow/login/extension_identifier_label_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0

package login

import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/sjson"

"github.com/ory/kratos/schema"
)

func constructSchema(t *testing.T, ecModifier, ucModifier func(*schema.ExtensionConfig)) string {
var emailConfig, usernameConfig schema.ExtensionConfig

if ecModifier != nil {
ecModifier(&emailConfig)
}
if ucModifier != nil {
ucModifier(&usernameConfig)
}

ec, err := json.Marshal(&emailConfig)
require.NoError(t, err)
uc, err := json.Marshal(&usernameConfig)
require.NoError(t, err)

ec, err = sjson.DeleteBytes(ec, "verification")
require.NoError(t, err)
ec, err = sjson.DeleteBytes(ec, "recovery")
require.NoError(t, err)
ec, err = sjson.DeleteBytes(ec, "credentials.code.via")
require.NoError(t, err)
uc, err = sjson.DeleteBytes(uc, "verification")
require.NoError(t, err)
uc, err = sjson.DeleteBytes(uc, "recovery")
require.NoError(t, err)
uc, err = sjson.DeleteBytes(uc, "credentials.code.via")
require.NoError(t, err)

return "base64://" + base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf(`
{
"properties": {
"traits": {
"properties": {
"email": {
"title": "Email",
"ory.sh/kratos": %s
},
"username": {
"title": "Username",
"ory.sh/kratos": %s
}
}
}
}
}`, ec, uc)))
}

func TestGetIdentifierLabelFromSchema(t *testing.T) {
ctx := context.Background()

for _, tc := range []struct {
name string
emailConfig, usernameConfig func(*schema.ExtensionConfig)
expected string
}{
{
name: "email for password",
emailConfig: func(c *schema.ExtensionConfig) {
c.Credentials.Password.Identifier = true
},
expected: "Email",
},
{
name: "email for webauthn",
emailConfig: func(c *schema.ExtensionConfig) {
c.Credentials.WebAuthn.Identifier = true
},
expected: "Email",
},
{
name: "email for totp",
emailConfig: func(c *schema.ExtensionConfig) {
c.Credentials.TOTP.AccountName = true
},
expected: "Email",
},
{
name: "email for code",
emailConfig: func(c *schema.ExtensionConfig) {
c.Credentials.Code.Identifier = true
},
expected: "Email",
},
{
name: "email for all",
emailConfig: func(c *schema.ExtensionConfig) {
c.Credentials.Password.Identifier = true
c.Credentials.WebAuthn.Identifier = true
c.Credentials.TOTP.AccountName = true
c.Credentials.Code.Identifier = true
},
expected: "Email",
},
{
name: "username works as well",
usernameConfig: func(c *schema.ExtensionConfig) {
c.Credentials.Password.Identifier = true
},
expected: "Username",
},
{
name: "multiple identifiers",
emailConfig: func(c *schema.ExtensionConfig) {
c.Credentials.Password.Identifier = true
},
usernameConfig: func(c *schema.ExtensionConfig) {
c.Credentials.Password.Identifier = true
},
expected: "Email",
},
} {
t.Run(tc.name, func(t *testing.T) {
label, err := GetIdentifierLabelFromSchema(ctx, constructSchema(t, tc.emailConfig, tc.usernameConfig))
require.NoError(t, err)
assert.Equal(t, tc.expected, label)
})
}
}
2 changes: 2 additions & 0 deletions selfservice/flow/login/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,8 @@ func TestGetFlow(t *testing.T) {
_ = testhelpers.NewErrorTestServer(t, reg)
_ = testhelpers.NewRedirTS(t, "", conf)

testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/password.schema.json")

setupLoginUI := func(t *testing.T, c *http.Client) *httptest.Server {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// It is important that we use a HTTP request to fetch the flow because that will show us if CSRF works or not
Expand Down
14 changes: 11 additions & 3 deletions selfservice/strategy/code/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,17 @@ func (s *Strategy) PopulateMethod(r *http.Request, f flow.Flow) error {
WithMetaLabel(text.NewInfoNodeInputEmail()),
)
} else if f.GetFlowName() == flow.LoginFlow {
// we use the identifier label here since we don't know what
// type of field the identifier is
nodes.Upsert(node.NewInputField("identifier", "", node.DefaultGroup, node.InputAttributeTypeText, node.WithRequiredInputAttribute).WithMetaLabel(text.NewInfoNodeLabelID()))
ds, err := s.deps.Config().DefaultIdentityTraitsSchemaURL(r.Context())
if err != nil {
return err

Check warning on line 154 in selfservice/strategy/code/strategy.go

View check run for this annotation

Codecov / codecov/patch

selfservice/strategy/code/strategy.go#L154

Added line #L154 was not covered by tests
}

identifierLabel, err := login.GetIdentifierLabelFromSchema(r.Context(), ds.String())
if err != nil {
return err

Check warning on line 159 in selfservice/strategy/code/strategy.go

View check run for this annotation

Codecov / codecov/patch

selfservice/strategy/code/strategy.go#L159

Added line #L159 was not covered by tests
}

nodes.Upsert(node.NewInputField("identifier", "", node.DefaultGroup, node.InputAttributeTypeText, node.WithRequiredInputAttribute).WithMetaLabel(text.NewInfoNodeLabelID(identifierLabel)))
} else if f.GetFlowName() == flow.RegistrationFlow {
ds, err := s.deps.Config().DefaultIdentityTraitsSchemaURL(r.Context())
if err != nil {
Expand Down
10 changes: 9 additions & 1 deletion selfservice/strategy/code/strategy_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,18 @@ func (s *Strategy) HandleLoginError(r *http.Request, f *login.Flow, body *update
email = body.Identifier
}

ds, err := s.deps.Config().DefaultIdentityTraitsSchemaURL(r.Context())
if err != nil {
return err

Check warning on line 84 in selfservice/strategy/code/strategy_login.go

View check run for this annotation

Codecov / codecov/patch

selfservice/strategy/code/strategy_login.go#L84

Added line #L84 was not covered by tests
}
identifierLabel, err := login.GetIdentifierLabelFromSchema(r.Context(), ds.String())
if err != nil {
return err

Check warning on line 88 in selfservice/strategy/code/strategy_login.go

View check run for this annotation

Codecov / codecov/patch

selfservice/strategy/code/strategy_login.go#L88

Added line #L88 was not covered by tests
}
f.UI.SetCSRF(s.deps.GenerateCSRFToken(r))
f.UI.GetNodes().Upsert(
node.NewInputField("identifier", email, node.DefaultGroup, node.InputAttributeTypeText, node.WithRequiredInputAttribute).
WithMetaLabel(text.NewInfoNodeLabelID()),
WithMetaLabel(text.NewInfoNodeLabelID(identifierLabel)),
)
}

Expand Down
Loading

0 comments on commit e30425a

Please sign in to comment.