Skip to content

Commit

Permalink
Add policy support to query endpoint (#315)
Browse files Browse the repository at this point in the history
  • Loading branch information
kkajla12 authored Mar 31, 2024
1 parent 2a10400 commit c0a9763
Show file tree
Hide file tree
Showing 6 changed files with 1,553 additions and 70 deletions.
20 changes: 18 additions & 2 deletions pkg/authz/query/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,20 @@ func (svc QueryService) Routes() ([]service.Route, error) {
}

func queryV1(svc QueryService, w http.ResponseWriter, r *http.Request) error {
queryString := r.URL.Query().Get("q")
queryParams := r.URL.Query()
queryString := queryParams.Get("q")
query, err := NewQueryFromString(queryString)
if err != nil {
return err
}

if queryParams.Has("context") {
err = query.WithContext(queryParams.Get("context"))
if err != nil {
return service.NewInvalidParameterError("context", "invalid")
}
}

listParams := service.GetListParamsFromContext[QueryListParamParser](r.Context())
// create next cursor from lastId or afterId param
if r.URL.Query().Has("lastId") {
Expand Down Expand Up @@ -88,12 +96,20 @@ func queryV1(svc QueryService, w http.ResponseWriter, r *http.Request) error {
}

func queryV2(svc QueryService, w http.ResponseWriter, r *http.Request) error {
queryString := r.URL.Query().Get("q")
queryParams := r.URL.Query()
queryString := queryParams.Get("q")
query, err := NewQueryFromString(queryString)
if err != nil {
return err
}

if queryParams.Has("context") {
err = query.WithContext(queryParams.Get("context"))
if err != nil {
return service.NewInvalidParameterError("context", "invalid")
}
}

listParams := service.GetListParamsFromContext[QueryListParamParser](r.Context())
results, prevCursor, nextCursor, err := svc.Query(r.Context(), query, listParams)
if err != nil {
Expand Down
1 change: 0 additions & 1 deletion pkg/authz/query/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,5 @@ func NewQueryFromString(queryString string) (Query, error) {
}
}

query.rawString = queryString
return query, nil
}
44 changes: 28 additions & 16 deletions pkg/authz/query/resultset.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type ResultSetNode struct {
ObjectId string
Relation string
Warrant warrant.WarrantSpec
Policy warrant.Policy
IsImplicit bool
next *ResultSetNode
}
Expand All @@ -48,14 +49,15 @@ func (rs *ResultSet) List() *ResultSetNode {
return rs.head
}

func (rs *ResultSet) Add(objectType string, objectId string, relation string, warrant warrant.WarrantSpec, isImplicit bool) {
func (rs *ResultSet) Add(objectType string, objectId string, relation string, warrant warrant.WarrantSpec, policy warrant.Policy, isImplicit bool) {
existingRes, exists := rs.m[key(objectType, objectId, relation)]
if !exists {
newNode := ResultSetNode{
ObjectType: objectType,
ObjectId: objectId,
Relation: relation,
Warrant: warrant,
Policy: policy,
IsImplicit: isImplicit,
next: nil,
}
Expand All @@ -73,9 +75,15 @@ func (rs *ResultSet) Add(objectType string, objectId string, relation string, wa

// Add result node to map for O(1) lookups
rs.m[key(objectType, objectId, relation)] = &newNode
} else if existingRes.IsImplicit && !isImplicit { // favor explicit results
existingRes.IsImplicit = isImplicit
existingRes.Warrant = warrant
} else {
// favor explicit results
if existingRes.IsImplicit && !isImplicit {
existingRes.IsImplicit = isImplicit
existingRes.Warrant = warrant
existingRes.Policy = policy
}

existingRes.Policy = existingRes.Policy.Or(policy)
}
}

Expand All @@ -95,13 +103,11 @@ func (rs *ResultSet) Has(objectType string, objectId string, relation string) bo
func (rs *ResultSet) Union(other *ResultSet) *ResultSet {
resultSet := NewResultSet()
for iter := rs.List(); iter != nil; iter = iter.Next() {
resultSet.Add(iter.ObjectType, iter.ObjectId, iter.Relation, iter.Warrant, iter.IsImplicit)
resultSet.Add(iter.ObjectType, iter.ObjectId, iter.Relation, iter.Warrant, iter.Policy, iter.IsImplicit)
}

for iter := other.List(); iter != nil; iter = iter.Next() {
if !resultSet.Has(iter.ObjectType, iter.ObjectId, iter.Relation) || !iter.IsImplicit {
resultSet.Add(iter.ObjectType, iter.ObjectId, iter.Relation, iter.Warrant, iter.IsImplicit)
}
resultSet.Add(iter.ObjectType, iter.ObjectId, iter.Relation, iter.Warrant, iter.Policy, iter.IsImplicit)
}

return resultSet
Expand All @@ -121,11 +127,14 @@ func (rs *ResultSet) Intersect(other *ResultSet) *ResultSet {
for iter := a.List(); iter != nil; iter = iter.Next() {
if b.Has(iter.ObjectType, iter.ObjectId, iter.Relation) {
bRes := b.Get(iter.ObjectType, iter.ObjectId, iter.Relation)
if !bRes.IsImplicit {
result.Add(bRes.ObjectType, bRes.ObjectId, bRes.Relation, bRes.Warrant, bRes.IsImplicit)
} else {
result.Add(iter.ObjectType, iter.ObjectId, iter.Relation, iter.Warrant, iter.IsImplicit)
}
result.Add(
iter.ObjectType,
iter.ObjectId,
iter.Relation,
iter.Warrant,
iter.Policy.And(bRes.Policy),
bRes.IsImplicit || iter.IsImplicit,
)
}
}

Expand All @@ -135,11 +144,14 @@ func (rs *ResultSet) Intersect(other *ResultSet) *ResultSet {
func (rs *ResultSet) String() string {
var strs []string
for iter := rs.List(); iter != nil; iter = iter.Next() {
resStr := fmt.Sprintf("%s => %s", key(iter.ObjectType, iter.ObjectId, iter.Relation), iter.Warrant.String())
if iter.Policy != "" {
resStr += fmt.Sprintf("[%s]", iter.Policy)
}
if iter.IsImplicit {
strs = append(strs, fmt.Sprintf("%s => %s [implicit]", key(iter.ObjectType, iter.ObjectId, iter.Relation), iter.Warrant.String()))
} else {
strs = append(strs, fmt.Sprintf("%s => %s", key(iter.ObjectType, iter.ObjectId, iter.Relation), iter.Warrant.String()))
resStr += "[implicit]"
}
strs = append(strs, resStr)
}

return strings.Join(strs, ", ")
Expand Down
Loading

0 comments on commit c0a9763

Please sign in to comment.