Skip to content

Commit

Permalink
Add a Relation attribute to QueryResult (for internal use) (#306)
Browse files Browse the repository at this point in the history
  • Loading branch information
kkajla12 authored Mar 8, 2024
1 parent 2251884 commit b4920b4
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 65 deletions.
36 changes: 19 additions & 17 deletions pkg/authz/query/resultset.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
type ResultSetNode struct {
ObjectType string
ObjectId string
Relation string
Warrant warrant.WarrantSpec
IsImplicit bool
next *ResultSetNode
Expand All @@ -47,16 +48,17 @@ func (rs *ResultSet) List() *ResultSetNode {
return rs.head
}

func (rs *ResultSet) Add(objectType string, objectId string, warrant warrant.WarrantSpec, isImplicit bool) {
func (rs *ResultSet) Add(objectType string, objectId string, relation string, warrant warrant.WarrantSpec, isImplicit bool) {
newNode := &ResultSetNode{
ObjectType: objectType,
ObjectId: objectId,
Relation: relation,
Warrant: warrant,
IsImplicit: isImplicit,
next: nil,
}

existingRes, exists := rs.m[key(objectType, objectId)]
existingRes, exists := rs.m[key(objectType, objectId, relation)]
if !exists {
// Add warrant to list
if rs.head == nil {
Expand All @@ -72,35 +74,35 @@ func (rs *ResultSet) Add(objectType string, objectId string, warrant warrant.War

if !exists || (existingRes.IsImplicit && !isImplicit) {
// Add result node to map for O(1) lookups
rs.m[key(objectType, objectId)] = newNode
rs.m[key(objectType, objectId, relation)] = newNode
}
}

func (rs *ResultSet) Len() int {
return len(rs.m)
}

func (rs *ResultSet) Get(objectType string, objectId string) *ResultSetNode {
return rs.m[key(objectType, objectId)]
func (rs *ResultSet) Get(objectType string, objectId string, relation string) *ResultSetNode {
return rs.m[key(objectType, objectId, relation)]
}

func (rs *ResultSet) Has(objectType string, objectId string) bool {
_, exists := rs.m[key(objectType, objectId)]
func (rs *ResultSet) Has(objectType string, objectId string, relation string) bool {
_, exists := rs.m[key(objectType, objectId, relation)]
return exists
}

func (rs *ResultSet) Union(other *ResultSet) *ResultSet {
resultSet := NewResultSet()
for iter := rs.List(); iter != nil; iter = iter.Next() {
resultSet.Add(iter.ObjectType, iter.ObjectId, iter.Warrant, iter.IsImplicit)
resultSet.Add(iter.ObjectType, iter.ObjectId, iter.Relation, iter.Warrant, iter.IsImplicit)
}

for iter := other.List(); iter != nil; iter = iter.Next() {
isImplicit := iter.IsImplicit
if resultSet.Has(iter.ObjectType, iter.ObjectId) {
isImplicit = isImplicit && resultSet.Get(iter.ObjectType, iter.ObjectId).IsImplicit
if resultSet.Has(iter.ObjectType, iter.ObjectId, iter.Relation) {
isImplicit = isImplicit && resultSet.Get(iter.ObjectType, iter.ObjectId, iter.Relation).IsImplicit
}
resultSet.Add(iter.ObjectType, iter.ObjectId, iter.Warrant, isImplicit)
resultSet.Add(iter.ObjectType, iter.ObjectId, iter.Relation, iter.Warrant, isImplicit)
}

return resultSet
Expand All @@ -110,9 +112,9 @@ func (rs *ResultSet) Intersect(other *ResultSet) *ResultSet {
resultSet := NewResultSet()
for iter := rs.List(); iter != nil; iter = iter.Next() {
isImplicit := iter.IsImplicit
if other.Has(iter.ObjectType, iter.ObjectId) {
isImplicit = isImplicit || other.Get(iter.ObjectType, iter.ObjectId).IsImplicit
resultSet.Add(iter.ObjectType, iter.ObjectId, iter.Warrant, isImplicit)
if other.Has(iter.ObjectType, iter.ObjectId, iter.Relation) {
isImplicit = isImplicit || other.Get(iter.ObjectType, iter.ObjectId, iter.Relation).IsImplicit
resultSet.Add(iter.ObjectType, iter.ObjectId, iter.Relation, iter.Warrant, isImplicit)
}
}

Expand All @@ -122,7 +124,7 @@ func (rs *ResultSet) Intersect(other *ResultSet) *ResultSet {
func (rs *ResultSet) String() string {
var strs []string
for iter := rs.List(); iter != nil; iter = iter.Next() {
strs = append(strs, fmt.Sprintf("%s => %s", key(iter.ObjectType, iter.ObjectId), iter.Warrant.String()))
strs = append(strs, fmt.Sprintf("%s => %s", key(iter.ObjectType, iter.ObjectId, iter.Relation), iter.Warrant.String()))
}

return strings.Join(strs, ", ")
Expand All @@ -136,6 +138,6 @@ func NewResultSet() *ResultSet {
}
}

func key(objectType string, objectId string) string {
return fmt.Sprintf("%s:%s", objectType, objectId)
func key(objectType string, objectId string, relation string) string {
return fmt.Sprintf("%s:%s#%s", objectType, objectId, relation)
}
23 changes: 14 additions & 9 deletions pkg/authz/query/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func (svc QueryService) Query(ctx context.Context, query Query, listParams servi
}

for _, relation := range relations {
res, err := svc.query(ctx, Query{
queryResult, err := svc.query(ctx, Query{
Expand: query.Expand,
SelectObjects: &SelectObjects{
ObjectTypes: []string{objectType.Type},
Expand All @@ -113,7 +113,9 @@ func (svc QueryService) Query(ctx context.Context, query Query, listParams servi
return nil, nil, nil, err
}

resultSet = resultSet.Union(res)
for res := queryResult.List(); res != nil; res = res.Next() {
resultSet.Add(res.ObjectType, res.ObjectId, relation, res.Warrant, res.IsImplicit)
}
}
}
case query.SelectSubjects != nil:
Expand Down Expand Up @@ -159,7 +161,7 @@ func (svc QueryService) Query(ctx context.Context, query Query, listParams servi

for _, subjectType := range subjectTypes {
for _, relation := range relations {
res, err := svc.query(ctx, Query{
queryResult, err := svc.query(ctx, Query{
Expand: query.Expand,
SelectSubjects: &SelectSubjects{
SubjectTypes: []string{subjectType.Type},
Expand All @@ -172,7 +174,9 @@ func (svc QueryService) Query(ctx context.Context, query Query, listParams servi
return nil, nil, nil, err
}

resultSet = resultSet.Union(res)
for res := queryResult.List(); res != nil; res = res.Next() {
resultSet.Add(res.ObjectType, res.ObjectId, relation, res.Warrant, res.IsImplicit)
}
}
}
default:
Expand All @@ -183,6 +187,7 @@ func (svc QueryService) Query(ctx context.Context, query Query, listParams servi
queryResults = append(queryResults, QueryResult{
ObjectType: res.ObjectType,
ObjectId: res.ObjectId,
Relation: res.Relation,
Warrant: res.Warrant,
IsImplicit: res.IsImplicit,
})
Expand Down Expand Up @@ -358,17 +363,17 @@ func (svc QueryService) query(ctx context.Context, query Query, level int) (*Res

for _, w := range expandedWildcardWarrants {
if w.ObjectId != warrant.Wildcard {
resultSet.Add(w.ObjectType, w.ObjectId, matchedWarrant, level > 0)
resultSet.Add(w.ObjectType, w.ObjectId, relation, matchedWarrant, level > 0)
}
}
} else {
resultSet.Add(matchedWarrant.ObjectType, matchedWarrant.ObjectId, matchedWarrant, level > 0)
resultSet.Add(matchedWarrant.ObjectType, matchedWarrant.ObjectId, relation, matchedWarrant, level > 0)
}
}
} else if query.SelectObjects.WhereSubject == nil ||
(matchedWarrant.Subject.ObjectType == query.SelectObjects.WhereSubject.Type &&
matchedWarrant.Subject.ObjectId == query.SelectObjects.WhereSubject.Id) {
resultSet.Add(matchedWarrant.ObjectType, matchedWarrant.ObjectId, matchedWarrant, level > 0)
resultSet.Add(matchedWarrant.ObjectType, matchedWarrant.ObjectId, relation, matchedWarrant, level > 0)
}
}

Expand Down Expand Up @@ -425,10 +430,10 @@ func (svc QueryService) query(ctx context.Context, query Query, level int) (*Res
}

for res := userset.List(); res != nil; res = res.Next() {
resultSet.Add(res.ObjectType, res.ObjectId, matchedWarrant, level > 0)
resultSet.Add(res.ObjectType, res.ObjectId, relation, matchedWarrant, level > 0)
}
} else if query.SelectSubjects.SubjectTypes[0] == matchedWarrant.Subject.ObjectType {
resultSet.Add(matchedWarrant.Subject.ObjectType, matchedWarrant.Subject.ObjectId, matchedWarrant, level > 0)
resultSet.Add(matchedWarrant.Subject.ObjectType, matchedWarrant.Subject.ObjectId, relation, matchedWarrant, level > 0)
}
}

Expand Down
1 change: 1 addition & 0 deletions pkg/authz/query/spec.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ type QueryHaving struct {
type QueryResult struct {
ObjectType string `json:"objectType"`
ObjectId string `json:"objectId"`
Relation string `json:"-"`
Warrant baseWarrant.WarrantSpec `json:"warrant"`
IsImplicit bool `json:"isImplicit"`
Meta map[string]interface{} `json:"meta,omitempty"`
Expand Down
50 changes: 11 additions & 39 deletions tests/v1/query.json
Original file line number Diff line number Diff line change
Expand Up @@ -1460,10 +1460,10 @@
}
},
{
"name": "selectMembersOrViewersOfTypeUserForDocumentD5Limit3SortOrderDesc",
"name": "selectExplicitOwnersOrViewersOfTypeUserForDocumentF3Limit1SortOrderDesc",
"request": {
"method": "GET",
"url": "/v1/query?q=select%20member%2Cviewer%20of%20type%20user%20for%20document:D5&limit=3&sortOrder=DESC",
"url": "/v1/query?q=select%20explicit%20owner%2Cviewer%20of%20type%20user%20for%20document:F3&limit=1&sortOrder=DESC",
"headers": {
"Warrant-Token": "latest"
}
Expand All @@ -1485,46 +1485,18 @@
"relation": "member"
}
},
"isImplicit": true
},
{
"objectType": "user",
"objectId": "U3",
"warrant": {
"objectType": "document",
"objectId": "F3",
"relation": "viewer",
"subject": {
"objectType": "user",
"objectId": "U3"
}
},
"isImplicit": true
},
{
"objectType": "user",
"objectId": "U2",
"warrant": {
"objectType": "document",
"objectId": "F2",
"relation": "editor",
"subject": {
"objectType": "user",
"objectId": "U2"
}
},
"isImplicit": true
"isImplicit": false
}
],
"lastId": "{{ selectMembersOrViewersOfTypeUserForDocumentD5Limit3SortOrderDesc.lastId }}"
"lastId": "{{ selectExplicitOwnersOrViewersOfTypeUserForDocumentF3Limit1SortOrderDesc.lastId }}"
}
}
},
{
"name": "selectMembersOrViewersOfTypeUserLimit3SortOrderDescAfterId1",
"name": "selectExplicitOwnersOrViewersOfTypeUserForDocumentF3Limit1SortOrderDescAfterId1",
"request": {
"method": "GET",
"url": "/v1/query?q=select%20member%2Cviewer%20of%20type%20user%20for%20document:D5&limit=3&sortOrder=DESC&lastId={{ selectMembersOrViewersOfTypeUserForDocumentD5Limit3SortOrderDesc.lastId }}",
"url": "/v1/query?q=select%20explicit%20owner%2Cviewer%20of%20type%20user%20for%20document:F3&limit=3&sortOrder=DESC&lastId={{ selectExplicitOwnersOrViewersOfTypeUserForDocumentF3Limit1SortOrderDesc.lastId }}",
"headers": {
"Warrant-Token": "latest"
}
Expand All @@ -1535,17 +1507,17 @@
"results": [
{
"objectType": "user",
"objectId": "U1",
"objectId": "U3",
"warrant": {
"objectType": "document",
"objectId": "F1",
"relation": "owner",
"objectId": "F3",
"relation": "viewer",
"subject": {
"objectType": "user",
"objectId": "U1"
"objectId": "U3"
}
},
"isImplicit": true
"isImplicit": false
}
]
}
Expand Down

0 comments on commit b4920b4

Please sign in to comment.