From 22515a97091ac66aae3c1b194bf60a2288839ced Mon Sep 17 00:00:00 2001 From: Pushpalanka Jayawardhana Date: Sun, 12 Jan 2025 22:22:28 +0100 Subject: [PATCH] initial direct conversion from req to ast Signed-off-by: Pushpalanka Jayawardhana --- envoyauth/request.go | 169 ++++++++++++++++++++++++++++--------------- internal/internal.go | 32 +------- 2 files changed, 111 insertions(+), 90 deletions(-) diff --git a/envoyauth/request.go b/envoyauth/request.go index 4e29291a6..cf01ba6fa 100644 --- a/envoyauth/request.go +++ b/envoyauth/request.go @@ -3,6 +3,7 @@ package envoyauth import ( "encoding/binary" "fmt" + "github.com/open-policy-agent/opa/ast" "io" "mime" "mime/multipart" @@ -25,80 +26,67 @@ import ( var v2Info = map[string]string{"ext_authz": "v2", "encoding": "encoding/json"} var v3Info = map[string]string{"ext_authz": "v3", "encoding": "protojson"} -// RequestToInput - Converts a CheckRequest in either protobuf 2 or 3 to an input map -func RequestToInput(req interface{}, logger logging.Logger, protoSet *protoregistry.Files, skipRequestBodyParse bool) (map[string]interface{}, error) { - var err error - var input = make(map[string]interface{}) - - var rawBody []byte - var path, body, method, host string - var headers map[string]string - var version map[string]string - - switch req := req.(type) { +// RequestToAstValue - Converts a request to AST representation +func RequestToAstValue(req interface{}, logger logging.Logger, protoSet *protoregistry.Files, skipRequestBodyParse bool) (ast.Value, error) { + var ( + headers map[string]string + body string + rawBody []byte + method string + path string + version map[string]string + parsedPath []interface{} + parsedQuery map[string]interface{} + parsedBody interface{} + isBodyTruncated bool + err error + ) + + // Extract fields based on request type + switch r := req.(type) { case *ext_authz_v3.CheckRequest: - attrs := req.GetAttributes() - if attrs == nil || attrs.GetRequest() == nil || attrs.GetRequest().GetHttp() == nil { - return nil, fmt.Errorf("missing required attributes in v3 CheckRequest") - } - httpReq := attrs.GetRequest().GetHttp() - - path = httpReq.GetPath() - body = httpReq.GetBody() - headers = httpReq.GetHeaders() - rawBody = httpReq.GetRawBody() - method = httpReq.GetMethod() - host = httpReq.GetHost() + headers = r.GetAttributes().GetRequest().GetHttp().GetHeaders() + body = r.GetAttributes().GetRequest().GetHttp().GetBody() + rawBody = r.GetAttributes().GetRequest().GetHttp().GetRawBody() + method = r.GetAttributes().GetRequest().GetHttp().GetMethod() + path = r.GetAttributes().GetRequest().GetHttp().GetPath() version = v3Info - case *ext_authz_v2.CheckRequest: - attrs := req.GetAttributes() - if attrs == nil || attrs.GetRequest() == nil || attrs.GetRequest().GetHttp() == nil { - return nil, fmt.Errorf("missing required attributes in v2 CheckRequest") - } - httpReq := attrs.GetRequest().GetHttp() - - path = httpReq.GetPath() - body = httpReq.GetBody() - headers = httpReq.GetHeaders() - method = httpReq.GetMethod() - host = httpReq.GetHost() + headers = r.GetAttributes().GetRequest().GetHttp().GetHeaders() + body = r.GetAttributes().GetRequest().GetHttp().GetBody() + method = r.GetAttributes().GetRequest().GetHttp().GetMethod() + path = r.GetAttributes().GetRequest().GetHttp().GetPath() version = v2Info - default: - return nil, fmt.Errorf("unsupported request type: %T", req) - } - - input["attributes"] = map[string]interface{}{ - "request": map[string]interface{}{ - "http": map[string]interface{}{ - "path": path, - "body": body, - "headers": headers, - "method": method, - "host": host, - }, - }, + return nil, fmt.Errorf("unsupported request type") } - input["version"] = version - parsedPath, parsedQuery, err := getParsedPathAndQuery(path) + parsedPath, parsedQuery, err = getParsedPathAndQuery(path) if err != nil { - return nil, fmt.Errorf("error parsing path and query: %w", err) + return nil, err } - input["parsed_path"] = parsedPath - input["parsed_query"] = parsedQuery if !skipRequestBodyParse { - parsedBody, isBodyTruncated, err := getParsedBody(logger, headers, body, rawBody, parsedPath, protoSet) + parsedBody, isBodyTruncated, err = getParsedBody(logger, headers, body, rawBody, parsedPath, protoSet) if err != nil { - return nil, fmt.Errorf("error parsing request body: %w", err) + return nil, err } - input["parsed_body"] = parsedBody - input["truncated_body"] = isBodyTruncated } - return input, nil + astObject := ast.NewObject() + createRequestHTTP(astObject, headers, method, version, parsedBody, isBodyTruncated, skipRequestBodyParse) + + err = createAstParsedPath(astObject, parsedPath) + if err != nil { + return nil, err + } + + err = createAstParsedQuery(astObject, parsedQuery) + if err != nil { + return nil, err + } + + return astObject, nil } func getParsedPathAndQuery(path string) ([]interface{}, map[string]interface{}, error) { @@ -125,6 +113,69 @@ func getParsedPathAndQuery(path string) ([]interface{}, map[string]interface{}, return parsedPathInterface, parsedQueryInterface, nil } +func createAstParsedPath(astObj ast.Object, parsedPath []interface{}) error { + astTerms := make([]*ast.Term, len(parsedPath)) + for i, segment := range parsedPath { + term, err := ast.InterfaceToValue(segment) + if err != nil { + return fmt.Errorf("failed to convert parsed path to AST at index %d: %w", i, err) + } + astTerms[i] = ast.NewTerm(term) + } + + astArray := ast.NewArray(astTerms...) + astTerm := ast.NewTerm(astArray) + + astObj.Insert(ast.StringTerm("parsed_path"), astTerm) + return nil +} + +func createAstParsedQuery(astObj ast.Object, parsedQuery map[string]interface{}) error { + kvs := make([][2]*ast.Term, 0, len(parsedQuery)*2) + + for key, value := range parsedQuery { + termKey, err := ast.InterfaceToValue(key) + if err != nil { + return fmt.Errorf("failed to convert query param key to AST: %w", err) + } + + termValue, err := ast.InterfaceToValue(value) + if err != nil { + return fmt.Errorf("failed to convert query param value to AST: %w", err) + } + + kvs = append(kvs, [2]*ast.Term{ast.NewTerm(termKey), ast.NewTerm(termValue)}) + } + + astObject := ast.NewObject(kvs...) + astTerm := ast.NewTerm(astObject) + + astObj.Insert(ast.StringTerm("parsed_query"), astTerm) + + return nil +} + +func createRequestHTTP(astObj ast.Object, headers map[string]string, method string, version map[string]string, parsedBody interface{}, isBodyTruncated bool, skipBody bool) { + httpObj := ast.NewObject() + addAstField(httpObj, "headers", headers) + addAstField(httpObj, "method", method) + addAstField(httpObj, "version", version) + + if !skipBody { + addAstField(httpObj, "parsed_body", parsedBody) + addAstField(httpObj, "truncated_body", isBodyTruncated) + } + + astObj.Insert(ast.StringTerm("request"), ast.NewTerm(httpObj)) +} + +func addAstField(astObject ast.Object, key string, value interface{}) { + astValue, err := ast.InterfaceToValue(value) + if err == nil { + astObject.Insert(ast.StringTerm(key), ast.NewTerm(astValue)) + } +} + func getParsedBody(logger logging.Logger, headers map[string]string, body string, rawBody []byte, parsedPath []interface{}, protoSet *protoregistry.Files) (interface{}, bool, error) { var data interface{} diff --git a/internal/internal.go b/internal/internal.go index 0fcbaf91a..3dbc38d07 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -433,38 +433,8 @@ func (p *envoyExtAuthzGrpcServer) check(ctx context.Context, req interface{}) (* return nil } - input, err = envoyauth.RequestToInput(req, logger, p.cfg.protoSet, p.cfg.SkipRequestBodyParse) - if err != nil { - internalErr = newInternalError(RequestParseErr, err) - return &ext_authz_v3.CheckResponse{ - Status: &rpc_status.Status{ - Code: int32(code.Code_PERMISSION_DENIED), - Message: internalErr.Error(), - }, - HttpResponse: &ext_authz_v3.CheckResponse_DeniedResponse{ - DeniedResponse: &ext_authz_v3.DeniedHttpResponse{ - Status: &ext_type_v3.HttpStatus{ - Code: ext_type_v3.StatusCode(ext_type_v3.StatusCode_BadRequest), - }, - Body: internalErr.Error(), - }, - }, - DynamicMetadata: nil, - }, stop, nil - } - - if ctx.Err() != nil { - err = errors.Wrap(ctx.Err(), "check request timed out before query execution") - internalErr = newInternalError(CheckRequestTimeoutErr, err) - return nil, stop, internalErr - } - var inputValue ast.Value - inputValue, err = ast.InterfaceToValue(input) - if err != nil { - internalErr = newInternalError(InputParseErr, err) - return nil, stop, internalErr - } + inputValue, err = envoyauth.RequestToAstValue(req, logger, p.cfg.protoSet, p.cfg.SkipRequestBodyParse) if err = envoyauth.Eval(ctx, p, inputValue, result); err != nil { evalErr = err