Skip to content

Commit

Permalink
initial direct conversion from req to ast
Browse files Browse the repository at this point in the history
Signed-off-by: Pushpalanka Jayawardhana <[email protected]>
  • Loading branch information
Pushpalanka committed Jan 12, 2025
1 parent 1a1d372 commit 22515a9
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 90 deletions.
169 changes: 110 additions & 59 deletions envoyauth/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package envoyauth
import (
"encoding/binary"
"fmt"
"github.com/open-policy-agent/opa/ast"
"io"
"mime"
"mime/multipart"
Expand All @@ -25,80 +26,67 @@ import (
var v2Info = map[string]string{"ext_authz": "v2", "encoding": "encoding/json"}
var v3Info = map[string]string{"ext_authz": "v3", "encoding": "protojson"}

// RequestToInput - Converts a CheckRequest in either protobuf 2 or 3 to an input map
func RequestToInput(req interface{}, logger logging.Logger, protoSet *protoregistry.Files, skipRequestBodyParse bool) (map[string]interface{}, error) {
var err error
var input = make(map[string]interface{})

var rawBody []byte
var path, body, method, host string
var headers map[string]string
var version map[string]string

switch req := req.(type) {
// RequestToAstValue - Converts a request to AST representation
func RequestToAstValue(req interface{}, logger logging.Logger, protoSet *protoregistry.Files, skipRequestBodyParse bool) (ast.Value, error) {
var (
headers map[string]string
body string
rawBody []byte
method string
path string
version map[string]string
parsedPath []interface{}
parsedQuery map[string]interface{}
parsedBody interface{}
isBodyTruncated bool
err error
)

// Extract fields based on request type
switch r := req.(type) {
case *ext_authz_v3.CheckRequest:
attrs := req.GetAttributes()
if attrs == nil || attrs.GetRequest() == nil || attrs.GetRequest().GetHttp() == nil {
return nil, fmt.Errorf("missing required attributes in v3 CheckRequest")
}
httpReq := attrs.GetRequest().GetHttp()

path = httpReq.GetPath()
body = httpReq.GetBody()
headers = httpReq.GetHeaders()
rawBody = httpReq.GetRawBody()
method = httpReq.GetMethod()
host = httpReq.GetHost()
headers = r.GetAttributes().GetRequest().GetHttp().GetHeaders()
body = r.GetAttributes().GetRequest().GetHttp().GetBody()
rawBody = r.GetAttributes().GetRequest().GetHttp().GetRawBody()
method = r.GetAttributes().GetRequest().GetHttp().GetMethod()
path = r.GetAttributes().GetRequest().GetHttp().GetPath()
version = v3Info

case *ext_authz_v2.CheckRequest:
attrs := req.GetAttributes()
if attrs == nil || attrs.GetRequest() == nil || attrs.GetRequest().GetHttp() == nil {
return nil, fmt.Errorf("missing required attributes in v2 CheckRequest")
}
httpReq := attrs.GetRequest().GetHttp()

path = httpReq.GetPath()
body = httpReq.GetBody()
headers = httpReq.GetHeaders()
method = httpReq.GetMethod()
host = httpReq.GetHost()
headers = r.GetAttributes().GetRequest().GetHttp().GetHeaders()
body = r.GetAttributes().GetRequest().GetHttp().GetBody()
method = r.GetAttributes().GetRequest().GetHttp().GetMethod()
path = r.GetAttributes().GetRequest().GetHttp().GetPath()
version = v2Info

default:
return nil, fmt.Errorf("unsupported request type: %T", req)
}

input["attributes"] = map[string]interface{}{
"request": map[string]interface{}{
"http": map[string]interface{}{
"path": path,
"body": body,
"headers": headers,
"method": method,
"host": host,
},
},
return nil, fmt.Errorf("unsupported request type")
}
input["version"] = version

parsedPath, parsedQuery, err := getParsedPathAndQuery(path)
parsedPath, parsedQuery, err = getParsedPathAndQuery(path)
if err != nil {
return nil, fmt.Errorf("error parsing path and query: %w", err)
return nil, err
}
input["parsed_path"] = parsedPath
input["parsed_query"] = parsedQuery

if !skipRequestBodyParse {
parsedBody, isBodyTruncated, err := getParsedBody(logger, headers, body, rawBody, parsedPath, protoSet)
parsedBody, isBodyTruncated, err = getParsedBody(logger, headers, body, rawBody, parsedPath, protoSet)
if err != nil {
return nil, fmt.Errorf("error parsing request body: %w", err)
return nil, err
}
input["parsed_body"] = parsedBody
input["truncated_body"] = isBodyTruncated
}

return input, nil
astObject := ast.NewObject()
createRequestHTTP(astObject, headers, method, version, parsedBody, isBodyTruncated, skipRequestBodyParse)

err = createAstParsedPath(astObject, parsedPath)
if err != nil {
return nil, err
}

err = createAstParsedQuery(astObject, parsedQuery)
if err != nil {
return nil, err
}

return astObject, nil
}

func getParsedPathAndQuery(path string) ([]interface{}, map[string]interface{}, error) {
Expand All @@ -125,6 +113,69 @@ func getParsedPathAndQuery(path string) ([]interface{}, map[string]interface{},
return parsedPathInterface, parsedQueryInterface, nil
}

func createAstParsedPath(astObj ast.Object, parsedPath []interface{}) error {
astTerms := make([]*ast.Term, len(parsedPath))
for i, segment := range parsedPath {
term, err := ast.InterfaceToValue(segment)
if err != nil {
return fmt.Errorf("failed to convert parsed path to AST at index %d: %w", i, err)
}
astTerms[i] = ast.NewTerm(term)
}

astArray := ast.NewArray(astTerms...)
astTerm := ast.NewTerm(astArray)

astObj.Insert(ast.StringTerm("parsed_path"), astTerm)
return nil
}

func createAstParsedQuery(astObj ast.Object, parsedQuery map[string]interface{}) error {
kvs := make([][2]*ast.Term, 0, len(parsedQuery)*2)

for key, value := range parsedQuery {
termKey, err := ast.InterfaceToValue(key)
if err != nil {
return fmt.Errorf("failed to convert query param key to AST: %w", err)
}

termValue, err := ast.InterfaceToValue(value)
if err != nil {
return fmt.Errorf("failed to convert query param value to AST: %w", err)
}

kvs = append(kvs, [2]*ast.Term{ast.NewTerm(termKey), ast.NewTerm(termValue)})
}

astObject := ast.NewObject(kvs...)
astTerm := ast.NewTerm(astObject)

astObj.Insert(ast.StringTerm("parsed_query"), astTerm)

return nil
}

func createRequestHTTP(astObj ast.Object, headers map[string]string, method string, version map[string]string, parsedBody interface{}, isBodyTruncated bool, skipBody bool) {
httpObj := ast.NewObject()
addAstField(httpObj, "headers", headers)
addAstField(httpObj, "method", method)
addAstField(httpObj, "version", version)

if !skipBody {
addAstField(httpObj, "parsed_body", parsedBody)
addAstField(httpObj, "truncated_body", isBodyTruncated)
}

astObj.Insert(ast.StringTerm("request"), ast.NewTerm(httpObj))
}

func addAstField(astObject ast.Object, key string, value interface{}) {
astValue, err := ast.InterfaceToValue(value)
if err == nil {
astObject.Insert(ast.StringTerm(key), ast.NewTerm(astValue))
}
}

func getParsedBody(logger logging.Logger, headers map[string]string, body string, rawBody []byte, parsedPath []interface{}, protoSet *protoregistry.Files) (interface{}, bool, error) {
var data interface{}

Expand Down
32 changes: 1 addition & 31 deletions internal/internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -433,38 +433,8 @@ func (p *envoyExtAuthzGrpcServer) check(ctx context.Context, req interface{}) (*
return nil
}

input, err = envoyauth.RequestToInput(req, logger, p.cfg.protoSet, p.cfg.SkipRequestBodyParse)
if err != nil {
internalErr = newInternalError(RequestParseErr, err)
return &ext_authz_v3.CheckResponse{
Status: &rpc_status.Status{
Code: int32(code.Code_PERMISSION_DENIED),
Message: internalErr.Error(),
},
HttpResponse: &ext_authz_v3.CheckResponse_DeniedResponse{
DeniedResponse: &ext_authz_v3.DeniedHttpResponse{
Status: &ext_type_v3.HttpStatus{
Code: ext_type_v3.StatusCode(ext_type_v3.StatusCode_BadRequest),
},
Body: internalErr.Error(),
},
},
DynamicMetadata: nil,
}, stop, nil
}

if ctx.Err() != nil {
err = errors.Wrap(ctx.Err(), "check request timed out before query execution")
internalErr = newInternalError(CheckRequestTimeoutErr, err)
return nil, stop, internalErr
}

var inputValue ast.Value
inputValue, err = ast.InterfaceToValue(input)
if err != nil {
internalErr = newInternalError(InputParseErr, err)
return nil, stop, internalErr
}
inputValue, err = envoyauth.RequestToAstValue(req, logger, p.cfg.protoSet, p.cfg.SkipRequestBodyParse)

if err = envoyauth.Eval(ctx, p, inputValue, result); err != nil {
evalErr = err
Expand Down

0 comments on commit 22515a9

Please sign in to comment.