From cc9b991c0798b86a794621066c2ef78b62a0f18c Mon Sep 17 00:00:00 2001 From: Karol Nowak Date: Fri, 15 Dec 2023 11:57:27 +0100 Subject: [PATCH] chore: cr fixes --- batchMiddleware.go | 36 ++++-------------------------------- 1 file changed, 4 insertions(+), 32 deletions(-) diff --git a/batchMiddleware.go b/batchMiddleware.go index cac6af1..2af6230 100644 --- a/batchMiddleware.go +++ b/batchMiddleware.go @@ -15,18 +15,14 @@ const ( func BatchMiddleware( cfg *struct { SameOperationsThreshold int `inject:"config:graphql.batchMiddleware.sameOperationsThreshold,optional"` - AllOperationsThreshold int `inject:"config:graphql.batchMiddleware.sameOperationsThreshold,optional"` + AllOperationsThreshold int `inject:"config:graphql.batchMiddleware.allOperationsThreshold,optional"` }, ) func(ctx context.Context, next gql.OperationHandler) gql.ResponseHandler { return func(ctx context.Context, next gql.OperationHandler) gql.ResponseHandler { - var sameOperationsThreshold int + sameOperationsThreshold := sameOperationsDefaultThreshold + allOperationsThreshold := allOperationsDefaultThreshold - var allOperationsThreshold int - - if cfg == nil { - sameOperationsThreshold = sameOperationsDefaultThreshold - allOperationsThreshold = allOperationsDefaultThreshold - } else { + if cfg != nil { sameOperationsThreshold = cfg.SameOperationsThreshold allOperationsThreshold = cfg.AllOperationsThreshold } @@ -35,8 +31,6 @@ func BatchMiddleware( occurrences := countTopLevelGraphQLOperations(req.Operation.SelectionSet) - countGraphqlFunctionsCalled(occurrences, req.Operation.SelectionSet) - if isAboveThreshold(sameOperationsThreshold, allOperationsThreshold, occurrences) { return func(ctx context.Context) *gql.Response { return gql.ErrorResponse(ctx, "request not allowed") @@ -66,28 +60,6 @@ func countTopLevelGraphQLOperations(definition []ast.Selection) map[string]int { return mapOfOccurrences } -func countGraphqlFunctionsCalled(mapOfOccurrences map[string]int, definition []ast.Selection) { - for _, set := range definition { - field, ok := set.(*ast.Field) - if !ok { - continue - } - - // counting arguments is the only way to tell if this field is a function call - if len(field.Arguments) != 0 { - if _, exists := mapOfOccurrences[field.Name]; !exists { - mapOfOccurrences[field.Name] = 0 - } - - mapOfOccurrences[field.Name]++ - } - - if len(field.SelectionSet) > 0 { - countGraphqlFunctionsCalled(mapOfOccurrences, field.SelectionSet) - } - } -} - func isAboveThreshold(threshold, operationsThreshold int, operations map[string]int) bool { if len(operations) > operationsThreshold { return true