diff --git a/identity/handler.go b/identity/handler.go index 2749396632b3..9eb5f11cba99 100644 --- a/identity/handler.go +++ b/identity/handler.go @@ -138,6 +138,13 @@ type listIdentitiesResponse struct { type listIdentitiesParameters struct { migrationpagination.RequestParameters + // IdsFilter is list of ids used to filter identities. + // If this list is empty, then no filter will be applied. + // + // required: false + // in: query + IdsFilter []string `json:"ids_filter"` + // CredentialsIdentifier is the identifier (username, email) of the credentials to look up using exact match. // Only one of CredentialsIdentifier and CredentialsIdentifierSimilar can be used. // @@ -180,6 +187,7 @@ func (h *Handler) list(w http.ResponseWriter, r *http.Request, _ httprouter.Para err error params = ListIdentityParameters{ Expand: ExpandDefault, + IdsFilter: r.URL.Query()["ids"], CredentialsIdentifier: r.URL.Query().Get("credentials_identifier"), CredentialsIdentifierSimilar: r.URL.Query().Get("preview_credentials_identifier_similar"), ConsistencyLevel: crdbx.ConsistencyLevelFromRequest(r), diff --git a/identity/handler_test.go b/identity/handler_test.go index 2e38de8db126..0d5c031cbfad 100644 --- a/identity/handler_test.go +++ b/identity/handler_test.go @@ -339,6 +339,34 @@ func TestHandler(t *testing.T) { } }) + t.Run("suite=create and batch list", func(t *testing.T) { + var ids []uuid.UUID + identitiesAmount := 5 + listAmount := 3 + t.Run("case= create multiple identities", func(t *testing.T) { + for i := 0; i < identitiesAmount; i++ { + res := send(t, adminTS, "POST", "/identities", http.StatusCreated, json.RawMessage(`{"traits": {"bar":"baz"}}`)) + assert.NotEmpty(t, res.Get("id").String(), "%s", res.Raw) + + id := x.ParseUUID(res.Get("id").String()) + ids = append(ids, id) + } + require.Equal(t, len(ids), identitiesAmount) + }) + + t.Run("case= list few identities", func(t *testing.T) { + url := "/identities?ids=" + ids[0].String() + for i := 1; i < listAmount; i++ { + url += "&ids=" + ids[i].String() + } + res := get(t, adminTS, url, 200) + + identities := res.Array() + require.Equal(t, len(identities), listAmount) + }) + + }) + t.Run("suite=create and update", func(t *testing.T) { var i identity.Identity createOidcIdentity := func(t *testing.T, identifier, accessToken, refreshToken, idToken string, encrypt bool) string { diff --git a/identity/pool.go b/identity/pool.go index 46083d3d525b..1da5181b796a 100644 --- a/identity/pool.go +++ b/identity/pool.go @@ -18,6 +18,7 @@ import ( type ( ListIdentityParameters struct { Expand Expandables + IdsFilter []string CredentialsIdentifier string CredentialsIdentifierSimilar string KeySetPagination []keysetpagination.Option diff --git a/identity/test/pool.go b/identity/test/pool.go index 371d4a2cf6b1..8acf9dbfd078 100644 --- a/identity/test/pool.go +++ b/identity/test/pool.go @@ -657,6 +657,17 @@ func TestPool(ctx context.Context, conf *config.Config, p persistence.Persister, assert.Len(t, is, 0) }) + t.Run("list some using ids filter", func(t *testing.T) { + var filterIds []string + for _, id := range createdIDs[:2] { + filterIds = append(filterIds, id.String()) + } + + is, _, err := p.ListIdentities(ctx, identity.ListIdentityParameters{Expand: identity.ExpandDefault, IdsFilter: filterIds}) + require.NoError(t, err) + assert.Len(t, is, len(filterIds)) + }) + t.Run("eventually consistent", func(t *testing.T) { if dbname != "cockroach" { t.Skipf("Test only works with cockroachdb") diff --git a/persistence/sql/identity/persister_identity.go b/persistence/sql/identity/persister_identity.go index 6e3566b7721c..82d7b035e295 100644 --- a/persistence/sql/identity/persister_identity.go +++ b/persistence/sql/identity/persister_identity.go @@ -774,6 +774,13 @@ func (p *IdentityPersister) ListIdentities(ctx context.Context, params identity. args = append(args, nid, nid, identifier, identity.CredentialsTypeWebAuthn, identity.CredentialsTypePassword) } + if params.IdsFilter != nil && len(params.IdsFilter) != 0 { + wheres += ` + AND identities.id in (?) + ` + args = append(args, params.IdsFilter) + } + query := fmt.Sprintf(` SELECT DISTINCT identities.* FROM identities AS identities