From 9cc90063aeac9ca2e13e973bcb732fc82514b3a4 Mon Sep 17 00:00:00 2001 From: Jonas Hungershausen Date: Thu, 25 Jan 2024 10:25:22 +0100 Subject: [PATCH] chore: cr --- internal/testhelpers/selfservice_login.go | 1 + .../flow/login/extension_identifier_label.go | 2 +- selfservice/strategy/code/strategy.go | 27 +++++++++++++--- selfservice/strategy/code/strategy_test.go | 32 +++++++++++++++++++ 4 files changed, 56 insertions(+), 6 deletions(-) diff --git a/internal/testhelpers/selfservice_login.go b/internal/testhelpers/selfservice_login.go index 6469aec52ac3..2bd20f81dfd4 100644 --- a/internal/testhelpers/selfservice_login.go +++ b/internal/testhelpers/selfservice_login.go @@ -122,6 +122,7 @@ func InitFlowWithOAuth2LoginChallenge(hlc string) InitFlowWithOption { } } +// InitFlowWithVia sets the `via` query parameter which is used by the code MFA flows to determine the trait to use to send the code to the user func InitFlowWithVia(via string) InitFlowWithOption { return func(o *initFlowOptions) { o.via = via diff --git a/selfservice/flow/login/extension_identifier_label.go b/selfservice/flow/login/extension_identifier_label.go index 1774f62a9c3c..9d28dd8ecf9f 100644 --- a/selfservice/flow/login/extension_identifier_label.go +++ b/selfservice/flow/login/extension_identifier_label.go @@ -23,7 +23,7 @@ type identifierLabelExtension struct { var ( _ schema.CompileExtension = new(identifierLabelExtension) - ErrUnknownTrait = herodot.ErrBadRequest.WithReasonf("Trait does not exist in identity schema") + ErrUnknownTrait = herodot.ErrInternalServerError.WithReasonf("Trait does not exist in identity schema") ) func GetIdentifierLabelFromSchema(ctx context.Context, schemaURL string) (*text.Message, error) { diff --git a/selfservice/strategy/code/strategy.go b/selfservice/strategy/code/strategy.go index d33c5f34c388..61bca5915638 100644 --- a/selfservice/strategy/code/strategy.go +++ b/selfservice/strategy/code/strategy.go @@ -221,7 +221,7 @@ func (s *Strategy) populateChooseMethodFlow(r *http.Request, f flow.Flow) error codeMetaLabel = text.NewInfoSelfServiceLoginCodeMFA() idNode := node.NewInputField("identifier", "", node.DefaultGroup, node.InputAttributeTypeText, node.WithRequiredInputAttribute).WithMetaLabel(identifierLabel) - idNode.Messages.Add(text.NewInfoSelfServiceLoginCodeMFAHint(maskAddress(value))) + idNode.Messages.Add(text.NewInfoSelfServiceLoginCodeMFAHint(MaskAddress(value))) f.GetUI().Nodes.Upsert(idNode) } else { codeMetaLabel = text.NewInfoSelfServiceLoginCode() @@ -409,10 +409,27 @@ func GenerateCode() string { return randx.MustString(CodeLength, randx.Numeric) } -func maskAddress(input string) string { +// MaskAddress masks an address by replacing the middle part with asterisks. +// +// If the address contains an @, the part before the @ is masked by taking the first 2 characters and adding 4 * +// (if the part before the @ is less than 2 characters the full value is used). +// Otherwise, the first 3 characters and last two characters are taken and 4 * are added in between. +// +// Examples: +// - foo@bar -> fo****@bar +// - foobar -> fo****ar +// - fo@bar -> fo@bar +// - +12345678910 -> +12****10 +func MaskAddress(input string) string { if strings.Contains(input, "@") { - parts := strings.Split(input, "@") - return parts[0][:2] + strings.Repeat("*", 4) + "@" + parts[1] + pre, post, found := strings.Cut(input, "@") + if !found || len(pre) < 2 { + return input + } + return pre[:2] + strings.Repeat("*", 4) + "@" + post + } + if len(input) < 6 { + return input } - return input[:3] + strings.Repeat("*", 4) + input[len(input)-3:] + return input[:3] + strings.Repeat("*", 4) + input[len(input)-2:] } diff --git a/selfservice/strategy/code/strategy_test.go b/selfservice/strategy/code/strategy_test.go index 634605cba3f3..5a6234531d9d 100644 --- a/selfservice/strategy/code/strategy_test.go +++ b/selfservice/strategy/code/strategy_test.go @@ -38,3 +38,35 @@ func TestGenerateCode(t *testing.T) { assert.Len(t, stringslice.Unique(codes), len(codes)) } + +func TestMaskAddress(t *testing.T) { + for _, tc := range []struct { + address string + expected string + }{ + { + address: "a", + expected: "a", + }, + { + address: "fixed@ory.sh", + expected: "fi****@ory.sh", + }, + { + address: "f@ory.sh", + expected: "f@ory.sh", + }, + { + address: "+12345678910", + expected: "+12****10", + }, + { + address: "+123456", + expected: "+12****56", + }, + } { + t.Run("case="+tc.address, func(t *testing.T) { + assert.Equal(t, tc.expected, code.MaskAddress(tc.address)) + }) + } +}