From 79f8ec6fcd981f413747ce6b904131d492f880ce Mon Sep 17 00:00:00 2001 From: Karan Kajla Date: Sun, 17 Mar 2024 17:48:41 -0700 Subject: [PATCH] Update query cursor to include relation + fix implicit result logic (#311) --- pkg/authz/query/resultset.go | 21 +-- pkg/authz/query/service.go | 107 ++++++++++----- tests/v2/query.json | 248 +++++++++++++++++++++++++++++++++++ 3 files changed, 333 insertions(+), 43 deletions(-) diff --git a/pkg/authz/query/resultset.go b/pkg/authz/query/resultset.go index f82263f..c58f361 100644 --- a/pkg/authz/query/resultset.go +++ b/pkg/authz/query/resultset.go @@ -99,11 +99,9 @@ func (rs *ResultSet) Union(other *ResultSet) *ResultSet { } for iter := other.List(); iter != nil; iter = iter.Next() { - isImplicit := iter.IsImplicit - if resultSet.Has(iter.ObjectType, iter.ObjectId, iter.Relation) { - isImplicit = isImplicit && resultSet.Get(iter.ObjectType, iter.ObjectId, iter.Relation).IsImplicit + if !resultSet.Has(iter.ObjectType, iter.ObjectId, iter.Relation) || !iter.IsImplicit { + resultSet.Add(iter.ObjectType, iter.ObjectId, iter.Relation, iter.Warrant, iter.IsImplicit) } - resultSet.Add(iter.ObjectType, iter.ObjectId, iter.Relation, iter.Warrant, isImplicit) } return resultSet @@ -112,10 +110,13 @@ func (rs *ResultSet) Union(other *ResultSet) *ResultSet { func (rs *ResultSet) Intersect(other *ResultSet) *ResultSet { resultSet := NewResultSet() for iter := rs.List(); iter != nil; iter = iter.Next() { - isImplicit := iter.IsImplicit if other.Has(iter.ObjectType, iter.ObjectId, iter.Relation) { - isImplicit = isImplicit || other.Get(iter.ObjectType, iter.ObjectId, iter.Relation).IsImplicit - resultSet.Add(iter.ObjectType, iter.ObjectId, iter.Relation, iter.Warrant, isImplicit) + otherRes := other.Get(iter.ObjectType, iter.ObjectId, iter.Relation) + if !otherRes.IsImplicit { + resultSet.Add(otherRes.ObjectType, otherRes.ObjectId, otherRes.Relation, otherRes.Warrant, otherRes.IsImplicit) + } else { + resultSet.Add(iter.ObjectType, iter.ObjectId, iter.Relation, iter.Warrant, iter.IsImplicit) + } } } @@ -125,7 +126,11 @@ func (rs *ResultSet) Intersect(other *ResultSet) *ResultSet { func (rs *ResultSet) String() string { var strs []string for iter := rs.List(); iter != nil; iter = iter.Next() { - strs = append(strs, fmt.Sprintf("%s => %s", key(iter.ObjectType, iter.ObjectId, iter.Relation), iter.Warrant.String())) + if iter.IsImplicit { + strs = append(strs, fmt.Sprintf("%s => %s [implicit]", key(iter.ObjectType, iter.ObjectId, iter.Relation), iter.Warrant.String())) + } else { + strs = append(strs, fmt.Sprintf("%s => %s", key(iter.ObjectType, iter.ObjectId, iter.Relation), iter.Warrant.String())) + } } return strings.Join(strs, ", ") diff --git a/pkg/authz/query/service.go b/pkg/authz/query/service.go index 9d69554..d3d2837 100644 --- a/pkg/authz/query/service.go +++ b/pkg/authz/query/service.go @@ -222,25 +222,25 @@ func (svc QueryService) Query(ctx context.Context, query Query, listParams servi paginatedQueryResults := make([]QueryResult, 0) //nolint:gocritic if listParams.NextCursor != nil { // seek forward if NextCursor passed in - lastObjectType, lastObjectId, err := objectTypeAndObjectIdFromCursor(listParams.NextCursor) + lastObjectType, lastObjectId, lastRelation, err := objectTypeAndObjectIdAndRelationFromCursor(listParams.NextCursor) if err != nil { return nil, nil, nil, service.NewInvalidParameterError("nextCursor", "invalid cursor") } start = 0 - for start < len(queryResults) && (queryResults[start].ObjectType != lastObjectType || queryResults[start].ObjectId != lastObjectId) { + for start < len(queryResults) && (queryResults[start].ObjectType != lastObjectType || queryResults[start].ObjectId != lastObjectId || queryResults[start].Relation != lastRelation) { start++ } end = start + listParams.Limit } else if listParams.PrevCursor != nil { // seek backward if PrevCursor passed in - lastObjectType, lastObjectId, err := objectTypeAndObjectIdFromCursor(listParams.PrevCursor) + lastObjectType, lastObjectId, lastRelation, err := objectTypeAndObjectIdAndRelationFromCursor(listParams.PrevCursor) if err != nil { return nil, nil, nil, service.NewInvalidParameterError("prevCursor", "invalid cursor") } end = len(queryResults) - 1 - for end > 0 && (queryResults[end].ObjectType != lastObjectType || queryResults[end].ObjectId != lastObjectId) { + for end > 0 && (queryResults[end].ObjectType != lastObjectType || queryResults[end].ObjectId != lastObjectId || queryResults[end].Relation != lastRelation) { end-- } @@ -262,7 +262,7 @@ func (svc QueryService) Query(ctx context.Context, query Query, listParams servi value = queryResults[start].Meta[listParams.SortBy] } - prevCursor = service.NewCursor(objectKey(queryResults[start].ObjectType, queryResults[start].ObjectId), value) + prevCursor = service.NewCursor(objectRelationKey(queryResults[start].ObjectType, queryResults[start].ObjectId, queryResults[start].Relation), value) } // if there are more results forward @@ -277,7 +277,7 @@ func (svc QueryService) Query(ctx context.Context, query Query, listParams servi value = queryResults[end].Meta[listParams.SortBy] } - nextCursor = service.NewCursor(objectKey(queryResults[end].ObjectType, queryResults[end].ObjectId), value) + nextCursor = service.NewCursor(objectRelationKey(queryResults[end].ObjectType, queryResults[end].ObjectId, queryResults[end].Relation), value) } for start < end && start < len(queryResults) { @@ -336,7 +336,7 @@ func (svc QueryService) query(ctx context.Context, query Query, level int) (*Res for _, matchedWarrant := range matchedWarrants { if matchedWarrant.Subject.Relation != "" { // handle group warrants - userset, err := svc.query(ctx, Query{ + subset, err := svc.query(ctx, Query{ Expand: query.Expand, SelectSubjects: &SelectSubjects{ Relations: []string{matchedWarrant.Subject.Relation}, @@ -352,8 +352,8 @@ func (svc QueryService) query(ctx context.Context, query Query, level int) (*Res return nil, err } - for res := userset.List(); res != nil; res = res.Next() { - if res.ObjectType != query.SelectObjects.WhereSubject.Type || res.ObjectId != query.SelectObjects.WhereSubject.Id { + for sub := subset.List(); sub != nil; sub = sub.Next() { + if sub.ObjectType != query.SelectObjects.WhereSubject.Type || sub.ObjectId != query.SelectObjects.WhereSubject.Id { continue } @@ -367,27 +367,29 @@ func (svc QueryService) query(ctx context.Context, query Query, level int) (*Res for _, w := range expandedWildcardWarrants { if w.ObjectId != warrant.Wildcard { - resultSet.Add(w.ObjectType, w.ObjectId, relation, matchedWarrant, level > 0) + resultSet.Add(w.ObjectType, w.ObjectId, relation, matchedWarrant, sub.IsImplicit || level > 0) } } } else { - resultSet.Add(matchedWarrant.ObjectType, matchedWarrant.ObjectId, relation, matchedWarrant, level > 0) + resultSet.Add(matchedWarrant.ObjectType, matchedWarrant.ObjectId, relation, matchedWarrant, sub.IsImplicit || level > 0) } } } else if query.SelectObjects.WhereSubject == nil || (matchedWarrant.Subject.ObjectType == query.SelectObjects.WhereSubject.Type && matchedWarrant.Subject.ObjectId == query.SelectObjects.WhereSubject.Id) { - resultSet.Add(matchedWarrant.ObjectType, matchedWarrant.ObjectId, relation, matchedWarrant, level > 0) + resultSet.Add(matchedWarrant.ObjectType, matchedWarrant.ObjectId, relation, matchedWarrant, false) } } if query.Expand { - implicitResultSet, err := svc.queryRule(ctx, query, level, objectTypeDef.Relations[relation]) + implicitResultSet, err := svc.queryRule(ctx, query, level+1, relation, objectTypeDef.Relations[relation]) if err != nil { return nil, err } - resultSet = resultSet.Union(implicitResultSet) + for res := implicitResultSet.List(); res != nil; res = res.Next() { + resultSet.Add(res.ObjectType, res.ObjectId, relation, res.Warrant, res.IsImplicit || level > 0) + } } return resultSet, nil @@ -417,7 +419,7 @@ func (svc QueryService) query(ctx context.Context, query Query, level int) (*Res for _, matchedWarrant := range matchedWarrants { if matchedWarrant.Subject.Relation != "" { // handle group warrants - userset, err := svc.query(ctx, Query{ + subset, err := svc.query(ctx, Query{ Expand: query.Expand, SelectSubjects: &SelectSubjects{ Relations: []string{matchedWarrant.Subject.Relation}, @@ -433,21 +435,23 @@ func (svc QueryService) query(ctx context.Context, query Query, level int) (*Res return nil, err } - for res := userset.List(); res != nil; res = res.Next() { - resultSet.Add(res.ObjectType, res.ObjectId, relation, matchedWarrant, level > 0) + for sub := subset.List(); sub != nil; sub = sub.Next() { + resultSet.Add(sub.ObjectType, sub.ObjectId, relation, matchedWarrant, sub.IsImplicit || level > 0) } } else if query.SelectSubjects.SubjectTypes[0] == matchedWarrant.Subject.ObjectType { - resultSet.Add(matchedWarrant.Subject.ObjectType, matchedWarrant.Subject.ObjectId, relation, matchedWarrant, level > 0) + resultSet.Add(matchedWarrant.Subject.ObjectType, matchedWarrant.Subject.ObjectId, relation, matchedWarrant, false) } } if query.Expand { - implicitResultSet, err := svc.queryRule(ctx, query, level, objectTypeDef.Relations[relation]) + implicitResultSet, err := svc.queryRule(ctx, query, level+1, relation, objectTypeDef.Relations[relation]) if err != nil { return nil, err } - return resultSet.Union(implicitResultSet), nil + for res := implicitResultSet.List(); res != nil; res = res.Next() { + resultSet.Add(res.ObjectType, res.ObjectId, relation, res.Warrant, res.IsImplicit || level > 0) + } } return resultSet, nil @@ -456,14 +460,14 @@ func (svc QueryService) query(ctx context.Context, query Query, level int) (*Res } } -func (svc QueryService) queryRule(ctx context.Context, query Query, level int, rule objecttype.RelationRule) (*ResultSet, error) { +func (svc QueryService) queryRule(ctx context.Context, query Query, level int, relation string, rule objecttype.RelationRule) (*ResultSet, error) { switch rule.InheritIf { case "": return NewResultSet(), nil case objecttype.InheritIfAllOf: var resultSet *ResultSet for _, r := range rule.Rules { - res, err := svc.queryRule(ctx, query, level, r) + res, err := svc.queryRule(ctx, query, level, relation, r) if err != nil { return nil, err } @@ -479,7 +483,7 @@ func (svc QueryService) queryRule(ctx context.Context, query Query, level int, r case objecttype.InheritIfAnyOf: var resultSet *ResultSet for _, r := range rule.Rules { - res, err := svc.queryRule(ctx, query, level, r) + res, err := svc.queryRule(ctx, query, level, relation, r) if err != nil { return nil, err } @@ -498,7 +502,7 @@ func (svc QueryService) queryRule(ctx context.Context, query Query, level int, r switch { case query.SelectObjects != nil: if rule.OfType == "" && rule.WithRelation == "" { - return svc.query(ctx, Query{ + results, err := svc.query(ctx, Query{ Expand: true, SelectObjects: &SelectObjects{ ObjectTypes: query.SelectObjects.ObjectTypes, @@ -506,7 +510,17 @@ func (svc QueryService) queryRule(ctx context.Context, query Query, level int, r Relations: []string{rule.InheritIf}, }, Context: query.Context, - }, level+1) + }, 0) + if err != nil { + return nil, err + } + + resultSet := NewResultSet() + for res := results.List(); res != nil; res = res.Next() { + resultSet.Add(res.ObjectType, res.ObjectId, relation, res.Warrant, res.IsImplicit || level > 0) + } + + return resultSet, nil } else { indirectWarrants, err := svc.listWarrants(ctx, warrant.FilterParams{ ObjectType: rule.OfType, @@ -535,19 +549,21 @@ func (svc QueryService) queryRule(ctx context.Context, query Query, level int, r Relations: []string{rule.WithRelation}, }, Context: query.Context, - }, level+1) + }, 0) if err != nil { return nil, err } - resultSet = resultSet.Union(inheritedResults) + for res := inheritedResults.List(); res != nil; res = res.Next() { + resultSet.Add(res.ObjectType, res.ObjectId, relation, res.Warrant, res.IsImplicit || level > 0) + } } return resultSet, nil } case query.SelectSubjects != nil: if rule.OfType == "" && rule.WithRelation == "" { - return svc.query(ctx, Query{ + results, err := svc.query(ctx, Query{ Expand: true, SelectSubjects: &SelectSubjects{ SubjectTypes: query.SelectSubjects.SubjectTypes, @@ -555,7 +571,17 @@ func (svc QueryService) queryRule(ctx context.Context, query Query, level int, r ForObject: query.SelectSubjects.ForObject, }, Context: query.Context, - }, level+1) + }, 0) + if err != nil { + return nil, err + } + + resultSet := NewResultSet() + for res := results.List(); res != nil; res = res.Next() { + resultSet.Add(res.ObjectType, res.ObjectId, relation, res.Warrant, res.IsImplicit || level > 0) + } + + return resultSet, nil } else { userset, err := svc.listWarrants(ctx, warrant.FilterParams{ ObjectType: query.SelectSubjects.ForObject.Type, @@ -584,12 +610,14 @@ func (svc QueryService) queryRule(ctx context.Context, query Query, level int, r }, }, Context: query.Context, - }, level+1) + }, 0) if err != nil { return nil, err } - resultSet = resultSet.Union(subset) + for res := subset.List(); res != nil; res = res.Next() { + resultSet.Add(res.ObjectType, res.ObjectId, relation, res.Warrant, res.IsImplicit || level > 0) + } } return resultSet, nil @@ -628,11 +656,20 @@ func objectKey(objectType string, objectId string) string { return fmt.Sprintf("%s:%s", objectType, objectId) } -func objectTypeAndObjectIdFromCursor(cursor *service.Cursor) (string, string, error) { - objectType, objectId, found := strings.Cut(cursor.ID(), ":") +func objectRelationKey(objectType string, objectId string, relation string) string { + return fmt.Sprintf("%s:%s#%s", objectType, objectId, relation) +} + +func objectTypeAndObjectIdAndRelationFromCursor(cursor *service.Cursor) (string, string, string, error) { + objectType, objectIdRelation, found := strings.Cut(cursor.ID(), ":") + if !found { + return "", "", "", errors.New("invalid cursor") + } + + objectId, relation, found := strings.Cut(objectIdRelation, "#") if !found { - return "", "", errors.New("invalid cursor") + return "", "", "", errors.New("invalid cursor") } - return objectType, objectId, nil + return objectType, objectId, relation, nil } diff --git a/tests/v2/query.json b/tests/v2/query.json index b7ac7a8..249f9ba 100644 --- a/tests/v2/query.json +++ b/tests/v2/query.json @@ -3076,6 +3076,254 @@ "expectedResponse": { "statusCode": 200 } + }, + { + "name": "assignUserRoadrunnerAsAdminOfTenantAcme", + "request": { + "method": "POST", + "url": "/v2/warrants", + "body": { + "objectType": "tenant", + "objectId": "acme", + "relation": "admin", + "subject": { + "objectType": "user", + "objectId": "road-runner" + } + } + }, + "expectedResponse": { + "statusCode": 200, + "body": { + "objectType": "tenant", + "objectId": "acme", + "relation": "admin", + "subject": { + "objectType": "user", + "objectId": "road-runner" + } + } + } + }, + { + "name": "assignUserWileyAsMemberOfTenantAcme", + "request": { + "method": "POST", + "url": "/v2/warrants", + "body": { + "objectType": "tenant", + "objectId": "acme", + "relation": "member", + "subject": { + "objectType": "user", + "objectId": "wiley" + } + } + }, + "expectedResponse": { + "statusCode": 200, + "body": { + "objectType": "tenant", + "objectId": "acme", + "relation": "member", + "subject": { + "objectType": "user", + "objectId": "wiley" + } + } + } + }, + { + "name": "selectTenantWhereUserRoadrunnerIsAnything", + "request": { + "method": "GET", + "url": "/v2/query?q=select%20tenant%20where%20user:road-runner%20is%20*" + }, + "expectedResponse": { + "statusCode": 200, + "body": { + "results": [ + { + "objectType": "tenant", + "objectId": "acme", + "relation": "admin", + "warrant": { + "objectType": "tenant", + "objectId": "acme", + "relation": "admin", + "subject": { + "objectType": "user", + "objectId": "road-runner" + } + }, + "isImplicit": false + }, + { + "objectType": "tenant", + "objectId": "acme", + "relation": "manager", + "warrant": { + "objectType": "tenant", + "objectId": "acme", + "relation": "admin", + "subject": { + "objectType": "user", + "objectId": "road-runner" + } + }, + "isImplicit": true + }, + { + "objectType": "tenant", + "objectId": "acme", + "relation": "member", + "warrant": { + "objectType": "tenant", + "objectId": "acme", + "relation": "admin", + "subject": { + "objectType": "user", + "objectId": "road-runner" + } + }, + "isImplicit": true + } + ] + } + } + }, + { + "name": "selectTenantWhereUserRoadrunnerIsManager", + "request": { + "method": "GET", + "url": "/v2/query?q=select%20tenant%20where%20user:road-runner%20is%20manager" + }, + "expectedResponse": { + "statusCode": 200, + "body": { + "results": [ + { + "objectType": "tenant", + "objectId": "acme", + "relation": "manager", + "warrant": { + "objectType": "tenant", + "objectId": "acme", + "relation": "admin", + "subject": { + "objectType": "user", + "objectId": "road-runner" + } + }, + "isImplicit": true + } + ] + } + } + }, + { + "name": "selectAnythingOfTypeUserForTenantAcme", + "request": { + "method": "GET", + "url": "/v2/query?q=select%20*%20of%20type%20user%20for%20tenant:acme" + }, + "expectedResponse": { + "statusCode": 200, + "body": { + "results": [ + { + "objectType": "user", + "objectId": "road-runner", + "relation": "admin", + "warrant": { + "objectType": "tenant", + "objectId": "acme", + "relation": "admin", + "subject": { + "objectType": "user", + "objectId": "road-runner" + } + }, + "isImplicit": false + }, + { + "objectType": "user", + "objectId": "road-runner", + "relation": "manager", + "warrant": { + "objectType": "tenant", + "objectId": "acme", + "relation": "admin", + "subject": { + "objectType": "user", + "objectId": "road-runner" + } + }, + "isImplicit": true + }, + { + "objectType": "user", + "objectId": "road-runner", + "relation": "member", + "warrant": { + "objectType": "tenant", + "objectId": "acme", + "relation": "admin", + "subject": { + "objectType": "user", + "objectId": "road-runner" + } + }, + "isImplicit": true + }, + { + "objectType": "user", + "objectId": "wiley", + "relation": "member", + "warrant": { + "objectType": "tenant", + "objectId": "acme", + "relation": "member", + "subject": { + "objectType": "user", + "objectId": "wiley" + } + }, + "isImplicit": false + } + ] + } + } + }, + { + "name": "deleteTenantAcme", + "request": { + "method": "DELETE", + "url": "/v2/objects/tenant/acme" + }, + "expectedResponse": { + "statusCode": 200 + } + }, + { + "name": "deleteUserRoadrunner", + "request": { + "method": "DELETE", + "url": "/v2/objects/user/road-runner" + }, + "expectedResponse": { + "statusCode": 200 + } + }, + { + "name": "deleteUserWiley", + "request": { + "method": "DELETE", + "url": "/v2/objects/user/wiley" + }, + "expectedResponse": { + "statusCode": 200 + } } ] }