From 0ba11baddcb9adfefc7cd68694397a4bab72dbdc Mon Sep 17 00:00:00 2001 From: Brett Lawson Date: Fri, 3 Aug 2018 04:15:07 -0700 Subject: [PATCH] Fixes #5, #32: Rewrote transformer to calculate after nodes correctly. In a number of different cases, the compiler was incapable of figuring out the correct place for after nodes. This also made it impossible for there to be fields on the RHS of an expression as the compiler would get confused and incorrectly compile them. --- expression.go | 157 +----------- expression_utils.go | 105 ++++++++ matcher.go | 569 ++++++++++++++++++++++++++++---------------- matcher_test.go | 26 ++ matcherdef.go | 142 +++++++---- transform.go | 414 ++++++++++++++++++++++---------- 6 files changed, 874 insertions(+), 539 deletions(-) create mode 100644 expression_utils.go diff --git a/expression.go b/expression.go index 1b4fc3c..96bab41 100644 --- a/expression.go +++ b/expression.go @@ -9,44 +9,16 @@ import ( type VariableID int -func fieldExprCompare(lhs FieldExpr, rhs FieldExpr) bool { - if rhs.Root != lhs.Root { - return false +func (id VariableID) String() string { + if id == 0 { + return "$doc" } - if len(rhs.Path) != len(lhs.Path) { - return false - } - - for i, path := range rhs.Path { - if lhs.Path[i] != path { - return false - } - } - - return true -} - -func rootSetAddOne(set []FieldExpr, item FieldExpr) []FieldExpr { - for _, oitem := range set { - if fieldExprCompare(oitem, item) { - return set - } - } - return append(set, item) -} - -func rootSetAdd(set []FieldExpr, items ...FieldExpr) []FieldExpr { - out := set - for _, item := range items { - out = rootSetAddOne(out, item) - } - return out + return fmt.Sprintf("$%d", id) } type Expression interface { String() string - RootRefs() []FieldExpr } type TrueExpr struct { @@ -56,10 +28,6 @@ func (expr TrueExpr) String() string { return "True" } -func (expr TrueExpr) RootRefs() []FieldExpr { - return nil -} - type FalseExpr struct { } @@ -67,10 +35,6 @@ func (expr FalseExpr) String() string { return "False" } -func (expr FalseExpr) RootRefs() []FieldExpr { - return nil -} - type ValueExpr struct { Value interface{} } @@ -79,10 +43,6 @@ func (expr ValueExpr) String() string { return fmt.Sprintf("%v", expr.Value) } -func (expr ValueExpr) RootRefs() []FieldExpr { - return nil -} - type RegexExpr struct { Regex interface{} } @@ -91,10 +51,6 @@ func (expr RegexExpr) String() string { return fmt.Sprintf("/%v/", expr.Regex) } -func (expr RegexExpr) RootRefs() []FieldExpr { - return nil -} - type NotExpr struct { SubExpr Expression } @@ -103,10 +59,6 @@ func (expr NotExpr) String() string { return "NOT " + expr.SubExpr.String() } -func (expr NotExpr) RootRefs() []FieldExpr { - return expr.SubExpr.RootRefs() -} - type AndExpr []Expression func (expr AndExpr) String() string { @@ -124,14 +76,6 @@ func (expr AndExpr) String() string { } } -func (expr AndExpr) RootRefs() []FieldExpr { - var out []FieldExpr - for _, subexpr := range expr { - out = rootSetAdd(out, subexpr.RootRefs()...) - } - return out -} - type OrExpr []Expression func (expr OrExpr) String() string { @@ -149,14 +93,6 @@ func (expr OrExpr) String() string { } } -func (expr OrExpr) RootRefs() []FieldExpr { - var out []FieldExpr - for _, subexpr := range expr { - out = rootSetAdd(out, subexpr.RootRefs()...) - } - return out -} - type FieldExpr struct { Root VariableID Path []string @@ -175,13 +111,6 @@ func (expr FieldExpr) String() string { } } -func (expr FieldExpr) RootRefs() []FieldExpr { - if expr.Root != 0 { - return nil - } - return []FieldExpr{expr} -} - type AnyInExpr struct { VarId VariableID InExpr Expression @@ -193,13 +122,6 @@ func (expr AnyInExpr) String() string { return fmt.Sprintf("any $%d in %s\n%s\nend", expr.VarId, expr.InExpr, exprStr) } -func (expr AnyInExpr) RootRefs() []FieldExpr { - var out []FieldExpr - out = rootSetAdd(out, expr.InExpr.RootRefs()...) - out = rootSetAdd(out, expr.SubExpr.RootRefs()...) - return out -} - type EveryInExpr struct { VarId VariableID InExpr Expression @@ -211,13 +133,6 @@ func (expr EveryInExpr) String() string { return fmt.Sprintf("every $%d in %s\n%s\nend", expr.VarId, expr.InExpr, exprStr) } -func (expr EveryInExpr) RootRefs() []FieldExpr { - var out []FieldExpr - out = rootSetAdd(out, expr.InExpr.RootRefs()...) - out = rootSetAdd(out, expr.SubExpr.RootRefs()...) - return out -} - type AnyEveryInExpr struct { VarId VariableID InExpr Expression @@ -229,13 +144,6 @@ func (expr AnyEveryInExpr) String() string { return fmt.Sprintf("any and every $%d in %s\n%s\nend", expr.VarId, expr.InExpr, exprStr) } -func (expr AnyEveryInExpr) RootRefs() []FieldExpr { - var out []FieldExpr - out = rootSetAdd(out, expr.InExpr.RootRefs()...) - out = rootSetAdd(out, expr.SubExpr.RootRefs()...) - return out -} - type ExistsExpr struct { SubExpr Expression } @@ -244,10 +152,6 @@ func (expr ExistsExpr) String() string { return fmt.Sprintf("%s EXISTS", expr.SubExpr) } -func (expr ExistsExpr) RootRefs() []FieldExpr { - return expr.SubExpr.RootRefs() -} - type NotExistsExpr struct { SubExpr Expression } @@ -256,10 +160,6 @@ func (expr NotExistsExpr) String() string { return fmt.Sprintf("%s IS MISSING", expr.SubExpr) } -func (expr NotExistsExpr) RootRefs() []FieldExpr { - return expr.SubExpr.RootRefs() -} - type EqualsExpr struct { Lhs Expression Rhs Expression @@ -269,13 +169,6 @@ func (expr EqualsExpr) String() string { return fmt.Sprintf("%s = %s", expr.Lhs, expr.Rhs) } -func (expr EqualsExpr) RootRefs() []FieldExpr { - var out []FieldExpr - out = rootSetAdd(out, expr.Lhs.RootRefs()...) - out = rootSetAdd(out, expr.Rhs.RootRefs()...) - return out -} - type NotEqualsExpr struct { Lhs Expression Rhs Expression @@ -285,13 +178,6 @@ func (expr NotEqualsExpr) String() string { return fmt.Sprintf("%s != %s", expr.Lhs, expr.Rhs) } -func (expr NotEqualsExpr) RootRefs() []FieldExpr { - var out []FieldExpr - out = rootSetAdd(out, expr.Lhs.RootRefs()...) - out = rootSetAdd(out, expr.Rhs.RootRefs()...) - return out -} - type LessThanExpr struct { Lhs Expression Rhs Expression @@ -301,13 +187,6 @@ func (expr LessThanExpr) String() string { return fmt.Sprintf("%s < %s", expr.Lhs, expr.Rhs) } -func (expr LessThanExpr) RootRefs() []FieldExpr { - var out []FieldExpr - out = rootSetAdd(out, expr.Lhs.RootRefs()...) - out = rootSetAdd(out, expr.Rhs.RootRefs()...) - return out -} - type LessEqualsExpr struct { Lhs Expression Rhs Expression @@ -317,13 +196,6 @@ func (expr LessEqualsExpr) String() string { return fmt.Sprintf("%s <= %s", expr.Lhs, expr.Rhs) } -func (expr LessEqualsExpr) RootRefs() []FieldExpr { - var out []FieldExpr - out = rootSetAdd(out, expr.Lhs.RootRefs()...) - out = rootSetAdd(out, expr.Rhs.RootRefs()...) - return out -} - type GreaterThanExpr struct { Lhs Expression Rhs Expression @@ -333,13 +205,6 @@ func (expr GreaterThanExpr) String() string { return fmt.Sprintf("%s > %s", expr.Lhs, expr.Rhs) } -func (expr GreaterThanExpr) RootRefs() []FieldExpr { - var out []FieldExpr - out = rootSetAdd(out, expr.Lhs.RootRefs()...) - out = rootSetAdd(out, expr.Rhs.RootRefs()...) - return out -} - type GreaterEqualsExpr struct { Lhs Expression Rhs Expression @@ -349,13 +214,6 @@ func (expr GreaterEqualsExpr) String() string { return fmt.Sprintf("%s >= %s", expr.Lhs, expr.Rhs) } -func (expr GreaterEqualsExpr) RootRefs() []FieldExpr { - var out []FieldExpr - out = rootSetAdd(out, expr.Lhs.RootRefs()...) - out = rootSetAdd(out, expr.Rhs.RootRefs()...) - return out -} - type LikeExpr struct { Lhs Expression Rhs Expression @@ -364,10 +222,3 @@ type LikeExpr struct { func (expr LikeExpr) String() string { return fmt.Sprintf("%s =~ %s", expr.Lhs, expr.Rhs) } - -func (expr LikeExpr) RootRefs() []FieldExpr { - var out []FieldExpr - out = rootSetAdd(out, expr.Lhs.RootRefs()...) - out = rootSetAdd(out, expr.Rhs.RootRefs()...) - return out -} diff --git a/expression_utils.go b/expression_utils.go new file mode 100644 index 0000000..94e2b84 --- /dev/null +++ b/expression_utils.go @@ -0,0 +1,105 @@ +package gojsonsm + +import "fmt" + +func fieldExprMatches(lhs FieldExpr, rhs FieldExpr) bool { + if lhs.Root != rhs.Root { + return false + } + if len(lhs.Path) != len(rhs.Path) { + return false + } + for pathIdx := range lhs.Path { + if lhs.Path[pathIdx] != rhs.Path[pathIdx] { + return false + } + } + return true +} + +func fetchExprFieldRefsRecurse(expr Expression, loopVars []VariableID, fields []FieldExpr) []FieldExpr { + switch expr := expr.(type) { + case FieldExpr: + isLoopVarRef := false + for _, loopVar := range loopVars { + if expr.Root == loopVar { + isLoopVarRef = true + break + } + } + if isLoopVarRef { + break + } + + fieldAlreadyExists := false + for _, oexpr := range fields { + if fieldExprMatches(expr, oexpr) { + fieldAlreadyExists = true + break + } + } + if fieldAlreadyExists { + break + } + + fields = append(fields, expr) + case ValueExpr: + case RegexExpr: + case NotExpr: + fields = fetchExprFieldRefsRecurse(expr.SubExpr, loopVars, fields) + case AndExpr: + for _, subexpr := range expr { + fields = fetchExprFieldRefsRecurse(subexpr, loopVars, fields) + } + case OrExpr: + for _, subexpr := range expr { + fields = fetchExprFieldRefsRecurse(subexpr, loopVars, fields) + } + case AnyInExpr: + fields = fetchExprFieldRefsRecurse(expr.InExpr, loopVars, fields) + loopVars = append(loopVars, expr.VarId) + fields = fetchExprFieldRefsRecurse(expr.SubExpr, loopVars, fields) + loopVars = loopVars[0 : len(loopVars)-1] + case EveryInExpr: + fields = fetchExprFieldRefsRecurse(expr.InExpr, loopVars, fields) + loopVars = append(loopVars, expr.VarId) + fields = fetchExprFieldRefsRecurse(expr.SubExpr, loopVars, fields) + loopVars = loopVars[0 : len(loopVars)-1] + case AnyEveryInExpr: + fields = fetchExprFieldRefsRecurse(expr.InExpr, loopVars, fields) + loopVars = append(loopVars, expr.VarId) + fields = fetchExprFieldRefsRecurse(expr.SubExpr, loopVars, fields) + loopVars = loopVars[0 : len(loopVars)-1] + case EqualsExpr: + fields = fetchExprFieldRefsRecurse(expr.Lhs, loopVars, fields) + fields = fetchExprFieldRefsRecurse(expr.Rhs, loopVars, fields) + case NotEqualsExpr: + fields = fetchExprFieldRefsRecurse(expr.Lhs, loopVars, fields) + fields = fetchExprFieldRefsRecurse(expr.Rhs, loopVars, fields) + case LessThanExpr: + fields = fetchExprFieldRefsRecurse(expr.Lhs, loopVars, fields) + fields = fetchExprFieldRefsRecurse(expr.Rhs, loopVars, fields) + case LessEqualsExpr: + fields = fetchExprFieldRefsRecurse(expr.Lhs, loopVars, fields) + fields = fetchExprFieldRefsRecurse(expr.Rhs, loopVars, fields) + case GreaterThanExpr: + fields = fetchExprFieldRefsRecurse(expr.Lhs, loopVars, fields) + fields = fetchExprFieldRefsRecurse(expr.Rhs, loopVars, fields) + case GreaterEqualsExpr: + fields = fetchExprFieldRefsRecurse(expr.Lhs, loopVars, fields) + fields = fetchExprFieldRefsRecurse(expr.Rhs, loopVars, fields) + case ExistsExpr: + fields = fetchExprFieldRefsRecurse(expr.SubExpr, loopVars, fields) + case LikeExpr: + fields = fetchExprFieldRefsRecurse(expr.Lhs, loopVars, fields) + fields = fetchExprFieldRefsRecurse(expr.Rhs, loopVars, fields) + default: + panic(fmt.Sprintf("unexpected expression type %T", expr)) + } + + return fields +} + +func fetchExprFieldRefs(expr Expression) []FieldExpr { + return fetchExprFieldRefsRecurse(expr, nil, nil) +} diff --git a/matcher.go b/matcher.go index 95d7a06..4b50504 100644 --- a/matcher.go +++ b/matcher.go @@ -86,141 +86,384 @@ func (m *Matcher) skipValue(token tokenType) error { panic("unexpected value") } -func (m *Matcher) resolveParam(in interface{}) FastVal { - if opVal, ok := in.(SlotRef); ok { - panic(fmt.Sprintf("Cannot read %d", opVal.Slot)) +func (m *Matcher) literalFromSlot(slot SlotID) FastVal { + value := NewMissingFastVal() + + savePos := m.tokens.Position() + + slotInfo := m.slots[slot-1] + m.tokens.Seek(slotInfo.start) + token, tokenData, _ := m.tokens.Step() + + if isLiteralToken(token) { + var parser fastLitParser + value = parser.Parse(token, tokenData) } - if opValV, ok := in.(FastVal); ok { - return opValV - } else { + + m.tokens.Seek(savePos) + + return value +} + +func (m *Matcher) resolveParam(in interface{}, activeLit *FastVal) FastVal { + switch opVal := in.(type) { + case FastVal: + return opVal + case activeLitRef: + if activeLit == nil { + panic("cannot resolve active literal without having an active context") + } + + return *activeLit + case SlotRef: + return m.literalFromSlot(opVal.Slot) + default: panic(fmt.Sprintf("unexpected op value: %#v", in)) } } -func (m *Matcher) matchExec(token tokenType, tokenData []byte, node *ExecNode) error { - startPos := m.tokens.pos - endPos := -1 +func (m *Matcher) matchOp(op *OpNode, litVal *FastVal) error { + bucketIdx := int(op.BucketIdx) - if isLiteralToken(token) { - var litParse fastLitParser + if m.buckets.IsResolved(bucketIdx) { + // If the bucket for this op is already resolved in the binary tree, + // we don't need to perform the op and can just skip it. + return nil + } - // TODO(brett19): Move the litVal generation to be lazy-evaluated by the - // op execution below so we avoid performing any translations when the op - // is already resolved by something else. + lhsVal := NewMissingFastVal() + if op.Lhs != nil { + lhsVal = m.resolveParam(op.Lhs, litVal) + } else if litVal != nil { + lhsVal = *litVal + } - // Parse the literal token from the tokenizer into a FastVal value - // to be used for op execution below. - litVal := litParse.Parse(token, tokenData) + rhsVal := NewMissingFastVal() + if op.Rhs != nil { + rhsVal = m.resolveParam(op.Rhs, litVal) + } else if litVal != nil { + rhsVal = *litVal + } - for _, op := range node.Ops { - if m.buckets.IsResolved(int(op.BucketIdx)) { - // If the bucket for this op is already resolved in the binary tree, - // we don't need to perform the op and can just skip it. - continue - } else { - var opVal FastVal - if op.Rhs != nil { - opVal = m.resolveParam(op.Rhs) - } + var opRes bool + switch op.Op { + case OpTypeEquals: + opRes = lhsVal.Equals(rhsVal) + case OpTypeNotEquals: + opRes = !lhsVal.Equals(rhsVal) + case OpTypeLessThan: + opRes = lhsVal.Compare(rhsVal) < 0 + case OpTypeLessEquals: + opRes = lhsVal.Compare(rhsVal) <= 0 + case OpTypeGreaterThan: + opRes = lhsVal.Compare(rhsVal) > 0 + case OpTypeGreaterEquals: + opRes = lhsVal.Compare(rhsVal) >= 0 + case OpTypeMatches: + opRes = lhsVal.Matches(rhsVal) + case OpTypeExists: + opRes = true + default: + panic("invalid op type") + } - var opRes bool - switch op.Op { - case OpTypeEquals: - opRes = litVal.Equals(opVal) - case OpTypeNotEquals: - opRes = !litVal.Equals(opVal) - case OpTypeLessThan: - opRes = litVal.Compare(opVal) < 0 - case OpTypeLessEquals: - opRes = litVal.Compare(opVal) <= 0 - case OpTypeGreaterThan: - opRes = litVal.Compare(opVal) > 0 - case OpTypeGreaterEquals: - opRes = litVal.Compare(opVal) >= 0 - case OpTypeMatches: - opRes = litVal.Matches(opVal) - case OpTypeExists: - opRes = true - default: - panic("invalid op type") - } + // Mark the result of this operation + m.buckets.MarkNode(bucketIdx, opRes) - // Mark the result of this operation - m.buckets.MarkNode(int(op.BucketIdx), opRes) + // Check if running this values ops has resolved the entirety + // of the expression, if so we can leave immediately. + if m.buckets.IsResolved(0) { + return nil + } - // Check if running this values ops has resolved the entirety - // of the expression, if so we can leave immediately. - if m.buckets.IsResolved(0) { - return nil - } + return nil +} + +func (m *Matcher) matchElems(token tokenType, tokenData []byte, elems map[string]*ExecNode) error { + // Note that this assumes that the tokenizer has already been placed at the target + // that referenced the elements themselves... + + // Check that the token that we started with is an object that we can scan over, + // if it is not, we need to exit early as these elements do not apply. + if token != tknObjectStart { + return nil + } + + var keyLitParse fastLitParser + + for i := 0; ; i++ { + // If this is not the first entry in the object, there should be a + // list delimiter ('c') that shows up in the input first. + if i != 0 { + token, _, err := m.tokens.Step() + if err != nil { + return err + } + + if token == tknObjectEnd { + return nil + } + if token != tknListDelim { + panic("expected object field element delimiter") } } - return nil - } else if token == tknObjectStart { - var keyLitParse fastLitParser + token, tokenData, err := m.tokens.Step() + if err != nil { + return err + } + if token == tknObjectEnd { + return nil + } - for i := 0; ; i++ { - // If this is not the first entry in the object, there should be a - // list delimiter ('c') that shows up in the input first. - if i != 0 { - token, _, err := m.tokens.Step() - if err != nil { - return err - } + // TODO(brett19): These byte-string conversion pieces are a bit wierd + var keyBytes []byte + if token == tknString { + keyBytes = keyLitParse.ParseString(tokenData) + } else if token == tknEscString { + keyBytes = keyLitParse.ParseEscString(tokenData) + } else { + panic("expected literal") + } - if token == tknObjectEnd { - return nil - } - if token != tknListDelim { - panic("expected object field element delimiter") - } + token, _, err = m.tokens.Step() + if err != nil { + return err + } + if token != tknObjectKeyDelim { + panic("expected object key delimiter") + } + + token, tokenData, err = m.tokens.Step() + if err != nil { + return err + } + + if keyElem, ok := elems[string(keyBytes)]; ok { + // Run the execution node that applies to this particular + // key of the object. + m.matchExec(token, tokenData, keyElem) + + // Check if running this keys execution has resolved the entirety + // of the expression, if so we can leave immediately. + if m.buckets.IsResolved(0) { + return nil } + } else { + // If we don't have any parse requirements for this key in + // the object, we can just skip its value and continue + m.skipValue(token) + } + } +} - token, tokenData, err := m.tokens.Step() +func (m *Matcher) matchLoop(token tokenType, tokenData []byte, loop *LoopNode) error { + // Note that this assumes that the tokenizer has already been placed at the target + // that referenced the loop node itself... + + // Check that the token that we started with is an array that we can loop over, + // if it is not, we need to exit early as this LoopNode does not apply. + if token != tknArrayStart { + return nil + } + + // We need to keep track of the overall loop result value while the bin tree + // is being iterated on, reset, etc... + var loopState bool + if loop.Mode == LoopTypeAny { + loopState = false + } else if loop.Mode == LoopTypeEvery { + loopState = true + } else if loop.Mode == LoopTypeAnyEvery { + loopState = false + } else { + panic("invalid loop mode") + } + + loopBucketIdx := int(loop.BucketIdx) + + // We need to mark the stall index on our binary tree so that + // resolution of a loop iteration does not propagate up the tree + // and cause resolution of the entire expression. + previousStallIndex := m.buckets.SetStallIndex(loopBucketIdx) + + // Scan through all the values in the loop + for i := 0; ; i++ { + // If this is not the first entry in the array, there should be a + // list delimiter (',') that shows up in the input first. + if i != 0 { + token, _, err := m.tokens.Step() if err != nil { return err } - if token == tknObjectEnd { - return nil + + if token == tknArrayEnd { + break + } + if token != tknListDelim { + panic(fmt.Sprintf("expected array element delimiter got %s", tokenToText(token))) } + } + + token, tokenData, err := m.tokens.Step() + if err != nil { + return err + } + if token == tknArrayEnd { + break + } - var keyBytes []byte - if token == tknString { - keyBytes = keyLitParse.ParseString(tokenData) - } else if token == tknEscString { - keyBytes = keyLitParse.ParseEscString(tokenData) + // Reset the looping node in the binary tree so that previous iterations + // of the loop do not impact the results of this iteration + m.buckets.ResetNode(loopBucketIdx) + + // Run the execution node for this element of the array. + err = m.matchExec(token, tokenData, loop.Node) + if err != nil { + return err + } + + iterationMatched := m.buckets.IsTrue(loopBucketIdx) + if loop.Mode == LoopTypeAny { + if iterationMatched { + // If any element of the array matches, we know that + // this loop is successful + loopState = true + + // Skip the remainder of the array and leave the loop + m.leaveValue() + break + } + } else if loop.Mode == LoopTypeEvery { + if !iterationMatched { + // If any element of the array does not match, we know that + // this loop will never match + loopState = false + + // Skip the remainder of the array and leave the loop + m.leaveValue() + break + } + } else if loop.Mode == LoopTypeAnyEvery { + if !iterationMatched { + // If any element of the array does not match, we know that + // this loop will never match the `every` semantic. + loopState = false + + // Skip the remainder of the array and leave the loop + m.leaveValue() + break } else { - panic("expected literal") + // If we encounter a truthy value, we have satisfied the 'any' + // semantics of this loop and should mark it as such. + loopState = true + + // We must continue looping to satisfy the 'every' portion. } + } + } + + // We have to reset the node before we can mark it or our double-marking + // protection on the binary tree will trigger, this helpfully also marks + // the children of the loop to undefined resolution, which makes more sense + // then it having the state of the last iteration of the loop. + m.buckets.ResetNode(loopBucketIdx) + + // Reset the stall index to whatever it used to be to exit the 'context' + // of this particular loop. This acts as a stack in case there are + // multiple nested loops being processed. + m.buckets.SetStallIndex(previousStallIndex) + + // Apply the overall loop result to the binary tree + m.buckets.MarkNode(loopBucketIdx, loopState) - token, _, err = m.tokens.Step() + return nil +} + +func (m *Matcher) matchAfter(node *AfterNode) error { + savePos := m.tokens.Position() + + // Run loop matching + for _, loop := range node.Loops { + if slot, ok := loop.Target.(SlotRef); ok { + slotInfo := m.slots[slot.Slot-1] + + m.tokens.Seek(slotInfo.start) + token, tokenData, err := m.tokens.Step() + + // run the loop matcher + err = m.matchLoop(token, tokenData, &loop) if err != nil { return err } - if token != tknObjectKeyDelim { - panic("expected object key delimiter") + + if m.buckets.IsResolved(0) { + return nil } + } else { + panic("encountered after loop with non-slot target") + } + } + + // Run op matching + for _, op := range node.Ops { + err := m.matchOp(&op, nil) + if err != nil { + return err + } + + if m.buckets.IsResolved(0) { + return nil + } + } + + m.tokens.Seek(savePos) + + return nil +} + +func (m *Matcher) matchExec(token tokenType, tokenData []byte, node *ExecNode) error { + startPos := m.tokens.Position() + endPos := -1 + + // The start position needs to include the token we already parsed, so lets + // back up our position based on how long that is... + // TODO(brett19): We should probably find a more optimal way to handle this... + startPos -= len(tokenData) + + if isLiteralToken(token) { + var litParse fastLitParser + + // TODO(brett19): Move the litVal generation to be lazy-evaluated by the + // op execution below so we avoid performing any translations when the op + // is already resolved by something else. - token, tokenData, err = m.tokens.Step() + // Parse the literal token from the tokenizer into a FastVal value + // to be used for op execution below. + litVal := litParse.Parse(token, tokenData) + + for _, op := range node.Ops { + err := m.matchOp(&op, &litVal) if err != nil { return err } - if keyElem, ok := node.Elems[string(keyBytes)]; ok { - // Run the execution node that applies to this particular - // key of the object. - m.matchExec(token, tokenData, keyElem) + if m.buckets.IsResolved(0) { + return nil + } + } + } else if token == tknObjectStart { + if len(node.Elems) == 0 { + // If we have no element handlers, we can just skip the whole thing... + m.skipValue(token) + } else { + err := m.matchElems(token, tokenData, node.Elems) + if err != nil { + return nil + } - // Check if running this keys execution has resolved the entirety - // of the expression, if so we can leave immediately. - if m.buckets.IsResolved(0) { - return nil - } - } else { - // If we don't have any parse requirements for this key in - // the object, we can just skip its value and continue - m.skipValue(token) + if m.buckets.IsResolved(0) { + return nil } } } else if token == tknArrayStart { @@ -234,6 +477,10 @@ func (m *Matcher) matchExec(token tokenType, tokenData []byte, node *ExecNode) e savePos := m.tokens.Position() for loopIdx, loop := range node.Loops { + if loop.Target != nil { + panic("loops must always target the active state") + + } if loopIdx != 0 { // If this is not the first loop, we will need to reset back to the // begining of the array the loops are scanning. In the future, perhaps @@ -241,113 +488,12 @@ func (m *Matcher) matchExec(token tokenType, tokenData []byte, node *ExecNode) e m.tokens.Seek(savePos) } - // We need to keep track of the overall loop result value while the bin tree - // is being iterated on, reset, etc... - var loopState bool - if loop.Mode == LoopTypeAny { - loopState = false - } else if loop.Mode == LoopTypeEvery { - loopState = true - } else if loop.Mode == LoopTypeAnyEvery { - loopState = false - } else { - panic("invalid loop mode") - } - - loopBucketIdx := int(loop.BucketIdx) - - // We need to mark the stall index on our binary tree so that - // resolution of a loop iteration does not propagate up the tree - // and cause resolution of the entire expression. - previousStallIndex := m.buckets.SetStallIndex(loopBucketIdx) - - // Scan through all the values in the loop - for i := 0; ; i++ { - // If this is not the first entry in the array, there should be a - // list delimiter ('c') that shows up in the input first. - if i != 0 { - token, _, err := m.tokens.Step() - if err != nil { - return err - } - - if token == tknArrayEnd { - break - } - if token != tknListDelim { - panic("expected array element delimiter") - } - } - - token, tokenData, err := m.tokens.Step() - if err != nil { - return err - } - if token == tknArrayEnd { - break - } - - // Reset the looping node in the binary tree so that previous iterations - // of the loop do not impact the results of this iteration - m.buckets.ResetNode(loopBucketIdx) - - // Run the execution node for this element of the array. - m.matchExec(token, tokenData, loop.Node) - - iterationMatched := m.buckets.IsTrue(loopBucketIdx) - if loop.Mode == LoopTypeAny { - if iterationMatched { - // If any element of the array matches, we know that - // this loop is successful - loopState = true - - // Skip the remainder of the array and leave the loop - m.leaveValue() - break - } - } else if loop.Mode == LoopTypeEvery { - if !iterationMatched { - // If any element of the array does not match, we know that - // this loop will never match - loopState = false - - // Skip the remainder of the array and leave the loop - m.leaveValue() - break - } - } else if loop.Mode == LoopTypeAnyEvery { - if !iterationMatched { - // If any element of the array does not match, we know that - // this loop will never match the `every` semantic. - loopState = false - - // Skip the remainder of the array and leave the loop - m.leaveValue() - break - } else { - // If we encounter a truthy value, we have satisfied the 'any' - // semantics of this loop and should mark it as such. - loopState = true - - // We must continue looping to satisfy the 'every' portion. - } - } + // Run the loop matching logic + err := m.matchLoop(token, tokenData, &loop) + if err != nil { + return err } - // We have to reset the node before we can mark it or our double-marking - // protection on the binary tree will trigger, this helpfully also marks - // the children of the loop to undefined resolution, which makes more sense - // then it having the state of the last iteration of the loop. - m.buckets.ResetNode(loopBucketIdx) - - // Reset the stall index to whatever it used to be to exit the 'context' - // of this particular loop. This acts as a stack in case there are - // multiple nested loops being processed. - m.buckets.SetStallIndex(previousStallIndex) - - // Apply the overall loop result to the binary tree - m.buckets.MarkNode(loopBucketIdx, loopState) - // Check if the entire expression has been resolved, if so we can simply // exit the entire set of looping if m.buckets.IsResolved(0) { @@ -359,11 +505,20 @@ func (m *Matcher) matchExec(token tokenType, tokenData []byte, node *ExecNode) e panic("invalid token read") } - endPos = m.tokens.pos + if node.After != nil { + m.matchAfter(node.After) + + if m.buckets.IsResolved(0) { + return nil + } + } + + endPos = m.tokens.Position() + if node.StoreId > 0 { - varData := &m.slots[node.StoreId-1] - varData.start = startPos - varData.size = endPos - startPos + slotData := &m.slots[node.StoreId-1] + slotData.start = startPos + slotData.size = endPos - startPos } return nil diff --git a/matcher_test.go b/matcher_test.go index b5113ae..1af930d 100644 --- a/matcher_test.go +++ b/matcher_test.go @@ -337,3 +337,29 @@ func TestMatcherAnyEveryInEquals(t *testing.T) { "5b47eb0936ff92a567a0307e", }) } + +func TestMatcherCrossScopeLoop(t *testing.T) { + // TODO(brett19): There is a bug workaround for an issue that + // if loops use bucket 0, they are incorrectly exited early. + + runJSONExprMatchTest(t, ` + ["or", + ["equals", + ["field", "name"], + ["value", "This is a bug workaround!"] + ], + ["anyin", + 1, + ["field", "friends"], + ["equals", + ["field", 1, "id"], + ["field", "index"] + ] + ] + ] + `, []string{ + "5b47eb0936ff92a567a0307e", + "5b47eb096b1d911c0b9492fb", + "5b47eb0950e9076fc0aecd52", + }) +} diff --git a/matcherdef.go b/matcherdef.go index e22ddd6..46879f0 100644 --- a/matcherdef.go +++ b/matcherdef.go @@ -7,8 +7,35 @@ import ( ) type SlotID int + +func (id SlotID) String() string { + return fmt.Sprintf("#%d", id) +} + type BucketID int +func (id BucketID) String() string { + return fmt.Sprintf("%%%d", id) +} + +type DataRef interface { + String() string +} + +func dataRefToString(ref DataRef) string { + if ref == nil { + return activeLitRef{}.String() + } + return ref.String() +} + +type activeLitRef struct { +} + +func (activeLitRef) String() string { + return "@" +} + type SlotRef struct { Slot SlotID } @@ -31,7 +58,7 @@ const ( OpTypeMatches ) -func opTypeToString(value OpType) string { +func (value OpType) String() string { switch value { case OpTypeEquals: return "eq" @@ -57,18 +84,16 @@ func opTypeToString(value OpType) string { type OpNode struct { BucketIdx BucketID Op OpType - Rhs interface{} + Lhs DataRef + Rhs DataRef } func (op OpNode) String() string { - var out string - out += fmt.Sprintf("[%d] %s", op.BucketIdx, opTypeToString(op.Op)) - - if op.Rhs != nil { - out += " " + fmt.Sprintf("%v", op.Rhs) - } - - return out + return fmt.Sprintf("[%d] %s %s %s", + op.BucketIdx, + dataRefToString(op.Lhs), + op.Op, + dataRefToString(op.Rhs)) } type LoopType int @@ -79,7 +104,7 @@ const ( LoopTypeAnyEvery ) -func loopTypeToString(value LoopType) string { +func (value LoopType) String() string { switch value { case LoopTypeAny: return "any" @@ -95,15 +120,54 @@ func loopTypeToString(value LoopType) string { type LoopNode struct { BucketIdx BucketID Mode LoopType + Target DataRef Node *ExecNode } +func (node *LoopNode) String() string { + out := "" + out += fmt.Sprintf("[%d] :%s in %s:\n", node.BucketIdx, node.Mode, dataRefToString(node.Target)) + out += reindentString(node.Node.String(), " ") + return out +} + +type AfterNode struct { + Ops []OpNode + Loops []LoopNode +} + type ExecNode struct { StoreId SlotID - Ops []*OpNode Elems map[string]*ExecNode + Ops []OpNode Loops []LoopNode - After map[SlotID]*ExecNode + After *AfterNode +} + +type MatchDef struct { + ParseNode *ExecNode + MatchTree binTree + MatchBuckets []int + NumBuckets int + NumSlots int +} + +func (def MatchDef) String() string { + var out string + out += "match tree:\n" + out += " $doc:\n" + out += reindentString(def.ParseNode.String(), " ") + out += "\n" + out += "bin tree:\n" + out += reindentString(def.MatchTree.String(), " ") + out += "\n" + out += "match buckets:\n" + for i, bucketID := range def.MatchBuckets { + out += fmt.Sprintf(" %d: %d\n", i, bucketID) + } + out += fmt.Sprintf("num buckets: %d\n", def.NumBuckets) + out += fmt.Sprintf("num slots: %d\n", def.NumSlots) + return strings.TrimRight(out, "\n") } func (node ExecNode) String() string { @@ -137,51 +201,31 @@ func (node ExecNode) String() string { } } - if node.Loops != nil { + if len(node.Loops) > 0 { out += fmt.Sprintf(":loops\n") for _, loop := range node.Loops { - out += fmt.Sprintf("[%d] :%s:\n", loop.BucketIdx, loopTypeToString(loop.Mode)) - - out += reindentString(loop.Node.String(), " ") + out += reindentString(loop.String(), " ") out += "\n" } } if node.After != nil { - out += fmt.Sprintf(":after:\n") - for varId, anode := range node.After { - out += fmt.Sprintf(" #with $%d:\n", varId) - out += reindentString(anode.String(), " ") - out += "\n" + if len(node.After.Ops) > 0 { + out += fmt.Sprintf(":after-ops:\n") + for _, anode := range node.After.Ops { + out += reindentString(anode.String(), " ") + out += "\n" + } } - } - return strings.TrimRight(out, "\n") -} - -type MatchDef struct { - ParseNode *ExecNode - MatchTree binTree - MatchBuckets []int - NumBuckets int - NumSlots int - MaxDepth int -} - -func (def MatchDef) String() string { - var out string - out += "match tree:\n" - out += reindentString(def.ParseNode.String(), " ") - out += "\n" - out += "bin tree:\n" - out += reindentString(def.MatchTree.String(), " ") - out += "\n" - out += "match buckets:\n" - for i, bucketID := range def.MatchBuckets { - out += fmt.Sprintf(" %d: %d\n", i, bucketID) + if len(node.After.Loops) > 0 { + out += fmt.Sprintf(":after-loops:\n") + for _, loop := range node.After.Loops { + out += reindentString(loop.String(), " ") + out += "\n" + } + } } - out += fmt.Sprintf("num buckets: %d\n", def.NumBuckets) - out += fmt.Sprintf("num fetches: %d\n", def.NumSlots) - out += fmt.Sprintf("max depth: %d\n", def.MaxDepth) + return strings.TrimRight(out, "\n") } diff --git a/transform.go b/transform.go index b32bae0..8b709c9 100644 --- a/transform.go +++ b/transform.go @@ -3,9 +3,30 @@ package gojsonsm import ( + "errors" + "fmt" "regexp" + "strings" ) +type resolvedFieldRef struct { + Context *compileContext + Path []string +} + +func (ref resolvedFieldRef) String() string { + outStr := "$ROOT" + if ref.Context != nil { + outStr = fmt.Sprintf("%s", ref.Context) + } + + if len(ref.Path) > 0 { + return outStr + "." + strings.Join(ref.Path, ".") + } + + return outStr +} + type mergeExpr struct { exprs []Expression bucketIDs []BucketID @@ -26,35 +47,30 @@ func (expr mergeExpr) String() string { } } -func (expr mergeExpr) RootRefs() []FieldExpr { - var out []FieldExpr - for _, subexpr := range expr.exprs { - out = rootSetAdd(out, subexpr.RootRefs()...) - } - return out +type compileContext struct { + Depth int + Var VariableID + Node *ExecNode +} + +func (ctx *compileContext) String() string { + return fmt.Sprintf("$%d@%d", ctx.Var, ctx.Depth) } type Transformer struct { - SlotIdx SlotID - BucketIdx BucketID - RootExec *ExecNode - RootTree binTree - NodeMap map[VariableID]*ExecNode - ActiveExec *ExecNode + SlotIdx SlotID + BucketIdx BucketID + RootExec *ExecNode + RootTree binTree + + ContextStack []*compileContext ActiveBucketIdx BucketID - MaxDepth int - CurDepth int } -func (t *Transformer) getExecNode(field FieldExpr) *ExecNode { +func (t *Transformer) getExecNode(field resolvedFieldRef) *ExecNode { node := t.RootExec - - if field.Root != 0 { - node = t.NodeMap[field.Root] - if node == nil { - // TODO - panic("invalid field reference") - } + if field.Context != nil { + node = field.Context.Node } for _, entry := range field.Path { @@ -72,26 +88,19 @@ func (t *Transformer) getExecNode(field FieldExpr) *ExecNode { return node } -func (t *Transformer) storeNode(node *ExecNode) SlotID { +func (t *Transformer) storeExecNode(node *ExecNode) SlotID { if node.StoreId == 0 { node.StoreId = t.newSlot() } return node.StoreId } -func (t *Transformer) makeAfterNode(node *ExecNode, slot SlotID) *ExecNode { +func (t *Transformer) getAfterNode(node *ExecNode) *AfterNode { if node.After == nil { - node.After = make(map[SlotID]*ExecNode) - } else { - foundNode := node.After[slot] - if foundNode != nil { - return foundNode - } + node.After = &AfterNode{} } - newNode := &ExecNode{} - node.After[slot] = newNode - return newNode + return node.After } func (t *Transformer) newBucket() BucketID { @@ -114,6 +123,190 @@ func (t *Transformer) newSlot() SlotID { return newSlotID + 1 } +func (t *Transformer) pushContext(varID VariableID, execNode *ExecNode) { + t.ContextStack = append(t.ContextStack, &compileContext{ + Depth: len(t.ContextStack) + 1, + Var: varID, + Node: execNode, + }) +} + +func (t *Transformer) popContext(execNode *ExecNode) { + topContext := t.ContextStack[len(t.ContextStack)-1] + if topContext.Node != execNode { + panic("unexpected context in the stack") + } + + t.ContextStack = t.ContextStack[0 : len(t.ContextStack)-1] +} + +func (t *Transformer) gatherResolvedFieldRefs(expr Expression) []resolvedFieldRef { + fieldRefs := fetchExprFieldRefs(expr) + + var resolvedFieldRefs []resolvedFieldRef + for _, fieldRef := range fieldRefs { + resolvedFieldRefs = append(resolvedFieldRefs, t.resolveRef(fieldRef)) + } + return resolvedFieldRefs +} + +func (t *Transformer) getContext(varID VariableID) *compileContext { + if varID == 0 { + return nil + } + + for i := len(t.ContextStack) - 1; i >= 0; i++ { + if t.ContextStack[i].Var == varID { + return t.ContextStack[i] + } + } + + panic("reference to out-of-context variable was encountered") +} + +func (t *Transformer) resolveRef(fieldExpr FieldExpr) resolvedFieldRef { + return resolvedFieldRef{ + Context: t.getContext(fieldExpr.Root), + Path: fieldExpr.Path, + } +} + +func (t *Transformer) findFieldRefsBestRoot(fieldRefs []resolvedFieldRef) (resolvedFieldRef, bool) { + var currentContext *compileContext + if len(t.ContextStack) > 0 { + currentContext = t.ContextStack[len(t.ContextStack)-1] + } + + var contextFields []resolvedFieldRef + for _, fieldRef := range fieldRefs { + if fieldRef.Context == currentContext { + contextFields = append(contextFields, fieldRef) + } + } + + if len(contextFields) == 0 { + return resolvedFieldRef{ + Context: currentContext, + Path: []string{}, + }, false + } + + // Pick the base path as being the longest of all the paths + basePath := contextFields[0].Path + for i := 1; i < len(contextFields); i++ { + if len(contextFields[i].Path) > len(basePath) { + basePath = contextFields[i].Path + } + } + + var commonPath []string + +PathLoop: + for j := 0; j < len(basePath); j++ { + for i := 0; i < len(contextFields); i++ { + deepField := contextFields[i] + if len(deepField.Path) < j || deepField.Path[j] != basePath[j] { + break PathLoop + } + } + commonPath = append(commonPath, basePath[j]) + } + + needsAfter := len(commonPath) < len(basePath) + + return resolvedFieldRef{ + Context: currentContext, + Path: commonPath, + }, needsAfter +} + +type nodeRef struct { + node *ExecNode + after *AfterNode +} + +func (ref *nodeRef) AddOp(op OpNode) { + if ref.node != nil { + ref.node.Ops = append(ref.node.Ops, op) + } else if ref.after != nil { + ref.after.Ops = append(ref.after.Ops, op) + } else { + panic("cannot add an op to a null node reference") + } +} + +func (ref *nodeRef) AddLoop(loop LoopNode) { + // TODO(brett19): This function currently validates that there + // is only 1 valid possible loop target used depending on which + // loop type its going into. Someday we may implement function + // support which will invalidate this error check. We do this + // here to ensure that the error is caught at compilation rather + // than at match time. + + if ref.node != nil { + if loop.Target != nil { + panic("loops must always target the active state") + } + + ref.node.Loops = append(ref.node.Loops, loop) + } else if ref.after != nil { + if _, ok := loop.Target.(SlotRef); !ok { + panic("after-loops must always target a slot") + } + + ref.after.Loops = append(ref.after.Loops, loop) + } else { + panic("cannot add a loop to a null node reference") + } +} + +func (t *Transformer) pickBaseNode(expr Expression) nodeRef { + fieldRefs := t.gatherResolvedFieldRefs(expr) + bestBase, needsAfter := t.findFieldRefsBestRoot(fieldRefs) + baseNode := t.getExecNode(bestBase) + + if !needsAfter { + return nodeRef{ + node: baseNode, + after: nil, + } + } + + afterNode := t.getAfterNode(baseNode) + return nodeRef{ + node: nil, + after: afterNode, + } +} + +func (t *Transformer) makeDataRef(expr Expression, context nodeRef) (DataRef, error) { + switch expr := expr.(type) { + case FieldExpr: + resField := t.resolveRef(expr) + fieldNode := t.getExecNode(resField) + if context.node == fieldNode { + return nil, nil + } + + slot := t.storeExecNode(fieldNode) + return SlotRef{slot}, nil + case ValueExpr: + val := NewFastVal(expr.Value) + if val.IsStringLike() { + val, _ = val.AsJsonString() + } + return val, nil + case RegexExpr: + regex, err := regexp.Compile(expr.Regex.(string)) + if err != nil { + return nil, errors.New("failed to compile RegexExpr: " + err.Error()) + } + return NewFastVal(regex), nil + } + + return nil, errors.New("unsupported expression in parameter") +} + func (t *Transformer) transformMergePiece(expr mergeExpr, i int) *ExecNode { if i == len(expr.exprs)-1 { expr.bucketIDs[i] = t.ActiveBucketIdx @@ -191,96 +384,62 @@ func (t *Transformer) transformAnd(expr AndExpr) *ExecNode { return nil } -func (t *Transformer) transformLoop(loopType LoopType, varID VariableID, inExpr, subExpr Expression) *ExecNode { - if rhsField, ok := inExpr.(FieldExpr); ok { - newNode := &ExecNode{} - execNode := t.getExecNode(rhsField) - - // If the sub-expression of this loop access data that - // is not whole contained within the loop variables, we - // need to pull the whole loop out to the after block - // to guarentee that all data dependencies have been - // resolved and are available. - subRootRefs := subExpr.RootRefs() - if len(subRootRefs) > 0 { - storeId := t.storeNode(execNode) - execNode = t.makeAfterNode(t.ActiveExec, storeId) - } +func (t *Transformer) transformLoop(expr Expression, loopType LoopType, varID VariableID, inExpr, subExpr Expression) *ExecNode { + baseNode := t.pickBaseNode(expr) - execNode.Loops = append(execNode.Loops, LoopNode{ - t.ActiveBucketIdx, - loopType, - newNode, - }) + newNode := &ExecNode{} - oldActiveExec := t.ActiveExec - t.ActiveExec = newNode - t.CurDepth++ + loopTarget, err := t.makeDataRef(inExpr, baseNode) + if err != nil { + panic(err) + } - if t.CurDepth > t.MaxDepth { - t.MaxDepth = t.CurDepth - } + baseNode.AddLoop(LoopNode{ + t.ActiveBucketIdx, + loopType, + loopTarget, + newNode, + }) - t.NodeMap[varID] = newNode + // Push this context to the stack + t.pushContext(varID, newNode) - t.transformOne(subExpr) + // Transform the loops expression body + t.transformOne(subExpr) - t.CurDepth-- - t.ActiveExec = oldActiveExec - } else { - panic("RHS of AnyIn must be a FieldExpr") - } + // Pop from the context stack + t.popContext(newNode) return nil } func (t *Transformer) transformAnyIn(expr AnyInExpr) *ExecNode { - return t.transformLoop(LoopTypeAny, expr.VarId, expr.InExpr, expr.SubExpr) + return t.transformLoop(expr, LoopTypeAny, expr.VarId, expr.InExpr, expr.SubExpr) } func (t *Transformer) transformEveryIn(expr EveryInExpr) *ExecNode { - return t.transformLoop(LoopTypeEvery, expr.VarId, expr.InExpr, expr.SubExpr) + return t.transformLoop(expr, LoopTypeEvery, expr.VarId, expr.InExpr, expr.SubExpr) } func (t *Transformer) transformAnyEveryIn(expr AnyEveryInExpr) *ExecNode { - return t.transformLoop(LoopTypeAnyEvery, expr.VarId, expr.InExpr, expr.SubExpr) -} - -func (t *Transformer) makeRhsParam(expr Expression) interface{} { - if rhsField, ok := expr.(FieldExpr); ok { - rhsNode := t.getExecNode(rhsField) - rhsStoreId := t.storeNode(rhsNode) - return SlotRef{rhsStoreId} - } else if rhsValue, ok := expr.(ValueExpr); ok { - val := NewFastVal(rhsValue.Value) - if val.IsStringLike() { - val, _ = val.AsJsonString() - } - return val - } else if rhsValue, ok := expr.(RegexExpr); ok { - regex, err := regexp.Compile(rhsValue.Regex.(string)) - if err != nil { - return "??ERROR??" - } - return NewFastVal(regex) - } else { - return "??ERROR??" - } + return t.transformLoop(expr, LoopTypeAnyEvery, expr.VarId, expr.InExpr, expr.SubExpr) } func (t *Transformer) transformExists(expr ExistsExpr) *ExecNode { - if lhsField, ok := expr.SubExpr.(FieldExpr); ok { - execNode := t.getExecNode(lhsField) - - execNode.Ops = append(execNode.Ops, &OpNode{ - t.ActiveBucketIdx, - OpTypeExists, - nil, - }) - } else { - panic("LHS of a comparison expression must be a FieldExpr") + baseNode := t.pickBaseNode(expr) + + lhsDataRef, err := t.makeDataRef(expr.SubExpr, baseNode) + if err != nil { + panic(err) } + baseNode.AddOp(OpNode{ + t.ActiveBucketIdx, + OpTypeExists, + lhsDataRef, + nil, + }) + return nil } @@ -292,54 +451,55 @@ func (t *Transformer) transformNotExists(expr NotExistsExpr) *ExecNode { }) } -func (t *Transformer) transformComparison(op OpType, lhs, rhs Expression) *ExecNode { - if lhsField, ok := lhs.(FieldExpr); ok { - execNode := t.getExecNode(lhsField) +func (t *Transformer) transformComparison(expr Expression, op OpType, lhs, rhs Expression) *ExecNode { + baseNode := t.pickBaseNode(expr) - lhsRootRefs := rhs.RootRefs() - if len(lhsRootRefs) > 0 { - storeId := t.storeNode(execNode) - execNode = t.makeAfterNode(t.ActiveExec, storeId) - } + lhsRef, err := t.makeDataRef(lhs, baseNode) + if err != nil { + panic(err) + } - execNode.Ops = append(execNode.Ops, &OpNode{ - t.ActiveBucketIdx, - op, - t.makeRhsParam(rhs), - }) - } else { - panic("LHS of a comparison expression must be a FieldExpr") + rhsRef, err := t.makeDataRef(rhs, baseNode) + if err != nil { + panic(err) } + baseNode.AddOp(OpNode{ + t.ActiveBucketIdx, + op, + lhsRef, + rhsRef, + }) + return nil } func (t *Transformer) transformEquals(expr EqualsExpr) *ExecNode { - return t.transformComparison(OpTypeEquals, expr.Lhs, expr.Rhs) + return t.transformComparison(expr, OpTypeEquals, expr.Lhs, expr.Rhs) } func (t *Transformer) transformNotEquals(expr NotEqualsExpr) *ExecNode { - return t.transformComparison(OpTypeNotEquals, expr.Lhs, expr.Rhs) + return t.transformComparison(expr, OpTypeNotEquals, expr.Lhs, expr.Rhs) } func (t *Transformer) transformLessThan(expr LessThanExpr) *ExecNode { - return t.transformComparison(OpTypeLessThan, expr.Lhs, expr.Rhs) + return t.transformComparison(expr, OpTypeLessThan, expr.Lhs, expr.Rhs) } func (t *Transformer) transformLessEquals(expr LessEqualsExpr) *ExecNode { - return t.transformComparison(OpTypeLessEquals, expr.Lhs, expr.Rhs) + return t.transformComparison(expr, OpTypeLessEquals, expr.Lhs, expr.Rhs) } func (t *Transformer) transformGreaterThan(expr GreaterThanExpr) *ExecNode { - return t.transformComparison(OpTypeGreaterThan, expr.Lhs, expr.Rhs) + return t.transformComparison(expr, OpTypeGreaterThan, expr.Lhs, expr.Rhs) } func (t *Transformer) transformGreaterEquals(expr GreaterEqualsExpr) *ExecNode { - return t.transformComparison(OpTypeGreaterEquals, expr.Lhs, expr.Rhs) + return t.transformComparison(expr, OpTypeGreaterEquals, expr.Lhs, expr.Rhs) } func (t *Transformer) transformLike(expr LikeExpr) *ExecNode { - return t.transformComparison(OpTypeMatches, expr.Lhs, expr.Rhs) + return t.transformComparison(expr, OpTypeMatches, expr.Lhs, expr.Rhs) } func (t *Transformer) transformOne(expr Expression) *ExecNode { @@ -385,11 +545,7 @@ var AlwaysFalseIdent = -2 func (t *Transformer) Transform(exprs []Expression) *MatchDef { t.RootExec = &ExecNode{} - t.ActiveExec = t.RootExec - t.NodeMap = make(map[VariableID]*ExecNode) - - t.CurDepth = 1 - t.MaxDepth = t.CurDepth + t.ContextStack = nil t.BucketIdx = 1 t.ActiveBucketIdx = 0 t.RootTree = binTree{[]binTreeNode{ @@ -435,7 +591,6 @@ func (t *Transformer) Transform(exprs []Expression) *MatchDef { t.RootTree = binTree{} t.BucketIdx = 0 t.SlotIdx = 0 - t.MaxDepth = 0 } if t.RootExec != nil { @@ -455,6 +610,5 @@ func (t *Transformer) Transform(exprs []Expression) *MatchDef { MatchBuckets: exprBucketIDs, NumBuckets: int(t.BucketIdx), NumSlots: int(t.SlotIdx), - MaxDepth: t.MaxDepth, } }