Skip to content

Commit

Permalink
maintain: track provider groups separate from users
Browse files Browse the repository at this point in the history
- join group membership from provider group relation
- unlink provider groups from groups on deletion
- migrate provider user groups to provider groups
  • Loading branch information
BruceMacD committed Oct 20, 2022
1 parent 1460d01 commit b89e82d
Show file tree
Hide file tree
Showing 23 changed files with 1,389 additions and 200 deletions.
29 changes: 29 additions & 0 deletions docs/dev/identity-provider-tracking.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Identity Provider Tracking
Infra tracks users and groups that exist in external identity providers and maps them access. The following database diagram shows how the relation of groups and users in identity providers are mapped to users and groups within Infra.

Our identity system has two different type of relations, one of which exists in Infra (users and groups controlled by Infra), and the other relation which exists externally (information about users and groups which exist in an identity provider).

To get all the groups a user belongs to as a combination of direct group membership and membership in a mapped identity provider group use the `resolved_identity_groups` view.

## Users and Groups Managed by Infra
The following tables represent users whose identity is managed by Infra and groups which are controlled by Infra.
- **Identities**: Known information about a user. This may be a an email and password hash that can be used to login to Infra directly, or information about a user that has logged via an identity provider.
- **Groups**: Groups which exist in Infra that membership to grants access to some resource.
- **Identities_Groups**: Identities that have been added as a member of a group within Infra, this membership is added through either membership in an external identity provider group or through direct assignment (represented in the database as being part of an Infra provider group).

## Users and Groups Managed by an Identity Provider
- **Provider_User**: A user that exists in an identity provider.
- **Provider_Group**: A group that exists in an identity provider.
- **Provider_Groups_Provider_Users**: This relation shows which users of an identity provders are members of which groups in the same identity provider.

```mermaid
erDiagram
PROVIDER_GROUPS_MAPPINGS }o--|| PROVIDER : ""
GROUP ||--o{ IDENTITIES_GROUPS : "direct membership to infra group"
IDENTITIES_GROUPS }o--|| IDENTITY : "direct membership to infra group"
IDENTITY ||--o{ PROVIDER_USER : "infra identity composed of identity provider users"
PROVIDER_GROUP }o--|| PROVIDER : "sync groups to infra"
PROVIDER_GROUP ||--o{ PROVIDER_GROUPS_PROVIDER_USERS : "sync members to infra"
PROVIDER_GROUPS_PROVIDER_USERS }o--|| PROVIDER_USER : "sync members to infra"
PROVIDER_USER }o--|| PROVIDER : "sync identities to infra"
```
4 changes: 2 additions & 2 deletions internal/access/access_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,13 @@ func TestRequireInfraRole_GrantsFromGroupMembership(t *testing.T) {
err := data.CreateIdentity(db, tom)
assert.NilError(t, err)

_, err = data.CreateProviderUser(db, provider, tom)
providedTom, err := data.CreateProviderUser(db, provider, tom)
assert.NilError(t, err)

err = data.CreateGroup(db, tomsGroup)
assert.NilError(t, err)

err = data.AssignIdentityToGroups(db, tom, provider, []string{tomsGroup.Name})
err = data.AssignUserToProviderGroups(db, providedTom, provider, []string{tomsGroup.Name})
assert.NilError(t, err)

c, _ := gin.CreateTestContext(httptest.NewRecorder())
Expand Down
6 changes: 3 additions & 3 deletions internal/access/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func UpdateUsersInGroup(c *gin.Context, groupID uid.ID, uidsToAdd []uid.ID, uids
return err
}

_, err = data.GetGroup(db, data.ByID(groupID))
group, err := data.GetGroup(db, data.ByID(groupID))
if err != nil {
return err
}
Expand All @@ -130,13 +130,13 @@ func UpdateUsersInGroup(c *gin.Context, groupID uid.ID, uidsToAdd []uid.ID, uids
}

if len(addIDList) > 0 {
if err := data.AddUsersToGroup(db, groupID, addIDList); err != nil {
if err := data.AddUsersToGroup(db, group.ID, group.Name, data.InfraProvider(db).ID, addIDList); err != nil {
return err
}
}

if len(rmIDList) > 0 {
if err := data.RemoveUsersFromGroup(db, groupID, rmIDList); err != nil {
if err := data.RemoveUsersFromGroup(db, group.ID, rmIDList); err != nil {
return err
}
}
Expand Down
6 changes: 1 addition & 5 deletions internal/access/identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func UpdateIdentityInfoFromProvider(c RequestContext, oidc providers.OIDCClient)
return fmt.Errorf("user info provider: %w", err)
}

// get current identity provider groups and account status
// update current identity provider groups and account status
err = data.SyncProviderUser(ctx, db, identity, provider, oidc)
if err != nil {
if errors.Is(err, internal.ErrBadGateway) {
Expand All @@ -141,10 +141,6 @@ func UpdateIdentityInfoFromProvider(c RequestContext, oidc providers.OIDCClient)
logging.Errorf("failed to revoke invalid user session: %s", nestedErr)
}

if nestedErr := data.DeleteProviderUsers(db, data.DeleteProviderUsersOptions{ByIdentityID: identity.ID, ByProviderID: provider.ID}); nestedErr != nil {
logging.Errorf("failed to delete provider user: %s", nestedErr)
}

return fmt.Errorf("sync user: %w", err)
}

Expand Down
2 changes: 1 addition & 1 deletion internal/access/identity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func TestDeleteIdentityCleansUpResources(t *testing.T) {
group := &models.Group{Name: "Group"}
err = data.CreateGroup(db, group)
assert.NilError(t, err)
err = data.AddUsersToGroup(db, group.ID, []uid.ID{identity.ID})
err = data.AddUsersToGroup(db, group.ID, group.Name, data.InfraProvider(db).ID, []uid.ID{identity.ID})
assert.NilError(t, err)

// delete the identity, and make sure all their resources are gone
Expand Down
42 changes: 16 additions & 26 deletions internal/server/authn/oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,14 @@ func TestOIDCAuthenticate(t *testing.T) {
// user should be created
assert.Equal(t, authnIdentity.Identity.Name, "[email protected]")

syncGroups, err := data.ListGroups(db, nil, data.ByGroupMember(authnIdentity.Identity.ID))
assert.NilError(t, err)

groups := make(map[string]bool)
for _, g := range authnIdentity.Identity.Groups {
for _, g := range syncGroups {
groups[g.Name] = true
}
assert.Assert(t, len(authnIdentity.Identity.Groups) == 2)
assert.Equal(t, len(syncGroups), 2)
assert.Equal(t, groups["Everyone"], true)
assert.Equal(t, groups["developers"], true)

Expand Down Expand Up @@ -198,31 +201,21 @@ func TestExchangeAuthCodeForProviderTokens(t *testing.T) {
err := data.CreateIdentity(db, user)
assert.NilError(t, err)

for _, name := range []string{"Foo", "existing3"} {
group := &models.Group{Name: name}
err = data.CreateGroup(db, group)
assert.NilError(t, err)
err = data.AddUsersToGroup(db, group.ID, []uid.ID{user.ID})
assert.NilError(t, err)
}

g, err := data.GetGroup(db, data.ByName("Foo"))
group := &models.Group{Name: "Foo"}
err = data.CreateGroup(db, group)
assert.NilError(t, err)
assert.Assert(t, g != nil)

user, err = data.GetIdentity(db, data.GetIdentityOptions{ByID: user.ID, LoadGroups: true})
err = data.AddUsersToGroup(db, group.ID, group.Name, data.InfraProvider(db).ID, []uid.ID{user.ID})
assert.NilError(t, err)
assert.Assert(t, user != nil)
assert.Equal(t, len(user.Groups), 2)

p, err := data.GetProvider(db, data.ByName("mockoidc"))
assert.NilError(t, err)

pu, err := data.CreateProviderUser(db, p, user)
assert.NilError(t, err)

pu.Groups = []string{"existing3"}
assert.NilError(t, data.UpdateProviderUser(db, pu))
err = data.AssignUserToProviderGroups(db, pu, p, []string{"existing3"})
assert.NilError(t, err)

return &mockOIDCImplementation{
UserEmailResp: "[email protected]",
Expand All @@ -234,7 +227,7 @@ func TestExchangeAuthCodeForProviderTokens(t *testing.T) {
assert.Equal(t, "mockoidc", a.Provider.Name)
assert.Assert(t, a.SessionExpiry.Equal(sessionExpiry))

assert.Assert(t, len(a.Identity.Groups) == 3)
assert.Equal(t, len(a.Identity.Groups), 3)

var groupNames []string
for _, g := range a.Identity.Groups {
Expand All @@ -261,15 +254,12 @@ func TestExchangeAuthCodeForProviderTokens(t *testing.T) {

a, err := loginMethod.Authenticate(context.Background(), db, sessionExpiry)
assert.NilError(t, err)
tc.expected(t, a)

if err == nil {
// make sure the associations are still set when you reload the object.
u, err := data.GetIdentity(db, data.GetIdentityOptions{ByID: a.Identity.ID, LoadGroups: true})
assert.NilError(t, err)
a.Identity = u
tc.expected(t, a)
}
syncGroups, err := data.ListGroups(db, nil, data.ByGroupMember(a.Identity.ID))
assert.NilError(t, err)
a.Identity.Groups = syncGroups

tc.expected(t, a)
})
}
}
6 changes: 3 additions & 3 deletions internal/server/data/grant_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -387,9 +387,9 @@ func TestListGrants(t *testing.T) {
userID, err := uid.Parse([]byte("userchar"))
assert.NilError(t, err)

assert.NilError(t, AddUsersToGroup(tx, uid.ID(111), []uid.ID{userID}))
assert.NilError(t, AddUsersToGroup(tx, uid.ID(112), []uid.ID{userID}))
assert.NilError(t, AddUsersToGroup(tx, uid.ID(113), []uid.ID{uid.ID(777)}))
assert.NilError(t, AddUsersToGroup(tx, uid.ID(111), "name-111", InfraProvider(db).ID, []uid.ID{userID}))
assert.NilError(t, AddUsersToGroup(tx, uid.ID(112), "name-111", InfraProvider(db).ID, []uid.ID{userID}))
assert.NilError(t, AddUsersToGroup(tx, uid.ID(113), "name-111", InfraProvider(db).ID, []uid.ID{uid.ID(777)}))

gGrant1 := &models.Grant{
Subject: uid.NewGroupPolymorphicID(111),
Expand Down
37 changes: 28 additions & 9 deletions internal/server/data/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ func GetGroup(db GormTxn, selectors ...SelectorFunc) (*models.Group, error) {

func ListGroups(db GormTxn, p *Pagination, selectors ...SelectorFunc) ([]models.Group, error) {
groups, err := list[models.Group](db, p, selectors...)

if err != nil {
return nil, err
}
Expand All @@ -63,7 +62,6 @@ func ListGroups(db GormTxn, p *Pagination, selectors ...SelectorFunc) ([]models.
}

return groups, nil

}

func ByGroupMember(id uid.ID) SelectorFunc {
Expand All @@ -75,7 +73,10 @@ func ByGroupMember(id uid.ID) SelectorFunc {
}

func groupIDsForUser(tx ReadTxn, userID uid.ID) ([]uid.ID, error) {
stmt := `SELECT DISTINCT group_id FROM identities_groups WHERE identity_id = ?`
stmt := `
SELECT DISTINCT group_id FROM identities_groups
WHERE identity_id = ?
`
rows, err := tx.Query(stmt, userID)
if err != nil {
return nil, err
Expand All @@ -101,7 +102,7 @@ func DeleteGroup(tx WriteTxn, id uid.ID) error {

_, err = tx.Exec(`DELETE from identities_groups WHERE group_id = ?`, id)
if err != nil {
return fmt.Errorf("remove useres from group: %w", err)
return fmt.Errorf("remove users from group: %w", err)
}

stmt := `
Expand All @@ -114,11 +115,11 @@ func DeleteGroup(tx WriteTxn, id uid.ID) error {
return handleError(err)
}

func AddUsersToGroup(tx WriteTxn, groupID uid.ID, idsToAdd []uid.ID) error {
query := querybuilder.New("INSERT INTO identities_groups(group_id, identity_id)")
func AddUsersToGroup(tx WriteTxn, groupID uid.ID, providerGroupName string, providerID uid.ID, idsToAdd []uid.ID) error {
query := querybuilder.New("INSERT INTO identities_groups(group_id, identity_id, provider_id, provider_group_name)")
query.B("VALUES")
for i, id := range idsToAdd {
query.B("(?, ?)", groupID, id)
query.B("(?, ?, ?, ?)", groupID, id, providerID, providerGroupName)
if i+1 != len(idsToAdd) {
query.B(",")
}
Expand All @@ -129,6 +130,21 @@ func AddUsersToGroup(tx WriteTxn, groupID uid.ID, idsToAdd []uid.ID) error {
return handleError(err)
}

func AddUserToGroups(tx WriteTxn, providerID uid.ID, identityID uid.ID, groups []models.Group) error {
query := querybuilder.New("INSERT INTO identities_groups(provider_id, identity_id, group_id, provider_group_name)")
query.B("VALUES")
for i, group := range groups {
query.B("(?, ?, ?, ?)", providerID, identityID, group.ID, group.Name)
if i+1 != len(groups) {
query.B(",")
}
}
query.B("ON CONFLICT DO NOTHING")

_, err := tx.Exec(query.String(), query.Args...)
return handleError(err)
}

// RemoveUsersFromGroup removes any user ID listed in idsToRemove from the group
// with ID groupID.
// Note that DeleteGroup also removes users from the group.
Expand All @@ -138,11 +154,14 @@ func RemoveUsersFromGroup(tx WriteTxn, groupID uid.ID, idsToRemove []uid.ID) err
return handleError(err)
}

// TODO: do this with a join in ListGroups and GetGroup
func CountUsersInGroup(tx GormTxn, groupID uid.ID) (int64, error) {
db := tx.GormDB()
var count int64
err := db.Table("identities_groups").Where("group_id = ?", groupID).Count(&count).Error
err := db.
Table("identities_groups").
Where("group_id = ?", groupID).
Distinct("identity_id").
Count(&count).Error
if err != nil {
return 0, err
}
Expand Down
54 changes: 45 additions & 9 deletions internal/server/data/group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ func TestListGroups(t *testing.T) {
Groups: []models.Group{everyone, product},
}
createIdentities(t, db, &firstUser, &secondUser)
err := AddUsersToGroup(db, everyone.ID, everyone.Name, InfraProvider(db).ID, []uid.ID{firstUser.ID, secondUser.ID})
assert.NilError(t, err)
err = AddUsersToGroup(db, engineers.ID, engineers.Name, InfraProvider(db).ID, []uid.ID{firstUser.ID})
assert.NilError(t, err)
err = AddUsersToGroup(db, product.ID, product.Name, InfraProvider(db).ID, []uid.ID{secondUser.ID})
assert.NilError(t, err)

t.Run("all", func(t *testing.T) {
actual, err := ListGroups(db, nil)
Expand Down Expand Up @@ -226,27 +232,23 @@ func TestAddUsersToGroup(t *testing.T) {
createGroups(t, db, &everyone, &other)

var (
bond = models.Identity{
Name: "[email protected]",
Groups: []models.Group{everyone},
}
bond = models.Identity{Name: "[email protected]"}
bourne = models.Identity{Name: "[email protected]"}
bauer = models.Identity{Name: "[email protected]"}
forth = models.Identity{
Name: "[email protected]",
Groups: []models.Group{everyone},
}
forth = models.Identity{Name: "[email protected]"}
)

createIdentities(t, db, &bond, &bourne, &bauer, &forth)
err := AddUsersToGroup(db, everyone.ID, everyone.Name, InfraProvider(db).ID, []uid.ID{bond.ID, forth.ID})
assert.NilError(t, err)

t.Run("add identities to group", func(t *testing.T) {
actual, err := ListIdentities(db, ListIdentityOptions{ByGroupID: everyone.ID})
assert.NilError(t, err)
expected := []models.Identity{forth, bond}
assert.DeepEqual(t, actual, expected, cmpModelsIdentityShallow)

err = AddUsersToGroup(db, everyone.ID, []uid.ID{bourne.ID, bauer.ID, forth.ID})
err = AddUsersToGroup(db, everyone.ID, everyone.Name, InfraProvider(db).ID, []uid.ID{bourne.ID, bauer.ID, forth.ID})
assert.NilError(t, err)

actual, err = ListIdentities(db, ListIdentityOptions{ByGroupID: everyone.ID})
Expand Down Expand Up @@ -286,6 +288,10 @@ func TestRemoveUsersFromGroup(t *testing.T) {
Groups: []models.Group{everyone},
}
createIdentities(t, tx, &bond, &bourne, &bauer, &forth)
err := AddUsersToGroup(tx, everyone.ID, everyone.Name, InfraProvider(db).ID, []uid.ID{bond.ID, bourne.ID, bauer.ID, forth.ID})
assert.NilError(t, err)
err = AddUsersToGroup(tx, other.ID, other.Name, InfraProvider(db).ID, []uid.ID{bond.ID, bourne.ID, bauer.ID})
assert.NilError(t, err)

users, err := ListIdentities(tx, ListIdentityOptions{ByGroupID: everyone.ID})
assert.NilError(t, err)
Expand All @@ -309,3 +315,33 @@ func TestRemoveUsersFromGroup(t *testing.T) {
assert.DeepEqual(t, actual, expected, cmpModelsIdentityShallow)
})
}

func TestCountUsersInGroup(t *testing.T) {
runDBTests(t, func(t *testing.T, db *DB) {
everyone := models.Group{Name: "Everyone"}
createGroups(t, db, &everyone)

mockta := &models.Provider{Name: "mokta", Kind: models.ProviderKindOkta}
err := CreateProvider(db, mockta)
assert.NilError(t, err)

var (
bond = models.Identity{Name: "[email protected]"}
bourne = models.Identity{Name: "[email protected]"}
bauer = models.Identity{Name: "[email protected]"}
forth = models.Identity{Name: "[email protected]"}
)

createIdentities(t, db, &bond, &bourne, &bauer, &forth)
err = AddUsersToGroup(db, everyone.ID, everyone.Name, InfraProvider(db).ID, []uid.ID{bond.ID, forth.ID})
assert.NilError(t, err)
err = AddUsersToGroup(db, everyone.ID, everyone.Name, mockta.ID, []uid.ID{bond.ID, bourne.ID, bauer.ID})
assert.NilError(t, err)

t.Run("count users in group", func(t *testing.T) {
count, err := CountUsersInGroup(db, everyone.ID)
assert.NilError(t, err)
assert.Equal(t, int(count), 4)
})
})
}
Loading

0 comments on commit b89e82d

Please sign in to comment.