From bff9c61b147648ab139e7e86cda4336b5d1cfd39 Mon Sep 17 00:00:00 2001 From: Henning Perl Date: Fri, 2 Feb 2024 09:59:28 +0100 Subject: [PATCH] feat: list by OIDC cred (#3721) --- identity/handler_test.go | 42 +++++++++++++------ identity/test/pool.go | 6 +-- internal/client-go/go.sum | 1 + .../sql/identity/persister_identity.go | 12 ++++-- 4 files changed, 42 insertions(+), 19 deletions(-) diff --git a/identity/handler_test.go b/identity/handler_test.go index 0d5c031cbfad..ab20e8780ca6 100644 --- a/identity/handler_test.go +++ b/identity/handler_test.go @@ -466,29 +466,47 @@ func TestHandler(t *testing.T) { }) t.Run("case=should be able to lookup the identity using identifier", func(t *testing.T) { - i1 := &identity.Identity{ + ident := &identity.Identity{ Credentials: map[identity.CredentialsType]identity.Credentials{ identity.CredentialsTypePassword: { Type: identity.CredentialsTypePassword, Identifiers: []string{"find.by.identifier@bar.com"}, Config: sqlxx.JSONRawMessage(`{"hashed_password":"$2a$08$.cOYmAd.vCpDOoiVJrO5B.hjTLKQQ6cAK40u8uB.FnZDyPvVvQ9Q."}`), // foobar }, + identity.CredentialsTypeOIDC: { + Type: identity.CredentialsTypeOIDC, + Identifiers: []string{"ProviderID:293b5d9b-1009-4600-a3e9-bd1845de22f2"}, + Config: sqlxx.JSONRawMessage("{\"some\" : \"secret\"}"), + }, }, State: identity.StateActive, Traits: identity.Traits(`{"username":"find.by.identifier@bar.com"}`), } + require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(context.Background(), ident)) + + t.Run("type=password", func(t *testing.T) { + res := get(t, adminTS, "/identities?credentials_identifier=FIND.BY.IDENTIFIER@bar.com", http.StatusOK) + assert.EqualValues(t, ident.ID.String(), res.Get("0.id").String(), "%s", res.Raw) + assert.EqualValues(t, "find.by.identifier@bar.com", res.Get("0.traits.username").String(), "%s", res.Raw) + assert.EqualValues(t, defaultSchemaExternalURL, res.Get("0.schema_url").String(), "%s", res.Raw) + assert.EqualValues(t, config.DefaultIdentityTraitsSchemaID, res.Get("0.schema_id").String(), "%s", res.Raw) + assert.EqualValues(t, identity.StateActive, res.Get("0.state").String(), "%s", res.Raw) + assert.EqualValues(t, "password", res.Get("0.credentials.password.type").String(), res.Raw) + assert.EqualValues(t, "1", res.Get("0.credentials.password.identifiers.#").String(), res.Raw) + assert.EqualValues(t, "find.by.identifier@bar.com", res.Get("0.credentials.password.identifiers.0").String(), res.Raw) + }) - require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(context.Background(), i1)) - - res := get(t, adminTS, "/identities?credentials_identifier=find.by.identifier@bar.com", http.StatusOK) - assert.EqualValues(t, i1.ID.String(), res.Get("0.id").String(), "%s", res.Raw) - assert.EqualValues(t, "find.by.identifier@bar.com", res.Get("0.traits.username").String(), "%s", res.Raw) - assert.EqualValues(t, defaultSchemaExternalURL, res.Get("0.schema_url").String(), "%s", res.Raw) - assert.EqualValues(t, config.DefaultIdentityTraitsSchemaID, res.Get("0.schema_id").String(), "%s", res.Raw) - assert.EqualValues(t, identity.StateActive, res.Get("0.state").String(), "%s", res.Raw) - assert.EqualValues(t, "password", res.Get("0.credentials.password.type").String(), res.Raw) - assert.EqualValues(t, "1", res.Get("0.credentials.password.identifiers.#").String(), res.Raw) - assert.EqualValues(t, "find.by.identifier@bar.com", res.Get("0.credentials.password.identifiers.0").String(), res.Raw) + t.Run("type=oidc", func(t *testing.T) { + res := get(t, adminTS, "/identities?credentials_identifier=ProviderID:293b5d9b-1009-4600-a3e9-bd1845de22f2", http.StatusOK) + assert.EqualValues(t, ident.ID.String(), res.Get("0.id").String(), "%s", res.Raw) + assert.EqualValues(t, "find.by.identifier@bar.com", res.Get("0.traits.username").String(), "%s", res.Raw) + assert.EqualValues(t, defaultSchemaExternalURL, res.Get("0.schema_url").String(), "%s", res.Raw) + assert.EqualValues(t, config.DefaultIdentityTraitsSchemaID, res.Get("0.schema_id").String(), "%s", res.Raw) + assert.EqualValues(t, identity.StateActive, res.Get("0.state").String(), "%s", res.Raw) + assert.EqualValues(t, "oidc", res.Get("0.credentials.oidc.type").String(), res.Raw) + assert.EqualValues(t, "1", res.Get("0.credentials.oidc.identifiers.#").String(), res.Raw) + assert.EqualValues(t, "ProviderID:293b5d9b-1009-4600-a3e9-bd1845de22f2", res.Get("0.credentials.oidc.identifiers.0").String(), res.Raw) + }) }) t.Run("case=should get oidc credential", func(t *testing.T) { diff --git a/identity/test/pool.go b/identity/test/pool.go index 8acf9dbfd078..c351c57c6160 100644 --- a/identity/test/pool.go +++ b/identity/test/pool.go @@ -771,7 +771,7 @@ func TestPool(ctx context.Context, conf *config.Config, p persistence.Persister, Expand: identity.ExpandCredentials, }) require.NoError(t, err) - assert.Len(t, actual, 3) + assert.Len(t, actual, 4) // webauthn, common, password, oidc outer: for _, e := range append(expectedIdentities[:2], create) { @@ -789,13 +789,13 @@ func TestPool(ctx context.Context, conf *config.Config, p persistence.Persister, } }) - t.Run("only webauthn and password", func(t *testing.T) { + t.Run("find by OIDC identifier", func(t *testing.T) { actual, next, err := p.ListIdentities(ctx, identity.ListIdentityParameters{ CredentialsIdentifier: "find-identity-by-identifier-oidc@ory.sh", Expand: identity.ExpandEverything, }) require.NoError(t, err) - assert.Len(t, actual, 0) + assert.Len(t, actual, 1) assert.True(t, next.IsLast()) }) diff --git a/internal/client-go/go.sum b/internal/client-go/go.sum index c966c8ddfd0d..6cc3f5911d11 100644 --- a/internal/client-go/go.sum +++ b/internal/client-go/go.sum @@ -4,6 +4,7 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e h1:bRhVy7zSSasaqNksaRZiA5EEI+Ei4I1nO5Jh72wfHlg= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/persistence/sql/identity/persister_identity.go b/persistence/sql/identity/persister_identity.go index 1bfbf107ee0a..00f3a22d38dd 100644 --- a/persistence/sql/identity/persister_identity.go +++ b/persistence/sql/identity/persister_identity.go @@ -762,16 +762,20 @@ func (p *IdentityPersister) ListIdentities(ctx context.Context, params identity. if len(identifier) > 0 { // When filtering by credentials identifier, we most likely are looking for a username or email. It is therefore // important to normalize the identifier before querying the database. - identifier = NormalizeIdentifier(identity.CredentialsTypePassword, identifier) joins = ` INNER JOIN identity_credentials ic ON ic.identity_id = identities.id INNER JOIN identity_credential_types ict ON ict.id = ic.identity_credential_type_id INNER JOIN identity_credential_identifiers ici ON ici.identity_credential_id = ic.id` wheres += fmt.Sprintf(` - AND (ic.nid = ? AND ici.nid = ? AND ici.identifier %s ?) - AND ict.name IN (?, ?)`, identifierOperator) - args = append(args, nid, nid, identifier, identity.CredentialsTypeWebAuthn, identity.CredentialsTypePassword) + AND ic.nid = ? AND ici.nid = ? + AND ((ict.name IN (?, ?, ?) AND ici.identifier %s ?) + OR (ict.name IN (?) AND ici.identifier %s ?)) + `, identifierOperator, identifierOperator) + args = append(args, + nid, nid, + identity.CredentialsTypeWebAuthn, identity.CredentialsTypePassword, identity.CredentialsTypeCodeAuth, NormalizeIdentifier(identity.CredentialsTypePassword, identifier), + identity.CredentialsTypeOIDC, identifier) } if params.IdsFilter != nil && len(params.IdsFilter) != 0 {