Skip to content

Commit

Permalink
Update query cursor to include relation + fix implicit result logic (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
kkajla12 authored Mar 18, 2024
1 parent 9542553 commit 79f8ec6
Show file tree
Hide file tree
Showing 3 changed files with 333 additions and 43 deletions.
21 changes: 13 additions & 8 deletions pkg/authz/query/resultset.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,9 @@ func (rs *ResultSet) Union(other *ResultSet) *ResultSet {
}

for iter := other.List(); iter != nil; iter = iter.Next() {
isImplicit := iter.IsImplicit
if resultSet.Has(iter.ObjectType, iter.ObjectId, iter.Relation) {
isImplicit = isImplicit && resultSet.Get(iter.ObjectType, iter.ObjectId, iter.Relation).IsImplicit
if !resultSet.Has(iter.ObjectType, iter.ObjectId, iter.Relation) || !iter.IsImplicit {
resultSet.Add(iter.ObjectType, iter.ObjectId, iter.Relation, iter.Warrant, iter.IsImplicit)
}
resultSet.Add(iter.ObjectType, iter.ObjectId, iter.Relation, iter.Warrant, isImplicit)
}

return resultSet
Expand All @@ -112,10 +110,13 @@ func (rs *ResultSet) Union(other *ResultSet) *ResultSet {
func (rs *ResultSet) Intersect(other *ResultSet) *ResultSet {
resultSet := NewResultSet()
for iter := rs.List(); iter != nil; iter = iter.Next() {
isImplicit := iter.IsImplicit
if other.Has(iter.ObjectType, iter.ObjectId, iter.Relation) {
isImplicit = isImplicit || other.Get(iter.ObjectType, iter.ObjectId, iter.Relation).IsImplicit
resultSet.Add(iter.ObjectType, iter.ObjectId, iter.Relation, iter.Warrant, isImplicit)
otherRes := other.Get(iter.ObjectType, iter.ObjectId, iter.Relation)
if !otherRes.IsImplicit {
resultSet.Add(otherRes.ObjectType, otherRes.ObjectId, otherRes.Relation, otherRes.Warrant, otherRes.IsImplicit)
} else {
resultSet.Add(iter.ObjectType, iter.ObjectId, iter.Relation, iter.Warrant, iter.IsImplicit)
}
}
}

Expand All @@ -125,7 +126,11 @@ func (rs *ResultSet) Intersect(other *ResultSet) *ResultSet {
func (rs *ResultSet) String() string {
var strs []string
for iter := rs.List(); iter != nil; iter = iter.Next() {
strs = append(strs, fmt.Sprintf("%s => %s", key(iter.ObjectType, iter.ObjectId, iter.Relation), iter.Warrant.String()))
if iter.IsImplicit {
strs = append(strs, fmt.Sprintf("%s => %s [implicit]", key(iter.ObjectType, iter.ObjectId, iter.Relation), iter.Warrant.String()))
} else {
strs = append(strs, fmt.Sprintf("%s => %s", key(iter.ObjectType, iter.ObjectId, iter.Relation), iter.Warrant.String()))
}
}

return strings.Join(strs, ", ")
Expand Down
107 changes: 72 additions & 35 deletions pkg/authz/query/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,25 +222,25 @@ func (svc QueryService) Query(ctx context.Context, query Query, listParams servi
paginatedQueryResults := make([]QueryResult, 0)
//nolint:gocritic
if listParams.NextCursor != nil { // seek forward if NextCursor passed in
lastObjectType, lastObjectId, err := objectTypeAndObjectIdFromCursor(listParams.NextCursor)
lastObjectType, lastObjectId, lastRelation, err := objectTypeAndObjectIdAndRelationFromCursor(listParams.NextCursor)
if err != nil {
return nil, nil, nil, service.NewInvalidParameterError("nextCursor", "invalid cursor")
}

start = 0
for start < len(queryResults) && (queryResults[start].ObjectType != lastObjectType || queryResults[start].ObjectId != lastObjectId) {
for start < len(queryResults) && (queryResults[start].ObjectType != lastObjectType || queryResults[start].ObjectId != lastObjectId || queryResults[start].Relation != lastRelation) {
start++
}

end = start + listParams.Limit
} else if listParams.PrevCursor != nil { // seek backward if PrevCursor passed in
lastObjectType, lastObjectId, err := objectTypeAndObjectIdFromCursor(listParams.PrevCursor)
lastObjectType, lastObjectId, lastRelation, err := objectTypeAndObjectIdAndRelationFromCursor(listParams.PrevCursor)
if err != nil {
return nil, nil, nil, service.NewInvalidParameterError("prevCursor", "invalid cursor")
}

end = len(queryResults) - 1
for end > 0 && (queryResults[end].ObjectType != lastObjectType || queryResults[end].ObjectId != lastObjectId) {
for end > 0 && (queryResults[end].ObjectType != lastObjectType || queryResults[end].ObjectId != lastObjectId || queryResults[end].Relation != lastRelation) {
end--
}

Expand All @@ -262,7 +262,7 @@ func (svc QueryService) Query(ctx context.Context, query Query, listParams servi
value = queryResults[start].Meta[listParams.SortBy]
}

prevCursor = service.NewCursor(objectKey(queryResults[start].ObjectType, queryResults[start].ObjectId), value)
prevCursor = service.NewCursor(objectRelationKey(queryResults[start].ObjectType, queryResults[start].ObjectId, queryResults[start].Relation), value)
}

// if there are more results forward
Expand All @@ -277,7 +277,7 @@ func (svc QueryService) Query(ctx context.Context, query Query, listParams servi
value = queryResults[end].Meta[listParams.SortBy]
}

nextCursor = service.NewCursor(objectKey(queryResults[end].ObjectType, queryResults[end].ObjectId), value)
nextCursor = service.NewCursor(objectRelationKey(queryResults[end].ObjectType, queryResults[end].ObjectId, queryResults[end].Relation), value)
}

for start < end && start < len(queryResults) {
Expand Down Expand Up @@ -336,7 +336,7 @@ func (svc QueryService) query(ctx context.Context, query Query, level int) (*Res
for _, matchedWarrant := range matchedWarrants {
if matchedWarrant.Subject.Relation != "" {
// handle group warrants
userset, err := svc.query(ctx, Query{
subset, err := svc.query(ctx, Query{
Expand: query.Expand,
SelectSubjects: &SelectSubjects{
Relations: []string{matchedWarrant.Subject.Relation},
Expand All @@ -352,8 +352,8 @@ func (svc QueryService) query(ctx context.Context, query Query, level int) (*Res
return nil, err
}

for res := userset.List(); res != nil; res = res.Next() {
if res.ObjectType != query.SelectObjects.WhereSubject.Type || res.ObjectId != query.SelectObjects.WhereSubject.Id {
for sub := subset.List(); sub != nil; sub = sub.Next() {
if sub.ObjectType != query.SelectObjects.WhereSubject.Type || sub.ObjectId != query.SelectObjects.WhereSubject.Id {
continue
}

Expand All @@ -367,27 +367,29 @@ func (svc QueryService) query(ctx context.Context, query Query, level int) (*Res

for _, w := range expandedWildcardWarrants {
if w.ObjectId != warrant.Wildcard {
resultSet.Add(w.ObjectType, w.ObjectId, relation, matchedWarrant, level > 0)
resultSet.Add(w.ObjectType, w.ObjectId, relation, matchedWarrant, sub.IsImplicit || level > 0)
}
}
} else {
resultSet.Add(matchedWarrant.ObjectType, matchedWarrant.ObjectId, relation, matchedWarrant, level > 0)
resultSet.Add(matchedWarrant.ObjectType, matchedWarrant.ObjectId, relation, matchedWarrant, sub.IsImplicit || level > 0)
}
}
} else if query.SelectObjects.WhereSubject == nil ||
(matchedWarrant.Subject.ObjectType == query.SelectObjects.WhereSubject.Type &&
matchedWarrant.Subject.ObjectId == query.SelectObjects.WhereSubject.Id) {
resultSet.Add(matchedWarrant.ObjectType, matchedWarrant.ObjectId, relation, matchedWarrant, level > 0)
resultSet.Add(matchedWarrant.ObjectType, matchedWarrant.ObjectId, relation, matchedWarrant, false)
}
}

if query.Expand {
implicitResultSet, err := svc.queryRule(ctx, query, level, objectTypeDef.Relations[relation])
implicitResultSet, err := svc.queryRule(ctx, query, level+1, relation, objectTypeDef.Relations[relation])
if err != nil {
return nil, err
}

resultSet = resultSet.Union(implicitResultSet)
for res := implicitResultSet.List(); res != nil; res = res.Next() {
resultSet.Add(res.ObjectType, res.ObjectId, relation, res.Warrant, res.IsImplicit || level > 0)
}
}

return resultSet, nil
Expand Down Expand Up @@ -417,7 +419,7 @@ func (svc QueryService) query(ctx context.Context, query Query, level int) (*Res
for _, matchedWarrant := range matchedWarrants {
if matchedWarrant.Subject.Relation != "" {
// handle group warrants
userset, err := svc.query(ctx, Query{
subset, err := svc.query(ctx, Query{
Expand: query.Expand,
SelectSubjects: &SelectSubjects{
Relations: []string{matchedWarrant.Subject.Relation},
Expand All @@ -433,21 +435,23 @@ func (svc QueryService) query(ctx context.Context, query Query, level int) (*Res
return nil, err
}

for res := userset.List(); res != nil; res = res.Next() {
resultSet.Add(res.ObjectType, res.ObjectId, relation, matchedWarrant, level > 0)
for sub := subset.List(); sub != nil; sub = sub.Next() {
resultSet.Add(sub.ObjectType, sub.ObjectId, relation, matchedWarrant, sub.IsImplicit || level > 0)
}
} else if query.SelectSubjects.SubjectTypes[0] == matchedWarrant.Subject.ObjectType {
resultSet.Add(matchedWarrant.Subject.ObjectType, matchedWarrant.Subject.ObjectId, relation, matchedWarrant, level > 0)
resultSet.Add(matchedWarrant.Subject.ObjectType, matchedWarrant.Subject.ObjectId, relation, matchedWarrant, false)
}
}

if query.Expand {
implicitResultSet, err := svc.queryRule(ctx, query, level, objectTypeDef.Relations[relation])
implicitResultSet, err := svc.queryRule(ctx, query, level+1, relation, objectTypeDef.Relations[relation])
if err != nil {
return nil, err
}

return resultSet.Union(implicitResultSet), nil
for res := implicitResultSet.List(); res != nil; res = res.Next() {
resultSet.Add(res.ObjectType, res.ObjectId, relation, res.Warrant, res.IsImplicit || level > 0)
}
}

return resultSet, nil
Expand All @@ -456,14 +460,14 @@ func (svc QueryService) query(ctx context.Context, query Query, level int) (*Res
}
}

func (svc QueryService) queryRule(ctx context.Context, query Query, level int, rule objecttype.RelationRule) (*ResultSet, error) {
func (svc QueryService) queryRule(ctx context.Context, query Query, level int, relation string, rule objecttype.RelationRule) (*ResultSet, error) {
switch rule.InheritIf {
case "":
return NewResultSet(), nil
case objecttype.InheritIfAllOf:
var resultSet *ResultSet
for _, r := range rule.Rules {
res, err := svc.queryRule(ctx, query, level, r)
res, err := svc.queryRule(ctx, query, level, relation, r)
if err != nil {
return nil, err
}
Expand All @@ -479,7 +483,7 @@ func (svc QueryService) queryRule(ctx context.Context, query Query, level int, r
case objecttype.InheritIfAnyOf:
var resultSet *ResultSet
for _, r := range rule.Rules {
res, err := svc.queryRule(ctx, query, level, r)
res, err := svc.queryRule(ctx, query, level, relation, r)
if err != nil {
return nil, err
}
Expand All @@ -498,15 +502,25 @@ func (svc QueryService) queryRule(ctx context.Context, query Query, level int, r
switch {
case query.SelectObjects != nil:
if rule.OfType == "" && rule.WithRelation == "" {
return svc.query(ctx, Query{
results, err := svc.query(ctx, Query{
Expand: true,
SelectObjects: &SelectObjects{
ObjectTypes: query.SelectObjects.ObjectTypes,
WhereSubject: query.SelectObjects.WhereSubject,
Relations: []string{rule.InheritIf},
},
Context: query.Context,
}, level+1)
}, 0)
if err != nil {
return nil, err
}

resultSet := NewResultSet()
for res := results.List(); res != nil; res = res.Next() {
resultSet.Add(res.ObjectType, res.ObjectId, relation, res.Warrant, res.IsImplicit || level > 0)
}

return resultSet, nil
} else {
indirectWarrants, err := svc.listWarrants(ctx, warrant.FilterParams{
ObjectType: rule.OfType,
Expand Down Expand Up @@ -535,27 +549,39 @@ func (svc QueryService) queryRule(ctx context.Context, query Query, level int, r
Relations: []string{rule.WithRelation},
},
Context: query.Context,
}, level+1)
}, 0)
if err != nil {
return nil, err
}

resultSet = resultSet.Union(inheritedResults)
for res := inheritedResults.List(); res != nil; res = res.Next() {
resultSet.Add(res.ObjectType, res.ObjectId, relation, res.Warrant, res.IsImplicit || level > 0)
}
}

return resultSet, nil
}
case query.SelectSubjects != nil:
if rule.OfType == "" && rule.WithRelation == "" {
return svc.query(ctx, Query{
results, err := svc.query(ctx, Query{
Expand: true,
SelectSubjects: &SelectSubjects{
SubjectTypes: query.SelectSubjects.SubjectTypes,
Relations: []string{rule.InheritIf},
ForObject: query.SelectSubjects.ForObject,
},
Context: query.Context,
}, level+1)
}, 0)
if err != nil {
return nil, err
}

resultSet := NewResultSet()
for res := results.List(); res != nil; res = res.Next() {
resultSet.Add(res.ObjectType, res.ObjectId, relation, res.Warrant, res.IsImplicit || level > 0)
}

return resultSet, nil
} else {
userset, err := svc.listWarrants(ctx, warrant.FilterParams{
ObjectType: query.SelectSubjects.ForObject.Type,
Expand Down Expand Up @@ -584,12 +610,14 @@ func (svc QueryService) queryRule(ctx context.Context, query Query, level int, r
},
},
Context: query.Context,
}, level+1)
}, 0)
if err != nil {
return nil, err
}

resultSet = resultSet.Union(subset)
for res := subset.List(); res != nil; res = res.Next() {
resultSet.Add(res.ObjectType, res.ObjectId, relation, res.Warrant, res.IsImplicit || level > 0)
}
}

return resultSet, nil
Expand Down Expand Up @@ -628,11 +656,20 @@ func objectKey(objectType string, objectId string) string {
return fmt.Sprintf("%s:%s", objectType, objectId)
}

func objectTypeAndObjectIdFromCursor(cursor *service.Cursor) (string, string, error) {
objectType, objectId, found := strings.Cut(cursor.ID(), ":")
func objectRelationKey(objectType string, objectId string, relation string) string {
return fmt.Sprintf("%s:%s#%s", objectType, objectId, relation)
}

func objectTypeAndObjectIdAndRelationFromCursor(cursor *service.Cursor) (string, string, string, error) {
objectType, objectIdRelation, found := strings.Cut(cursor.ID(), ":")
if !found {
return "", "", "", errors.New("invalid cursor")
}

objectId, relation, found := strings.Cut(objectIdRelation, "#")
if !found {
return "", "", errors.New("invalid cursor")
return "", "", "", errors.New("invalid cursor")
}

return objectType, objectId, nil
return objectType, objectId, relation, nil
}
Loading

0 comments on commit 79f8ec6

Please sign in to comment.