From 83ba3bafe1973076c49ce096caeb2525811d8b1b Mon Sep 17 00:00:00 2001 From: Aleksandr Tretiakov Date: Tue, 23 Jul 2024 16:22:52 +0300 Subject: [PATCH 1/4] Added functions nesting analyzer --- .../validation/functions_nesting_analyzer.go | 72 +++++++++++ .../functions_nesting_analyzer_test.go | 115 ++++++++++++++++++ 2 files changed, 187 insertions(+) create mode 100644 shield/internal/validation/functions_nesting_analyzer.go create mode 100644 shield/internal/validation/functions_nesting_analyzer_test.go diff --git a/shield/internal/validation/functions_nesting_analyzer.go b/shield/internal/validation/functions_nesting_analyzer.go new file mode 100644 index 000000000..5dde50e7b --- /dev/null +++ b/shield/internal/validation/functions_nesting_analyzer.go @@ -0,0 +1,72 @@ +package validation + +import ( + "context" + "github.com/warden-protocol/wardenprotocol/shield/ast" +) + +const MAX_DEPTH = 100 // TODO AT: Move to params or configuration? + +func AnalyzeFunctionsNesting(ctx context.Context, node *ast.Expression, depth int) (int, error) { + switch n := node.Value.(type) { + case *ast.Expression_Identifier: + return depth, nil + case *ast.Expression_ArrayLiteral: + newDepth, err := analyzeElements(ctx, n.ArrayLiteral.Elements, depth) + return newDepth, err + case *ast.Expression_CallExpression: + newDepth, err := analyzeCallExpression(ctx, n.CallExpression, depth) + return newDepth, err + case *ast.Expression_PrefixExpression: + newDepth, err := analyzePrefixExpression(ctx, n.PrefixExpression, depth) + return newDepth, err + case *ast.Expression_InfixExpression: + newDepth, err := analyzeInfixExpression(ctx, n.InfixExpression, depth) + return newDepth, err + default: + return depth, nil + } +} + +func analyzeElements(ctx context.Context, elements []*ast.Expression, depth int) (int, error) { + var currentMaxDepth = depth + var possibleMaxDepth = depth + for _, elem := range elements { + var err error + possibleMaxDepth, err = AnalyzeFunctionsNesting(ctx, elem, depth) + if err != nil { + return depth, err + } + + currentMaxDepth = max(currentMaxDepth, possibleMaxDepth) + } + return currentMaxDepth, nil +} + +func analyzePrefixExpression(ctx context.Context, prefix *ast.PrefixExpression, depth int) (int, error) { + var err error + var newMaxDepth int + newMaxDepth, err = AnalyzeFunctionsNesting(ctx, prefix.Right, depth) + return newMaxDepth, err +} + +func analyzeInfixExpression(ctx context.Context, infix *ast.InfixExpression, depth int) (int, error) { + var err error + var maxDepthLeft int + maxDepthLeft, err = AnalyzeFunctionsNesting(ctx, infix.Left, depth) + if err != nil { + return depth, err + } + + var maxDepthRight int + maxDepthRight, err = AnalyzeFunctionsNesting(ctx, infix.Right, depth) + if err != nil { + return depth, err + } + + return max(maxDepthLeft, maxDepthRight), nil +} + +func analyzeCallExpression(ctx context.Context, call *ast.CallExpression, depth int) (int, error) { + return analyzeElements(ctx, call.Arguments, depth+1) +} diff --git a/shield/internal/validation/functions_nesting_analyzer_test.go b/shield/internal/validation/functions_nesting_analyzer_test.go new file mode 100644 index 000000000..f8237bcaf --- /dev/null +++ b/shield/internal/validation/functions_nesting_analyzer_test.go @@ -0,0 +1,115 @@ +package validation + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "github.com/warden-protocol/wardenprotocol/shield/ast" + "github.com/warden-protocol/wardenprotocol/shield/internal/lexer" + "github.com/warden-protocol/wardenprotocol/shield/internal/parser" +) + +func parseExpression(t *testing.T, input string) *ast.Expression { + l := lexer.New(input) + p := parser.New(l) + expression := p.Parse() + + err := p.Errors() + if len(err) != 0 { + require.FailNow(t, "Parser finished with errors", err) + } + + require.NotNil(t, expression) + return expression +} + +func TestAnalyzer(t *testing.T) { + testCases := []struct { + expression string + expectedDepth int + }{ + {"foo1(foo2([foo2(), 11, 12]), false || true, [false, [10, 11, 12]])", 3}, + {"foo2() && foo1(foo2(), false || true, [false, [10, 11, 12]])", 2}, + {"foo2() ", 1}, + {"true", 0}, + } + + ctx := context.Background() + + for _, tc := range testCases { + expression := parseExpression(t, tc.expression) + + maxDepth, err := AnalyzeFunctionsNesting(ctx, expression, 0) + if err != nil { + t.Error(err) + } + + require.Equal(t, tc.expectedDepth, maxDepth) + } +} + +// +//func TestPreprocessElements(t *testing.T) { +// ctx := context.Background() +// expression := parseExpression(t, "[10, 11, 12]") +// expander := NoopExpander{} +// +// proc, err := Preprocess(ctx, expression, expander) +// if err != nil { +// t.Error(err) +// } +// +// arr := proc.GetArrayLiteral() +// require.Len(t, arr.Elements, 3) +// +// for i, arrElem := range arr.Elements { +// intVal := arrElem.GetIntegerLiteral() +// require.Equal(t, intVal.Value, big.NewInt(int64(10+i)).String()) +// } +//} +// +//func TestPreprocessInfixExpression(t *testing.T) { +// ctx := context.Background() +// expression := parseExpression(t, "false || true && false") +// expander := NoopExpander{} +// +// proc, err := Preprocess(ctx, expression, expander) +// if err != nil { +// t.Error(err) +// } +// +// inf1 := proc.GetInfixExpression() +// require.Equal(t, inf1.Operator, "||") +// +// inf1Left := inf1.Left.GetBooleanLiteral() +// require.Equal(t, inf1Left.Value, false) +// +// inf2 := inf1.Right.GetInfixExpression() +// require.Equal(t, inf2.Operator, "&&") +// +// inf2Left := inf2.Left.GetBooleanLiteral() +// require.Equal(t, inf2Left.Value, true) +// inf2Right := inf2.Right.GetBooleanLiteral() +// require.Equal(t, inf2Right.Value, false) +//} +// +//func TestPreprocessCallExpression(t *testing.T) { +// ctx := context.Background() +// expression := parseExpression(t, "foo1(123, foo2(235))") +// expander := NoopExpander{} +// +// proc, err := Preprocess(ctx, expression, expander) +// if err != nil { +// t.Error(err) +// } +// +// call1 := proc.GetCallExpression() +// require.Equal(t, call1.Function.Value, "foo1") +// require.Equal(t, len(call1.Arguments), 2) +// require.Equal(t, call1.Arguments[0].GetIntegerLiteral().Value, big.NewInt(int64(123)).String()) +// +// call2 := call1.Arguments[1].GetCallExpression() +// require.Equal(t, len(call2.Arguments), 1) +// require.Equal(t, call2.Arguments[0].GetIntegerLiteral().Value, big.NewInt(int64(235)).String()) +//} From 52ce5879dc5a1f66a1ac4b63df4490aa7d55a4b3 Mon Sep 17 00:00:00 2001 From: Aleksandr Tretiakov Date: Wed, 24 Jul 2024 10:08:52 +0300 Subject: [PATCH 2/4] Added validation --- .../validation/functions_nesting_analyzer.go | 63 ++++++------------- shield/internal/validation/validator.go | 16 +++++ shield/shield.go | 8 +++ 3 files changed, 43 insertions(+), 44 deletions(-) create mode 100644 shield/internal/validation/validator.go diff --git a/shield/internal/validation/functions_nesting_analyzer.go b/shield/internal/validation/functions_nesting_analyzer.go index 5dde50e7b..fb8f31a88 100644 --- a/shield/internal/validation/functions_nesting_analyzer.go +++ b/shield/internal/validation/functions_nesting_analyzer.go @@ -1,72 +1,47 @@ package validation import ( - "context" "github.com/warden-protocol/wardenprotocol/shield/ast" ) -const MAX_DEPTH = 100 // TODO AT: Move to params or configuration? - -func AnalyzeFunctionsNesting(ctx context.Context, node *ast.Expression, depth int) (int, error) { +func AnalyzeFunctionsNesting(node *ast.Expression, depth int) int { switch n := node.Value.(type) { case *ast.Expression_Identifier: - return depth, nil + return depth case *ast.Expression_ArrayLiteral: - newDepth, err := analyzeElements(ctx, n.ArrayLiteral.Elements, depth) - return newDepth, err + return analyzeElements(n.ArrayLiteral.Elements, depth) case *ast.Expression_CallExpression: - newDepth, err := analyzeCallExpression(ctx, n.CallExpression, depth) - return newDepth, err + return analyzeCallExpression(n.CallExpression, depth) case *ast.Expression_PrefixExpression: - newDepth, err := analyzePrefixExpression(ctx, n.PrefixExpression, depth) - return newDepth, err + return analyzePrefixExpression(n.PrefixExpression, depth) case *ast.Expression_InfixExpression: - newDepth, err := analyzeInfixExpression(ctx, n.InfixExpression, depth) - return newDepth, err + return analyzeInfixExpression(n.InfixExpression, depth) default: - return depth, nil + return depth } } -func analyzeElements(ctx context.Context, elements []*ast.Expression, depth int) (int, error) { +func analyzeElements(elements []*ast.Expression, depth int) int { var currentMaxDepth = depth - var possibleMaxDepth = depth for _, elem := range elements { - var err error - possibleMaxDepth, err = AnalyzeFunctionsNesting(ctx, elem, depth) - if err != nil { - return depth, err - } - + possibleMaxDepth := AnalyzeFunctionsNesting(elem, depth) currentMaxDepth = max(currentMaxDepth, possibleMaxDepth) } - return currentMaxDepth, nil + return currentMaxDepth } -func analyzePrefixExpression(ctx context.Context, prefix *ast.PrefixExpression, depth int) (int, error) { - var err error - var newMaxDepth int - newMaxDepth, err = AnalyzeFunctionsNesting(ctx, prefix.Right, depth) - return newMaxDepth, err +func analyzePrefixExpression(prefix *ast.PrefixExpression, depth int) int { + newMaxDepth := AnalyzeFunctionsNesting(prefix.Right, depth) + return newMaxDepth } -func analyzeInfixExpression(ctx context.Context, infix *ast.InfixExpression, depth int) (int, error) { - var err error - var maxDepthLeft int - maxDepthLeft, err = AnalyzeFunctionsNesting(ctx, infix.Left, depth) - if err != nil { - return depth, err - } - - var maxDepthRight int - maxDepthRight, err = AnalyzeFunctionsNesting(ctx, infix.Right, depth) - if err != nil { - return depth, err - } +func analyzeInfixExpression(infix *ast.InfixExpression, depth int) int { + maxDepthLeft := AnalyzeFunctionsNesting(infix.Left, depth) + maxDepthRight := AnalyzeFunctionsNesting(infix.Right, depth) - return max(maxDepthLeft, maxDepthRight), nil + return max(maxDepthLeft, maxDepthRight) } -func analyzeCallExpression(ctx context.Context, call *ast.CallExpression, depth int) (int, error) { - return analyzeElements(ctx, call.Arguments, depth+1) +func analyzeCallExpression(call *ast.CallExpression, depth int) int { + return analyzeElements(call.Arguments, depth+1) } diff --git a/shield/internal/validation/validator.go b/shield/internal/validation/validator.go new file mode 100644 index 000000000..7e44c49b3 --- /dev/null +++ b/shield/internal/validation/validator.go @@ -0,0 +1,16 @@ +package validation + +import ( + "fmt" + "github.com/warden-protocol/wardenprotocol/shield/ast" +) + +func Validate(root *ast.Expression, maxNestingDepth int) error { + maxDepth := AnalyzeFunctionsNesting(root, 0) + + if maxDepth > maxNestingDepth { + return fmt.Errorf("max allowed functions nesting depth is %d. Got %d", maxNestingDepth, maxDepth) + } + + return nil +} diff --git a/shield/shield.go b/shield/shield.go index 032b44da5..8bc5f0a2d 100644 --- a/shield/shield.go +++ b/shield/shield.go @@ -11,11 +11,15 @@ import ( "github.com/warden-protocol/wardenprotocol/shield/internal/metadata" "github.com/warden-protocol/wardenprotocol/shield/internal/parser" "github.com/warden-protocol/wardenprotocol/shield/internal/preprocess" + "github.com/warden-protocol/wardenprotocol/shield/internal/validation" "github.com/warden-protocol/wardenprotocol/shield/object" ) type Environment = env.Environment +// TODO AT: Move to Env or Config? +const MaxNestingDepth = 100 + // Parse parses the input string and returns the root node of the AST. // In case of syntax errors, it returns an error. func Parse(input string) (*ast.Expression, error) { @@ -26,6 +30,10 @@ func Parse(input string) (*ast.Expression, error) { return nil, fmt.Errorf("parser errors: %v", p.Errors()) } + if err := validation.Validate(root, MaxNestingDepth); err != nil { + return nil, fmt.Errorf("parser validation error: %v", err) + } + return root, nil } From d2a00b008c06e0ff981859b068a37ec3503c3140 Mon Sep 17 00:00:00 2001 From: Aleksandr Tretiakov Date: Wed, 24 Jul 2024 10:23:27 +0300 Subject: [PATCH 3/4] Fixed tests --- .../functions_nesting_analyzer_test.go | 115 ------------------ .../{validator.go => validation.go} | 0 shield/internal/validation/validation_test.go | 81 ++++++++++++ 3 files changed, 81 insertions(+), 115 deletions(-) delete mode 100644 shield/internal/validation/functions_nesting_analyzer_test.go rename shield/internal/validation/{validator.go => validation.go} (100%) create mode 100644 shield/internal/validation/validation_test.go diff --git a/shield/internal/validation/functions_nesting_analyzer_test.go b/shield/internal/validation/functions_nesting_analyzer_test.go deleted file mode 100644 index f8237bcaf..000000000 --- a/shield/internal/validation/functions_nesting_analyzer_test.go +++ /dev/null @@ -1,115 +0,0 @@ -package validation - -import ( - "context" - "testing" - - "github.com/stretchr/testify/require" - "github.com/warden-protocol/wardenprotocol/shield/ast" - "github.com/warden-protocol/wardenprotocol/shield/internal/lexer" - "github.com/warden-protocol/wardenprotocol/shield/internal/parser" -) - -func parseExpression(t *testing.T, input string) *ast.Expression { - l := lexer.New(input) - p := parser.New(l) - expression := p.Parse() - - err := p.Errors() - if len(err) != 0 { - require.FailNow(t, "Parser finished with errors", err) - } - - require.NotNil(t, expression) - return expression -} - -func TestAnalyzer(t *testing.T) { - testCases := []struct { - expression string - expectedDepth int - }{ - {"foo1(foo2([foo2(), 11, 12]), false || true, [false, [10, 11, 12]])", 3}, - {"foo2() && foo1(foo2(), false || true, [false, [10, 11, 12]])", 2}, - {"foo2() ", 1}, - {"true", 0}, - } - - ctx := context.Background() - - for _, tc := range testCases { - expression := parseExpression(t, tc.expression) - - maxDepth, err := AnalyzeFunctionsNesting(ctx, expression, 0) - if err != nil { - t.Error(err) - } - - require.Equal(t, tc.expectedDepth, maxDepth) - } -} - -// -//func TestPreprocessElements(t *testing.T) { -// ctx := context.Background() -// expression := parseExpression(t, "[10, 11, 12]") -// expander := NoopExpander{} -// -// proc, err := Preprocess(ctx, expression, expander) -// if err != nil { -// t.Error(err) -// } -// -// arr := proc.GetArrayLiteral() -// require.Len(t, arr.Elements, 3) -// -// for i, arrElem := range arr.Elements { -// intVal := arrElem.GetIntegerLiteral() -// require.Equal(t, intVal.Value, big.NewInt(int64(10+i)).String()) -// } -//} -// -//func TestPreprocessInfixExpression(t *testing.T) { -// ctx := context.Background() -// expression := parseExpression(t, "false || true && false") -// expander := NoopExpander{} -// -// proc, err := Preprocess(ctx, expression, expander) -// if err != nil { -// t.Error(err) -// } -// -// inf1 := proc.GetInfixExpression() -// require.Equal(t, inf1.Operator, "||") -// -// inf1Left := inf1.Left.GetBooleanLiteral() -// require.Equal(t, inf1Left.Value, false) -// -// inf2 := inf1.Right.GetInfixExpression() -// require.Equal(t, inf2.Operator, "&&") -// -// inf2Left := inf2.Left.GetBooleanLiteral() -// require.Equal(t, inf2Left.Value, true) -// inf2Right := inf2.Right.GetBooleanLiteral() -// require.Equal(t, inf2Right.Value, false) -//} -// -//func TestPreprocessCallExpression(t *testing.T) { -// ctx := context.Background() -// expression := parseExpression(t, "foo1(123, foo2(235))") -// expander := NoopExpander{} -// -// proc, err := Preprocess(ctx, expression, expander) -// if err != nil { -// t.Error(err) -// } -// -// call1 := proc.GetCallExpression() -// require.Equal(t, call1.Function.Value, "foo1") -// require.Equal(t, len(call1.Arguments), 2) -// require.Equal(t, call1.Arguments[0].GetIntegerLiteral().Value, big.NewInt(int64(123)).String()) -// -// call2 := call1.Arguments[1].GetCallExpression() -// require.Equal(t, len(call2.Arguments), 1) -// require.Equal(t, call2.Arguments[0].GetIntegerLiteral().Value, big.NewInt(int64(235)).String()) -//} diff --git a/shield/internal/validation/validator.go b/shield/internal/validation/validation.go similarity index 100% rename from shield/internal/validation/validator.go rename to shield/internal/validation/validation.go diff --git a/shield/internal/validation/validation_test.go b/shield/internal/validation/validation_test.go new file mode 100644 index 000000000..4c6c621c9 --- /dev/null +++ b/shield/internal/validation/validation_test.go @@ -0,0 +1,81 @@ +package validation + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/warden-protocol/wardenprotocol/shield/ast" + "github.com/warden-protocol/wardenprotocol/shield/internal/lexer" + "github.com/warden-protocol/wardenprotocol/shield/internal/parser" +) + +func parseExpression(t *testing.T, input string) *ast.Expression { + l := lexer.New(input) + p := parser.New(l) + expression := p.Parse() + + err := p.Errors() + if len(err) != 0 { + require.FailNow(t, "Parser finished with errors", err) + } + + require.NotNil(t, expression) + return expression +} + +func TestDepthAnalyzer(t *testing.T) { + testCases := []struct { + expression string + expectedDepth int + }{ + {"foo1(foo2([foo2(), 11, 12]), false || true, [false, [10, 11, 12]])", 3}, + {"foo2() && foo1(foo2(), false || true, [false, [10, 11, 12]])", 2}, + {"foo2() ", 1}, + {"true", 0}, + } + + for _, tc := range testCases { + expression := parseExpression(t, tc.expression) + maxDepth := AnalyzeFunctionsNesting(expression, 0) + + require.Equal(t, tc.expectedDepth, maxDepth) + } +} + +func TestValidatorShouldFail(t *testing.T) { + testCases := []struct { + expression string + expectedDepth int + }{ + {"foo1(foo2([foo2(), 11, 12]), false || true, [false, [10, 11, 12]])", 2}, + {"foo2() && foo1(foo2(), false || true, [false, [10, 11, 12]])", 1}, + {"foo2() ", 0}, + } + + for _, tc := range testCases { + expression := parseExpression(t, tc.expression) + err := Validate(expression, tc.expectedDepth) + + require.Error(t, err) + } +} + +func TestValidatorShouldSuccess(t *testing.T) { + testCases := []struct { + expression string + expectedDepth int + }{ + {"foo1(foo2([foo2(), 11, 12]))", 3}, + {"foo1(foo2([foo2(), 11, 12]))", 4}, + {"foo2() ", 1}, + {"foo2() ", 2}, + {"true", 0}, + } + + for _, tc := range testCases { + expression := parseExpression(t, tc.expression) + err := Validate(expression, tc.expectedDepth) + + require.NoError(t, err) + } +} From b72df779c733b7a41a734d4c5647076295305834 Mon Sep 17 00:00:00 2001 From: Aleksandr Tretiakov Date: Wed, 24 Jul 2024 11:03:05 +0300 Subject: [PATCH 4/4] Added CHANGELOG.md --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index e9292c064..e1500b634 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,6 +42,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Consensus Breaking Changes +* (shield) The depth of nesting for functions is limited to 100 levels + ### Features (non-breaking) ### Bug Fixes