diff --git a/pkg/types/query/query.go b/pkg/types/query/query.go index 589c7aa7f..1246a0bd7 100644 --- a/pkg/types/query/query.go +++ b/pkg/types/query/query.go @@ -82,12 +82,18 @@ func TxHash(txHash string) Expression { } func And(expressions ...Expression) Expression { + if len(expressions) == 1 { + return expressions[0] + } return Expression{ BoolExpression: BoolExpression{Expressions: expressions, BoolOperator: AND}, } } func Or(expressions ...Expression) Expression { + if len(expressions) == 1 { + return expressions[0] + } return Expression{ BoolExpression: BoolExpression{Expressions: expressions, BoolOperator: OR}, } diff --git a/pkg/types/query/query_test.go b/pkg/types/query/query_test.go new file mode 100644 index 000000000..7773c36a5 --- /dev/null +++ b/pkg/types/query/query_test.go @@ -0,0 +1,66 @@ +package query + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/chainlink-common/pkg/types/query/primitives" +) + +func Test_AndOrEdgeCases(t *testing.T) { + tests := []struct { + name string + expressions []Expression + constructor func(expressions ...Expression) Expression + expected Expression + }{ + { + name: "And with no expressions", + constructor: And, + expected: And(), + }, + { + name: "Or with no expressions", + constructor: Or, + expected: Or(), + }, + { + name: "And with one expression", + expressions: []Expression{TxHash("txHash")}, + constructor: And, + expected: TxHash("txHash"), + }, + { + name: "Or with one expression", + expressions: []Expression{TxHash("txHash")}, + constructor: Or, + expected: TxHash("txHash"), + }, + { + name: "And with multiple expressions", + expressions: []Expression{TxHash("txHash"), Block(123, primitives.Eq)}, + constructor: And, + expected: And( + TxHash("txHash"), + Block(123, primitives.Eq), + ), + }, + { + name: "Or with multiple expressions", + expressions: []Expression{TxHash("txHash"), Block(123, primitives.Eq)}, + constructor: Or, + expected: Or( + TxHash("txHash"), + Block(123, primitives.Eq), + ), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.constructor(tt.expressions...) + require.Equal(t, tt.expected, got) + }) + } +}