Skip to content

Commit

Permalink
feat: list by OIDC cred (#3721)
Browse files Browse the repository at this point in the history
  • Loading branch information
hperl authored Feb 2, 2024
1 parent 1c3eeb7 commit bff9c61
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 19 deletions.
42 changes: 30 additions & 12 deletions identity/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -466,29 +466,47 @@ func TestHandler(t *testing.T) {
})

t.Run("case=should be able to lookup the identity using identifier", func(t *testing.T) {
i1 := &identity.Identity{
ident := &identity.Identity{
Credentials: map[identity.CredentialsType]identity.Credentials{
identity.CredentialsTypePassword: {
Type: identity.CredentialsTypePassword,
Identifiers: []string{"[email protected]"},
Config: sqlxx.JSONRawMessage(`{"hashed_password":"$2a$08$.cOYmAd.vCpDOoiVJrO5B.hjTLKQQ6cAK40u8uB.FnZDyPvVvQ9Q."}`), // foobar
},
identity.CredentialsTypeOIDC: {
Type: identity.CredentialsTypeOIDC,
Identifiers: []string{"ProviderID:293b5d9b-1009-4600-a3e9-bd1845de22f2"},
Config: sqlxx.JSONRawMessage("{\"some\" : \"secret\"}"),
},
},
State: identity.StateActive,
Traits: identity.Traits(`{"username":"[email protected]"}`),
}
require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(context.Background(), ident))

t.Run("type=password", func(t *testing.T) {
res := get(t, adminTS, "/[email protected]", http.StatusOK)
assert.EqualValues(t, ident.ID.String(), res.Get("0.id").String(), "%s", res.Raw)
assert.EqualValues(t, "[email protected]", res.Get("0.traits.username").String(), "%s", res.Raw)
assert.EqualValues(t, defaultSchemaExternalURL, res.Get("0.schema_url").String(), "%s", res.Raw)
assert.EqualValues(t, config.DefaultIdentityTraitsSchemaID, res.Get("0.schema_id").String(), "%s", res.Raw)
assert.EqualValues(t, identity.StateActive, res.Get("0.state").String(), "%s", res.Raw)
assert.EqualValues(t, "password", res.Get("0.credentials.password.type").String(), res.Raw)
assert.EqualValues(t, "1", res.Get("0.credentials.password.identifiers.#").String(), res.Raw)
assert.EqualValues(t, "[email protected]", res.Get("0.credentials.password.identifiers.0").String(), res.Raw)
})

require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(context.Background(), i1))

res := get(t, adminTS, "/[email protected]", http.StatusOK)
assert.EqualValues(t, i1.ID.String(), res.Get("0.id").String(), "%s", res.Raw)
assert.EqualValues(t, "[email protected]", res.Get("0.traits.username").String(), "%s", res.Raw)
assert.EqualValues(t, defaultSchemaExternalURL, res.Get("0.schema_url").String(), "%s", res.Raw)
assert.EqualValues(t, config.DefaultIdentityTraitsSchemaID, res.Get("0.schema_id").String(), "%s", res.Raw)
assert.EqualValues(t, identity.StateActive, res.Get("0.state").String(), "%s", res.Raw)
assert.EqualValues(t, "password", res.Get("0.credentials.password.type").String(), res.Raw)
assert.EqualValues(t, "1", res.Get("0.credentials.password.identifiers.#").String(), res.Raw)
assert.EqualValues(t, "[email protected]", res.Get("0.credentials.password.identifiers.0").String(), res.Raw)
t.Run("type=oidc", func(t *testing.T) {
res := get(t, adminTS, "/identities?credentials_identifier=ProviderID:293b5d9b-1009-4600-a3e9-bd1845de22f2", http.StatusOK)
assert.EqualValues(t, ident.ID.String(), res.Get("0.id").String(), "%s", res.Raw)
assert.EqualValues(t, "[email protected]", res.Get("0.traits.username").String(), "%s", res.Raw)
assert.EqualValues(t, defaultSchemaExternalURL, res.Get("0.schema_url").String(), "%s", res.Raw)
assert.EqualValues(t, config.DefaultIdentityTraitsSchemaID, res.Get("0.schema_id").String(), "%s", res.Raw)
assert.EqualValues(t, identity.StateActive, res.Get("0.state").String(), "%s", res.Raw)
assert.EqualValues(t, "oidc", res.Get("0.credentials.oidc.type").String(), res.Raw)
assert.EqualValues(t, "1", res.Get("0.credentials.oidc.identifiers.#").String(), res.Raw)
assert.EqualValues(t, "ProviderID:293b5d9b-1009-4600-a3e9-bd1845de22f2", res.Get("0.credentials.oidc.identifiers.0").String(), res.Raw)
})
})

t.Run("case=should get oidc credential", func(t *testing.T) {
Expand Down
6 changes: 3 additions & 3 deletions identity/test/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -771,7 +771,7 @@ func TestPool(ctx context.Context, conf *config.Config, p persistence.Persister,
Expand: identity.ExpandCredentials,
})
require.NoError(t, err)
assert.Len(t, actual, 3)
assert.Len(t, actual, 4) // webauthn, common, password, oidc

outer:
for _, e := range append(expectedIdentities[:2], create) {
Expand All @@ -789,13 +789,13 @@ func TestPool(ctx context.Context, conf *config.Config, p persistence.Persister,
}
})

t.Run("only webauthn and password", func(t *testing.T) {
t.Run("find by OIDC identifier", func(t *testing.T) {
actual, next, err := p.ListIdentities(ctx, identity.ListIdentityParameters{
CredentialsIdentifier: "[email protected]",
Expand: identity.ExpandEverything,
})
require.NoError(t, err)
assert.Len(t, actual, 0)
assert.Len(t, actual, 1)
assert.True(t, next.IsLast())
})

Expand Down
1 change: 1 addition & 0 deletions internal/client-go/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e h1:bRhVy7zSSasaqNksaRZiA5EEI+Ei4I1nO5Jh72wfHlg=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
Expand Down
12 changes: 8 additions & 4 deletions persistence/sql/identity/persister_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -762,16 +762,20 @@ func (p *IdentityPersister) ListIdentities(ctx context.Context, params identity.
if len(identifier) > 0 {
// When filtering by credentials identifier, we most likely are looking for a username or email. It is therefore
// important to normalize the identifier before querying the database.
identifier = NormalizeIdentifier(identity.CredentialsTypePassword, identifier)

joins = `
INNER JOIN identity_credentials ic ON ic.identity_id = identities.id
INNER JOIN identity_credential_types ict ON ict.id = ic.identity_credential_type_id
INNER JOIN identity_credential_identifiers ici ON ici.identity_credential_id = ic.id`
wheres += fmt.Sprintf(`
AND (ic.nid = ? AND ici.nid = ? AND ici.identifier %s ?)
AND ict.name IN (?, ?)`, identifierOperator)
args = append(args, nid, nid, identifier, identity.CredentialsTypeWebAuthn, identity.CredentialsTypePassword)
AND ic.nid = ? AND ici.nid = ?
AND ((ict.name IN (?, ?, ?) AND ici.identifier %s ?)
OR (ict.name IN (?) AND ici.identifier %s ?))
`, identifierOperator, identifierOperator)
args = append(args,
nid, nid,
identity.CredentialsTypeWebAuthn, identity.CredentialsTypePassword, identity.CredentialsTypeCodeAuth, NormalizeIdentifier(identity.CredentialsTypePassword, identifier),
identity.CredentialsTypeOIDC, identifier)
}

if params.IdsFilter != nil && len(params.IdsFilter) != 0 {
Expand Down

0 comments on commit bff9c61

Please sign in to comment.