diff --git a/docs/dev/identity-provider-tracking.md b/docs/dev/identity-provider-tracking.md new file mode 100644 index 0000000000..8276a15f90 --- /dev/null +++ b/docs/dev/identity-provider-tracking.md @@ -0,0 +1,29 @@ +# Identity Provider Tracking +Infra tracks users and groups that exist in external identity providers and maps them access. The following database diagram shows how the relation of groups and users in identity providers are mapped to users and groups within Infra. + +Our identity system has two different type of relations, one of which exists in Infra (users and groups controlled by Infra), and the other relation which exists externally (information about users and groups which exist in an identity provider). + +To get all the groups a user belongs to as a combination of direct group membership and membership in a mapped identity provider group use the `resolved_identity_groups` view. + +## Users and Groups Managed by Infra +The following tables represent users whose identity is managed by Infra and groups which are controlled by Infra. +- **Identities**: Known information about a user. This may be a an email and password hash that can be used to login to Infra directly, or information about a user that has logged via an identity provider. +- **Groups**: Groups which exist in Infra that membership to grants access to some resource. +- **Identities_Groups**: Identities that have been added as a member of a group within Infra, this membership is added through either membership in an external identity provider group or through direct assignment (represented in the database as being part of an Infra provider group). + +## Users and Groups Managed by an Identity Provider +- **Provider_User**: A user that exists in an identity provider. +- **Provider_Group**: A group that exists in an identity provider. +- **Provider_Groups_Provider_Users**: This relation shows which users of an identity provders are members of which groups in the same identity provider. + +```mermaid +erDiagram + PROVIDER_GROUPS_MAPPINGS }o--|| PROVIDER : "" + GROUP ||--o{ IDENTITIES_GROUPS : "direct membership to infra group" + IDENTITIES_GROUPS }o--|| IDENTITY : "direct membership to infra group" + IDENTITY ||--o{ PROVIDER_USER : "infra identity composed of identity provider users" + PROVIDER_GROUP }o--|| PROVIDER : "sync groups to infra" + PROVIDER_GROUP ||--o{ PROVIDER_GROUPS_PROVIDER_USERS : "sync members to infra" + PROVIDER_GROUPS_PROVIDER_USERS }o--|| PROVIDER_USER : "sync members to infra" + PROVIDER_USER }o--|| PROVIDER : "sync identities to infra" +``` \ No newline at end of file diff --git a/internal/access/access_test.go b/internal/access/access_test.go index 44d102ee72..f6329cb8f0 100644 --- a/internal/access/access_test.go +++ b/internal/access/access_test.go @@ -102,13 +102,13 @@ func TestRequireInfraRole_GrantsFromGroupMembership(t *testing.T) { err := data.CreateIdentity(db, tom) assert.NilError(t, err) - _, err = data.CreateProviderUser(db, provider, tom) + providedTom, err := data.CreateProviderUser(db, provider, tom) assert.NilError(t, err) err = data.CreateGroup(db, tomsGroup) assert.NilError(t, err) - err = data.AssignIdentityToGroups(db, tom, provider, []string{tomsGroup.Name}) + err = data.AssignUserToProviderGroups(db, providedTom, provider, []string{tomsGroup.Name}) assert.NilError(t, err) c, _ := gin.CreateTestContext(httptest.NewRecorder()) diff --git a/internal/access/group.go b/internal/access/group.go index b9d5531492..e82b5b8e99 100644 --- a/internal/access/group.go +++ b/internal/access/group.go @@ -114,7 +114,7 @@ func UpdateUsersInGroup(c *gin.Context, groupID uid.ID, uidsToAdd []uid.ID, uids return err } - _, err = data.GetGroup(db, data.ByID(groupID)) + group, err := data.GetGroup(db, data.ByID(groupID)) if err != nil { return err } @@ -130,13 +130,13 @@ func UpdateUsersInGroup(c *gin.Context, groupID uid.ID, uidsToAdd []uid.ID, uids } if len(addIDList) > 0 { - if err := data.AddUsersToGroup(db, groupID, addIDList); err != nil { + if err := data.AddUsersToGroup(db, group.ID, group.Name, data.InfraProvider(db).ID, addIDList); err != nil { return err } } if len(rmIDList) > 0 { - if err := data.RemoveUsersFromGroup(db, groupID, rmIDList); err != nil { + if err := data.RemoveUsersFromGroup(db, group.ID, rmIDList); err != nil { return err } } diff --git a/internal/access/identity.go b/internal/access/identity.go index 5d6718af13..6b2569b7c5 100644 --- a/internal/access/identity.go +++ b/internal/access/identity.go @@ -130,7 +130,7 @@ func UpdateIdentityInfoFromProvider(c RequestContext, oidc providers.OIDCClient) return fmt.Errorf("user info provider: %w", err) } - // get current identity provider groups and account status + // update current identity provider groups and account status err = data.SyncProviderUser(ctx, db, identity, provider, oidc) if err != nil { if errors.Is(err, internal.ErrBadGateway) { @@ -141,10 +141,6 @@ func UpdateIdentityInfoFromProvider(c RequestContext, oidc providers.OIDCClient) logging.Errorf("failed to revoke invalid user session: %s", nestedErr) } - if nestedErr := data.DeleteProviderUsers(db, data.DeleteProviderUsersOptions{ByIdentityID: identity.ID, ByProviderID: provider.ID}); nestedErr != nil { - logging.Errorf("failed to delete provider user: %s", nestedErr) - } - return fmt.Errorf("sync user: %w", err) } diff --git a/internal/access/identity_test.go b/internal/access/identity_test.go index ef5ff5216b..9dea721eca 100644 --- a/internal/access/identity_test.go +++ b/internal/access/identity_test.go @@ -88,7 +88,7 @@ func TestDeleteIdentityCleansUpResources(t *testing.T) { group := &models.Group{Name: "Group"} err = data.CreateGroup(db, group) assert.NilError(t, err) - err = data.AddUsersToGroup(db, group.ID, []uid.ID{identity.ID}) + err = data.AddUsersToGroup(db, group.ID, group.Name, data.InfraProvider(db).ID, []uid.ID{identity.ID}) assert.NilError(t, err) // delete the identity, and make sure all their resources are gone diff --git a/internal/server/authn/oidc_test.go b/internal/server/authn/oidc_test.go index 70f59f4f50..b9160a103f 100644 --- a/internal/server/authn/oidc_test.go +++ b/internal/server/authn/oidc_test.go @@ -71,11 +71,14 @@ func TestOIDCAuthenticate(t *testing.T) { // user should be created assert.Equal(t, authnIdentity.Identity.Name, "bruce@example.com") + syncGroups, err := data.ListGroups(db, nil, data.ByGroupMember(authnIdentity.Identity.ID)) + assert.NilError(t, err) + groups := make(map[string]bool) - for _, g := range authnIdentity.Identity.Groups { + for _, g := range syncGroups { groups[g.Name] = true } - assert.Assert(t, len(authnIdentity.Identity.Groups) == 2) + assert.Equal(t, len(syncGroups), 2) assert.Equal(t, groups["Everyone"], true) assert.Equal(t, groups["developers"], true) @@ -198,22 +201,12 @@ func TestExchangeAuthCodeForProviderTokens(t *testing.T) { err := data.CreateIdentity(db, user) assert.NilError(t, err) - for _, name := range []string{"Foo", "existing3"} { - group := &models.Group{Name: name} - err = data.CreateGroup(db, group) - assert.NilError(t, err) - err = data.AddUsersToGroup(db, group.ID, []uid.ID{user.ID}) - assert.NilError(t, err) - } - - g, err := data.GetGroup(db, data.ByName("Foo")) + group := &models.Group{Name: "Foo"} + err = data.CreateGroup(db, group) assert.NilError(t, err) - assert.Assert(t, g != nil) - user, err = data.GetIdentity(db, data.GetIdentityOptions{ByID: user.ID, LoadGroups: true}) + err = data.AddUsersToGroup(db, group.ID, group.Name, data.InfraProvider(db).ID, []uid.ID{user.ID}) assert.NilError(t, err) - assert.Assert(t, user != nil) - assert.Equal(t, len(user.Groups), 2) p, err := data.GetProvider(db, data.ByName("mockoidc")) assert.NilError(t, err) @@ -221,8 +214,8 @@ func TestExchangeAuthCodeForProviderTokens(t *testing.T) { pu, err := data.CreateProviderUser(db, p, user) assert.NilError(t, err) - pu.Groups = []string{"existing3"} - assert.NilError(t, data.UpdateProviderUser(db, pu)) + err = data.AssignUserToProviderGroups(db, pu, p, []string{"existing3"}) + assert.NilError(t, err) return &mockOIDCImplementation{ UserEmailResp: "eugwnw@example.com", @@ -234,7 +227,7 @@ func TestExchangeAuthCodeForProviderTokens(t *testing.T) { assert.Equal(t, "mockoidc", a.Provider.Name) assert.Assert(t, a.SessionExpiry.Equal(sessionExpiry)) - assert.Assert(t, len(a.Identity.Groups) == 3) + assert.Equal(t, len(a.Identity.Groups), 3) var groupNames []string for _, g := range a.Identity.Groups { @@ -261,15 +254,12 @@ func TestExchangeAuthCodeForProviderTokens(t *testing.T) { a, err := loginMethod.Authenticate(context.Background(), db, sessionExpiry) assert.NilError(t, err) - tc.expected(t, a) - if err == nil { - // make sure the associations are still set when you reload the object. - u, err := data.GetIdentity(db, data.GetIdentityOptions{ByID: a.Identity.ID, LoadGroups: true}) - assert.NilError(t, err) - a.Identity = u - tc.expected(t, a) - } + syncGroups, err := data.ListGroups(db, nil, data.ByGroupMember(a.Identity.ID)) + assert.NilError(t, err) + a.Identity.Groups = syncGroups + + tc.expected(t, a) }) } } diff --git a/internal/server/data/grant_test.go b/internal/server/data/grant_test.go index 4c26775f66..9e80766ae2 100644 --- a/internal/server/data/grant_test.go +++ b/internal/server/data/grant_test.go @@ -387,9 +387,9 @@ func TestListGrants(t *testing.T) { userID, err := uid.Parse([]byte("userchar")) assert.NilError(t, err) - assert.NilError(t, AddUsersToGroup(tx, uid.ID(111), []uid.ID{userID})) - assert.NilError(t, AddUsersToGroup(tx, uid.ID(112), []uid.ID{userID})) - assert.NilError(t, AddUsersToGroup(tx, uid.ID(113), []uid.ID{uid.ID(777)})) + assert.NilError(t, AddUsersToGroup(tx, uid.ID(111), "name-111", InfraProvider(db).ID, []uid.ID{userID})) + assert.NilError(t, AddUsersToGroup(tx, uid.ID(112), "name-111", InfraProvider(db).ID, []uid.ID{userID})) + assert.NilError(t, AddUsersToGroup(tx, uid.ID(113), "name-111", InfraProvider(db).ID, []uid.ID{uid.ID(777)})) gGrant1 := &models.Grant{ Subject: uid.NewGroupPolymorphicID(111), diff --git a/internal/server/data/group.go b/internal/server/data/group.go index ff1b68c5af..3c4023f3c3 100644 --- a/internal/server/data/group.go +++ b/internal/server/data/group.go @@ -49,7 +49,6 @@ func GetGroup(db GormTxn, selectors ...SelectorFunc) (*models.Group, error) { func ListGroups(db GormTxn, p *Pagination, selectors ...SelectorFunc) ([]models.Group, error) { groups, err := list[models.Group](db, p, selectors...) - if err != nil { return nil, err } @@ -63,7 +62,6 @@ func ListGroups(db GormTxn, p *Pagination, selectors ...SelectorFunc) ([]models. } return groups, nil - } func ByGroupMember(id uid.ID) SelectorFunc { @@ -75,7 +73,10 @@ func ByGroupMember(id uid.ID) SelectorFunc { } func groupIDsForUser(tx ReadTxn, userID uid.ID) ([]uid.ID, error) { - stmt := `SELECT DISTINCT group_id FROM identities_groups WHERE identity_id = ?` + stmt := ` + SELECT DISTINCT group_id FROM identities_groups + WHERE identity_id = ? + ` rows, err := tx.Query(stmt, userID) if err != nil { return nil, err @@ -101,7 +102,7 @@ func DeleteGroup(tx WriteTxn, id uid.ID) error { _, err = tx.Exec(`DELETE from identities_groups WHERE group_id = ?`, id) if err != nil { - return fmt.Errorf("remove useres from group: %w", err) + return fmt.Errorf("remove users from group: %w", err) } stmt := ` @@ -114,11 +115,11 @@ func DeleteGroup(tx WriteTxn, id uid.ID) error { return handleError(err) } -func AddUsersToGroup(tx WriteTxn, groupID uid.ID, idsToAdd []uid.ID) error { - query := querybuilder.New("INSERT INTO identities_groups(group_id, identity_id)") +func AddUsersToGroup(tx WriteTxn, groupID uid.ID, providerGroupName string, providerID uid.ID, idsToAdd []uid.ID) error { + query := querybuilder.New("INSERT INTO identities_groups(group_id, identity_id, provider_id, provider_group_name)") query.B("VALUES") for i, id := range idsToAdd { - query.B("(?, ?)", groupID, id) + query.B("(?, ?, ?, ?)", groupID, id, providerID, providerGroupName) if i+1 != len(idsToAdd) { query.B(",") } @@ -129,6 +130,21 @@ func AddUsersToGroup(tx WriteTxn, groupID uid.ID, idsToAdd []uid.ID) error { return handleError(err) } +func AddUserToGroups(tx WriteTxn, providerID uid.ID, identityID uid.ID, groups []models.Group) error { + query := querybuilder.New("INSERT INTO identities_groups(provider_id, identity_id, group_id, provider_group_name)") + query.B("VALUES") + for i, group := range groups { + query.B("(?, ?, ?, ?)", providerID, identityID, group.ID, group.Name) + if i+1 != len(groups) { + query.B(",") + } + } + query.B("ON CONFLICT DO NOTHING") + + _, err := tx.Exec(query.String(), query.Args...) + return handleError(err) +} + // RemoveUsersFromGroup removes any user ID listed in idsToRemove from the group // with ID groupID. // Note that DeleteGroup also removes users from the group. @@ -138,11 +154,14 @@ func RemoveUsersFromGroup(tx WriteTxn, groupID uid.ID, idsToRemove []uid.ID) err return handleError(err) } -// TODO: do this with a join in ListGroups and GetGroup func CountUsersInGroup(tx GormTxn, groupID uid.ID) (int64, error) { db := tx.GormDB() var count int64 - err := db.Table("identities_groups").Where("group_id = ?", groupID).Count(&count).Error + err := db. + Table("identities_groups"). + Where("group_id = ?", groupID). + Distinct("identity_id"). + Count(&count).Error if err != nil { return 0, err } diff --git a/internal/server/data/group_test.go b/internal/server/data/group_test.go index 551844ed39..d5d66d5e41 100644 --- a/internal/server/data/group_test.go +++ b/internal/server/data/group_test.go @@ -96,6 +96,12 @@ func TestListGroups(t *testing.T) { Groups: []models.Group{everyone, product}, } createIdentities(t, db, &firstUser, &secondUser) + err := AddUsersToGroup(db, everyone.ID, everyone.Name, InfraProvider(db).ID, []uid.ID{firstUser.ID, secondUser.ID}) + assert.NilError(t, err) + err = AddUsersToGroup(db, engineers.ID, engineers.Name, InfraProvider(db).ID, []uid.ID{firstUser.ID}) + assert.NilError(t, err) + err = AddUsersToGroup(db, product.ID, product.Name, InfraProvider(db).ID, []uid.ID{secondUser.ID}) + assert.NilError(t, err) t.Run("all", func(t *testing.T) { actual, err := ListGroups(db, nil) @@ -226,19 +232,15 @@ func TestAddUsersToGroup(t *testing.T) { createGroups(t, db, &everyone, &other) var ( - bond = models.Identity{ - Name: "jbond@infrahq.com", - Groups: []models.Group{everyone}, - } + bond = models.Identity{Name: "jbond@infrahq.com"} bourne = models.Identity{Name: "jbourne@infrahq.com"} bauer = models.Identity{Name: "jbauer@infrahq.com"} - forth = models.Identity{ - Name: "forth@example.com", - Groups: []models.Group{everyone}, - } + forth = models.Identity{Name: "forth@example.com"} ) createIdentities(t, db, &bond, &bourne, &bauer, &forth) + err := AddUsersToGroup(db, everyone.ID, everyone.Name, InfraProvider(db).ID, []uid.ID{bond.ID, forth.ID}) + assert.NilError(t, err) t.Run("add identities to group", func(t *testing.T) { actual, err := ListIdentities(db, ListIdentityOptions{ByGroupID: everyone.ID}) @@ -246,7 +248,7 @@ func TestAddUsersToGroup(t *testing.T) { expected := []models.Identity{forth, bond} assert.DeepEqual(t, actual, expected, cmpModelsIdentityShallow) - err = AddUsersToGroup(db, everyone.ID, []uid.ID{bourne.ID, bauer.ID, forth.ID}) + err = AddUsersToGroup(db, everyone.ID, everyone.Name, InfraProvider(db).ID, []uid.ID{bourne.ID, bauer.ID, forth.ID}) assert.NilError(t, err) actual, err = ListIdentities(db, ListIdentityOptions{ByGroupID: everyone.ID}) @@ -286,6 +288,10 @@ func TestRemoveUsersFromGroup(t *testing.T) { Groups: []models.Group{everyone}, } createIdentities(t, tx, &bond, &bourne, &bauer, &forth) + err := AddUsersToGroup(tx, everyone.ID, everyone.Name, InfraProvider(db).ID, []uid.ID{bond.ID, bourne.ID, bauer.ID, forth.ID}) + assert.NilError(t, err) + err = AddUsersToGroup(tx, other.ID, other.Name, InfraProvider(db).ID, []uid.ID{bond.ID, bourne.ID, bauer.ID}) + assert.NilError(t, err) users, err := ListIdentities(tx, ListIdentityOptions{ByGroupID: everyone.ID}) assert.NilError(t, err) @@ -309,3 +315,33 @@ func TestRemoveUsersFromGroup(t *testing.T) { assert.DeepEqual(t, actual, expected, cmpModelsIdentityShallow) }) } + +func TestCountUsersInGroup(t *testing.T) { + runDBTests(t, func(t *testing.T, db *DB) { + everyone := models.Group{Name: "Everyone"} + createGroups(t, db, &everyone) + + mockta := &models.Provider{Name: "mokta", Kind: models.ProviderKindOkta} + err := CreateProvider(db, mockta) + assert.NilError(t, err) + + var ( + bond = models.Identity{Name: "jbond@infrahq.com"} + bourne = models.Identity{Name: "jbourne@infrahq.com"} + bauer = models.Identity{Name: "jbauer@infrahq.com"} + forth = models.Identity{Name: "forth@example.com"} + ) + + createIdentities(t, db, &bond, &bourne, &bauer, &forth) + err = AddUsersToGroup(db, everyone.ID, everyone.Name, InfraProvider(db).ID, []uid.ID{bond.ID, forth.ID}) + assert.NilError(t, err) + err = AddUsersToGroup(db, everyone.ID, everyone.Name, mockta.ID, []uid.ID{bond.ID, bourne.ID, bauer.ID}) + assert.NilError(t, err) + + t.Run("count users in group", func(t *testing.T) { + count, err := CountUsersInGroup(db, everyone.ID) + assert.NilError(t, err) + assert.Equal(t, int(count), 4) + }) + }) +} diff --git a/internal/server/data/identity.go b/internal/server/data/identity.go index 7a2e149066..4729acce4a 100644 --- a/internal/server/data/identity.go +++ b/internal/server/data/identity.go @@ -33,99 +33,75 @@ func (i *identitiesTable) ScanFields() []any { return []any{&i.CreatedAt, &i.CreatedBy, &i.DeletedAt, &i.ID, &i.LastSeenAt, &i.Name, &i.OrganizationID, &i.UpdatedAt, &i.VerificationToken, &i.Verified} } -func AssignIdentityToGroups(tx WriteTxn, user *models.Identity, provider *models.Provider, newGroups []string) error { - pu, err := GetProviderUser(tx, provider.ID, user.ID) +func AssignUserToProviderGroups(tx GormTxn, providerUser *models.ProviderUser, provider *models.Provider, newGroupNames []string) error { + // before applying new groups, see which groups this user is currently known to be in for this provider + previousUserGroupsForProvider, err := ListProviderGroups(tx, ListProviderGroupOptions{ByProviderID: provider.ID, ByMemberIdentityID: providerUser.IdentityID}) if err != nil { - return err + return fmt.Errorf("current user idp groups: %w", err) } - oldGroups := pu.Groups - groupsToBeRemoved := slice.Subtract(oldGroups, newGroups) - groupsToBeAdded := slice.Subtract(newGroups, oldGroups) - - pu.Groups = newGroups - pu.LastUpdate = time.Now().UTC() - if err := UpdateProviderUser(tx, pu); err != nil { - return fmt.Errorf("save: %w", err) + // get all the groups this user was previously known to have at this provider for comparison to the new group names + oldGroupNames := []string{} + for _, g := range previousUserGroupsForProvider { + oldGroupNames = append(oldGroupNames, g.Name) } - // remove user from groups + groupsToBeRemoved := slice.Subtract(oldGroupNames, newGroupNames) + groupsToBeAdded := slice.Subtract(newGroupNames, oldGroupNames) + if len(groupsToBeRemoved) > 0 { - stmt := `DELETE FROM identities_groups WHERE identity_id = ? AND group_id in ( - SELECT id FROM groups WHERE organization_id = ? AND name IN (?))` - if _, err := tx.Exec(stmt, user.ID, tx.OrganizationID(), groupsToBeRemoved); err != nil { - return err - } - for _, name := range groupsToBeRemoved { - for i, g := range user.Groups { - if g.Name == name { - // remove from list - user.Groups = append(user.Groups[:i], user.Groups[i+1:]...) - } - } + if err := removeMemberFromProviderGroups(tx, providerUser, groupsToBeRemoved); err != nil { + return fmt.Errorf("remove previous idp groups: %w", err) } } - type idNamePair struct { - ID uid.ID - Name string - } + if len(groupsToBeAdded) > 0 { + infraGroups := []models.Group{} + // make sure all the groups and provider groups exist, then add the membership relation + for _, name := range groupsToBeAdded { + // check if provider group with this name exists + _, err = GetProviderGroup(tx, provider.ID, name) + if err != nil { + if !errors.Is(err, internal.ErrNotFound) { + return err + } - stmt := `SELECT id, name FROM groups WHERE deleted_at is null AND name IN (?) AND organization_id = ?` - rows, err := tx.Query(stmt, groupsToBeAdded, tx.OrganizationID()) - if err != nil { - return err - } - addIDs, err := scanRows(rows, func(item *idNamePair) []any { - return []any{&item.ID, &item.Name} - }) - if err != nil { - return err - } + // this is the first time this provider group has been seen, create it + providerGroup := &models.ProviderGroup{ + ProviderID: provider.ID, + Name: name, + } - for _, name := range groupsToBeAdded { - // find or create group - var groupID uid.ID - found := false - for _, obj := range addIDs { - if obj.Name == name { - found = true - groupID = obj.ID - break - } - } - if !found { - group := &models.Group{ - Name: name, - CreatedByProvider: provider.ID, + if err := CreateProviderGroup(tx, providerGroup); err != nil { + return err + } } + // check if an infra group with this name exists + infraGroup, err := GetGroup(tx, ByName(name)) + if err != nil { + if !errors.Is(err, internal.ErrNotFound) { + return err + } + // this is the first time this group has been seen, create the infra version of this group + infraGroup = &models.Group{ + Name: name, + CreatedByProvider: provider.ID, + } - if err = CreateGroup(tx, group); err != nil { - return fmt.Errorf("create group: %w", err) + if err := CreateGroup(tx, infraGroup); err != nil { + return err + } } - groupID = group.ID + infraGroups = append(infraGroups, *infraGroup) } - - rows, err := tx.Query("SELECT identity_id FROM identities_groups WHERE identity_id = ? AND group_id = ?", user.ID, groupID) - if err != nil { - return err + // now that we know all the groups and provider groups exist in our database set the user as a member + if err := addMemberToProviderGroups(tx, providerUser, groupsToBeAdded); err != nil { + return fmt.Errorf("update existing provider groups: %w", err) } - ids, err := scanRows(rows, func(item *uid.ID) []any { - return []any{item} - }) - if err != nil { + // automatically map the identity provider group to the infra group with the same name + if err := AddUserToGroups(tx, providerUser.ProviderID, providerUser.IdentityID, infraGroups); err != nil { return err } - - if len(ids) == 0 { - // add user to group - _, err = tx.Exec("INSERT INTO identities_groups (identity_id, group_id) VALUES (?, ?)", user.ID, groupID) - if err != nil { - return fmt.Errorf("insert: %w", handleError(err)) - } - } - - user.Groups = append(user.Groups, models.Group{Model: models.Model{ID: groupID}, Name: name}) } return nil diff --git a/internal/server/data/identity_test.go b/internal/server/data/identity_test.go index ec61bf1632..e94a9ccfc2 100644 --- a/internal/server/data/identity_test.go +++ b/internal/server/data/identity_test.go @@ -14,10 +14,6 @@ import ( "github.com/infrahq/infra/uid" ) -var cmpModelsGroupShallow = cmp.Comparer(func(x, y models.Group) bool { - return x.Name == y.Name && x.OrganizationID == y.OrganizationID -}) - func TestCreateIdentity(t *testing.T) { runDBTests(t, func(t *testing.T, db *DB) { t.Run("success", func(t *testing.T) { @@ -47,7 +43,7 @@ func createIdentities(t *testing.T, db GormTxn, identities ...*models.Identity) assert.NilError(t, err, user.Name) for _, group := range user.Groups { - err = AddUsersToGroup(db, group.ID, []uid.ID{user.ID}) + err = AddUsersToGroup(db, group.ID, group.Name, InfraProvider(db).ID, []uid.ID{user.ID}) assert.NilError(t, err) } } @@ -148,6 +144,7 @@ func TestListIdentities(t *testing.T) { var ( bond = models.Identity{ Name: "jbond@infrahq.com", + Groups: []models.Group{everyone}, Providers: providers, } salt = models.Identity{ @@ -168,6 +165,9 @@ func TestListIdentities(t *testing.T) { ) createIdentities(t, db, &bond, &salt, &bourne, &bauer) + assert.NilError(t, AddUsersToGroup(db, everyone.ID, everyone.Name, InfraProvider(db).ID, []uid.ID{bond.ID, bourne.ID, bauer.ID})) + assert.NilError(t, AddUsersToGroup(db, product.ID, product.Name, InfraProvider(db).ID, []uid.ID{bourne.ID})) + connector := InfraConnectorIdentity(db) t.Run("list all", func(t *testing.T) { @@ -208,7 +208,7 @@ func TestListIdentities(t *testing.T) { t.Run("filter identities by group", func(t *testing.T) { actual, err := ListIdentities(db, ListIdentityOptions{ByGroupID: everyone.ID}) assert.NilError(t, err) - expected := []models.Identity{bauer, bourne} + expected := []models.Identity{bauer, bond, bourne} assert.DeepEqual(t, actual, expected, cmpModelsIdentityShallow) }) @@ -341,7 +341,7 @@ func TestDeleteIdentities(t *testing.T) { } err = CreateGroup(tx, group) assert.NilError(t, err) - err = AddUsersToGroup(tx, group.ID, []uid.ID{bond.ID}) + err = AddUsersToGroup(tx, group.ID, group.Name, InfraProvider(tx).ID, []uid.ID{bond.ID}) assert.NilError(t, err) err = CreateGrant(tx, &models.Grant{Subject: bond.PolyID(), Privilege: "admin", Resource: "infra"}) @@ -595,7 +595,7 @@ func TestDeleteIdentityWithGroups(t *testing.T) { createIdentities(t, db, &bond, &bourne, &bauer) - err = AddUsersToGroup(db, group.ID, []uid.ID{bond.ID, bourne.ID, bauer.ID}) + err = AddUsersToGroup(db, group.ID, group.Name, InfraProvider(db).ID, []uid.ID{bond.ID, bourne.ID, bauer.ID}) assert.NilError(t, err) opts := DeleteIdentitiesOptions{ @@ -613,17 +613,17 @@ func TestDeleteIdentityWithGroups(t *testing.T) { func TestAssignIdentityToGroups(t *testing.T) { tests := []struct { - Name string - StartingGroups []string // groups identity starts with - ExistingGroups []string // groups from last provider sync - IncomingGroups []string // groups from this provider sync - ExpectedGroups []models.Group // groups identity should have at end + Name string + StartingGroups []string // groups identity starts with + ExistingProviderGroups []string // provider groups this identity is in from last provider sync + IncomingProviderGroups []string // provider groups this identity is in from this provider sync + ExpectedGroups []models.Group // groups identity should have at end }{ { - Name: "test where the provider is trying to add a group the identity doesn't have elsewhere", - StartingGroups: []string{"foo"}, - ExistingGroups: []string{}, - IncomingGroups: []string{"foo2"}, + Name: "test where the provider is trying to add a group the identity doesn't have elsewhere", + StartingGroups: []string{"foo"}, + ExistingProviderGroups: []string{}, + IncomingProviderGroups: []string{"foo2"}, ExpectedGroups: []models.Group{ { Name: "foo", @@ -640,10 +640,10 @@ func TestAssignIdentityToGroups(t *testing.T) { }, }, { - Name: "test where the provider is trying to add a group the identity has from elsewhere", - StartingGroups: []string{"foo"}, - ExistingGroups: []string{}, - IncomingGroups: []string{"foo", "foo2"}, + Name: "test where the provider is trying to add a group the identity has from elsewhere", + StartingGroups: []string{"foo"}, + ExistingProviderGroups: []string{}, + IncomingProviderGroups: []string{"foo", "foo2"}, ExpectedGroups: []models.Group{ { Name: "foo", @@ -660,10 +660,10 @@ func TestAssignIdentityToGroups(t *testing.T) { }, }, { - Name: "test where the group with the same name exists in another org", - StartingGroups: []string{}, - ExistingGroups: []string{}, - IncomingGroups: []string{"Everyone"}, + Name: "test where the group with the same name exists in another org", + StartingGroups: []string{}, + ExistingProviderGroups: []string{}, + IncomingProviderGroups: []string{"Everyone"}, ExpectedGroups: []models.Group{ { Name: "Everyone", @@ -673,6 +673,55 @@ func TestAssignIdentityToGroups(t *testing.T) { }, }, }, + { + Name: "test where the user is in no groups before the provider sync", + StartingGroups: []string{}, + ExistingProviderGroups: []string{}, + IncomingProviderGroups: []string{"foo"}, + ExpectedGroups: []models.Group{ + { + Name: "foo", + OrganizationMember: models.OrganizationMember{ + OrganizationID: 1000, + }, + }, + }, + }, + { + Name: "test where the user is in no groups after the provider sync", + StartingGroups: []string{}, + ExistingProviderGroups: []string{"foo"}, + IncomingProviderGroups: []string{}, + ExpectedGroups: []models.Group{}, + }, + { + Name: "test where the user is in no provider groups after the provider sync", + StartingGroups: []string{"foo"}, + ExistingProviderGroups: []string{"foo 1"}, + IncomingProviderGroups: []string{}, + ExpectedGroups: []models.Group{ + { + Name: "foo", + OrganizationMember: models.OrganizationMember{ + OrganizationID: 1000, + }, + }, + }, + }, + { + Name: "test when there is a comma in a group name", + StartingGroups: []string{}, + ExistingProviderGroups: []string{}, + IncomingProviderGroups: []string{"my, group"}, + ExpectedGroups: []models.Group{ + { + Name: "my, group", + OrganizationMember: models.OrganizationMember{ + OrganizationID: 1000, + }, + }, + }, + }, } runDBTests(t, func(t *testing.T, db *DB) { @@ -690,6 +739,8 @@ func TestAssignIdentityToGroups(t *testing.T) { err := CreateIdentity(db, identity) assert.NilError(t, err) + provider := InfraProvider(db) + // setup identity's groups for _, gn := range test.StartingGroups { g, err := GetGroup(db, ByName(gn)) @@ -698,19 +749,17 @@ func TestAssignIdentityToGroups(t *testing.T) { err = CreateGroup(db, g) assert.NilError(t, err) } - assert.NilError(t, AddUsersToGroup(db, g.ID, []uid.ID{identity.ID})) + assert.NilError(t, AddUsersToGroup(db, g.ID, g.Name, provider.ID, []uid.ID{identity.ID})) } // setup providerUser record - provider := InfraProvider(db) pu, err := CreateProviderUser(db, provider, identity) assert.NilError(t, err) - pu.Groups = test.ExistingGroups - err = UpdateProviderUser(db, pu) + err = AssignUserToProviderGroups(db, pu, provider, test.ExistingProviderGroups) assert.NilError(t, err) - err = AssignIdentityToGroups(db, identity, provider, test.IncomingGroups) + err = AssignUserToProviderGroups(db, pu, provider, test.IncomingProviderGroups) assert.NilError(t, err) // check the result diff --git a/internal/server/data/migrations.go b/internal/server/data/migrations.go index 060ce05b02..ddcfebacb7 100644 --- a/internal/server/data/migrations.go +++ b/internal/server/data/migrations.go @@ -69,6 +69,7 @@ func migrations() []*migrator.Migration { addUpdateIndexAndGrantNotify(), addUpdateIndexToExistingGrants(), addDeviceFlowAuthRequestTable(), + addProviderGroups(), // next one here } } @@ -838,3 +839,245 @@ func addDeviceFlowAuthRequestTable() *migrator.Migration { }, } } + +func addProviderGroups() *migrator.Migration { + return &migrator.Migration{ + ID: "2022-10-03T13:00", + Migrate: func(tx migrator.DB) error { + // setup the new tables + _, err := tx.Exec( + ` + CREATE TABLE IF NOT EXISTS provider_groups ( + organization_id bigint NOT NULL, + provider_id bigint NOT NULL, + name text NOT NULL, + created_at timestamp with time zone, + updated_at timestamp with time zone + ); + CREATE TABLE IF NOT EXISTS provider_groups_provider_users ( + provider_id bigint NOT NULL, + provider_group_name text NOT NULL, + provider_user_identity_id bigint NOT NULL + ); + `) + if err != nil { + return fmt.Errorf("create provider groups table: %w", err) + } + + if migrator.HasColumn(tx, "provider_users", "groups") { + _, err = tx.Exec("CREATE UNIQUE INDEX idx_provider_group_names ON provider_groups(name, provider_id, organization_id)") + if err != nil { + return fmt.Errorf("add provider groups index: %w", err) + } + + _, err = tx.Exec("CREATE UNIQUE INDEX idx_provider_groups_provider_users ON provider_groups_provider_users(provider_group_name, provider_id, provider_user_identity_id)") + if err != nil { + return fmt.Errorf("add provider users index: %w", err) + } + + _, err = tx.Exec("ALTER TABLE identities_groups DROP CONSTRAINT identities_groups_pkey") + if err != nil { + return fmt.Errorf("remove identity groups pkey: %w", err) + } + + _, err = tx.Exec("ALTER TABLE identities_groups ADD provider_id bigint") + if err != nil { + return fmt.Errorf("add identity group provider id: %w", err) + } + + _, err = tx.Exec("ALTER TABLE identities_groups ADD provider_group_name text") + if err != nil { + return fmt.Errorf("add identities groups provider group name: %w", err) + } + + // find which group memberships came from identity providers and create their provider groups + type groupProviderUser struct { + IdentityID uid.ID + ProviderID uid.ID + Groups models.CommaSeparatedStrings + } + stmt := ` + SELECT identity_id, provider_id, groups + FROM provider_users + WHERE groups != '' + ` + rows, err := tx.Query(stmt) + if err != nil { + return fmt.Errorf("migrate existing provider groups: %w", err) + } + + groupProviderUsers := []groupProviderUser{} + for rows.Next() { + var groupProviderUser groupProviderUser + if err := rows.Scan(&groupProviderUser.IdentityID, &groupProviderUser.ProviderID, &groupProviderUser.Groups); err != nil { + return err + } + + groupProviderUsers = append(groupProviderUsers, groupProviderUser) + } + + if rows.Err() != nil { + return err + } + + if err := rows.Close(); err != nil { + return err + } + + for _, groupProviderUser := range groupProviderUsers { + var orgID uid.ID + if err := tx.QueryRow(`SELECT organization_id FROM providers WHERE id = ?`, groupProviderUser.ProviderID).Scan(&orgID); err != nil { + return err + } + + for _, name := range groupProviderUser.Groups { + // start by, getting the corresponding infra group which already exists + var infraGroupID uid.ID + err := tx.QueryRow(`SELECT id FROM groups WHERE name = ? AND organization_id = ? AND deleted_at IS NULL`, name, orgID).Scan(&infraGroupID) + if err != nil { + return err + } + + // check if this provider user group exists yet + providerGroup := &models.ProviderGroup{} + err = tx.QueryRow(`SELECT * FROM provider_groups WHERE organization_id = ? AND name = ? AND provider_id = ?`, orgID, name, groupProviderUser.ProviderID).Scan(&providerGroup) + if err != nil { + if !strings.Contains(err.Error(), "no rows in result set") { + return err + } + // need to create the provider group + createPGStmt := ` + INSERT INTO provider_groups(organization_id, created_at, updated_at, provider_id, name) + VALUES (?, ?, ?, ?, ?) + ` + _, err = tx.Exec(createPGStmt, orgID, time.Now(), time.Now(), groupProviderUser.ProviderID, name) + if err != nil { + return err + } + } + // link this provider user as a provider group member + _, err = tx.Exec(` + INSERT INTO provider_groups_provider_users ("provider_id", "provider_group_name", "provider_user_identity_id") + VALUES (?, ?, ?) + ON CONFLICT DO NOTHING + `, + groupProviderUser.ProviderID, name, groupProviderUser.IdentityID) + if err != nil { + return fmt.Errorf("add provider group member: %w", err) + } + + // find the relation that was previously created for this user's group relation + var existing uid.ID + err = tx.QueryRow(` + SELECT group_id FROM identities_groups + WHERE identity_id = ? AND group_id = ? AND provider_id IS NULL`, + groupProviderUser.IdentityID, infraGroupID).Scan(&existing) + if err != nil { + if !strings.Contains(err.Error(), "no rows in result set") { + return err + } + // this user is a member of this group in multiple providers, create another record + _, err = tx.Exec(` + INSERT INTO identities_groups (provider_id, identity_id, group_id, provider_group_name) + VALUES (?, ?, ?, ?) + `, + groupProviderUser.ProviderID, groupProviderUser.IdentityID, infraGroupID, name) + if err != nil { + return fmt.Errorf("map provider group to infra group: %w", err) + } + } else { + _, err = tx.Exec(` + UPDATE identities_groups + SET provider_id = ?, identity_id = ?, group_id = ?, provider_group_name = ? + WHERE identity_id = ? AND group_id = ? AND provider_id IS NULL + `, + groupProviderUser.ProviderID, groupProviderUser.IdentityID, infraGroupID, name, groupProviderUser.IdentityID, infraGroupID) + if err != nil { + return fmt.Errorf("map provider group to infra group: %w", err) + } + } + } + } + + // populate the new columns on the existing identities_groups relations that came from Infra + stmt = ` + SELECT id, name, organization_id, identity_id + FROM groups + JOIN identities_groups ON id = group_id + WHERE provider_id IS NULL AND deleted_at IS NULL + ` + rows, err = tx.Query(stmt) + if err != nil { + return fmt.Errorf("migrate existing identity groups: %w", err) + } + + type infraGroupMember struct { + GroupID uid.ID + GroupName string + OrgID uid.ID + IdentityID uid.ID + } + infraGroupMembers := []infraGroupMember{} + for rows.Next() { + var g infraGroupMember + if err := rows.Scan(&g.GroupID, &g.GroupName, &g.OrgID, &g.IdentityID); err != nil { + return err + } + + infraGroupMembers = append(infraGroupMembers, g) + } + + if rows.Err() != nil { + return err + } + + if err := rows.Close(); err != nil { + return err + } + + for _, member := range infraGroupMembers { + stmt = ` + SELECT id + FROM providers + WHERE organization_id = ? AND name = 'infra'; + ` + + var orgInfraProviderID uid.ID + if err := tx.QueryRow(stmt, member.OrgID).Scan(&orgInfraProviderID); err != nil { + return err + } + + stmt = ` + UPDATE identities_groups + SET provider_id = ?, provider_group_name = ? + WHERE group_id = ? AND identity_id = ? AND provider_id IS NULL + ` + _, err = tx.Exec(stmt, orgInfraProviderID, member.GroupName, member.GroupID, member.IdentityID) + if err != nil { + return fmt.Errorf("add provider info to identity groups: %w", err) + } + } + + // we dont care about the groups column anymore + _, err = tx.Exec("ALTER TABLE provider_users DROP COLUMN IF EXISTS groups") + if err != nil { + return fmt.Errorf("remove provider_user groups: %w", err) + } + _, err = tx.Exec("ALTER TABLE identities_groups ALTER COLUMN provider_id SET NOT NULL") + if err != nil { + return fmt.Errorf("add identities groups provider: %w", err) + } + _, err = tx.Exec("ALTER TABLE identities_groups ALTER COLUMN provider_group_name SET NOT NULL") + if err != nil { + return fmt.Errorf("add identities groups provider: %w", err) + } + _, err = tx.Exec("CREATE UNIQUE INDEX idx_identities_groups ON identities_groups(identity_id, group_id, provider_id, provider_group_name)") + if err != nil { + return fmt.Errorf("add identities groups index: %w", err) + } + } + + return err + }, + } +} diff --git a/internal/server/data/migrations_test.go b/internal/server/data/migrations_test.go index b75808513f..e1399c9d2e 100644 --- a/internal/server/data/migrations_test.go +++ b/internal/server/data/migrations_test.go @@ -813,6 +813,143 @@ DELETE FROM settings WHERE id=24567; // schema changes are tested with schema comparison }, }, + { + label: testCaseLine("2022-10-03T13:00"), + setup: func(t *testing.T, tx WriteTxn) { + // setup provider, groups, and identites for multiple orgs, make sure no group membership is moved across those orgs + var originalOrgID uid.ID + err := tx.QueryRow(`SELECT id from organizations where name='Default'`).Scan(&originalOrgID) + assert.NilError(t, err) + + anotherOrgID := uid.New() + + stmt := `INSERT INTO providers(id, name, organization_id) VALUES (?, ?, ?)` + _, err = tx.Exec(stmt, 12345, "okta", originalOrgID) + assert.NilError(t, err) + _, err = tx.Exec(stmt, 54321, "azure", originalOrgID) + assert.NilError(t, err) + _, err = tx.Exec(stmt, 1000, "infra", originalOrgID) + assert.NilError(t, err) + _, err = tx.Exec(stmt, 12346, "okta", anotherOrgID) + assert.NilError(t, err) + _, err = tx.Exec(stmt, 1001, "infra", anotherOrgID) + assert.NilError(t, err) + + stmt = `INSERT INTO groups(id, name, organization_id) VALUES (?, ?, ?)` + _, err = tx.Exec(stmt, 23456, "Developers", originalOrgID) + assert.NilError(t, err) + _, err = tx.Exec(stmt, 65432, "Everyone", originalOrgID) + assert.NilError(t, err) + _, err = tx.Exec(stmt, 23457, "Developers", anotherOrgID) + assert.NilError(t, err) + _, err = tx.Exec(stmt, 75432, "Ops", anotherOrgID) + assert.NilError(t, err) + _, err = tx.Exec(stmt, 77777, "Local", anotherOrgID) + assert.NilError(t, err) + + stmt = `INSERT INTO identities(id, name, organization_id) VALUES (?, 'hello@infrahq.com', ?)` + _, err = tx.Exec(stmt, 34567, originalOrgID) + assert.NilError(t, err) + _, err = tx.Exec(stmt, 34568, anotherOrgID) + assert.NilError(t, err) + + stmt = `INSERT INTO identities_groups(identity_id, group_id) VALUES (?, ?)` + _, err = tx.Exec(stmt, 34567, 23456) + assert.NilError(t, err) + _, err = tx.Exec(stmt, 34567, 65432) + assert.NilError(t, err) + _, err = tx.Exec(stmt, 34568, 23457) + assert.NilError(t, err) + _, err = tx.Exec(stmt, 34568, 75432) + assert.NilError(t, err) + _, err = tx.Exec(stmt, 34568, 77777) + assert.NilError(t, err) + + stmt = `INSERT INTO provider_users(identity_id, provider_id, email, groups) VALUES (?, ?, 'hello@infrahq.com', ?)` + _, err = tx.Exec(stmt, 34567, 12345, "Developers,Everyone") + assert.NilError(t, err) + _, err = tx.Exec(stmt, 34567, 54321, "Developers") + assert.NilError(t, err) + _, err = tx.Exec(stmt, 34568, 12346, "Developers,Ops") + assert.NilError(t, err) + }, + cleanup: func(t *testing.T, tx WriteTxn) { + stmt := ` + DELETE FROM providers; + DELETE FROM groups; + DELETE FROM identities; + DELETE FROM provider_users; + DELETE FROM provider_groups; + DELETE FROM provider_groups_provider_users; + DELETE FROM identities_groups; + ` + _, err := tx.Exec(stmt) + assert.NilError(t, err) + }, + expected: func(t *testing.T, tx WriteTxn) { + // check that the provider groups exist for their respective providers + stmt := `SELECT provider_id, name FROM provider_groups` + pgs := []models.ProviderGroup{} + rows, err := tx.Query(stmt) + assert.NilError(t, err) + for rows.Next() { + var pg models.ProviderGroup + err = rows.Scan(&pg.ProviderID, &pg.Name) + assert.NilError(t, err) + pgs = append(pgs, pg) + } + + expectedProviderGroups := []models.ProviderGroup{ + { + Name: "Developers", + ProviderID: 12345, + }, + { + Name: "Everyone", + ProviderID: 12345, + }, + { + Name: "Developers", + ProviderID: 54321, + }, + { + Name: "Developers", + ProviderID: 12346, + }, + { + Name: "Ops", + ProviderID: 12346, + }, + } + assert.DeepEqual(t, pgs, expectedProviderGroups) + + // check the relation of group to provider group + var count int + err = tx.QueryRow(`SELECT COUNT(*) FROM identities_groups WHERE group_id=23456`).Scan(&count) + assert.NilError(t, err) + assert.Equal(t, count, 2) // 1 from okta, 1 from azure, for default org + err = tx.QueryRow(`SELECT COUNT(*) FROM identities_groups WHERE group_id=65432`).Scan(&count) + assert.NilError(t, err) + assert.Equal(t, count, 1) // 1 from okta, default org + err = tx.QueryRow(`SELECT COUNT(*) FROM identities_groups WHERE group_id=23457`).Scan(&count) + assert.NilError(t, err) + assert.Equal(t, count, 1) // 1 from okta, other org + err = tx.QueryRow(`SELECT COUNT(*) FROM identities_groups WHERE group_id=75432`).Scan(&count) + assert.NilError(t, err) + assert.Equal(t, count, 1) // 1 from okta, other org + err = tx.QueryRow(`SELECT COUNT(*) FROM identities_groups WHERE group_id=77777`).Scan(&count) + assert.NilError(t, err) + assert.Equal(t, count, 1) // 1 from infra, other org + + // check that the users of the provider groups have the correct identity relation + err = tx.QueryRow(`SELECT COUNT(*) FROM provider_groups_provider_users WHERE provider_user_identity_id=34567`).Scan(&count) + assert.NilError(t, err) + assert.Equal(t, count, 3) // 2 from okta, 1 from azure, default org + err = tx.QueryRow(`SELECT COUNT(*) FROM provider_groups_provider_users WHERE provider_user_identity_id=34568`).Scan(&count) + assert.NilError(t, err) + assert.Equal(t, count, 2) // 2 from okta, other org + }, + }, } ids := make(map[string]struct{}, len(testCases)) diff --git a/internal/server/data/providergroup.go b/internal/server/data/providergroup.go new file mode 100644 index 0000000000..cfe59f557a --- /dev/null +++ b/internal/server/data/providergroup.go @@ -0,0 +1,167 @@ +package data + +import ( + "fmt" + "time" + + "github.com/infrahq/infra/internal/server/data/querybuilder" + "github.com/infrahq/infra/internal/server/models" + "github.com/infrahq/infra/uid" +) + +type providerGroupTable models.ProviderGroup + +func (p providerGroupTable) Table() string { + return "provider_groups" +} + +func (p providerGroupTable) Columns() []string { + return []string{"organization_id", "created_at", "updated_at", "provider_id", "name"} +} + +func (p providerGroupTable) Values() []any { + return []any{p.OrganizationID, p.CreatedAt, p.UpdatedAt, p.ProviderID, p.Name} +} + +func (p *providerGroupTable) ScanFields() []any { + return []any{&p.OrganizationID, &p.CreatedAt, &p.UpdatedAt, &p.ProviderID, &p.Name} +} + +func (pg *providerGroupTable) OnInsert() error { + if pg.CreatedAt.IsZero() { + pg.CreatedAt = time.Now() + } + pg.UpdatedAt = pg.CreatedAt + return nil +} + +func (pg *providerGroupTable) OnUpdate() error { + pg.UpdatedAt = time.Now() + return nil +} + +// CreateProviderGroup adds a database entity for tracking group members at a provider +func CreateProviderGroup(db WriteTxn, providerGroup *models.ProviderGroup) error { + switch { + case providerGroup.ProviderID == 0: + return fmt.Errorf("providerID is required") + case providerGroup.Name == "": + return fmt.Errorf("name is required") + } + + providerGroup.OrganizationID = db.OrganizationID() + + return insert(db, (*providerGroupTable)(providerGroup)) +} + +// GetProviderGroup returns the group with the specified name for a provider +func GetProviderGroup(tx ReadTxn, providerID uid.ID, name string) (*models.ProviderGroup, error) { + providerGroup := &providerGroupTable{} + query := querybuilder.New("SELECT") + query.B(columnsForSelect(providerGroup)) + query.B("FROM") + query.B(providerGroup.Table()) + query.B("WHERE organization_id = ?", tx.OrganizationID()) + query.B("AND provider_id = ?", providerID) + query.B("AND name = ?", name) + + err := tx.QueryRow(query.String(), query.Args...).Scan(providerGroup.ScanFields()...) + if err != nil { + return nil, handleReadError(err) + } + pg := (*models.ProviderGroup)(providerGroup) + + return pg, nil +} + +type ListProviderGroupOptions struct { + ByProviderID uid.ID + ByMemberIdentityID uid.ID +} + +// ListProviderGroups returns all provider groups that match the specified criteria +func ListProviderGroups(tx ReadTxn, opts ListProviderGroupOptions) ([]models.ProviderGroup, error) { + table := &providerGroupTable{} + query := querybuilder.New("SELECT") + query.B(columnsForSelect(table)) + query.B("FROM") + query.B(table.Table()) + + if opts.ByMemberIdentityID != 0 { + query.B(` + JOIN provider_groups_provider_users + ON provider_groups.name = provider_groups_provider_users.provider_group_name + AND provider_groups.provider_id = provider_groups_provider_users.provider_id + `) + } + + query.B("WHERE organization_id = ?", tx.OrganizationID()) + if opts.ByProviderID != 0 { + query.B("AND provider_groups.provider_id = ?", opts.ByProviderID) + } + if opts.ByMemberIdentityID != 0 { + query.B(`AND provider_groups_provider_users.provider_user_identity_id = ?`, opts.ByMemberIdentityID) + } + + rows, err := tx.Query(query.String(), query.Args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var result []models.ProviderGroup + for rows.Next() { + var providerGroup models.ProviderGroup + fields := (*providerGroupTable)(&providerGroup).ScanFields() + + if err := rows.Scan(fields...); err != nil { + return nil, err + } + + result = append(result, providerGroup) + } + + return result, rows.Err() +} + +// addMemberToProviderGroups adds a link between a provider user and group that exists in a provider +func addMemberToProviderGroups(tx GormTxn, user *models.ProviderUser, providerGroupNames []string) error { + // the org does not need to be set here since all provider users and groups are org specific + query := querybuilder.New("INSERT INTO provider_groups_provider_users (provider_id, provider_user_identity_id, provider_group_name)") + query.B("VALUES") + + for i, grpName := range providerGroupNames { + query.B("(?, ?, ?)", user.ProviderID, user.IdentityID, grpName) + if i+1 != len(providerGroupNames) { + query.B(",") + } + } + query.B("ON CONFLICT DO NOTHING") + + _, err := tx.Exec(query.String(), query.Args...) + if err != nil { + return fmt.Errorf("add member to provider groups %w", err) + } + + return nil +} + +func removeMemberFromProviderGroups(tx GormTxn, user *models.ProviderUser, providerGroupNames []string) error { + // the org does not need to be set here since all provider users and groups are org specific + _, err := tx.Exec(` + DELETE FROM provider_groups_provider_users + WHERE provider_user_identity_id = ? AND provider_id = ? AND provider_group_name IN ? + `, + user.IdentityID, user.ProviderID, providerGroupNames) + if err != nil { + return err + } + + _, err = tx.Exec(` + DELETE FROM identities_groups + WHERE identity_id = ? AND provider_id = ? AND provider_group_name IN ? + `, + user.IdentityID, user.ProviderID, providerGroupNames) + + return err +} diff --git a/internal/server/data/providergroup_test.go b/internal/server/data/providergroup_test.go new file mode 100644 index 0000000000..93f88bc9d7 --- /dev/null +++ b/internal/server/data/providergroup_test.go @@ -0,0 +1,502 @@ +package data + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "gotest.tools/v3/assert" + + "github.com/infrahq/infra/internal/server/models" + "github.com/infrahq/infra/uid" +) + +var cmpModelsGroupShallow = cmp.Comparer(func(x, y models.Group) bool { + return x.Name == y.Name && x.OrganizationID == y.OrganizationID +}) + +func TestCreateProviderGroup(t *testing.T) { + runDBTests(t, func(t *testing.T, db *DB) { + org := &models.Organization{Name: "something", Domain: "example.com"} + assert.NilError(t, CreateOrganization(db, org)) + + tx := txnForTestCase(t, db, org.ID) + + infraProviderID := InfraProvider(tx).ID + + t.Run("valid", func(t *testing.T) { + group := &models.Group{Name: "default"} + err := CreateGroup(tx, group) + assert.NilError(t, err) + + pg := &models.ProviderGroup{ + ProviderID: infraProviderID, + Name: "default", + } + err = CreateProviderGroup(tx, pg) + assert.NilError(t, err) + + // check that the provider group we fetch from the DB matches what is expected + retrieved, err := GetProviderGroup(tx, pg.ProviderID, pg.Name) + assert.NilError(t, err) + assert.DeepEqual(t, retrieved, pg, cmpTimeWithDBPrecision) + }) + t.Run("provider ID not specified fails", func(t *testing.T) { + pg := &models.ProviderGroup{ + Name: "default", + } + err := CreateProviderGroup(tx, pg) + assert.ErrorContains(t, err, "providerID is required") + }) + t.Run("name not specified fails", func(t *testing.T) { + pg := &models.ProviderGroup{ + ProviderID: 1234, + } + err := CreateProviderGroup(tx, pg) + assert.ErrorContains(t, err, "name is required") + }) + }) +} + +func TestGetProviderGroup(t *testing.T) { + type testCase struct { + name string + setup func(t *testing.T, tx *Transaction) (providerID uid.ID, name string) + checkResult func(t *testing.T, err error, tx *Transaction, result *models.ProviderGroup) + } + + testCases := []testCase{ + { + name: "get existing group", + setup: func(t *testing.T, tx *Transaction) (providerID uid.ID, name string) { + providerID = InfraProvider(tx).ID + name = "group 1" + + testSetup := []TestProviderGroup{ + { + Provider: InfraProvider(tx), + GroupName: name, + }, + } + + setupTestProviderGroups(t, tx, testSetup) + + return providerID, name + }, + checkResult: func(t *testing.T, err error, tx *Transaction, result *models.ProviderGroup) { + assert.NilError(t, err) + assert.Equal(t, result.ProviderID, InfraProvider(tx).ID) + assert.Equal(t, result.Name, "group 1") + }, + }, + { + name: "get non-existent provider ID", + setup: func(t *testing.T, tx *Transaction) (providerID uid.ID, name string) { + providerID = 123 + name = "group 1" + + testSetup := []TestProviderGroup{ + { + Provider: InfraProvider(tx), + GroupName: name, + }, + } + + setupTestProviderGroups(t, tx, testSetup) + + return providerID, name + }, + checkResult: func(t *testing.T, err error, tx *Transaction, result *models.ProviderGroup) { + assert.ErrorContains(t, err, "record not found") + }, + }, + { + name: "get non-existent provider group name", + setup: func(t *testing.T, tx *Transaction) (providerID uid.ID, name string) { + providerID = 123 + name = "does not exist" + + testSetup := []TestProviderGroup{ + { + Provider: InfraProvider(tx), + GroupName: "name", + }, + } + + setupTestProviderGroups(t, tx, testSetup) + + return providerID, name + }, + checkResult: func(t *testing.T, err error, tx *Transaction, result *models.ProviderGroup) { + assert.ErrorContains(t, err, "record not found") + }, + }, + } + + runDBTests(t, func(t *testing.T, db *DB) { + org := &models.Organization{Name: "something", Domain: "example.com"} + assert.NilError(t, CreateOrganization(db, org)) + + tx := txnForTestCase(t, db, org.ID) + + for _, tc := range testCases { + providerID, name := tc.setup(t, tx) + + result, err := GetProviderGroup(tx, providerID, name) + + tc.checkResult(t, err, tx, result) + + // clean up + _, err = tx.Exec("DELETE FROM identities") + assert.NilError(t, err) + _, err = tx.Exec("DELETE FROM groups") + assert.NilError(t, err) + _, err = tx.Exec("DELETE FROM provider_groups_provider_users") + assert.NilError(t, err) + _, err = tx.Exec("DELETE FROM provider_groups") + assert.NilError(t, err) + _, err = tx.Exec("DELETE FROM provider_users") + assert.NilError(t, err) + _, err = tx.Exec("DELETE FROM providers WHERE name != 'infra'") + assert.NilError(t, err) + } + }) +} + +func TestListProviderGroups(t *testing.T) { + type testCase struct { + name string + setup func(t *testing.T, tx *Transaction) (opts ListProviderGroupOptions, expected []models.ProviderGroup) + } + + testCases := []testCase{ + { + name: "list all groups", + setup: func(t *testing.T, tx *Transaction) (opts ListProviderGroupOptions, expected []models.ProviderGroup) { + testSetup := []TestProviderGroup{ + { + Provider: InfraProvider(tx), + GroupName: "group 1", + }, + { + Provider: InfraProvider(tx), + GroupName: "group 2", + }, + } + + opts = ListProviderGroupOptions{} + expected = setupTestProviderGroups(t, tx, testSetup) + + return opts, expected + }, + }, + { + name: "list all groups for provider", + setup: func(t *testing.T, tx *Transaction) (opts ListProviderGroupOptions, expected []models.ProviderGroup) { + testSetup := []TestProviderGroup{ + { + Provider: InfraProvider(tx), + GroupName: "group 1", + }, + { + Provider: InfraProvider(tx), + GroupName: "group 2", + }, + } + + opts = ListProviderGroupOptions{ByProviderID: InfraProvider(tx).ID} + expected = setupTestProviderGroups(t, tx, testSetup) + + return opts, expected + }, + }, + { + name: "list all groups for member ID", + setup: func(t *testing.T, tx *Transaction) (opts ListProviderGroupOptions, expected []models.ProviderGroup) { + user := models.Identity{ + Name: "hello@example.com", + } + err := CreateIdentity(tx, &user) + assert.NilError(t, err) + + testSetup := []TestProviderGroup{ + { + Provider: InfraProvider(tx), + GroupName: "group", + Members: []models.Identity{user}, + }, + } + + opts = ListProviderGroupOptions{ByMemberIdentityID: user.ID} + expected = setupTestProviderGroups(t, tx, testSetup) + + return opts, expected + }, + }, + { + name: "list all groups for member ID and provider ID", + setup: func(t *testing.T, tx *Transaction) (opts ListProviderGroupOptions, expected []models.ProviderGroup) { + user := models.Identity{ + Name: "hello@example.com", + } + err := CreateIdentity(tx, &user) + assert.NilError(t, err) + + okta := &models.Provider{ + Name: "okta", + Kind: models.ProviderKindOkta, + } + err = CreateProvider(tx, okta) + assert.NilError(t, err) + + testSetup := []TestProviderGroup{ + { + Provider: InfraProvider(tx), + GroupName: "group", + Members: []models.Identity{user}, + }, + // this provider group should not be returned in the test result + { + Provider: okta, + GroupName: "group", + Members: []models.Identity{user}, + }, + } + + opts = ListProviderGroupOptions{ByMemberIdentityID: user.ID, ByProviderID: InfraProvider(tx).ID} + + testProviderGroups := setupTestProviderGroups(t, tx, testSetup) + expected = []models.ProviderGroup{testProviderGroups[0]} // the infra provider group + + return opts, expected + }, + }, + } + + runDBTests(t, func(t *testing.T, db *DB) { + org := &models.Organization{Name: "something", Domain: "example.com"} + assert.NilError(t, CreateOrganization(db, org)) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tx := txnForTestCase(t, db, org.ID) + + opts, expected := tc.setup(t, tx) + + result, err := ListProviderGroups(tx, opts) + + assert.NilError(t, err) + assert.DeepEqual(t, result, expected, cmpTimeWithDBPrecision) + }) + } + }) +} + +func TestAddProviderUserToProviderGroup(t *testing.T) { + runDBTests(t, func(t *testing.T, db *DB) { + org := &models.Organization{Name: "something", Domain: "example.com"} + assert.NilError(t, CreateOrganization(db, org)) + tx := txnForTestCase(t, db, org.ID) + + testSetup := []TestProviderGroup{ + { + Provider: InfraProvider(tx), + GroupName: "Everyone", + }, + { + Provider: InfraProvider(tx), + GroupName: "Developers", + }, + } + _ = setupTestProviderGroups(t, tx, testSetup) + + spike := models.Identity{ + Name: "spike@infrahq.com", + } + + createIdentities(t, tx, &spike) + + pu, err := CreateProviderUser(tx, InfraProvider(tx), &spike) + assert.NilError(t, err) + + t.Run("add provider user to provider groups", func(t *testing.T) { + stmt := ` + SELECT provider_id FROM provider_groups_provider_users + WHERE provider_user_identity_id = ? + ` + err = tx.QueryRow(stmt, pu.IdentityID).Scan() + assert.ErrorContains(t, err, "no rows in result set") + + err = addMemberToProviderGroups(tx, pu, []string{"Everyone", "Developers"}) + assert.NilError(t, err) + + stmt = ` + SELECT provider_group_name FROM provider_groups_provider_users + WHERE provider_id = ? AND provider_user_identity_id = ? + ` + rows, err := tx.Query(stmt, pu.ProviderID, pu.IdentityID) + assert.NilError(t, err) + defer rows.Close() + + var result []string + for rows.Next() { + var name string + err = rows.Scan(&name) + assert.NilError(t, err) + + result = append(result, name) + } + + assert.DeepEqual(t, result, []string{"Developers", "Everyone"}) + }) + }) +} + +func TestRemoveProviderUserFromProviderGroup(t *testing.T) { + runDBTests(t, func(t *testing.T, db *DB) { + org := &models.Organization{Name: "something", Domain: "example.com"} + assert.NilError(t, CreateOrganization(db, org)) + tx := txnForTestCase(t, db, org.ID) + + testSetup := []TestProviderGroup{ + { + Provider: InfraProvider(tx), + GroupName: "Everyone", + }, + { + Provider: InfraProvider(tx), + GroupName: "Developers", + }, + } + _ = setupTestProviderGroups(t, tx, testSetup) + + spike := models.Identity{ + Name: "spike@infrahq.com", + } + + createIdentities(t, tx, &spike) + + pu, err := CreateProviderUser(tx, InfraProvider(tx), &spike) + assert.NilError(t, err) + + t.Run("remove provider user from a provider group", func(t *testing.T) { + err = addMemberToProviderGroups(tx, pu, []string{"Everyone", "Developers"}) + assert.NilError(t, err) + + stmt := ` + SELECT provider_group_name FROM provider_groups_provider_users + WHERE provider_id = ? AND provider_user_identity_id = ? + ` + rows, err := tx.Query(stmt, pu.ProviderID, pu.IdentityID) + assert.NilError(t, err) + defer rows.Close() + + var result []string + for rows.Next() { + var name string + err = rows.Scan(&name) + assert.NilError(t, err) + + result = append(result, name) + } + + assert.DeepEqual(t, result, []string{"Developers", "Everyone"}) + + err = removeMemberFromProviderGroups(tx, pu, []string{"Everyone"}) + assert.NilError(t, err) + + rows, err = tx.Query(stmt, pu.ProviderID, pu.IdentityID) + assert.NilError(t, err) + defer rows.Close() + + result = []string{} + for rows.Next() { + var name string + err = rows.Scan(&name) + assert.NilError(t, err) + + result = append(result, name) + } + + assert.DeepEqual(t, result, []string{"Developers"}) + }) + t.Run("remove provider user from all provider groups", func(t *testing.T) { + err = addMemberToProviderGroups(tx, pu, []string{"Everyone", "Developers"}) + assert.NilError(t, err) + + stmt := ` + SELECT provider_group_name FROM provider_groups_provider_users + WHERE provider_id = ? AND provider_user_identity_id = ? + ` + rows, err := tx.Query(stmt, pu.ProviderID, pu.IdentityID) + assert.NilError(t, err) + defer rows.Close() + + var result []string + for rows.Next() { + var name string + err = rows.Scan(&name) + assert.NilError(t, err) + + result = append(result, name) + } + + assert.DeepEqual(t, result, []string{"Developers", "Everyone"}) + + err = removeMemberFromProviderGroups(tx, pu, []string{"Everyone", "Developers"}) + assert.NilError(t, err) + + err = tx.QueryRow(stmt, pu.ProviderID, pu.IdentityID).Scan() + assert.ErrorContains(t, err, "no rows in result set") + }) + }) +} + +type TestProviderGroup struct { + Provider *models.Provider + GroupName string + Members []models.Identity +} + +func setupTestProviderGroups(t *testing.T, tx *Transaction, testProviderGroups []TestProviderGroup) []models.ProviderGroup { + // parent groups need to be create first, in case any provider groups share the same parent + parentGroups := make(map[string]*models.Group) + for _, testPg := range testProviderGroups { + if parentGroups[testPg.GroupName] == nil { + group := &models.Group{Name: testPg.GroupName} + err := CreateGroup(tx, group) + assert.NilError(t, err) + + parentGroups[group.Name] = group + } + } + + created := []models.ProviderGroup{} + for _, testPg := range testProviderGroups { + pg := &models.ProviderGroup{ + ProviderID: testPg.Provider.ID, + Name: testPg.GroupName, + } + err := CreateProviderGroup(tx, pg) + assert.NilError(t, err) + + memberIDs := []uid.ID{} + for i := range testPg.Members { + member, err := CreateProviderUser(tx, testPg.Provider, &testPg.Members[i]) + assert.NilError(t, err) + + err = addMemberToProviderGroups(tx, member, []string{testPg.GroupName}) + assert.NilError(t, err) + + memberIDs = append(memberIDs, member.IdentityID) + } + + if len(memberIDs) > 0 { + err = AddUsersToGroup(tx, parentGroups[testPg.GroupName].ID, pg.Name, pg.ProviderID, memberIDs) + assert.NilError(t, err) + } + + created = append(created, *pg) + } + + return created +} diff --git a/internal/server/data/provideruser.go b/internal/server/data/provideruser.go index 8df9992114..c1e6f66875 100644 --- a/internal/server/data/provideruser.go +++ b/internal/server/data/provideruser.go @@ -21,15 +21,15 @@ func (p providerUserTable) Table() string { } func (p providerUserTable) Columns() []string { - return []string{"identity_id", "provider_id", "email", "groups", "last_update", "redirect_url", "access_token", "refresh_token", "expires_at", "given_name", "family_name", "active"} + return []string{"identity_id", "provider_id", "email", "last_update", "redirect_url", "access_token", "refresh_token", "expires_at", "given_name", "family_name", "active"} } func (p providerUserTable) Values() []any { - return []any{p.IdentityID, p.ProviderID, p.Email, p.Groups, p.LastUpdate, p.RedirectURL, p.AccessToken, p.RefreshToken, p.ExpiresAt, p.GivenName, p.FamilyName, p.Active} + return []any{p.IdentityID, p.ProviderID, p.Email, p.LastUpdate, p.RedirectURL, p.AccessToken, p.RefreshToken, p.ExpiresAt, p.GivenName, p.FamilyName, p.Active} } func (p *providerUserTable) ScanFields() []any { - return []any{&p.IdentityID, &p.ProviderID, &p.Email, &p.Groups, &p.LastUpdate, &p.RedirectURL, &p.AccessToken, &p.RefreshToken, &p.ExpiresAt, &p.GivenName, &p.FamilyName, &p.Active} + return []any{&p.IdentityID, &p.ProviderID, &p.Email, &p.LastUpdate, &p.RedirectURL, &p.AccessToken, &p.RefreshToken, &p.ExpiresAt, &p.GivenName, &p.FamilyName, &p.Active} } func (p *providerUserTable) OnInsert() error { @@ -199,7 +199,7 @@ func GetProviderUser(tx ReadTxn, providerID, identityID uid.ID) (*models.Provide return (*models.ProviderUser)(pu), nil } -func SyncProviderUser(ctx context.Context, tx WriteTxn, user *models.Identity, provider *models.Provider, oidcClient providers.OIDCClient) error { +func SyncProviderUser(ctx context.Context, tx GormTxn, user *models.Identity, provider *models.Provider, oidcClient providers.OIDCClient) error { providerUser, err := GetProviderUser(tx, provider.ID, user.ID) if err != nil { return err @@ -216,6 +216,7 @@ func SyncProviderUser(ctx context.Context, tx WriteTxn, user *models.Identity, p providerUser.AccessToken = models.EncryptedAtRest(accessToken) providerUser.ExpiresAt = *expiry + providerUser.LastUpdate = time.Now().UTC() err = UpdateProviderUser(tx, providerUser) if err != nil { @@ -228,7 +229,7 @@ func SyncProviderUser(ctx context.Context, tx WriteTxn, user *models.Identity, p return fmt.Errorf("oidc user sync failed: %w", err) } - if err := AssignIdentityToGroups(tx, user, provider, info.Groups); err != nil { + if err := AssignUserToProviderGroups(tx, providerUser, provider, info.Groups); err != nil { return fmt.Errorf("assign identity to groups: %w", err) } diff --git a/internal/server/data/provideruser_test.go b/internal/server/data/provideruser_test.go index c2333501d9..afc5ec5503 100644 --- a/internal/server/data/provideruser_test.go +++ b/internal/server/data/provideruser_test.go @@ -99,7 +99,6 @@ func TestSyncProviderUser(t *testing.T) { expected := models.ProviderUser{ Email: "hello@example.com", - Groups: models.CommaSeparatedStrings{"Everyone", "Developers"}, ProviderID: provider.ID, IdentityID: user.ID, RedirectURL: "http://example.com", @@ -150,7 +149,7 @@ func TestSyncProviderUser(t *testing.T) { }, oidcClient: &mockOIDCImplementation{ UserEmailResp: "sync@example.com", - UserGroupsResp: []string{"Everyone", "Developers"}, + UserGroupsResp: []string{"Sync Group 1", "Sync Group 2"}, }, verifyFunc: func(t *testing.T, err error, user *models.Identity) { assert.NilError(t, err) @@ -158,9 +157,8 @@ func TestSyncProviderUser(t *testing.T) { pu, err := GetProviderUser(db, provider.ID, user.ID) assert.NilError(t, err) - expected := models.ProviderUser{ + expectedProviderUser := models.ProviderUser{ Email: "sync@example.com", - Groups: models.CommaSeparatedStrings{"Everyone", "Developers"}, ProviderID: provider.ID, IdentityID: user.ID, RedirectURL: "http://example.com", @@ -183,29 +181,34 @@ func TestSyncProviderUser(t *testing.T) { cmpEncryptedAtRestNotZero), } - assert.DeepEqual(t, *pu, expected, cmpProviderUser) + assert.DeepEqual(t, *pu, expectedProviderUser, cmpProviderUser) - assert.Assert(t, len(pu.Groups) == 2) + pgs, err := ListProviderGroups(db, ListProviderGroupOptions{ByMemberIdentityID: pu.IdentityID}) + assert.NilError(t, err) + + assert.Assert(t, len(pgs) == 2) puGroups := make(map[string]bool) - for _, g := range pu.Groups { - puGroups[g] = true + for _, g := range pgs { + puGroups[g.Name] = true } - assert.Assert(t, puGroups["Everyone"]) - assert.Assert(t, puGroups["Developers"]) + assert.Assert(t, puGroups["Sync Group 1"]) + assert.Assert(t, puGroups["Sync Group 2"]) // check that the direct user-to-group relation was updated storedGroups, err := ListGroups(db, nil, ByGroupMember(pu.IdentityID)) assert.NilError(t, err) + assert.Assert(t, len(storedGroups) == 2) + userGroups := make(map[string]bool) for _, g := range storedGroups { userGroups[g.Name] = true } - assert.Assert(t, userGroups["Everyone"]) - assert.Assert(t, userGroups["Developers"]) + assert.Assert(t, userGroups["Sync Group 1"]) + assert.Assert(t, userGroups["Sync Group 2"]) }, }, } @@ -369,7 +372,5 @@ func createTestProviderUser(t *testing.T, tx *Transaction, provider *models.Prov pu, err := CreateProviderUser(tx, provider, user) assert.NilError(t, err) - pu.Groups = models.CommaSeparatedStrings{} - return *pu } diff --git a/internal/server/data/schema.sql b/internal/server/data/schema.sql index 7a3156f353..514d32d228 100644 --- a/internal/server/data/schema.sql +++ b/internal/server/data/schema.sql @@ -197,7 +197,9 @@ CREATE TABLE identities ( CREATE TABLE identities_groups ( identity_id bigint NOT NULL, - group_id bigint NOT NULL + group_id bigint NOT NULL, + provider_id bigint NOT NULL, + provider_group_name text NOT NULL ); CREATE TABLE organizations ( @@ -218,11 +220,24 @@ CREATE TABLE password_reset_tokens ( organization_id bigint ); +CREATE TABLE provider_groups ( + organization_id bigint NOT NULL, + provider_id bigint NOT NULL, + name text NOT NULL, + created_at timestamp with time zone, + updated_at timestamp with time zone +); + +CREATE TABLE provider_groups_provider_users ( + provider_id bigint NOT NULL, + provider_group_name text NOT NULL, + provider_user_identity_id bigint NOT NULL +); + CREATE TABLE provider_users ( identity_id bigint NOT NULL, provider_id bigint NOT NULL, email text, - groups text, last_update timestamp with time zone, redirect_url text, access_token text, @@ -292,9 +307,6 @@ ALTER TABLE ONLY grants ALTER TABLE ONLY groups ADD CONSTRAINT groups_pkey PRIMARY KEY (id); -ALTER TABLE ONLY identities_groups - ADD CONSTRAINT identities_groups_pkey PRIMARY KEY (identity_id, group_id); - ALTER TABLE ONLY identities ADD CONSTRAINT identities_pkey PRIMARY KEY (id); @@ -335,6 +347,8 @@ CREATE INDEX idx_grants_update_index ON grants USING btree (organization_id, upd CREATE UNIQUE INDEX idx_groups_name ON groups USING btree (organization_id, name) WHERE (deleted_at IS NULL); +CREATE UNIQUE INDEX idx_identities_groups ON identities_groups USING btree (identity_id, group_id, provider_id, provider_group_name); + CREATE UNIQUE INDEX idx_identities_name ON identities USING btree (organization_id, name) WHERE (deleted_at IS NULL); CREATE UNIQUE INDEX idx_identities_verified ON identities USING btree (organization_id, verification_token) WHERE (deleted_at IS NULL); @@ -343,6 +357,10 @@ CREATE UNIQUE INDEX idx_organizations_domain ON organizations USING btree (domai CREATE UNIQUE INDEX idx_password_reset_tokens_token ON password_reset_tokens USING btree (token); +CREATE UNIQUE INDEX idx_provider_group_names ON provider_groups USING btree (name, provider_id, organization_id); + +CREATE UNIQUE INDEX idx_provider_groups_provider_users ON provider_groups_provider_users USING btree (provider_group_name, provider_id, provider_user_identity_id); + CREATE UNIQUE INDEX idx_providers_name ON providers USING btree (organization_id, name) WHERE (deleted_at IS NULL); CREATE UNIQUE INDEX settings_org_id ON settings USING btree (organization_id) WHERE (deleted_at IS NULL); diff --git a/internal/server/grants_test.go b/internal/server/grants_test.go index 0c1a72cefa..d40e47e27a 100644 --- a/internal/server/grants_test.go +++ b/internal/server/grants_test.go @@ -74,7 +74,7 @@ func TestAPI_ListGrants(t *testing.T) { err := data.CreateGroup(srv.DB(), group) assert.NilError(t, err) - err = data.AddUsersToGroup(srv.DB(), group.ID, users) + err = data.AddUsersToGroup(srv.DB(), group.ID, group.Name, data.InfraProvider(srv.DB()).ID, users) assert.NilError(t, err) return group.ID @@ -544,7 +544,7 @@ func TestAPI_ListGrants_InheritedGrants(t *testing.T) { err := data.CreateGroup(srv.DB(), group) assert.NilError(t, err) - err = data.AddUsersToGroup(srv.DB(), group.ID, users) + err = data.AddUsersToGroup(srv.DB(), group.ID, group.Name, data.InfraProvider(srv.DB()).ID, users) assert.NilError(t, err) return group.ID diff --git a/internal/server/groups_test.go b/internal/server/groups_test.go index 5e91a44694..b15e53e496 100644 --- a/internal/server/groups_test.go +++ b/internal/server/groups_test.go @@ -23,7 +23,7 @@ func createIdentities(t *testing.T, db data.GormTxn, identities ...*models.Ident err := data.CreateIdentity(db, identities[i]) assert.NilError(t, err, identities[i].Name) for _, g := range identities[i].Groups { - err := data.AddUsersToGroup(db, g.ID, []uid.ID{identities[i].ID}) + err := data.AddUsersToGroup(db, g.ID, g.Name, data.InfraProvider(db).ID, []uid.ID{identities[i].ID}) assert.NilError(t, err) } assert.NilError(t, err, identities[i].Name) @@ -51,18 +51,20 @@ func TestAPI_ListGroups(t *testing.T) { createGroups(t, srv.DB(), &humans, &second, &others) var ( - idInGroup = models.Identity{ - Name: "inagroup@example.com", - Groups: []models.Group{humans, second}, - } - idOther = models.Identity{ - Name: "other@example.com", - Groups: []models.Group{others}, - } + idInGroup = models.Identity{Name: "inagroup@example.com"} + idOther = models.Identity{Name: "other@example.com"} ) createIdentities(t, srv.DB(), &idInGroup, &idOther) + provider := data.InfraProvider(srv.db).ID + err := data.AddUsersToGroup(srv.db, humans.ID, humans.Name, provider, []uid.ID{idInGroup.ID}) + assert.NilError(t, err) + err = data.AddUsersToGroup(srv.db, second.ID, second.Name, provider, []uid.ID{idInGroup.ID}) + assert.NilError(t, err) + err = data.AddUsersToGroup(srv.db, others.ID, others.Name, provider, []uid.ID{idOther.ID}) + assert.NilError(t, err) + token := &models.AccessKey{ IssuedFor: idInGroup.ID, ProviderID: data.InfraProvider(srv.DB()).ID, @@ -430,7 +432,7 @@ func TestAPI_UpdateUsersInGroup(t *testing.T) { urlPath: fmt.Sprintf("/api/groups/%s/users", humans.ID.String()), setup: func(t *testing.T, req *http.Request) { req.Header.Set("Authorization", "Bearer "+adminAccessKey(srv)) - err := data.AddUsersToGroup(srv.DB(), humans.ID, []uid.ID{first.ID, second.ID}) + err := data.AddUsersToGroup(srv.DB(), humans.ID, humans.Name, data.InfraProvider(srv.DB()).ID, []uid.ID{first.ID, second.ID}) assert.NilError(t, err) }, expected: func(t *testing.T, resp *httptest.ResponseRecorder) { diff --git a/internal/server/models/providergroup.go b/internal/server/models/providergroup.go new file mode 100644 index 0000000000..35ce72b9eb --- /dev/null +++ b/internal/server/models/providergroup.go @@ -0,0 +1,22 @@ +package models + +import ( + "time" + + "github.com/infrahq/infra/uid" +) + +// ProviderGroup is a local copy of the of a group from an identity provider +// See this diagram for more details about how this model relates to a group +// https://github.com/infrahq/infra/blob/main/docs/dev/identity-provider-tracking.md +type ProviderGroup struct { + OrganizationMember + CreatedAt time.Time + UpdatedAt time.Time + + ProviderID uid.ID + Name string + + // for loading, not for saving + Members []ProviderUser +} diff --git a/internal/server/models/provideruser.go b/internal/server/models/provideruser.go index 627dcf6b18..aebb0fe22e 100644 --- a/internal/server/models/provideruser.go +++ b/internal/server/models/provideruser.go @@ -15,7 +15,6 @@ type ProviderUser struct { Email string GivenName string FamilyName string - Groups CommaSeparatedStrings LastUpdate time.Time RedirectURL string // needs to match the redirect URL specified when the token was issued for refreshing diff --git a/internal/server/users_test.go b/internal/server/users_test.go index 33db6dbf8a..3ecbe1c277 100644 --- a/internal/server/users_test.go +++ b/internal/server/users_test.go @@ -191,6 +191,8 @@ func TestAPI_ListUsers(t *testing.T) { Groups: []models.Group{humans}, } createIdentities(t, srv.DB(), &anotherID) + err := data.AddUsersToGroup(srv.DB(), humans.ID, humans.Name, data.InfraProvider(srv.DB()).ID, []uid.ID{anotherID.ID}) + assert.NilError(t, err) createID := func(t *testing.T, name string) uid.ID { t.Helper()