Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Remove intermediate states of the http request conversion to AST #629

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 108 additions & 44 deletions envoyauth/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ package envoyauth

import (
"encoding/binary"
"encoding/json"
"fmt"
"github.com/open-policy-agent/opa/ast"
"io"
"mime"
"mime/multipart"
Expand All @@ -26,66 +26,67 @@ import (
var v2Info = map[string]string{"ext_authz": "v2", "encoding": "encoding/json"}
var v3Info = map[string]string{"ext_authz": "v3", "encoding": "protojson"}

// RequestToInput - Converts a CheckRequest in either protobuf 2 or 3 to an input map
func RequestToInput(req interface{}, logger logging.Logger, protoSet *protoregistry.Files, skipRequestBodyParse bool) (map[string]interface{}, error) {
var err error
var input map[string]interface{}

var bs, rawBody []byte
var path, body string
var headers, version map[string]string

// NOTE: The path/body/headers blocks look silly, but they allow us to retrieve
// the parts of the incoming request we care about, without having to convert
// the entire v2 message into v3. It's nested, each level has a different type,
// etc -- we only care for its JSON representation as fed into evaluation later.
switch req := req.(type) {
// RequestToAstValue - Converts a request to AST representation
func RequestToAstValue(req interface{}, logger logging.Logger, protoSet *protoregistry.Files, skipRequestBodyParse bool) (ast.Value, error) {
var (
headers map[string]string
body string
rawBody []byte
method string
path string
version map[string]string
parsedPath []interface{}
parsedQuery map[string]interface{}
parsedBody interface{}
isBodyTruncated bool
err error
)

// Extract fields based on request type
switch r := req.(type) {
case *ext_authz_v3.CheckRequest:
bs, err = protojson.Marshal(req)
if err != nil {
return nil, err
}
path = req.GetAttributes().GetRequest().GetHttp().GetPath()
body = req.GetAttributes().GetRequest().GetHttp().GetBody()
headers = req.GetAttributes().GetRequest().GetHttp().GetHeaders()
rawBody = req.GetAttributes().GetRequest().GetHttp().GetRawBody()
headers = r.GetAttributes().GetRequest().GetHttp().GetHeaders()
body = r.GetAttributes().GetRequest().GetHttp().GetBody()
rawBody = r.GetAttributes().GetRequest().GetHttp().GetRawBody()
method = r.GetAttributes().GetRequest().GetHttp().GetMethod()
path = r.GetAttributes().GetRequest().GetHttp().GetPath()
version = v3Info
case *ext_authz_v2.CheckRequest:
bs, err = json.Marshal(req)
if err != nil {
return nil, err
}
path = req.GetAttributes().GetRequest().GetHttp().GetPath()
body = req.GetAttributes().GetRequest().GetHttp().GetBody()
headers = req.GetAttributes().GetRequest().GetHttp().GetHeaders()
headers = r.GetAttributes().GetRequest().GetHttp().GetHeaders()
body = r.GetAttributes().GetRequest().GetHttp().GetBody()
method = r.GetAttributes().GetRequest().GetHttp().GetMethod()
path = r.GetAttributes().GetRequest().GetHttp().GetPath()
version = v2Info
default:
return nil, fmt.Errorf("unsupported request type")
}

err = util.UnmarshalJSON(bs, &input)
if err != nil {
return nil, err
}
input["version"] = version

parsedPath, parsedQuery, err := getParsedPathAndQuery(path)
parsedPath, parsedQuery, err = getParsedPathAndQuery(path)
if err != nil {
return nil, err
}

input["parsed_path"] = parsedPath
input["parsed_query"] = parsedQuery

if !skipRequestBodyParse {
parsedBody, isBodyTruncated, err := getParsedBody(logger, headers, body, rawBody, parsedPath, protoSet)
parsedBody, isBodyTruncated, err = getParsedBody(logger, headers, body, rawBody, parsedPath, protoSet)
if err != nil {
return nil, err
}
}

astObject := ast.NewObject()
createRequestHTTP(astObject, headers, method, version, parsedBody, isBodyTruncated, skipRequestBodyParse)

err = createAstParsedPath(astObject, parsedPath)
if err != nil {
return nil, err
}

input["parsed_body"] = parsedBody
input["truncated_body"] = isBodyTruncated
err = createAstParsedQuery(astObject, parsedQuery)
if err != nil {
return nil, err
}

return input, nil
return astObject, nil
}

func getParsedPathAndQuery(path string) ([]interface{}, map[string]interface{}, error) {
Expand All @@ -112,6 +113,69 @@ func getParsedPathAndQuery(path string) ([]interface{}, map[string]interface{},
return parsedPathInterface, parsedQueryInterface, nil
}

func createAstParsedPath(astObj ast.Object, parsedPath []interface{}) error {
astTerms := make([]*ast.Term, len(parsedPath))
for i, segment := range parsedPath {
term, err := ast.InterfaceToValue(segment)
if err != nil {
return fmt.Errorf("failed to convert parsed path to AST at index %d: %w", i, err)
}
astTerms[i] = ast.NewTerm(term)
}

astArray := ast.NewArray(astTerms...)
astTerm := ast.NewTerm(astArray)

astObj.Insert(ast.StringTerm("parsed_path"), astTerm)
return nil
}

func createAstParsedQuery(astObj ast.Object, parsedQuery map[string]interface{}) error {
kvs := make([][2]*ast.Term, 0, len(parsedQuery)*2)

for key, value := range parsedQuery {
termKey, err := ast.InterfaceToValue(key)
if err != nil {
return fmt.Errorf("failed to convert query param key to AST: %w", err)
}

termValue, err := ast.InterfaceToValue(value)
if err != nil {
return fmt.Errorf("failed to convert query param value to AST: %w", err)
}

kvs = append(kvs, [2]*ast.Term{ast.NewTerm(termKey), ast.NewTerm(termValue)})
}

astObject := ast.NewObject(kvs...)
astTerm := ast.NewTerm(astObject)

astObj.Insert(ast.StringTerm("parsed_query"), astTerm)

return nil
}

func createRequestHTTP(astObj ast.Object, headers map[string]string, method string, version map[string]string, parsedBody interface{}, isBodyTruncated bool, skipBody bool) {
httpObj := ast.NewObject()
addAstField(httpObj, "headers", headers)
addAstField(httpObj, "method", method)
addAstField(httpObj, "version", version)

if !skipBody {
addAstField(httpObj, "parsed_body", parsedBody)
addAstField(httpObj, "truncated_body", isBodyTruncated)
}

astObj.Insert(ast.StringTerm("request"), ast.NewTerm(httpObj))
}

func addAstField(astObject ast.Object, key string, value interface{}) {
astValue, err := ast.InterfaceToValue(value)
if err == nil {
astObject.Insert(ast.StringTerm(key), ast.NewTerm(astValue))
}
}

func getParsedBody(logger logging.Logger, headers map[string]string, body string, rawBody []byte, parsedPath []interface{}, protoSet *protoregistry.Files) (interface{}, bool, error) {
var data interface{}

Expand Down
32 changes: 1 addition & 31 deletions internal/internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -433,38 +433,8 @@ func (p *envoyExtAuthzGrpcServer) check(ctx context.Context, req interface{}) (*
return nil
}

input, err = envoyauth.RequestToInput(req, logger, p.cfg.protoSet, p.cfg.SkipRequestBodyParse)
if err != nil {
internalErr = newInternalError(RequestParseErr, err)
return &ext_authz_v3.CheckResponse{
Status: &rpc_status.Status{
Code: int32(code.Code_PERMISSION_DENIED),
Message: internalErr.Error(),
},
HttpResponse: &ext_authz_v3.CheckResponse_DeniedResponse{
DeniedResponse: &ext_authz_v3.DeniedHttpResponse{
Status: &ext_type_v3.HttpStatus{
Code: ext_type_v3.StatusCode(ext_type_v3.StatusCode_BadRequest),
},
Body: internalErr.Error(),
},
},
DynamicMetadata: nil,
}, stop, nil
}

if ctx.Err() != nil {
err = errors.Wrap(ctx.Err(), "check request timed out before query execution")
internalErr = newInternalError(CheckRequestTimeoutErr, err)
return nil, stop, internalErr
}

var inputValue ast.Value
inputValue, err = ast.InterfaceToValue(input)
if err != nil {
internalErr = newInternalError(InputParseErr, err)
return nil, stop, internalErr
}
inputValue, err = envoyauth.RequestToAstValue(req, logger, p.cfg.protoSet, p.cfg.SkipRequestBodyParse)

if err = envoyauth.Eval(ctx, p, inputValue, result); err != nil {
evalErr = err
Expand Down