diff --git a/go/vt/vtgate/evalengine/cached_size.go b/go/vt/vtgate/evalengine/cached_size.go index 65f0bd37d12..8830ea5fc79 100644 --- a/go/vt/vtgate/evalengine/cached_size.go +++ b/go/vt/vtgate/evalengine/cached_size.go @@ -1355,7 +1355,7 @@ func (cached *builtinMultiComparison) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(48) + size += int64(64) } // field CallExpr vitess.io/vitess/go/vt/vtgate/evalengine.CallExpr size += cached.CallExpr.CachedSize(false) diff --git a/go/vt/vtgate/evalengine/compiler.go b/go/vt/vtgate/evalengine/compiler.go index bcb2281f1a6..e56020c6ec3 100644 --- a/go/vt/vtgate/evalengine/compiler.go +++ b/go/vt/vtgate/evalengine/compiler.go @@ -331,7 +331,7 @@ func (c *compiler) compileToDecimal(ct ctype, offset int) ctype { c.asm.Convert_id(offset) case sqltypes.Uint64: c.asm.Convert_ud(offset) - case sqltypes.Datetime, sqltypes.Time: + case sqltypes.Datetime, sqltypes.Time, sqltypes.Timestamp: scale = ct.Size size = ct.Size + decimalSizeBase fallthrough @@ -341,6 +341,28 @@ func (c *compiler) compileToDecimal(ct ctype, offset int) ctype { return ctype{Type: sqltypes.Decimal, Flag: ct.Flag, Col: collationNumeric, Scale: scale, Size: size} } +func (c *compiler) compileToTemporal(doct ctype, typ sqltypes.Type, offset, prec int) ctype { + switch doct.Type { + case typ: + if int(doct.Size) == prec { + return doct + } + fallthrough + default: + switch typ { + case sqltypes.Date: + c.asm.Convert_xD(offset, c.sqlmode.AllowZeroDate()) + case sqltypes.Datetime: + c.asm.Convert_xDT(offset, prec, c.sqlmode.AllowZeroDate()) + case sqltypes.Timestamp: + c.asm.Convert_xDTs(offset, prec, c.sqlmode.AllowZeroDate()) + case sqltypes.Time: + c.asm.Convert_xT(offset, prec) + } + } + return ctype{Type: typ, Col: collationBinary, Flag: flagNullable} +} + func (c *compiler) compileToDate(doct ctype, offset int) ctype { switch doct.Type { case sqltypes.Date: @@ -362,6 +384,17 @@ func (c *compiler) compileToDateTime(doct ctype, offset, prec int) ctype { return ctype{Type: sqltypes.Datetime, Size: int32(prec), Col: collationBinary, Flag: flagNullable} } +func (c *compiler) compileToTimestamp(doct ctype, offset, prec int) ctype { + switch doct.Type { + case sqltypes.Timestamp: + c.asm.Convert_tp(offset, prec) + return doct + default: + c.asm.Convert_xDTs(offset, prec, c.sqlmode.AllowZeroDate()) + } + return ctype{Type: sqltypes.Timestamp, Size: int32(prec), Col: collationBinary, Flag: flagNullable} +} + func (c *compiler) compileToTime(doct ctype, offset, prec int) ctype { switch doct.Type { case sqltypes.Time: diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index 6c8896bb1f4..22414a86f34 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -767,11 +767,11 @@ func (asm *assembler) CmpDates() { }, "CMP DATE(SP-2), DATE(SP-1)") } -func (asm *assembler) Collate(col collations.ID) { +func (asm *assembler) Collate(col collations.TypedCollation) { asm.emit(func(env *ExpressionEnv) int { a := env.vm.stack[env.vm.sp-1].(*evalBytes) a.tt = int16(sqltypes.VarChar) - a.col.Collation = col + a.col = col return 1 }, "COLLATE VARCHAR(SP-1), %d", col) } @@ -1170,6 +1170,21 @@ func (asm *assembler) Convert_xDT(offset, prec int, allowZero bool) { }, "CONV (SP-%d), DATETIME", offset) } +func (asm *assembler) Convert_xDTs(offset, prec int, allowZero bool) { + asm.emit(func(env *ExpressionEnv) int { + // Need to explicitly check here or we otherwise + // store a nil wrapper in an interface vs. a direct + // nil. + dt := evalToTimestamp(env.vm.stack[env.vm.sp-offset], prec, env.now, allowZero) + if dt == nil { + env.vm.stack[env.vm.sp-offset] = nil + } else { + env.vm.stack[env.vm.sp-offset] = dt + } + return 1 + }, "CONV (SP-%d), TIMESTAMP", offset) +} + func (asm *assembler) Convert_xT(offset, prec int) { asm.emit(func(env *ExpressionEnv) int { t := evalToTime(env.vm.stack[env.vm.sp-offset], prec) @@ -2670,6 +2685,40 @@ func (asm *assembler) Fn_MULTICMP_u(args int, lessThan bool) { }, "FN MULTICMP UINT64(SP-%d)...UINT64(SP-1)", args) } +func (asm *assembler) Fn_MULTICMP_temporal(args int, lessThan bool) { + asm.adjustStack(-(args - 1)) + + asm.emit(func(env *ExpressionEnv) int { + var x *evalTemporal + x, _ = env.vm.stack[env.vm.sp-args].(*evalTemporal) + for sp := env.vm.sp - args + 1; sp < env.vm.sp; sp++ { + if env.vm.stack[sp] == nil { + if lessThan { + x = nil + } + continue + } + y := env.vm.stack[sp].(*evalTemporal) + if lessThan == (y.compare(x) < 0) { + x = y + } + } + env.vm.stack[env.vm.sp-args] = x + env.vm.sp -= args - 1 + return 1 + }, "FN MULTICMP TEMPORAL(SP-%d)...TEMPORAL(SP-1)", args) +} + +func (asm *assembler) Fn_MULTICMP_temporal_fallback(f multiComparisonFunc, args int, cmp, prec int) { + asm.adjustStack(-(args - 1)) + + asm.emit(func(env *ExpressionEnv) int { + env.vm.stack[env.vm.sp-args], env.vm.err = f(env, env.vm.stack[env.vm.sp-args:env.vm.sp], cmp, prec) + env.vm.sp -= args - 1 + return 1 + }, "FN MULTICMP_FALLBACK TEMPORAL(SP-%d)...TEMPORAL(SP-1)", args) +} + func (asm *assembler) Fn_REPEAT() { asm.adjustStack(-1) diff --git a/go/vt/vtgate/evalengine/compiler_asm_push.go b/go/vt/vtgate/evalengine/compiler_asm_push.go index 87d2ee9af9b..74ce514c69d 100644 --- a/go/vt/vtgate/evalengine/compiler_asm_push.go +++ b/go/vt/vtgate/evalengine/compiler_asm_push.go @@ -332,6 +332,23 @@ func (asm *assembler) PushColumn_datetime(offset int) { }, "PUSH DATETIME(:%d)", offset) } +func push_timestamp(env *ExpressionEnv, raw []byte) int { + env.vm.stack[env.vm.sp], env.vm.err = parseTimestamp(raw) + env.vm.sp++ + return 1 +} + +func (asm *assembler) PushColumn_timestamp(offset int) { + asm.adjustStack(1) + asm.emit(func(env *ExpressionEnv) int { + col := env.Row[offset] + if col.IsNull() { + return push_null(env) + } + return push_timestamp(env, col.Raw()) + }, "PUSH TIMESTAMP(:%d)", offset) +} + func (asm *assembler) PushBVar_datetime(key string) { asm.adjustStack(1) asm.emit(func(env *ExpressionEnv) int { @@ -344,6 +361,18 @@ func (asm *assembler) PushBVar_datetime(key string) { }, "PUSH DATETIME(:%q)", key) } +func (asm *assembler) PushBVar_timestamp(key string) { + asm.adjustStack(1) + asm.emit(func(env *ExpressionEnv) int { + var bvar *querypb.BindVariable + bvar, env.vm.err = env.lookupBindVar(key) + if env.vm.err != nil { + return 0 + } + return push_timestamp(env, bvar.Value) + }, "PUSH TIMESTAMP(:%q)", key) +} + func push_date(env *ExpressionEnv, raw []byte) int { env.vm.stack[env.vm.sp], env.vm.err = parseDate(raw) env.vm.sp++ diff --git a/go/vt/vtgate/evalengine/compiler_test.go b/go/vt/vtgate/evalengine/compiler_test.go index 7bb4d48df51..742aa9cc072 100644 --- a/go/vt/vtgate/evalengine/compiler_test.go +++ b/go/vt/vtgate/evalengine/compiler_test.go @@ -24,12 +24,12 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" - "github.com/olekukonko/tablewriter" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/mysql/collations/colldata" "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" "vitess.io/vitess/go/vt/sqlparser" @@ -119,7 +119,7 @@ func TestCompilerReference(t *testing.T) { var supported, total int env := evalengine.EmptyExpressionEnv(venv) - tc.Run(func(query string, row []sqltypes.Value) { + tc.Run(func(query string, row []sqltypes.Value, _ bool) { env.Row = row total++ testCompilerCase(t, query, venv, tc.Schema, env) @@ -171,6 +171,7 @@ func testCompilerCase(t *testing.T, query string, venv *vtenv.Environment, schem eval := expected.String() comp := res.String() assert.Equalf(t, eval, comp, "bad evaluation from compiler:\nSQL: %s\nEval: %s\nComp: %s", query, eval, comp) + assert.Equalf(t, expected.Collation(), res.Collation(), "bad collation from compiler:\nSQL: %s\nEval: %s\nComp: %s", query, colldata.Lookup(expected.Collation()).Name(), colldata.Lookup(res.Collation()).Name()) case vmErr == nil: t.Errorf("failed evaluation from evalengine:\nSQL: %s\nError: %s", query, evalErr) case evalErr == nil: diff --git a/go/vt/vtgate/evalengine/eval_temporal.go b/go/vt/vtgate/evalengine/eval_temporal.go index d73485441c3..2adab98afc4 100644 --- a/go/vt/vtgate/evalengine/eval_temporal.go +++ b/go/vt/vtgate/evalengine/eval_temporal.go @@ -29,7 +29,7 @@ func (e *evalTemporal) ToRawBytes() []byte { switch e.t { case sqltypes.Date: return e.dt.Date.Format() - case sqltypes.Datetime: + case sqltypes.Datetime, sqltypes.Timestamp: return e.dt.Format(e.prec) case sqltypes.Time: return e.dt.Time.Format(e.prec) @@ -54,7 +54,7 @@ func (e *evalTemporal) toInt64() int64 { switch e.SQLType() { case sqltypes.Date: return e.dt.Date.FormatInt64() - case sqltypes.Datetime: + case sqltypes.Datetime, sqltypes.Timestamp: return e.dt.FormatInt64() case sqltypes.Time: return e.dt.Time.FormatInt64() @@ -67,7 +67,7 @@ func (e *evalTemporal) toFloat() float64 { switch e.SQLType() { case sqltypes.Date: return float64(e.dt.Date.FormatInt64()) - case sqltypes.Datetime: + case sqltypes.Datetime, sqltypes.Timestamp: return e.dt.FormatFloat64() case sqltypes.Time: return e.dt.Time.FormatFloat64() @@ -80,7 +80,7 @@ func (e *evalTemporal) toDecimal() decimal.Decimal { switch e.SQLType() { case sqltypes.Date: return decimal.NewFromInt(e.dt.Date.FormatInt64()) - case sqltypes.Datetime: + case sqltypes.Datetime, sqltypes.Timestamp: return e.dt.FormatDecimal() case sqltypes.Time: return e.dt.Time.FormatDecimal() @@ -93,7 +93,7 @@ func (e *evalTemporal) toJSON() *evalJSON { switch e.SQLType() { case sqltypes.Date: return json.NewDate(hack.String(e.dt.Date.Format())) - case sqltypes.Datetime: + case sqltypes.Datetime, sqltypes.Timestamp: return json.NewDateTime(hack.String(e.dt.Format(datetime.DefaultPrecision))) case sqltypes.Time: return json.NewTime(hack.String(e.dt.Time.Format(datetime.DefaultPrecision))) @@ -104,7 +104,7 @@ func (e *evalTemporal) toJSON() *evalJSON { func (e *evalTemporal) toDateTime(l int, now time.Time) *evalTemporal { switch e.SQLType() { - case sqltypes.Datetime, sqltypes.Date: + case sqltypes.Datetime, sqltypes.Date, sqltypes.Timestamp: return &evalTemporal{t: sqltypes.Datetime, dt: e.dt.Round(l), prec: uint8(l)} case sqltypes.Time: return &evalTemporal{t: sqltypes.Datetime, dt: e.dt.Time.Round(l).ToDateTime(now), prec: uint8(l)} @@ -113,9 +113,23 @@ func (e *evalTemporal) toDateTime(l int, now time.Time) *evalTemporal { } } +func (e *evalTemporal) toTimestamp(l int, now time.Time) *evalTemporal { + switch e.SQLType() { + case sqltypes.Datetime, sqltypes.Date, sqltypes.Timestamp: + return &evalTemporal{t: sqltypes.Timestamp, dt: e.dt.Round(l), prec: uint8(l)} + case sqltypes.Time: + return &evalTemporal{t: sqltypes.Timestamp, dt: e.dt.Time.Round(l).ToDateTime(now), prec: uint8(l)} + default: + panic("unreachable") + } +} + func (e *evalTemporal) toTime(l int) *evalTemporal { + if l == -1 { + l = int(e.prec) + } switch e.SQLType() { - case sqltypes.Datetime: + case sqltypes.Datetime, sqltypes.Timestamp: dt := datetime.DateTime{Time: e.dt.Time.Round(l)} return &evalTemporal{t: sqltypes.Time, dt: dt, prec: uint8(l)} case sqltypes.Date: @@ -130,7 +144,7 @@ func (e *evalTemporal) toTime(l int) *evalTemporal { func (e *evalTemporal) toDate(now time.Time) *evalTemporal { switch e.SQLType() { - case sqltypes.Datetime: + case sqltypes.Datetime, sqltypes.Timestamp: dt := datetime.DateTime{Date: e.dt.Date} return &evalTemporal{t: sqltypes.Date, dt: dt} case sqltypes.Date: @@ -148,6 +162,13 @@ func (e *evalTemporal) isZero() bool { return e.dt.IsZero() } +func (e *evalTemporal) compare(other *evalTemporal) int { + if other == nil { + return 1 + } + return e.dt.Compare(other.dt) +} + func (e *evalTemporal) addInterval(interval *datetime.Interval, coll collations.ID, now time.Time) eval { var tmp *evalTemporal var ok bool @@ -179,6 +200,13 @@ func newEvalDateTime(dt datetime.DateTime, l int, allowZero bool) *evalTemporal return &evalTemporal{t: sqltypes.Datetime, dt: dt.Round(l), prec: uint8(l)} } +func newEvalTimestamp(dt datetime.DateTime, l int, allowZero bool) *evalTemporal { + if !allowZero && dt.IsZero() { + return nil + } + return &evalTemporal{t: sqltypes.Timestamp, dt: dt.Round(l), prec: uint8(l)} +} + func newEvalDate(d datetime.Date, allowZero bool) *evalTemporal { if !allowZero && d.IsZero() { return nil @@ -210,6 +238,14 @@ func parseDateTime(s []byte) (*evalTemporal, error) { return newEvalDateTime(t, l, true), nil } +func parseTimestamp(s []byte) (*evalTemporal, error) { + t, l, ok := datetime.ParseDateTime(hack.String(s), -1) + if !ok { + return nil, errIncorrectTemporal("TIMESTAMP", s) + } + return newEvalTimestamp(t, l, true), nil +} + func parseTime(s []byte) (*evalTemporal, error) { t, l, state := datetime.ParseTime(hack.String(s), -1) if state != datetime.TimeOK { @@ -387,6 +423,53 @@ func evalToDateTime(e eval, l int, now time.Time, allowZero bool) *evalTemporal return nil } +func evalToTimestamp(e eval, l int, now time.Time, allowZero bool) *evalTemporal { + switch e := e.(type) { + case *evalTemporal: + return e.toTimestamp(precision(l, int(e.prec)), now) + case *evalBytes: + if t, l, _ := datetime.ParseDateTime(e.string(), l); !t.IsZero() { + return newEvalTimestamp(t, l, allowZero) + } + if d, _ := datetime.ParseDate(e.string()); !d.IsZero() { + return newEvalTimestamp(datetime.DateTime{Date: d}, precision(l, 0), allowZero) + } + case *evalInt64: + if t, ok := datetime.ParseDateTimeInt64(e.i); ok { + return newEvalTimestamp(t, precision(l, 0), allowZero) + } + if d, ok := datetime.ParseDateInt64(e.i); ok { + return newEvalTimestamp(datetime.DateTime{Date: d}, precision(l, 0), allowZero) + } + case *evalUint64: + if t, ok := datetime.ParseDateTimeInt64(int64(e.u)); ok { + return newEvalTimestamp(t, precision(l, 0), allowZero) + } + if d, ok := datetime.ParseDateInt64(int64(e.u)); ok { + return newEvalTimestamp(datetime.DateTime{Date: d}, precision(l, 0), allowZero) + } + case *evalFloat: + if t, l, ok := datetime.ParseDateTimeFloat(e.f, l); ok { + return newEvalTimestamp(t, l, allowZero) + } + if d, ok := datetime.ParseDateFloat(e.f); ok { + return newEvalTimestamp(datetime.DateTime{Date: d}, precision(l, 0), allowZero) + } + case *evalDecimal: + if t, l, ok := datetime.ParseDateTimeDecimal(e.dec, e.length, l); ok { + return newEvalTimestamp(t, l, allowZero) + } + if d, ok := datetime.ParseDateDecimal(e.dec); ok { + return newEvalTimestamp(datetime.DateTime{Date: d}, precision(l, 0), allowZero) + } + case *evalJSON: + if dt, ok := e.DateTime(); ok { + return newEvalTimestamp(dt, precision(l, datetime.DefaultPrecision), allowZero) + } + } + return nil +} + func evalToDate(e eval, now time.Time, allowZero bool) *evalTemporal { switch e := e.(type) { case *evalTemporal: diff --git a/go/vt/vtgate/evalengine/expr_bvar.go b/go/vt/vtgate/evalengine/expr_bvar.go index 0fffe3140a2..d1bca326d01 100644 --- a/go/vt/vtgate/evalengine/expr_bvar.go +++ b/go/vt/vtgate/evalengine/expr_bvar.go @@ -157,8 +157,10 @@ func (bvar *BindVariable) compile(c *compiler) (ctype, error) { c.asm.PushNull() case tt == sqltypes.TypeJSON: c.asm.PushBVar_json(bvar.Key) - case tt == sqltypes.Datetime || tt == sqltypes.Timestamp: + case tt == sqltypes.Datetime: c.asm.PushBVar_datetime(bvar.Key) + case tt == sqltypes.Timestamp: + c.asm.PushBVar_timestamp(bvar.Key) case tt == sqltypes.Date: c.asm.PushBVar_date(bvar.Key) case tt == sqltypes.Time: diff --git a/go/vt/vtgate/evalengine/expr_collate.go b/go/vt/vtgate/evalengine/expr_collate.go index be0eb78882b..b381acf6356 100644 --- a/go/vt/vtgate/evalengine/expr_collate.go +++ b/go/vt/vtgate/evalengine/expr_collate.go @@ -118,7 +118,7 @@ func (expr *CollateExpr) compile(c *compiler) (ctype, error) { } fallthrough case sqltypes.VarBinary: - c.asm.Collate(expr.TypedCollation.Collation) + c.asm.Collate(expr.TypedCollation) default: c.asm.Convert_xc(1, sqltypes.VarChar, expr.TypedCollation.Collation, nil) } diff --git a/go/vt/vtgate/evalengine/expr_column.go b/go/vt/vtgate/evalengine/expr_column.go index d53585ceb8b..ba7c2dbcb32 100644 --- a/go/vt/vtgate/evalengine/expr_column.go +++ b/go/vt/vtgate/evalengine/expr_column.go @@ -145,8 +145,10 @@ func (column *Column) compile(c *compiler) (ctype, error) { c.asm.PushNull() case tt == sqltypes.TypeJSON: c.asm.PushColumn_json(column.Offset) - case tt == sqltypes.Datetime || tt == sqltypes.Timestamp: + case tt == sqltypes.Datetime: c.asm.PushColumn_datetime(column.Offset) + case tt == sqltypes.Timestamp: + c.asm.PushColumn_timestamp(column.Offset) case tt == sqltypes.Date: c.asm.PushColumn_date(column.Offset) case tt == sqltypes.Time: diff --git a/go/vt/vtgate/evalengine/fn_compare.go b/go/vt/vtgate/evalengine/fn_compare.go index 1deec6752ef..1084a240bd8 100644 --- a/go/vt/vtgate/evalengine/fn_compare.go +++ b/go/vt/vtgate/evalengine/fn_compare.go @@ -22,6 +22,7 @@ import ( "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/mysql/collations/charset" "vitess.io/vitess/go/mysql/collations/colldata" + datetime2 "vitess.io/vitess/go/mysql/datetime" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/vterrors" @@ -32,11 +33,12 @@ type ( CallExpr } - multiComparisonFunc func(collationEnv *collations.Environment, args []eval, cmp int) (eval, error) + multiComparisonFunc func(env *ExpressionEnv, args []eval, cmp, prec int) (eval, error) builtinMultiComparison struct { CallExpr - cmp int + cmp int + prec int } ) @@ -93,7 +95,7 @@ func (b *builtinCoalesce) compile(c *compiler) (ctype, error) { return ctype{Type: ta.result(), Flag: f, Col: ca.result()}, nil } -func getMultiComparisonFunc(args []eval) multiComparisonFunc { +func (call *builtinMultiComparison) getMultiComparisonFunc(args []eval) multiComparisonFunc { var ( integersI int integersU int @@ -101,6 +103,11 @@ func getMultiComparisonFunc(args []eval) multiComparisonFunc { decimals int text int binary int + temporal int + datetime int + timestamp int + date int + time int ) /* @@ -114,7 +121,7 @@ func getMultiComparisonFunc(args []eval) multiComparisonFunc { for _, arg := range args { if arg == nil { - return func(collationEnv *collations.Environment, args []eval, cmp int) (eval, error) { + return func(_ *ExpressionEnv, _ []eval, _, _ int) (eval, error) { return nil, nil } } @@ -126,18 +133,86 @@ func getMultiComparisonFunc(args []eval) multiComparisonFunc { integersU++ case *evalFloat: floats++ + call.prec = datetime2.DefaultPrecision case *evalDecimal: decimals++ + call.prec = max(call.prec, int(arg.length)) case *evalBytes: switch arg.SQLType() { case sqltypes.Text, sqltypes.VarChar: text++ + call.prec = max(call.prec, datetime2.DefaultPrecision) case sqltypes.Blob, sqltypes.Binary, sqltypes.VarBinary: binary++ + if !arg.isHexOrBitLiteral() { + call.prec = max(call.prec, datetime2.DefaultPrecision) + } + } + case *evalTemporal: + temporal++ + call.prec = max(call.prec, int(arg.prec)) + switch arg.SQLType() { + case sqltypes.Datetime: + datetime++ + case sqltypes.Timestamp: + timestamp++ + case sqltypes.Date: + date++ + case sqltypes.Time: + time++ } } } + if temporal == len(args) { + switch { + case datetime > 0: + return compareAllTemporal(func(env *ExpressionEnv, arg eval, prec int) *evalTemporal { + return evalToDateTime(arg, prec, env.now, env.sqlmode.AllowZeroDate()) + }) + case timestamp > 0: + return compareAllTemporal(func(env *ExpressionEnv, arg eval, prec int) *evalTemporal { + return evalToTimestamp(arg, prec, env.now, env.sqlmode.AllowZeroDate()) + }) + case date > 0 && time > 0: + // When all types are temporal, we convert the case + // of having a date and time all to datetime. + // This is contrary to the case where we have a non-temporal + // type in the list, since MySQL doesn't do that. + return compareAllTemporal(func(env *ExpressionEnv, arg eval, prec int) *evalTemporal { + return evalToDateTime(arg, prec, env.now, env.sqlmode.AllowZeroDate()) + }) + case date > 0: + return compareAllTemporal(func(env *ExpressionEnv, arg eval, _ int) *evalTemporal { + return evalToDate(arg, env.now, env.sqlmode.AllowZeroDate()) + }) + case time > 0: + return compareAllTemporal(func(_ *ExpressionEnv, arg eval, prec int) *evalTemporal { + return evalToTime(arg, prec) + }) + } + } + + switch { + case datetime > 0: + return compareAllTemporalAsString(func(env *ExpressionEnv, arg eval, prec int) *evalTemporal { + return evalToDateTime(arg, prec, env.now, env.sqlmode.AllowZeroDate()) + }) + case timestamp > 0: + return compareAllTemporalAsString(func(env *ExpressionEnv, arg eval, prec int) *evalTemporal { + return evalToTimestamp(arg, prec, env.now, env.sqlmode.AllowZeroDate()) + }) + case date > 0: + return compareAllTemporalAsString(func(env *ExpressionEnv, arg eval, _ int) *evalTemporal { + return evalToDate(arg, env.now, env.sqlmode.AllowZeroDate()) + }) + case time > 0: + // So for time, there's actually no conversion and + // internal comparisons as time. So we don't pass it + // a conversion function. + return compareAllTemporalAsString(nil) + } + if integersI+integersU == len(args) { if integersI == len(args) { return compareAllInteger_i @@ -165,7 +240,93 @@ func getMultiComparisonFunc(args []eval) multiComparisonFunc { panic("unexpected argument type") } -func compareAllInteger_u(_ *collations.Environment, args []eval, cmp int) (eval, error) { +func compareAllTemporal(f func(env *ExpressionEnv, arg eval, prec int) *evalTemporal) multiComparisonFunc { + return func(env *ExpressionEnv, args []eval, cmp, prec int) (eval, error) { + var x *evalTemporal + for _, arg := range args { + conv := f(env, arg, prec) + if x == nil { + x = conv + continue + } + if (cmp < 0) == (conv.compare(x) < 0) { + x = conv + } + } + return x, nil + } +} + +func compareAllTemporalAsString(f func(env *ExpressionEnv, arg eval, prec int) *evalTemporal) multiComparisonFunc { + return func(env *ExpressionEnv, args []eval, cmp, prec int) (eval, error) { + validArgs := make([]*evalTemporal, 0, len(args)) + var ca collationAggregation + for _, arg := range args { + if err := ca.add(evalCollation(arg), env.collationEnv); err != nil { + return nil, err + } + if f != nil { + conv := f(env, arg, prec) + validArgs = append(validArgs, conv) + } + } + tc := ca.result() + if tc.Coercibility == collations.CoerceNumeric { + tc = typedCoercionCollation(sqltypes.VarChar, env.collationEnv.DefaultConnectionCharset()) + } + if f != nil { + idx := compareTemporalInternal(validArgs, cmp) + if idx >= 0 { + arg := args[idx] + if _, ok := arg.(*evalTemporal); ok { + arg = validArgs[idx] + } + return evalToVarchar(arg, tc.Collation, false) + } + } + txt, err := compareAllText(env, args, cmp, prec) + if err != nil { + return nil, err + } + return evalToVarchar(txt, tc.Collation, false) + } +} + +func compareTemporalInternal(args []*evalTemporal, cmp int) int { + if cmp < 0 { + // If we have any failed conversions and want to have the smallest value, + // we can't find that so we return -1 to indicate that. + // This will result in a fallback to do a string comparison. + for _, arg := range args { + if arg == nil { + return -1 + } + } + } + + x := 0 + for i, arg := range args[1:] { + if arg == nil { + continue + } + if (cmp < 0) == (compareTemporal(args, i+1, x) < 0) { + x = i + 1 + } + } + return x +} + +func compareTemporal(args []*evalTemporal, idx1, idx2 int) int { + if idx1 < 0 { + return 1 + } + if idx2 < 0 { + return -1 + } + return args[idx1].compare(args[idx2]) +} + +func compareAllInteger_u(_ *ExpressionEnv, args []eval, cmp, _ int) (eval, error) { x := args[0].(*evalUint64) for _, arg := range args[1:] { y := arg.(*evalUint64) @@ -176,7 +337,7 @@ func compareAllInteger_u(_ *collations.Environment, args []eval, cmp int) (eval, return x, nil } -func compareAllInteger_i(_ *collations.Environment, args []eval, cmp int) (eval, error) { +func compareAllInteger_i(_ *ExpressionEnv, args []eval, cmp, _ int) (eval, error) { x := args[0].(*evalInt64) for _, arg := range args[1:] { y := arg.(*evalInt64) @@ -187,7 +348,7 @@ func compareAllInteger_i(_ *collations.Environment, args []eval, cmp int) (eval, return x, nil } -func compareAllFloat(_ *collations.Environment, args []eval, cmp int) (eval, error) { +func compareAllFloat(_ *ExpressionEnv, args []eval, cmp, _ int) (eval, error) { candidateF, ok := evalToFloat(args[0]) if !ok { return nil, errDecimalOutOfRange @@ -212,7 +373,7 @@ func evalDecimalPrecision(e eval) int32 { return 0 } -func compareAllDecimal(_ *collations.Environment, args []eval, cmp int) (eval, error) { +func compareAllDecimal(_ *ExpressionEnv, args []eval, cmp, _ int) (eval, error) { decExtreme := evalToDecimal(args[0], 0, 0).dec precExtreme := evalDecimalPrecision(args[0]) @@ -229,12 +390,12 @@ func compareAllDecimal(_ *collations.Environment, args []eval, cmp int) (eval, e return newEvalDecimalWithPrec(decExtreme, precExtreme), nil } -func compareAllText(collationEnv *collations.Environment, args []eval, cmp int) (eval, error) { +func compareAllText(env *ExpressionEnv, args []eval, cmp, _ int) (eval, error) { var charsets = make([]charset.Charset, 0, len(args)) var ca collationAggregation for _, arg := range args { col := evalCollation(arg) - if err := ca.add(col, collationEnv); err != nil { + if err := ca.add(col, env.collationEnv); err != nil { return nil, err } charsets = append(charsets, colldata.Lookup(col.Collation).Charset()) @@ -262,7 +423,7 @@ func compareAllText(collationEnv *collations.Environment, args []eval, cmp int) return newEvalText(b1, tc), nil } -func compareAllBinary(_ *collations.Environment, args []eval, cmp int) (eval, error) { +func compareAllBinary(_ *ExpressionEnv, args []eval, cmp, _ int) (eval, error) { candidateB := args[0].ToRawBytes() for _, arg := range args[1:] { @@ -280,7 +441,7 @@ func (call *builtinMultiComparison) eval(env *ExpressionEnv) (eval, error) { if err != nil { return nil, err } - return getMultiComparisonFunc(args)(env.collationEnv, args, call.cmp) + return call.getMultiComparisonFunc(args)(env, args, call.cmp, call.prec) } func (call *builtinMultiComparison) compile_c(c *compiler, args []ctype) (ctype, error) { @@ -314,14 +475,20 @@ func (call *builtinMultiComparison) compile_d(c *compiler, args []ctype) (ctype, func (call *builtinMultiComparison) compile(c *compiler) (ctype, error) { var ( - signed int - unsigned int - floats int - decimals int - text int - binary int - args []ctype - nullable bool + signed int + unsigned int + floats int + decimals int + temporal int + date int + datetime int + timestamp int + time int + text int + binary int + args []ctype + nullable bool + prec int ) /* @@ -349,12 +516,34 @@ func (call *builtinMultiComparison) compile(c *compiler) (ctype, error) { unsigned++ case sqltypes.Float64: floats++ + prec = max(prec, datetime2.DefaultPrecision) case sqltypes.Decimal: decimals++ + prec = max(prec, int(tt.Scale)) case sqltypes.Text, sqltypes.VarChar: text++ + prec = max(prec, datetime2.DefaultPrecision) case sqltypes.Blob, sqltypes.Binary, sqltypes.VarBinary: binary++ + if !tt.isHexOrBitLiteral() { + prec = max(prec, datetime2.DefaultPrecision) + } + case sqltypes.Date: + temporal++ + date++ + prec = max(prec, int(tt.Size)) + case sqltypes.Datetime: + temporal++ + datetime++ + prec = max(prec, int(tt.Size)) + case sqltypes.Timestamp: + temporal++ + timestamp++ + prec = max(prec, int(tt.Size)) + case sqltypes.Time: + temporal++ + time++ + prec = max(prec, int(tt.Size)) case sqltypes.Null: nullable = true default: @@ -366,6 +555,61 @@ func (call *builtinMultiComparison) compile(c *compiler) (ctype, error) { if nullable { f |= flagNullable } + if temporal == len(args) { + var typ sqltypes.Type + switch { + case datetime > 0: + typ = sqltypes.Datetime + case timestamp > 0: + typ = sqltypes.Timestamp + case date > 0 && time > 0: + // When all types are temporal, we convert the case + // of having a date and time all to datetime. + // This is contrary to the case where we have a non-temporal + // type in the list, since MySQL doesn't do that. + typ = sqltypes.Datetime + case date > 0: + typ = sqltypes.Date + case time > 0: + typ = sqltypes.Time + } + for i, tt := range args { + if tt.Type != typ || int(tt.Size) != prec { + c.compileToTemporal(tt, typ, len(args)-i, prec) + } + } + c.asm.Fn_MULTICMP_temporal(len(args), call.cmp < 0) + return ctype{Type: typ, Flag: f, Col: collationBinary}, nil + } else if temporal > 0 { + var ca collationAggregation + for _, arg := range args { + if err := ca.add(arg.Col, c.env.CollationEnv()); err != nil { + return ctype{}, err + } + } + + tc := ca.result() + if tc.Coercibility == collations.CoerceNumeric { + tc = typedCoercionCollation(sqltypes.VarChar, c.collation) + } + switch { + case datetime > 0: + c.asm.Fn_MULTICMP_temporal_fallback(compareAllTemporalAsString(func(env *ExpressionEnv, arg eval, prec int) *evalTemporal { + return evalToDateTime(arg, prec, env.now, env.sqlmode.AllowZeroDate()) + }), len(args), call.cmp, prec) + case timestamp > 0: + c.asm.Fn_MULTICMP_temporal_fallback(compareAllTemporalAsString(func(env *ExpressionEnv, arg eval, prec int) *evalTemporal { + return evalToTimestamp(arg, prec, env.now, env.sqlmode.AllowZeroDate()) + }), len(args), call.cmp, prec) + case date > 0: + c.asm.Fn_MULTICMP_temporal_fallback(compareAllTemporalAsString(func(env *ExpressionEnv, arg eval, prec int) *evalTemporal { + return evalToDate(arg, env.now, env.sqlmode.AllowZeroDate()) + }), len(args), call.cmp, prec) + case time > 0: + c.asm.Fn_MULTICMP_temporal_fallback(compareAllTemporalAsString(nil), len(args), call.cmp, prec) + } + return ctype{Type: sqltypes.VarChar, Flag: f, Col: tc}, nil + } if signed+unsigned == len(args) { if signed == len(args) { c.asm.Fn_MULTICMP_i(len(args), call.cmp < 0) diff --git a/go/vt/vtgate/evalengine/fn_compare_test.go b/go/vt/vtgate/evalengine/fn_compare_test.go new file mode 100644 index 00000000000..def40d8365c --- /dev/null +++ b/go/vt/vtgate/evalengine/fn_compare_test.go @@ -0,0 +1,80 @@ +/* +Copyright 2025 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package evalengine + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "vitess.io/vitess/go/mysql/datetime" +) + +func TestCompareTemporal(t *testing.T) { + tests := []struct { + name string + val1 *evalTemporal + val2 *evalTemporal + result int + }{ + { + name: "equal values", + val1: newEvalDateTime(datetime.NewDateTimeFromStd(time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC)), 6, false), + val2: newEvalDateTime(datetime.NewDateTimeFromStd(time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC)), 6, false), + result: 0, + }, + { + name: "larger value", + val1: newEvalDateTime(datetime.NewDateTimeFromStd(time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC)), 6, false), + val2: newEvalDateTime(datetime.NewDateTimeFromStd(time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC)), 6, false), + result: -1, + }, + { + name: "smaller value", + val1: newEvalDateTime(datetime.NewDateTimeFromStd(time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC)), 6, false), + val2: newEvalDateTime(datetime.NewDateTimeFromStd(time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC)), 6, false), + result: 1, + }, + { + name: "first nil value", + val1: nil, + val2: newEvalDateTime(datetime.NewDateTimeFromStd(time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC)), 6, false), + result: 1, + }, + + { + name: "second nil value", + val1: newEvalDateTime(datetime.NewDateTimeFromStd(time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC)), 6, false), + val2: nil, + result: -1, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + idx1 := 0 + idx2 := 1 + if tt.val1 == nil { + idx1 = -1 + } + if tt.val2 == nil { + idx2 = -1 + } + assert.Equal(t, tt.result, compareTemporal([]*evalTemporal{tt.val1, tt.val2}, idx1, idx2)) + }) + } +} diff --git a/go/vt/vtgate/evalengine/fn_time.go b/go/vt/vtgate/evalengine/fn_time.go index fc18efa80aa..0998cdb56e9 100644 --- a/go/vt/vtgate/evalengine/fn_time.go +++ b/go/vt/vtgate/evalengine/fn_time.go @@ -321,7 +321,7 @@ func (call *builtinDateFormat) compile(c *compiler) (ctype, error) { skip1 := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Datetime, sqltypes.Date: + case sqltypes.Datetime, sqltypes.Date, sqltypes.Timestamp: default: c.asm.Convert_xDT(1, datetime.DefaultPrecision, false) } @@ -437,7 +437,7 @@ func (call *builtinConvertTz) compile(c *compiler) (ctype, error) { var prec int32 switch n.Type { - case sqltypes.Datetime, sqltypes.Date: + case sqltypes.Datetime, sqltypes.Date, sqltypes.Timestamp: prec = n.Size case sqltypes.Decimal: prec = n.Scale @@ -519,7 +519,7 @@ func (call *builtinDayOfMonth) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Timestamp: default: c.asm.Convert_xD(1, true) } @@ -552,7 +552,7 @@ func (call *builtinDayOfWeek) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Timestamp: default: c.asm.Convert_xD(1, false) } @@ -585,7 +585,7 @@ func (call *builtinDayOfYear) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Timestamp: default: c.asm.Convert_xD(1, false) } @@ -728,7 +728,7 @@ func (call *builtinFromUnixtime) compile(c *compiler) (ctype, error) { case sqltypes.Decimal: prec = arg.Size c.asm.Fn_FROM_UNIXTIME_d() - case sqltypes.Datetime, sqltypes.Date, sqltypes.Time: + case sqltypes.Datetime, sqltypes.Date, sqltypes.Time, sqltypes.Timestamp: prec = arg.Size if prec == 0 { c.asm.Convert_Ti(1) @@ -800,7 +800,7 @@ func (call *builtinHour) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime, sqltypes.Time: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Time, sqltypes.Timestamp: default: c.asm.Convert_xT(1, -1) } @@ -1146,7 +1146,7 @@ func (call *builtinMicrosecond) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime, sqltypes.Time: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Time, sqltypes.Timestamp: default: c.asm.Convert_xT(1, -1) } @@ -1179,7 +1179,7 @@ func (call *builtinMinute) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime, sqltypes.Time: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Time, sqltypes.Timestamp: default: c.asm.Convert_xT(1, -1) } @@ -1212,7 +1212,7 @@ func (call *builtinMonth) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Timestamp: default: c.asm.Convert_xD(1, true) } @@ -1250,7 +1250,7 @@ func (call *builtinMonthName) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Timestamp: default: c.asm.Convert_xD(1, false) } @@ -1295,7 +1295,7 @@ func (call *builtinLastDay) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Timestamp: default: c.asm.Convert_xD(1, c.sqlmode.AllowZeroDate()) } @@ -1330,7 +1330,7 @@ func (call *builtinToDays) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Timestamp: default: c.asm.Convert_xD(1, false) } @@ -1467,7 +1467,7 @@ func (call *builtinTimeToSec) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime, sqltypes.Time: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Time, sqltypes.Timestamp: default: c.asm.Convert_xT(1, -1) } @@ -1502,7 +1502,7 @@ func (call *builtinToSeconds) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Timestamp: default: c.asm.Convert_xDT(1, -1, false) } @@ -1535,7 +1535,7 @@ func (call *builtinQuarter) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Timestamp: default: c.asm.Convert_xD(1, true) } @@ -1568,7 +1568,7 @@ func (call *builtinSecond) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime, sqltypes.Time: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Time, sqltypes.Timestamp: default: c.asm.Convert_xT(1, -1) } @@ -1603,7 +1603,7 @@ func (call *builtinTime) compile(c *compiler) (ctype, error) { var prec int32 switch arg.Type { case sqltypes.Time: - case sqltypes.Datetime, sqltypes.Date: + case sqltypes.Datetime, sqltypes.Date, sqltypes.Timestamp: prec = arg.Size c.asm.Convert_xT(1, -1) case sqltypes.Decimal: @@ -1703,7 +1703,7 @@ func (call *builtinUnixTimestamp) compile(c *compiler) (ctype, error) { c.asm.Fn_UNIX_TIMESTAMP1() c.asm.jumpDestination(skip) switch arg.Type { - case sqltypes.Datetime, sqltypes.Time, sqltypes.Decimal: + case sqltypes.Datetime, sqltypes.Time, sqltypes.Decimal, sqltypes.Timestamp: if arg.Size == 0 { return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: arg.Flag}, nil } @@ -1768,7 +1768,7 @@ func (call *builtinWeek) compile(c *compiler) (ctype, error) { var skip2 *jump switch arg.Type { - case sqltypes.Date, sqltypes.Datetime: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Timestamp: default: c.asm.Convert_xD(1, false) } @@ -1813,7 +1813,7 @@ func (call *builtinWeekDay) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Timestamp: default: c.asm.Convert_xD(1, false) } @@ -1849,7 +1849,7 @@ func (call *builtinWeekOfYear) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Timestamp: default: c.asm.Convert_xD(1, false) } @@ -1884,7 +1884,7 @@ func (call *builtinYear) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Timestamp: default: c.asm.Convert_xD(1, true) } @@ -1934,7 +1934,7 @@ func (call *builtinYearWeek) compile(c *compiler) (ctype, error) { var skip2 *jump switch arg.Type { - case sqltypes.Date, sqltypes.Datetime: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Timestamp: default: c.asm.Convert_xD(1, false) } diff --git a/go/vt/vtgate/evalengine/integration/comparison_test.go b/go/vt/vtgate/evalengine/integration/comparison_test.go index ea327601975..db47745d720 100644 --- a/go/vt/vtgate/evalengine/integration/comparison_test.go +++ b/go/vt/vtgate/evalengine/integration/comparison_test.go @@ -82,12 +82,12 @@ func normalizeValue(v sqltypes.Value, coll collations.ID) sqltypes.Value { return v } -func compareRemoteExprEnv(t *testing.T, collationEnv *collations.Environment, env *evalengine.ExpressionEnv, conn *mysql.Conn, expr string, fields []*querypb.Field, cmp *testcases.Comparison) { +func compareRemoteExprEnv(t *testing.T, collationEnv *collations.Environment, env *evalengine.ExpressionEnv, conn *mysql.Conn, expr string, fields []*querypb.Field, cmp *testcases.Comparison, skipCollationCheck bool) { t.Helper() localQuery := "SELECT " + expr remoteQuery := "SELECT " + expr - if debugCheckCollations { + if debugCheckCollations && !skipCollationCheck { remoteQuery = fmt.Sprintf("SELECT %s, COLLATION(%s)", expr, expr) } if len(fields) > 0 { @@ -146,7 +146,7 @@ func compareRemoteExprEnv(t *testing.T, collationEnv *collations.Environment, en var localCollation, remoteCollation collations.ID if localErr == nil { v := local.Value(collations.MySQL8().DefaultConnectionCharset()) - if debugCheckCollations { + if debugCheckCollations && !skipCollationCheck { if v.IsNull() { localCollation = collations.CollationBinaryID } else { @@ -166,7 +166,7 @@ func compareRemoteExprEnv(t *testing.T, collationEnv *collations.Environment, en } else { remoteVal = remote.Rows[0][0] } - if debugCheckCollations { + if debugCheckCollations && !skipCollationCheck { if remote.Rows[0][0].IsNull() { // TODO: passthrough proper collations for nullable fields remoteCollation = collations.CollationBinaryID @@ -271,9 +271,9 @@ func TestMySQL(t *testing.T) { Username: "vt_dba", }) env := evalengine.NewExpressionEnv(ctx, nil, &vcursor{env: venv}) - tc.Run(func(query string, row []sqltypes.Value) { + tc.Run(func(query string, row []sqltypes.Value, skipCollationCheck bool) { env.Row = row - compareRemoteExprEnv(t, collationEnv, env, conn, query, tc.Schema, tc.Compare) + compareRemoteExprEnv(t, collationEnv, env, conn, query, tc.Schema, tc.Compare, skipCollationCheck) }) }) } diff --git a/go/vt/vtgate/evalengine/testcases/cases.go b/go/vt/vtgate/evalengine/testcases/cases.go index 0d5f99f5c83..38a605d0ad7 100644 --- a/go/vt/vtgate/evalengine/testcases/cases.go +++ b/go/vt/vtgate/evalengine/testcases/cases.go @@ -176,18 +176,18 @@ var Cases = []TestCase{ func JSONPathOperations(yield Query) { for _, obj := range inputJSONObjects { - yield(fmt.Sprintf("JSON_KEYS('%s')", obj), nil) + yield(fmt.Sprintf("JSON_KEYS('%s')", obj), nil, false) for _, path1 := range inputJSONPaths { - yield(fmt.Sprintf("JSON_EXTRACT('%s', '%s')", obj, path1), nil) - yield(fmt.Sprintf("JSON_CONTAINS_PATH('%s', 'one', '%s')", obj, path1), nil) - yield(fmt.Sprintf("JSON_CONTAINS_PATH('%s', 'all', '%s')", obj, path1), nil) - yield(fmt.Sprintf("JSON_KEYS('%s', '%s')", obj, path1), nil) + yield(fmt.Sprintf("JSON_EXTRACT('%s', '%s')", obj, path1), nil, false) + yield(fmt.Sprintf("JSON_CONTAINS_PATH('%s', 'one', '%s')", obj, path1), nil, false) + yield(fmt.Sprintf("JSON_CONTAINS_PATH('%s', 'all', '%s')", obj, path1), nil, false) + yield(fmt.Sprintf("JSON_KEYS('%s', '%s')", obj, path1), nil, false) for _, path2 := range inputJSONPaths { - yield(fmt.Sprintf("JSON_EXTRACT('%s', '%s', '%s')", obj, path1, path2), nil) - yield(fmt.Sprintf("JSON_CONTAINS_PATH('%s', 'one', '%s', '%s')", obj, path1, path2), nil) - yield(fmt.Sprintf("JSON_CONTAINS_PATH('%s', 'all', '%s', '%s')", obj, path1, path2), nil) + yield(fmt.Sprintf("JSON_EXTRACT('%s', '%s', '%s')", obj, path1, path2), nil, false) + yield(fmt.Sprintf("JSON_CONTAINS_PATH('%s', 'one', '%s', '%s')", obj, path1, path2), nil, false) + yield(fmt.Sprintf("JSON_CONTAINS_PATH('%s', 'all', '%s', '%s')", obj, path1, path2), nil, false) } } } @@ -195,21 +195,21 @@ func JSONPathOperations(yield Query) { func JSONArray(yield Query) { for _, a := range inputJSONPrimitives { - yield(fmt.Sprintf("JSON_ARRAY(%s)", a), nil) + yield(fmt.Sprintf("JSON_ARRAY(%s)", a), nil, false) for _, b := range inputJSONPrimitives { - yield(fmt.Sprintf("JSON_ARRAY(%s, %s)", a, b), nil) + yield(fmt.Sprintf("JSON_ARRAY(%s, %s)", a, b), nil, false) } } - yield("JSON_ARRAY()", nil) + yield("JSON_ARRAY()", nil, false) } func JSONObject(yield Query) { for _, a := range inputJSONPrimitives { for _, b := range inputJSONPrimitives { - yield(fmt.Sprintf("JSON_OBJECT(%s, %s)", a, b), nil) + yield(fmt.Sprintf("JSON_OBJECT(%s, %s)", a, b), nil, false) } } - yield("JSON_OBJECT()", nil) + yield("JSON_OBJECT()", nil, false) } func CharsetConversionOperators(yield Query) { @@ -226,7 +226,7 @@ func CharsetConversionOperators(yield Query) { for _, pfx := range introducers { for _, lhs := range contents { for _, rhs := range charsets { - yield(fmt.Sprintf("HEX(CONVERT(%s %s USING %s))", pfx, lhs, rhs), nil) + yield(fmt.Sprintf("HEX(CONVERT(%s %s USING %s))", pfx, lhs, rhs), nil, false) } } } @@ -248,7 +248,7 @@ func CaseExprWithPredicate(yield Query) { for _, pred1 := range predicates { for _, val1 := range elements { for _, elseVal := range elements { - yield(fmt.Sprintf("case when %s then %s else %s end", pred1, val1, elseVal), nil) + yield(fmt.Sprintf("case when %s then %s else %s end", pred1, val1, elseVal), nil, false) } } } @@ -257,7 +257,7 @@ func CaseExprWithPredicate(yield Query) { genSubsets(elements, 3, func(values []string) { yield(fmt.Sprintf("case when %s then %s when %s then %s when %s then %s end", predicates[0], values[0], predicates[1], values[1], predicates[2], values[2], - ), nil) + ), nil, false) }) }) } @@ -277,13 +277,13 @@ func FnCeil(yield Query) { } for _, num := range ceilInputs { - yield(fmt.Sprintf("CEIL(%s)", num), nil) - yield(fmt.Sprintf("CEILING(%s)", num), nil) + yield(fmt.Sprintf("CEIL(%s)", num), nil, false) + yield(fmt.Sprintf("CEILING(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("CEIL(%s)", num), nil) - yield(fmt.Sprintf("CEILING(%s)", num), nil) + yield(fmt.Sprintf("CEIL(%s)", num), nil, false) + yield(fmt.Sprintf("CEILING(%s)", num), nil, false) } } @@ -302,11 +302,11 @@ func FnFloor(yield Query) { } for _, num := range floorInputs { - yield(fmt.Sprintf("FLOOR(%s)", num), nil) + yield(fmt.Sprintf("FLOOR(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("FLOOR(%s)", num), nil) + yield(fmt.Sprintf("FLOOR(%s)", num), nil, false) } } @@ -325,280 +325,280 @@ func FnAbs(yield Query) { } for _, num := range absInputs { - yield(fmt.Sprintf("ABS(%s)", num), nil) + yield(fmt.Sprintf("ABS(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("ABS(%s)", num), nil) + yield(fmt.Sprintf("ABS(%s)", num), nil, false) } } func FnPi(yield Query) { - yield("PI()+0.000000000000000000", nil) + yield("PI()+0.000000000000000000", nil, false) } func FnAcos(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("ACOS(%s)", num), nil) + yield(fmt.Sprintf("ACOS(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("ACOS(%s)", num), nil) + yield(fmt.Sprintf("ACOS(%s)", num), nil, false) } } func FnAsin(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("ASIN(%s)", num), nil) + yield(fmt.Sprintf("ASIN(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("ASIN(%s)", num), nil) + yield(fmt.Sprintf("ASIN(%s)", num), nil, false) } } func FnAtan(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("ATAN(%s)", num), nil) + yield(fmt.Sprintf("ATAN(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("ATAN(%s)", num), nil) + yield(fmt.Sprintf("ATAN(%s)", num), nil, false) } } func FnAtan2(yield Query) { for _, num1 := range radianInputs { for _, num2 := range radianInputs { - yield(fmt.Sprintf("ATAN(%s, %s)", num1, num2), nil) - yield(fmt.Sprintf("ATAN2(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("ATAN(%s, %s)", num1, num2), nil, false) + yield(fmt.Sprintf("ATAN2(%s, %s)", num1, num2), nil, false) } } for _, num1 := range inputBitwise { for _, num2 := range inputBitwise { - yield(fmt.Sprintf("ATAN(%s, %s)", num1, num2), nil) - yield(fmt.Sprintf("ATAN2(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("ATAN(%s, %s)", num1, num2), nil, false) + yield(fmt.Sprintf("ATAN2(%s, %s)", num1, num2), nil, false) } } } func FnCos(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("COS(%s)", num), nil) + yield(fmt.Sprintf("COS(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("COS(%s)", num), nil) + yield(fmt.Sprintf("COS(%s)", num), nil, false) } } func FnCot(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("COT(%s)", num), nil) + yield(fmt.Sprintf("COT(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("COT(%s)", num), nil) + yield(fmt.Sprintf("COT(%s)", num), nil, false) } } func FnSin(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("SIN(%s)", num), nil) + yield(fmt.Sprintf("SIN(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("SIN(%s)", num), nil) + yield(fmt.Sprintf("SIN(%s)", num), nil, false) } } func FnTan(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("TAN(%s)", num), nil) + yield(fmt.Sprintf("TAN(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("TAN(%s)", num), nil) + yield(fmt.Sprintf("TAN(%s)", num), nil, false) } } func FnDegrees(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("DEGREES(%s)", num), nil) + yield(fmt.Sprintf("DEGREES(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("DEGREES(%s)", num), nil) + yield(fmt.Sprintf("DEGREES(%s)", num), nil, false) } } func FnRadians(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("RADIANS(%s)", num), nil) + yield(fmt.Sprintf("RADIANS(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("RADIANS(%s)", num), nil) + yield(fmt.Sprintf("RADIANS(%s)", num), nil, false) } } func FnExp(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("EXP(%s)", num), nil) + yield(fmt.Sprintf("EXP(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("EXP(%s)", num), nil) + yield(fmt.Sprintf("EXP(%s)", num), nil, false) } } func FnLn(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("LN(%s)", num), nil) + yield(fmt.Sprintf("LN(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("LN(%s)", num), nil) + yield(fmt.Sprintf("LN(%s)", num), nil, false) } } func FnLog(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("LOG(%s)", num), nil) + yield(fmt.Sprintf("LOG(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("LOG(%s)", num), nil) + yield(fmt.Sprintf("LOG(%s)", num), nil, false) } for _, num1 := range radianInputs { for _, num2 := range radianInputs { - yield(fmt.Sprintf("LOG(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("LOG(%s, %s)", num1, num2), nil, false) } for _, num2 := range inputBitwise { - yield(fmt.Sprintf("LOG(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("LOG(%s, %s)", num1, num2), nil, false) } } for _, num1 := range inputBitwise { for _, num2 := range radianInputs { - yield(fmt.Sprintf("LOG(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("LOG(%s, %s)", num1, num2), nil, false) } for _, num2 := range inputBitwise { - yield(fmt.Sprintf("LOG(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("LOG(%s, %s)", num1, num2), nil, false) } } } func FnLog10(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("LOG10(%s)", num), nil) + yield(fmt.Sprintf("LOG10(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("LOG10(%s)", num), nil) + yield(fmt.Sprintf("LOG10(%s)", num), nil, false) } } func FnMod(yield Query) { for _, num1 := range radianInputs { for _, num2 := range radianInputs { - yield(fmt.Sprintf("MOD(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("MOD(%s, %s)", num1, num2), nil, false) } for _, num2 := range inputBitwise { - yield(fmt.Sprintf("MOD(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("MOD(%s, %s)", num1, num2), nil, false) } } for _, num1 := range inputBitwise { for _, num2 := range radianInputs { - yield(fmt.Sprintf("MOD(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("MOD(%s, %s)", num1, num2), nil, false) } for _, num2 := range inputBitwise { - yield(fmt.Sprintf("MOD(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("MOD(%s, %s)", num1, num2), nil, false) } } } func FnLog2(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("LOG2(%s)", num), nil) + yield(fmt.Sprintf("LOG2(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("LOG2(%s)", num), nil) + yield(fmt.Sprintf("LOG2(%s)", num), nil, false) } } func FnPow(yield Query) { for _, num1 := range radianInputs { for _, num2 := range radianInputs { - yield(fmt.Sprintf("POW(%s, %s)", num1, num2), nil) - yield(fmt.Sprintf("POWER(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("POW(%s, %s)", num1, num2), nil, false) + yield(fmt.Sprintf("POWER(%s, %s)", num1, num2), nil, false) } for _, num2 := range inputBitwise { - yield(fmt.Sprintf("POW(%s, %s)", num1, num2), nil) - yield(fmt.Sprintf("POWER(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("POW(%s, %s)", num1, num2), nil, false) + yield(fmt.Sprintf("POWER(%s, %s)", num1, num2), nil, false) } } for _, num1 := range inputBitwise { for _, num2 := range radianInputs { - yield(fmt.Sprintf("POW(%s, %s)", num1, num2), nil) - yield(fmt.Sprintf("POWER(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("POW(%s, %s)", num1, num2), nil, false) + yield(fmt.Sprintf("POWER(%s, %s)", num1, num2), nil, false) } for _, num2 := range inputBitwise { - yield(fmt.Sprintf("POW(%s, %s)", num1, num2), nil) - yield(fmt.Sprintf("POWER(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("POW(%s, %s)", num1, num2), nil, false) + yield(fmt.Sprintf("POWER(%s, %s)", num1, num2), nil, false) } } } func FnSign(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("SIGN(%s)", num), nil) + yield(fmt.Sprintf("SIGN(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("SIGN(%s)", num), nil) + yield(fmt.Sprintf("SIGN(%s)", num), nil, false) } } func FnSqrt(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("SQRT(%s)", num), nil) + yield(fmt.Sprintf("SQRT(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("SQRT(%s)", num), nil) + yield(fmt.Sprintf("SQRT(%s)", num), nil, false) } } func FnRound(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("ROUND(%s)", num), nil) + yield(fmt.Sprintf("ROUND(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("ROUND(%s)", num), nil) + yield(fmt.Sprintf("ROUND(%s)", num), nil, false) } for _, num1 := range radianInputs { for _, num2 := range radianInputs { - yield(fmt.Sprintf("ROUND(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("ROUND(%s, %s)", num1, num2), nil, false) } for _, num2 := range inputBitwise { - yield(fmt.Sprintf("ROUND(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("ROUND(%s, %s)", num1, num2), nil, false) } } for _, num1 := range inputBitwise { for _, num2 := range radianInputs { - yield(fmt.Sprintf("ROUND(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("ROUND(%s, %s)", num1, num2), nil, false) } for _, num2 := range inputBitwise { - yield(fmt.Sprintf("ROUND(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("ROUND(%s, %s)", num1, num2), nil, false) } } } @@ -606,34 +606,34 @@ func FnRound(yield Query) { func FnTruncate(yield Query) { for _, num1 := range radianInputs { for _, num2 := range radianInputs { - yield(fmt.Sprintf("TRUNCATE(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("TRUNCATE(%s, %s)", num1, num2), nil, false) } for _, num2 := range inputBitwise { - yield(fmt.Sprintf("TRUNCATE(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("TRUNCATE(%s, %s)", num1, num2), nil, false) } } for _, num1 := range inputBitwise { for _, num2 := range radianInputs { - yield(fmt.Sprintf("TRUNCATE(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("TRUNCATE(%s, %s)", num1, num2), nil, false) } for _, num2 := range inputBitwise { - yield(fmt.Sprintf("TRUNCATE(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("TRUNCATE(%s, %s)", num1, num2), nil, false) } } } func FnCrc32(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("CRC32(%s)", num), nil) + yield(fmt.Sprintf("CRC32(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("CRC32(%s)", num), nil) + yield(fmt.Sprintf("CRC32(%s)", num), nil, false) } for _, num := range inputConversions { - yield(fmt.Sprintf("CRC32(%s)", num), nil) + yield(fmt.Sprintf("CRC32(%s)", num), nil, false) } } @@ -641,10 +641,10 @@ func FnConv(yield Query) { for _, num1 := range radianInputs { for _, num2 := range radianInputs { for _, num3 := range radianInputs { - yield(fmt.Sprintf("CONV(%s, %s, %s)", num1, num2, num3), nil) + yield(fmt.Sprintf("CONV(%s, %s, %s)", num1, num2, num3), nil, false) } for _, num3 := range inputBitwise { - yield(fmt.Sprintf("CONV(%s, %s, %s)", num1, num2, num3), nil) + yield(fmt.Sprintf("CONV(%s, %s, %s)", num1, num2, num3), nil, false) } } } @@ -652,10 +652,10 @@ func FnConv(yield Query) { for _, num1 := range radianInputs { for _, num2 := range inputBitwise { for _, num3 := range radianInputs { - yield(fmt.Sprintf("CONV(%s, %s, %s)", num1, num2, num3), nil) + yield(fmt.Sprintf("CONV(%s, %s, %s)", num1, num2, num3), nil, false) } for _, num3 := range inputBitwise { - yield(fmt.Sprintf("CONV(%s, %s, %s)", num1, num2, num3), nil) + yield(fmt.Sprintf("CONV(%s, %s, %s)", num1, num2, num3), nil, false) } } } @@ -663,10 +663,10 @@ func FnConv(yield Query) { for _, num1 := range inputBitwise { for _, num2 := range inputBitwise { for _, num3 := range radianInputs { - yield(fmt.Sprintf("CONV(%s, %s, %s)", num1, num2, num3), nil) + yield(fmt.Sprintf("CONV(%s, %s, %s)", num1, num2, num3), nil, false) } for _, num3 := range inputBitwise { - yield(fmt.Sprintf("CONV(%s, %s, %s)", num1, num2, num3), nil) + yield(fmt.Sprintf("CONV(%s, %s, %s)", num1, num2, num3), nil, false) } } } @@ -674,50 +674,50 @@ func FnConv(yield Query) { func FnBin(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("BIN(%s)", num), nil) + yield(fmt.Sprintf("BIN(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("BIN(%s)", num), nil) + yield(fmt.Sprintf("BIN(%s)", num), nil, false) } } func FnOct(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("OCT(%s)", num), nil) + yield(fmt.Sprintf("OCT(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("OCT(%s)", num), nil) + yield(fmt.Sprintf("OCT(%s)", num), nil, false) } } func FnMD5(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("MD5(%s)", num), nil) + yield(fmt.Sprintf("MD5(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("MD5(%s)", num), nil) + yield(fmt.Sprintf("MD5(%s)", num), nil, false) } for _, num := range inputConversions { - yield(fmt.Sprintf("MD5(%s)", num), nil) + yield(fmt.Sprintf("MD5(%s)", num), nil, false) } } func FnSHA1(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("SHA1(%s)", num), nil) - yield(fmt.Sprintf("SHA(%s)", num), nil) + yield(fmt.Sprintf("SHA1(%s)", num), nil, false) + yield(fmt.Sprintf("SHA(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("SHA1(%s)", num), nil) - yield(fmt.Sprintf("SHA(%s)", num), nil) + yield(fmt.Sprintf("SHA1(%s)", num), nil, false) + yield(fmt.Sprintf("SHA(%s)", num), nil, false) } for _, num := range inputConversions { - yield(fmt.Sprintf("SHA1(%s)", num), nil) - yield(fmt.Sprintf("SHA(%s)", num), nil) + yield(fmt.Sprintf("SHA1(%s)", num), nil, false) + yield(fmt.Sprintf("SHA(%s)", num), nil, false) } } @@ -725,28 +725,28 @@ func FnSHA2(yield Query) { bitLengths := []string{"0", "224", "256", "384", "512", "1", "0.1", "256.1e0", "1-1", "128+128"} for _, bits := range bitLengths { for _, num := range radianInputs { - yield(fmt.Sprintf("SHA2(%s, %s)", num, bits), nil) + yield(fmt.Sprintf("SHA2(%s, %s)", num, bits), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("SHA2(%s, %s)", num, bits), nil) + yield(fmt.Sprintf("SHA2(%s, %s)", num, bits), nil, false) } for _, num := range inputConversions { - yield(fmt.Sprintf("SHA2(%s, %s)", num, bits), nil) + yield(fmt.Sprintf("SHA2(%s, %s)", num, bits), nil, false) } } } func FnRandomBytes(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("LENGTH(RANDOM_BYTES(%s))", num), nil) - yield(fmt.Sprintf("COLLATION(RANDOM_BYTES(%s))", num), nil) + yield(fmt.Sprintf("LENGTH(RANDOM_BYTES(%s))", num), nil, false) + yield(fmt.Sprintf("COLLATION(RANDOM_BYTES(%s))", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("LENGTH(RANDOM_BYTES(%s))", num), nil) - yield(fmt.Sprintf("COLLATION(RANDOM_BYTES(%s))", num), nil) + yield(fmt.Sprintf("LENGTH(RANDOM_BYTES(%s))", num), nil, false) + yield(fmt.Sprintf("COLLATION(RANDOM_BYTES(%s))", num), nil, false) } } @@ -760,7 +760,7 @@ func CaseExprWithValue(yield Query) { if !(bugs{}).CanCompare(cmpbase, val1) { continue } - yield(fmt.Sprintf("case %s when %s then 1 else 0 end", cmpbase, val1), nil) + yield(fmt.Sprintf("case %s when %s then 1 else 0 end", cmpbase, val1), nil, false) } } } @@ -773,7 +773,7 @@ func If(yield Query) { for _, cmpbase := range elements { for _, val1 := range elements { for _, val2 := range elements { - yield(fmt.Sprintf("if(%s, %s, %s)", cmpbase, val1, val2), nil) + yield(fmt.Sprintf("if(%s, %s, %s)", cmpbase, val1, val2), nil, false) } } } @@ -794,17 +794,17 @@ func Base64(yield Query) { } for _, lhs := range inputs { - yield(fmt.Sprintf("FROM_BASE64(%s)", lhs), nil) - yield(fmt.Sprintf("TO_BASE64(%s)", lhs), nil) + yield(fmt.Sprintf("FROM_BASE64(%s)", lhs), nil, false) + yield(fmt.Sprintf("TO_BASE64(%s)", lhs), nil, false) } } func Conversion(yield Query) { for _, lhs := range inputConversions { for _, rhs := range inputConversionTypes { - yield(fmt.Sprintf("CAST(%s AS %s)", lhs, rhs), nil) - yield(fmt.Sprintf("CONVERT(%s, %s)", lhs, rhs), nil) - yield(fmt.Sprintf("CAST(CAST(%s AS JSON) AS %s)", lhs, rhs), nil) + yield(fmt.Sprintf("CAST(%s AS %s)", lhs, rhs), nil, false) + yield(fmt.Sprintf("CONVERT(%s, %s)", lhs, rhs), nil, false) + yield(fmt.Sprintf("CAST(CAST(%s AS JSON) AS %s)", lhs, rhs), nil, false) } } } @@ -813,8 +813,8 @@ func LargeDecimals(yield Query) { var largepi = inputPi + inputPi for pos := 0; pos < len(largepi); pos++ { - yield(fmt.Sprintf("%s.%s", largepi[:pos], largepi[pos:]), nil) - yield(fmt.Sprintf("-%s.%s", largepi[:pos], largepi[pos:]), nil) + yield(fmt.Sprintf("%s.%s", largepi[:pos], largepi[pos:]), nil, false) + yield(fmt.Sprintf("-%s.%s", largepi[:pos], largepi[pos:]), nil, false) } } @@ -822,8 +822,8 @@ func LargeIntegers(yield Query) { var largepi = inputPi + inputPi for pos := 1; pos < len(largepi); pos++ { - yield(largepi[:pos], nil) - yield(fmt.Sprintf("-%s", largepi[:pos]), nil) + yield(largepi[:pos], nil, false) + yield(fmt.Sprintf("-%s", largepi[:pos]), nil, false) } } @@ -831,7 +831,7 @@ func DecimalClamping(yield Query) { for pos := 0; pos < len(inputPi); pos++ { for m := 0; m < min(len(inputPi), 67); m += 2 { for d := 0; d <= min(m, 33); d += 2 { - yield(fmt.Sprintf("CAST(%s.%s AS DECIMAL(%d, %d))", inputPi[:pos], inputPi[pos:], m, d), nil) + yield(fmt.Sprintf("CAST(%s.%s AS DECIMAL(%d, %d))", inputPi[:pos], inputPi[pos:], m, d), nil, false) } } } @@ -840,7 +840,7 @@ func DecimalClamping(yield Query) { func BitwiseOperatorsUnary(yield Query) { for _, op := range []string{"~", "BIT_COUNT"} { for _, rhs := range inputBitwise { - yield(fmt.Sprintf("%s(%s)", op, rhs), nil) + yield(fmt.Sprintf("%s(%s)", op, rhs), nil, false) } } } @@ -849,13 +849,13 @@ func BitwiseOperators(yield Query) { for _, op := range []string{"&", "|", "^", "<<", ">>"} { for _, lhs := range inputBitwise { for _, rhs := range inputBitwise { - yield(fmt.Sprintf("%s %s %s", lhs, op, rhs), nil) + yield(fmt.Sprintf("%s %s %s", lhs, op, rhs), nil, false) } } for _, lhs := range inputConversions { for _, rhs := range inputConversions { - yield(fmt.Sprintf("%s %s %s", lhs, op, rhs), nil) + yield(fmt.Sprintf("%s %s %s", lhs, op, rhs), nil, false) } } } @@ -908,7 +908,7 @@ func WeightString(yield Query) { } for _, i := range inputs { - yield(fmt.Sprintf("WEIGHT_STRING(%s)", i), nil) + yield(fmt.Sprintf("WEIGHT_STRING(%s)", i), nil, false) } } @@ -925,18 +925,18 @@ func FloatFormatting(yield Query) { } for _, f := range floats { - yield(fmt.Sprintf("%s + 0.0e0", f), nil) - yield(fmt.Sprintf("-%s", f), nil) + yield(fmt.Sprintf("%s + 0.0e0", f), nil, false) + yield(fmt.Sprintf("-%s", f), nil, false) } for i := 0; i < 64; i++ { v := uint64(1) << i - yield(fmt.Sprintf("%d + 0.0e0", v), nil) - yield(fmt.Sprintf("%d + 0.0e0", v+1), nil) - yield(fmt.Sprintf("%d + 0.0e0", ^v), nil) - yield(fmt.Sprintf("-%de0", v), nil) - yield(fmt.Sprintf("-%de0", v+1), nil) - yield(fmt.Sprintf("-%de0", ^v), nil) + yield(fmt.Sprintf("%d + 0.0e0", v), nil, false) + yield(fmt.Sprintf("%d + 0.0e0", v+1), nil, false) + yield(fmt.Sprintf("%d + 0.0e0", ^v), nil, false) + yield(fmt.Sprintf("-%de0", v), nil, false) + yield(fmt.Sprintf("-%de0", v+1), nil, false) + yield(fmt.Sprintf("-%de0", ^v), nil, false) } } @@ -960,7 +960,7 @@ func UnderscoreAndPercentage(yield Query) { `'poke\_mon' = 'poke\_mon'`, } for _, query := range queries { - yield(query, nil) + yield(query, nil, false) } } @@ -991,7 +991,7 @@ func Types(yield Query) { } for _, query := range queries { - yield(query, nil) + yield(query, nil, false) } } @@ -1001,13 +1001,13 @@ func Arithmetic(yield Query) { for _, op := range operators { for _, lhs := range inputConversions { for _, rhs := range inputConversions { - yield(fmt.Sprintf("%s %s %s", lhs, op, rhs), nil) + yield(fmt.Sprintf("%s %s %s", lhs, op, rhs), nil, false) } } for _, lhs := range inputBitwise { for _, rhs := range inputBitwise { - yield(fmt.Sprintf("%s %s %s", lhs, op, rhs), nil) + yield(fmt.Sprintf("%s %s %s", lhs, op, rhs), nil, false) } } } @@ -1023,9 +1023,9 @@ func HexArithmetic(yield Query) { for _, lhs := range cases { for _, rhs := range cases { - yield(fmt.Sprintf("%s + %s", lhs, rhs), nil) + yield(fmt.Sprintf("%s + %s", lhs, rhs), nil, false) // compare with negative values too - yield(fmt.Sprintf("-%s + -%s", lhs, rhs), nil) + yield(fmt.Sprintf("-%s + -%s", lhs, rhs), nil, false) } } } @@ -1053,7 +1053,7 @@ func NumericTypes(yield Query) { } for _, rhs := range numbers { - yield(rhs, nil) + yield(rhs, nil, false) } } @@ -1070,13 +1070,13 @@ func NegateArithmetic(yield Query) { } for _, rhs := range cases { - yield(fmt.Sprintf("- %s", rhs), nil) - yield(fmt.Sprintf("-%s", rhs), nil) + yield(fmt.Sprintf("- %s", rhs), nil, false) + yield(fmt.Sprintf("-%s", rhs), nil, false) } for _, rhs := range inputConversions { - yield(fmt.Sprintf("- %s", rhs), nil) - yield(fmt.Sprintf("-%s", rhs), nil) + yield(fmt.Sprintf("- %s", rhs), nil, false) + yield(fmt.Sprintf("-%s", rhs), nil, false) } } @@ -1090,7 +1090,7 @@ func CollationOperations(yield Query) { } for _, expr := range cases { - yield(expr, nil) + yield(expr, nil, false) } } @@ -1113,7 +1113,7 @@ func LikeComparison(yield Query) { for _, lhs := range left { for _, rhs := range right { for _, op := range []string{"LIKE", "NOT LIKE"} { - yield(fmt.Sprintf("%s %s %s", lhs, op, rhs), nil) + yield(fmt.Sprintf("%s %s %s", lhs, op, rhs), nil, false) } } } @@ -1147,7 +1147,7 @@ func StrcmpComparison(yield Query) { for _, lhs := range inputs { for _, rhs := range inputs { - yield(fmt.Sprintf("STRCMP(%s, %s)", lhs, rhs), nil) + yield(fmt.Sprintf("STRCMP(%s, %s)", lhs, rhs), nil, false) } } } @@ -1166,7 +1166,7 @@ func MultiComparisons(yield Query) { `"0"`, `"-1"`, `"1"`, `_utf8mb4 'foobar'`, `_utf8mb4 'FOOBAR'`, `_binary '0'`, `_binary '-1'`, `_binary '1'`, - `0x0`, `0x1`, `-0x0`, `-0x1`, + `0x0`, `0x1`, "_utf8mb4 'Abc' COLLATE utf8mb4_0900_as_ci", "_utf8mb4 'aBC' COLLATE utf8mb4_0900_as_ci", "_utf8mb4 'ǍḄÇ' COLLATE utf8mb4_0900_as_ci", @@ -1181,17 +1181,37 @@ func MultiComparisons(yield Query) { "_utf8mb4 'ノ東京の' COLLATE utf8mb4_ja_0900_as_cs", "_utf8mb4 'の東京ノ' COLLATE utf8mb4_ja_0900_as_cs_ks", "_utf8mb4 'ノ東京の' COLLATE utf8mb4_ja_0900_as_cs_ks", + `date'2024-02-18'`, + `date'2023-02-01'`, + `date'2100-02-01'`, + `timestamp'2020-12-31 23:59:59'`, + `timestamp'2025-01-01 00:00:00.123456'`, + `time'23:59:59.5432'`, + `time'120:59:59'`, } for _, method := range []string{"LEAST", "GREATEST"} { + skip := func(arg []string) bool { + skipCollations := false + for _, a := range arg { + if strings.Contains(a, "date'") || strings.Contains(a, "time'") || strings.Contains(a, "timestamp'") { + skipCollations = true + break + } + } + return skipCollations + } + genSubsets(numbers, 2, func(arg []string) { - yield(fmt.Sprintf("%s(%s, %s)", method, arg[0], arg[1]), nil) - yield(fmt.Sprintf("%s(%s, %s)", method, arg[1], arg[0]), nil) + skipCollations := skip(arg) + yield(fmt.Sprintf("%s(%s, %s)", method, arg[0], arg[1]), nil, skipCollations) + yield(fmt.Sprintf("%s(%s, %s)", method, arg[1], arg[0]), nil, skipCollations) }) genSubsets(numbers, 3, func(arg []string) { - yield(fmt.Sprintf("%s(%s, %s, %s)", method, arg[0], arg[1], arg[2]), nil) - yield(fmt.Sprintf("%s(%s, %s, %s)", method, arg[2], arg[1], arg[0]), nil) + skipCollations := skip(arg) + yield(fmt.Sprintf("%s(%s, %s, %s)", method, arg[0], arg[1], arg[2]), nil, skipCollations) + yield(fmt.Sprintf("%s(%s, %s, %s)", method, arg[2], arg[1], arg[0]), nil, skipCollations) }) } } @@ -1211,7 +1231,7 @@ func IntervalStatement(yield Query) { for _, arg1 := range inputs { for _, arg2 := range inputs { for _, arg3 := range inputs { - yield(fmt.Sprintf("INTERVAL(%s, %s, %s, %s)", base, arg1, arg2, arg3), nil) + yield(fmt.Sprintf("INTERVAL(%s, %s, %s, %s)", base, arg1, arg2, arg3), nil, false) } } } @@ -1236,7 +1256,7 @@ func IsStatement(yield Query) { for _, l := range left { for _, r := range right { - yield(fmt.Sprintf("%s IS %s", l, r), nil) + yield(fmt.Sprintf("%s IS %s", l, r), nil, false) } } } @@ -1245,7 +1265,7 @@ func NotStatement(yield Query) { var ops = []string{"NOT", "!"} for _, op := range ops { for _, i := range inputConversions { - yield(fmt.Sprintf("%s %s", op, i), nil) + yield(fmt.Sprintf("%s %s", op, i), nil, false) } } } @@ -1255,7 +1275,7 @@ func LogicalStatement(yield Query) { for _, op := range ops { for _, l := range inputConversions { for _, r := range inputConversions { - yield(fmt.Sprintf("%s %s %s", l, op, r), nil) + yield(fmt.Sprintf("%s %s %s", l, op, r), nil, false) } } } @@ -1273,7 +1293,7 @@ func TupleComparisons(yield Query) { for _, op := range operators { for i := 0; i < len(tuples); i++ { for j := 0; j < len(tuples); j++ { - yield(fmt.Sprintf("%s %s %s", tuples[i], op, tuples[j]), nil) + yield(fmt.Sprintf("%s %s %s", tuples[i], op, tuples[j]), nil, false) } } } @@ -1284,13 +1304,13 @@ func Comparisons(yield Query) { for _, op := range operators { for _, l := range inputComparisonElement { for _, r := range inputComparisonElement { - yield(fmt.Sprintf("%s %s %s", l, op, r), nil) + yield(fmt.Sprintf("%s %s %s", l, op, r), nil, false) } } for _, l := range inputConversions { for _, r := range inputConversions { - yield(fmt.Sprintf("%s %s %s", l, op, r), nil) + yield(fmt.Sprintf("%s %s %s", l, op, r), nil, false) } } } @@ -1329,9 +1349,9 @@ func JSONExtract(yield Query) { expr2 := fmt.Sprintf("cast(%s as char) <=> %s", expr0, expr1) for _, row := range rows { - yield(expr0, []sqltypes.Value{row}) - yield(expr1, []sqltypes.Value{row}) - yield(expr2, []sqltypes.Value{row}) + yield(expr0, []sqltypes.Value{row}, false) + yield(expr1, []sqltypes.Value{row}, false) + yield(expr2, []sqltypes.Value{row}, false) } } } @@ -1348,7 +1368,7 @@ func FnField(yield Query) { for _, s1 := range inputStrings { for _, s2 := range inputStrings { for _, s3 := range inputStrings { - yield(fmt.Sprintf("FIELD(%s, %s, %s)", s1, s2, s3), nil) + yield(fmt.Sprintf("FIELD(%s, %s, %s)", s1, s2, s3), nil, false) } } } @@ -1356,7 +1376,7 @@ func FnField(yield Query) { for _, s1 := range radianInputs { for _, s2 := range radianInputs { for _, s3 := range radianInputs { - yield(fmt.Sprintf("FIELD(%s, %s, %s)", s1, s2, s3), nil) + yield(fmt.Sprintf("FIELD(%s, %s, %s)", s1, s2, s3), nil, false) } } } @@ -1365,7 +1385,7 @@ func FnField(yield Query) { for _, s1 := range inputStrings { for _, s2 := range radianInputs { for _, s3 := range inputStrings { - yield(fmt.Sprintf("FIELD(%s, %s, %s)", s1, s2, s3), nil) + yield(fmt.Sprintf("FIELD(%s, %s, %s)", s1, s2, s3), nil, false) } } } @@ -1374,7 +1394,7 @@ func FnField(yield Query) { for _, s1 := range inputBitwise { for _, s2 := range inputBitwise { for _, s3 := range inputBitwise { - yield(fmt.Sprintf("FIELD(%s, %s, %s)", s1, s2, s3), nil) + yield(fmt.Sprintf("FIELD(%s, %s, %s)", s1, s2, s3), nil, false) } } } @@ -1384,21 +1404,21 @@ func FnField(yield Query) { "FIELD('Gg', 'Aa', 'Bb', 'Cc', 'Dd', 'Ff')", } for _, q := range mysqlDocSamples { - yield(q, nil) + yield(q, nil, false) } } func FnElt(yield Query) { for _, s1 := range inputStrings { for _, n := range inputBitwise { - yield(fmt.Sprintf("ELT(%s, %s)", n, s1), nil) + yield(fmt.Sprintf("ELT(%s, %s)", n, s1), nil, false) } } for _, s1 := range inputStrings { for _, s2 := range inputStrings { for _, n := range inputBitwise { - yield(fmt.Sprintf("ELT(%s, %s, %s)", n, s1, s2), nil) + yield(fmt.Sprintf("ELT(%s, %s, %s)", n, s1, s2), nil, false) } } } @@ -1412,7 +1432,7 @@ func FnElt(yield Query) { for _, s2 := range inputStrings { for _, s3 := range inputStrings { for _, n := range validIndex { - yield(fmt.Sprintf("ELT(%s, %s, %s, %s)", n, s1, s2, s3), nil) + yield(fmt.Sprintf("ELT(%s, %s, %s, %s)", n, s1, s2, s3), nil, false) } } } @@ -1424,7 +1444,7 @@ func FnElt(yield Query) { } for _, q := range mysqlDocSamples { - yield(q, nil) + yield(q, nil, false) } } @@ -1433,7 +1453,7 @@ func FnInsert(yield Query) { for _, ns := range insertStrings { for _, l := range inputBitwise { for _, p := range inputBitwise { - yield(fmt.Sprintf("INSERT(%s, %s, %s, %s)", s, p, l, ns), nil) + yield(fmt.Sprintf("INSERT(%s, %s, %s, %s)", s, p, l, ns), nil, false) } } } @@ -1446,53 +1466,53 @@ func FnInsert(yield Query) { } for _, q := range mysqlDocSamples { - yield(q, nil) + yield(q, nil, false) } } func FnLower(yield Query) { for _, str := range inputStrings { - yield(fmt.Sprintf("LOWER(%s)", str), nil) - yield(fmt.Sprintf("LCASE(%s)", str), nil) + yield(fmt.Sprintf("LOWER(%s)", str), nil, false) + yield(fmt.Sprintf("LCASE(%s)", str), nil, false) } } func FnUpper(yield Query) { for _, str := range inputStrings { - yield(fmt.Sprintf("UPPER(%s)", str), nil) - yield(fmt.Sprintf("UCASE(%s)", str), nil) + yield(fmt.Sprintf("UPPER(%s)", str), nil, false) + yield(fmt.Sprintf("UCASE(%s)", str), nil, false) } } func FnCharLength(yield Query) { for _, str := range inputStrings { - yield(fmt.Sprintf("CHAR_LENGTH(%s)", str), nil) - yield(fmt.Sprintf("CHARACTER_LENGTH(%s)", str), nil) + yield(fmt.Sprintf("CHAR_LENGTH(%s)", str), nil, false) + yield(fmt.Sprintf("CHARACTER_LENGTH(%s)", str), nil, false) } } func FnLength(yield Query) { for _, str := range inputStrings { - yield(fmt.Sprintf("LENGTH(%s)", str), nil) - yield(fmt.Sprintf("OCTET_LENGTH(%s)", str), nil) + yield(fmt.Sprintf("LENGTH(%s)", str), nil, false) + yield(fmt.Sprintf("OCTET_LENGTH(%s)", str), nil, false) } } func FnBitLength(yield Query) { for _, str := range inputStrings { - yield(fmt.Sprintf("BIT_LENGTH(%s)", str), nil) + yield(fmt.Sprintf("BIT_LENGTH(%s)", str), nil, false) } } func FnAscii(yield Query) { for _, str := range inputStrings { - yield(fmt.Sprintf("ASCII(%s)", str), nil) + yield(fmt.Sprintf("ASCII(%s)", str), nil, false) } } func FnReverse(yield Query) { for _, str := range inputStrings { - yield(fmt.Sprintf("REVERSE(%s)", str), nil) + yield(fmt.Sprintf("REVERSE(%s)", str), nil, false) } } @@ -1514,13 +1534,13 @@ func FnSpace(yield Query) { } for _, c := range counts { - yield(fmt.Sprintf("SPACE(%s)", c), nil) + yield(fmt.Sprintf("SPACE(%s)", c), nil, false) } } func FnOrd(yield Query) { for _, str := range inputStrings { - yield(fmt.Sprintf("ORD(%s)", str), nil) + yield(fmt.Sprintf("ORD(%s)", str), nil, false) } } @@ -1528,7 +1548,7 @@ func FnRepeat(yield Query) { counts := []string{"-1", "1.9", "3", "1073741825", "'1.9'"} for _, str := range inputStrings { for _, cnt := range counts { - yield(fmt.Sprintf("REPEAT(%s, %s)", str, cnt), nil) + yield(fmt.Sprintf("REPEAT(%s, %s)", str, cnt), nil, false) } } } @@ -1537,7 +1557,7 @@ func FnLeft(yield Query) { counts := []string{"-1", "1.9", "3", "10", "'1.9'"} for _, str := range inputStrings { for _, cnt := range counts { - yield(fmt.Sprintf("LEFT(%s, %s)", str, cnt), nil) + yield(fmt.Sprintf("LEFT(%s, %s)", str, cnt), nil, false) } } } @@ -1547,7 +1567,7 @@ func FnLpad(yield Query) { for _, str := range inputStrings { for _, cnt := range counts { for _, pad := range inputStrings { - yield(fmt.Sprintf("LPAD(%s, %s, %s)", str, cnt, pad), nil) + yield(fmt.Sprintf("LPAD(%s, %s, %s)", str, cnt, pad), nil, false) } } } @@ -1557,7 +1577,7 @@ func FnRight(yield Query) { counts := []string{"-1", "1.9", "3", "10", "'1.9'"} for _, str := range inputStrings { for _, cnt := range counts { - yield(fmt.Sprintf("RIGHT(%s, %s)", str, cnt), nil) + yield(fmt.Sprintf("RIGHT(%s, %s)", str, cnt), nil, false) } } } @@ -1567,7 +1587,7 @@ func FnRpad(yield Query) { for _, str := range inputStrings { for _, cnt := range counts { for _, pad := range inputStrings { - yield(fmt.Sprintf("RPAD(%s, %s, %s)", str, cnt, pad), nil) + yield(fmt.Sprintf("RPAD(%s, %s, %s)", str, cnt, pad), nil, false) } } } @@ -1575,33 +1595,33 @@ func FnRpad(yield Query) { func FnLTrim(yield Query) { for _, str := range inputTrimStrings { - yield(fmt.Sprintf("LTRIM(%s)", str), nil) + yield(fmt.Sprintf("LTRIM(%s)", str), nil, false) } } func FnRTrim(yield Query) { for _, str := range inputTrimStrings { - yield(fmt.Sprintf("RTRIM(%s)", str), nil) + yield(fmt.Sprintf("RTRIM(%s)", str), nil, false) } } func FnTrim(yield Query) { for _, str := range inputTrimStrings { - yield(fmt.Sprintf("TRIM(%s)", str), nil) + yield(fmt.Sprintf("TRIM(%s)", str), nil, false) } modes := []string{"LEADING", "TRAILING", "BOTH"} for _, str := range inputTrimStrings { for _, mode := range modes { - yield(fmt.Sprintf("TRIM(%s FROM %s)", mode, str), nil) + yield(fmt.Sprintf("TRIM(%s FROM %s)", mode, str), nil, false) } } for _, str := range inputTrimStrings { for _, pat := range inputTrimStrings { - yield(fmt.Sprintf("TRIM(%s FROM %s)", pat, str), nil) + yield(fmt.Sprintf("TRIM(%s FROM %s)", pat, str), nil, false) for _, mode := range modes { - yield(fmt.Sprintf("TRIM(%s %s FROM %s)", mode, pat, str), nil) + yield(fmt.Sprintf("TRIM(%s %s FROM %s)", mode, pat, str), nil, false) } } } @@ -1626,15 +1646,15 @@ func FnSubstr(yield Query) { } for _, q := range mysqlDocSamples { - yield(q, nil) + yield(q, nil, false) } for _, str := range inputStrings { for _, i := range radianInputs { - yield(fmt.Sprintf("SUBSTRING(%s, %s)", str, i), nil) + yield(fmt.Sprintf("SUBSTRING(%s, %s)", str, i), nil, false) for _, j := range radianInputs { - yield(fmt.Sprintf("SUBSTRING(%s, %s, %s)", str, i, j), nil) + yield(fmt.Sprintf("SUBSTRING(%s, %s, %s)", str, i, j), nil, false) } } } @@ -1652,17 +1672,17 @@ func FnLocate(yield Query) { } for _, q := range mysqlDocSamples { - yield(q, nil) + yield(q, nil, false) } for _, substr := range locateStrings { for _, str := range locateStrings { - yield(fmt.Sprintf("LOCATE(%s, %s)", substr, str), nil) - yield(fmt.Sprintf("INSTR(%s, %s)", str, substr), nil) - yield(fmt.Sprintf("POSITION(%s IN %s)", str, substr), nil) + yield(fmt.Sprintf("LOCATE(%s, %s)", substr, str), nil, false) + yield(fmt.Sprintf("INSTR(%s, %s)", str, substr), nil, false) + yield(fmt.Sprintf("POSITION(%s IN %s)", str, substr), nil, false) for _, i := range radianInputs { - yield(fmt.Sprintf("LOCATE(%s, %s, %s)", substr, str, i), nil) + yield(fmt.Sprintf("LOCATE(%s, %s, %s)", substr, str, i), nil, false) } } } @@ -1683,13 +1703,13 @@ func FnReplace(yield Query) { } for _, q := range cases { - yield(q, nil) + yield(q, nil, false) } for _, substr := range inputStrings { for _, str := range inputStrings { for _, i := range inputStrings { - yield(fmt.Sprintf("REPLACE(%s, %s, %s)", substr, str, i), nil) + yield(fmt.Sprintf("REPLACE(%s, %s, %s)", substr, str, i), nil, false) } } } @@ -1697,19 +1717,19 @@ func FnReplace(yield Query) { func FnConcat(yield Query) { for _, str := range inputStrings { - yield(fmt.Sprintf("CONCAT(%s)", str), nil) + yield(fmt.Sprintf("CONCAT(%s)", str), nil, false) } for _, str1 := range inputConversions { for _, str2 := range inputConversions { - yield(fmt.Sprintf("CONCAT(%s, %s)", str1, str2), nil) + yield(fmt.Sprintf("CONCAT(%s, %s)", str1, str2), nil, false) } } for _, str1 := range inputStrings { for _, str2 := range inputStrings { for _, str3 := range inputStrings { - yield(fmt.Sprintf("CONCAT(%s, %s, %s)", str1, str2, str3), nil) + yield(fmt.Sprintf("CONCAT(%s, %s, %s)", str1, str2, str3), nil, false) } } } @@ -1717,13 +1737,13 @@ func FnConcat(yield Query) { func FnConcatWs(yield Query) { for _, str := range inputStrings { - yield(fmt.Sprintf("CONCAT_WS(%s, NULL)", str), nil) + yield(fmt.Sprintf("CONCAT_WS(%s, NULL)", str), nil, false) } for _, str1 := range inputConversions { for _, str2 := range inputStrings { for _, str3 := range inputStrings { - yield(fmt.Sprintf("CONCAT_WS(%s, %s, %s)", str1, str2, str3), nil) + yield(fmt.Sprintf("CONCAT_WS(%s, %s, %s)", str1, str2, str3), nil, false) } } } @@ -1731,7 +1751,7 @@ func FnConcatWs(yield Query) { for _, str1 := range inputStrings { for _, str2 := range inputConversions { for _, str3 := range inputStrings { - yield(fmt.Sprintf("CONCAT_WS(%s, %s, %s)", str1, str2, str3), nil) + yield(fmt.Sprintf("CONCAT_WS(%s, %s, %s)", str1, str2, str3), nil, false) } } } @@ -1739,7 +1759,7 @@ func FnConcatWs(yield Query) { for _, str1 := range inputStrings { for _, str2 := range inputStrings { for _, str3 := range inputConversions { - yield(fmt.Sprintf("CONCAT_WS(%s, %s, %s)", str1, str2, str3), nil) + yield(fmt.Sprintf("CONCAT_WS(%s, %s, %s)", str1, str2, str3), nil, false) } } } @@ -1758,13 +1778,13 @@ func FnChar(yield Query) { } for _, q := range mysqlDocSamples { - yield(q, nil) + yield(q, nil, false) } for _, i1 := range radianInputs { for _, i2 := range inputBitwise { for _, i3 := range inputConversions { - yield(fmt.Sprintf("CHAR(%s, %s, %s)", i1, i2, i3), nil) + yield(fmt.Sprintf("CHAR(%s, %s, %s)", i1, i2, i3), nil, false) } } } @@ -1772,15 +1792,15 @@ func FnChar(yield Query) { func FnHex(yield Query) { for _, str := range inputStrings { - yield(fmt.Sprintf("hex(%s)", str), nil) + yield(fmt.Sprintf("hex(%s)", str), nil, false) } for _, str := range inputConversions { - yield(fmt.Sprintf("hex(%s)", str), nil) + yield(fmt.Sprintf("hex(%s)", str), nil, false) } for _, str := range inputBitwise { - yield(fmt.Sprintf("hex(%s)", str), nil) + yield(fmt.Sprintf("hex(%s)", str), nil, false) } } @@ -1800,7 +1820,7 @@ func FnUnhex(yield Query) { } for _, lhs := range inputs { - yield(fmt.Sprintf("UNHEX(%s)", lhs), nil) + yield(fmt.Sprintf("UNHEX(%s)", lhs), nil, false) } } @@ -1812,15 +1832,15 @@ func InStatement(yield Query) { if !(bugs{}).CanCompare(inputs...) { return } - yield(fmt.Sprintf("%s IN (%s, %s)", inputs[0], inputs[1], inputs[2]), nil) - yield(fmt.Sprintf("%s IN (%s, %s)", inputs[2], inputs[1], inputs[0]), nil) - yield(fmt.Sprintf("%s IN (%s, %s)", inputs[1], inputs[0], inputs[2]), nil) - yield(fmt.Sprintf("%s IN (%s, %s, %s)", inputs[0], inputs[1], inputs[2], inputs[0]), nil) + yield(fmt.Sprintf("%s IN (%s, %s)", inputs[0], inputs[1], inputs[2]), nil, false) + yield(fmt.Sprintf("%s IN (%s, %s)", inputs[2], inputs[1], inputs[0]), nil, false) + yield(fmt.Sprintf("%s IN (%s, %s)", inputs[1], inputs[0], inputs[2]), nil, false) + yield(fmt.Sprintf("%s IN (%s, %s, %s)", inputs[0], inputs[1], inputs[2], inputs[0]), nil, false) - yield(fmt.Sprintf("%s NOT IN (%s, %s)", inputs[0], inputs[1], inputs[2]), nil) - yield(fmt.Sprintf("%s NOT IN (%s, %s)", inputs[2], inputs[1], inputs[0]), nil) - yield(fmt.Sprintf("%s NOT IN (%s, %s)", inputs[1], inputs[0], inputs[2]), nil) - yield(fmt.Sprintf("%s NOT IN (%s, %s, %s)", inputs[0], inputs[1], inputs[2], inputs[0]), nil) + yield(fmt.Sprintf("%s NOT IN (%s, %s)", inputs[0], inputs[1], inputs[2]), nil, false) + yield(fmt.Sprintf("%s NOT IN (%s, %s)", inputs[2], inputs[1], inputs[0]), nil, false) + yield(fmt.Sprintf("%s NOT IN (%s, %s)", inputs[1], inputs[0], inputs[2]), nil, false) + yield(fmt.Sprintf("%s NOT IN (%s, %s, %s)", inputs[0], inputs[1], inputs[2], inputs[0]), nil, false) }) } @@ -1843,7 +1863,7 @@ func FnNow(yield Query) { "SYSDATE(1)", "SYSDATE(2)", "SYSDATE(3)", "SYSDATE(4)", "SYSDATE(5)", } for _, fn := range fns { - yield(fn, nil) + yield(fn, nil, false) } } @@ -1855,7 +1875,7 @@ func FnInfo(yield Query) { "VERSION()", } for _, fn := range fns { - yield(fn, nil) + yield(fn, nil, false) } } @@ -1869,7 +1889,7 @@ func FnDateFormat(yield Query) { format := buf.String() for _, d := range inputConversions { - yield(fmt.Sprintf("DATE_FORMAT(%s, %q)", d, format), nil) + yield(fmt.Sprintf("DATE_FORMAT(%s, %q)", d, format), nil, false) } } @@ -1895,7 +1915,7 @@ func FnConvertTz(yield Query) { for _, tzFrom := range timezoneInputs { for _, tzTo := range timezoneInputs { q := fmt.Sprintf("CONVERT_TZ(%s, '%s', '%s')", num1, tzFrom, tzTo) - yield(q, nil) + yield(q, nil, false) } } } @@ -1903,26 +1923,26 @@ func FnConvertTz(yield Query) { func FnDate(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("DATE(%s)", d), nil) + yield(fmt.Sprintf("DATE(%s)", d), nil, false) } } func FnDayOfMonth(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("DAYOFMONTH(%s)", d), nil) - yield(fmt.Sprintf("DAY(%s)", d), nil) + yield(fmt.Sprintf("DAYOFMONTH(%s)", d), nil, false) + yield(fmt.Sprintf("DAY(%s)", d), nil, false) } } func FnDayOfWeek(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("DAYOFWEEK(%s)", d), nil) + yield(fmt.Sprintf("DAYOFWEEK(%s)", d), nil, false) } } func FnDayOfYear(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("DAYOFYEAR(%s)", d), nil) + yield(fmt.Sprintf("DAYOFYEAR(%s)", d), nil, false) } } @@ -1936,21 +1956,21 @@ func FnFromUnixtime(yield Query) { format := buf.String() for _, d := range inputConversions { - yield(fmt.Sprintf("FROM_UNIXTIME(%s)", d), nil) - yield(fmt.Sprintf("FROM_UNIXTIME(%s, %q)", d, format), nil) + yield(fmt.Sprintf("FROM_UNIXTIME(%s)", d), nil, false) + yield(fmt.Sprintf("FROM_UNIXTIME(%s, %q)", d, format), nil, false) } } func FnHour(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("HOUR(%s)", d), nil) + yield(fmt.Sprintf("HOUR(%s)", d), nil, false) } } func FnMakedate(yield Query) { for _, y := range inputConversions { for _, d := range inputConversions { - yield(fmt.Sprintf("MAKEDATE(%s, %s)", y, d), nil) + yield(fmt.Sprintf("MAKEDATE(%s, %s)", y, d), nil, false) } } } @@ -1964,7 +1984,7 @@ func FnMaketime(yield Query) { for _, h := range inputConversions { for _, m := range minutes { for _, s := range inputConversions { - yield(fmt.Sprintf("MAKETIME(%s, %s, %s)", h, m, s), nil) + yield(fmt.Sprintf("MAKETIME(%s, %s, %s)", h, m, s), nil, false) } } } @@ -1972,31 +1992,31 @@ func FnMaketime(yield Query) { func FnMicroSecond(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("MICROSECOND(%s)", d), nil) + yield(fmt.Sprintf("MICROSECOND(%s)", d), nil, false) } } func FnMinute(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("MINUTE(%s)", d), nil) + yield(fmt.Sprintf("MINUTE(%s)", d), nil, false) } } func FnMonth(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("MONTH(%s)", d), nil) + yield(fmt.Sprintf("MONTH(%s)", d), nil, false) } } func FnMonthName(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("MONTHNAME(%s)", d), nil) + yield(fmt.Sprintf("MONTHNAME(%s)", d), nil, false) } } func FnLastDay(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("LAST_DAY(%s)", d), nil) + yield(fmt.Sprintf("LAST_DAY(%s)", d), nil, false) } dates := []string{ @@ -2013,13 +2033,13 @@ func FnLastDay(yield Query) { } for _, d := range dates { - yield(fmt.Sprintf("LAST_DAY(%s)", d), nil) + yield(fmt.Sprintf("LAST_DAY(%s)", d), nil, false) } } func FnToDays(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("TO_DAYS(%s)", d), nil) + yield(fmt.Sprintf("TO_DAYS(%s)", d), nil, false) } dates := []string{ @@ -2037,13 +2057,13 @@ func FnToDays(yield Query) { } for _, d := range dates { - yield(fmt.Sprintf("TO_DAYS(%s)", d), nil) + yield(fmt.Sprintf("TO_DAYS(%s)", d), nil, false) } } func FnFromDays(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("FROM_DAYS(%s)", d), nil) + yield(fmt.Sprintf("FROM_DAYS(%s)", d), nil, false) } days := []string{ @@ -2059,13 +2079,13 @@ func FnFromDays(yield Query) { } for _, d := range days { - yield(fmt.Sprintf("FROM_DAYS(%s)", d), nil) + yield(fmt.Sprintf("FROM_DAYS(%s)", d), nil, false) } } func FnSecToTime(yield Query) { for _, s := range inputConversions { - yield(fmt.Sprintf("SEC_TO_TIME(%s)", s), nil) + yield(fmt.Sprintf("SEC_TO_TIME(%s)", s), nil, false) } mysqlDocSamples := []string{ @@ -2074,13 +2094,13 @@ func FnSecToTime(yield Query) { } for _, q := range mysqlDocSamples { - yield(q, nil) + yield(q, nil, false) } } func FnTimeToSec(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("TIME_TO_SEC(%s)", d), nil) + yield(fmt.Sprintf("TIME_TO_SEC(%s)", d), nil, false) } time := []string{ @@ -2098,13 +2118,13 @@ func FnTimeToSec(yield Query) { } for _, t := range time { - yield(fmt.Sprintf("TIME_TO_SEC(%s)", t), nil) + yield(fmt.Sprintf("TIME_TO_SEC(%s)", t), nil, false) } } func FnToSeconds(yield Query) { for _, t := range inputConversions { - yield(fmt.Sprintf("TO_SECONDS(%s)", t), nil) + yield(fmt.Sprintf("TO_SECONDS(%s)", t), nil, false) } timeInputs := []string{ @@ -2122,7 +2142,7 @@ func FnToSeconds(yield Query) { } for _, t := range timeInputs { - yield(fmt.Sprintf("TO_SECONDS(%s)", t), nil) + yield(fmt.Sprintf("TO_SECONDS(%s)", t), nil, false) } mysqlDocSamples := []string{ @@ -2132,25 +2152,25 @@ func FnToSeconds(yield Query) { } for _, q := range mysqlDocSamples { - yield(q, nil) + yield(q, nil, false) } } func FnQuarter(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("QUARTER(%s)", d), nil) + yield(fmt.Sprintf("QUARTER(%s)", d), nil, false) } } func FnSecond(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("SECOND(%s)", d), nil) + yield(fmt.Sprintf("SECOND(%s)", d), nil, false) } } func FnTime(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("TIME(%s)", d), nil) + yield(fmt.Sprintf("TIME(%s)", d), nil, false) } times := []string{ "'00:00:00'", @@ -2169,108 +2189,108 @@ func FnTime(yield Query) { } for _, d := range times { - yield(fmt.Sprintf("TIME(%s)", d), nil) + yield(fmt.Sprintf("TIME(%s)", d), nil, false) } } func FnUnixTimestamp(yield Query) { - yield("UNIX_TIMESTAMP()", nil) + yield("UNIX_TIMESTAMP()", nil, false) for _, d := range inputConversions { - yield(fmt.Sprintf("UNIX_TIMESTAMP(%s)", d), nil) - yield(fmt.Sprintf("UNIX_TIMESTAMP(%s) + 1", d), nil) + yield(fmt.Sprintf("UNIX_TIMESTAMP(%s)", d), nil, false) + yield(fmt.Sprintf("UNIX_TIMESTAMP(%s) + 1", d), nil, false) } } func FnWeek(yield Query) { for i := 0; i < 16; i++ { for _, d := range inputConversions { - yield(fmt.Sprintf("WEEK(%s, %d)", d, i), nil) + yield(fmt.Sprintf("WEEK(%s, %d)", d, i), nil, false) } } for _, d := range inputConversions { - yield(fmt.Sprintf("WEEK(%s)", d), nil) + yield(fmt.Sprintf("WEEK(%s)", d), nil, false) } } func FnWeekDay(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("WEEKDAY(%s)", d), nil) + yield(fmt.Sprintf("WEEKDAY(%s)", d), nil, false) } } func FnWeekOfYear(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("WEEKOFYEAR(%s)", d), nil) + yield(fmt.Sprintf("WEEKOFYEAR(%s)", d), nil, false) } } func FnYear(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("YEAR(%s)", d), nil) + yield(fmt.Sprintf("YEAR(%s)", d), nil, false) } } func FnYearWeek(yield Query) { for i := 0; i < 8; i++ { for _, d := range inputConversions { - yield(fmt.Sprintf("YEARWEEK(%s, %d)", d, i), nil) + yield(fmt.Sprintf("YEARWEEK(%s, %d)", d, i), nil, false) } } for _, d := range inputConversions { - yield(fmt.Sprintf("YEARWEEK(%s)", d), nil) + yield(fmt.Sprintf("YEARWEEK(%s)", d), nil, false) } } func FnInetAton(yield Query) { for _, d := range ipInputs { - yield(fmt.Sprintf("INET_ATON(%s)", d), nil) + yield(fmt.Sprintf("INET_ATON(%s)", d), nil, false) } } func FnInetNtoa(yield Query) { for _, d := range ipInputs { - yield(fmt.Sprintf("INET_NTOA(%s)", d), nil) - yield(fmt.Sprintf("INET_NTOA(INET_ATON(%s))", d), nil) + yield(fmt.Sprintf("INET_NTOA(%s)", d), nil, false) + yield(fmt.Sprintf("INET_NTOA(INET_ATON(%s))", d), nil, false) } } func FnInet6Aton(yield Query) { for _, d := range ipInputs { - yield(fmt.Sprintf("INET6_ATON(%s)", d), nil) + yield(fmt.Sprintf("INET6_ATON(%s)", d), nil, false) } } func FnInet6Ntoa(yield Query) { for _, d := range ipInputs { - yield(fmt.Sprintf("INET6_NTOA(%s)", d), nil) - yield(fmt.Sprintf("INET6_NTOA(INET6_ATON(%s))", d), nil) + yield(fmt.Sprintf("INET6_NTOA(%s)", d), nil, false) + yield(fmt.Sprintf("INET6_NTOA(INET6_ATON(%s))", d), nil, false) } } func FnIsIPv4(yield Query) { for _, d := range ipInputs { - yield(fmt.Sprintf("IS_IPV4(%s)", d), nil) + yield(fmt.Sprintf("IS_IPV4(%s)", d), nil, false) } } func FnIsIPv4Compat(yield Query) { for _, d := range ipInputs { - yield(fmt.Sprintf("IS_IPV4_COMPAT(%s)", d), nil) - yield(fmt.Sprintf("IS_IPV4_COMPAT(INET6_ATON(%s))", d), nil) + yield(fmt.Sprintf("IS_IPV4_COMPAT(%s)", d), nil, false) + yield(fmt.Sprintf("IS_IPV4_COMPAT(INET6_ATON(%s))", d), nil, false) } } func FnIsIPv4Mapped(yield Query) { for _, d := range ipInputs { - yield(fmt.Sprintf("IS_IPV4_MAPPED(%s)", d), nil) - yield(fmt.Sprintf("IS_IPV4_MAPPED(INET6_ATON(%s))", d), nil) + yield(fmt.Sprintf("IS_IPV4_MAPPED(%s)", d), nil, false) + yield(fmt.Sprintf("IS_IPV4_MAPPED(INET6_ATON(%s))", d), nil, false) } } func FnIsIPv6(yield Query) { for _, d := range ipInputs { - yield(fmt.Sprintf("IS_IPV6(%s)", d), nil) + yield(fmt.Sprintf("IS_IPV6(%s)", d), nil, false) } } @@ -2288,27 +2308,27 @@ func FnBinToUUID(yield Query) { "'2'", } for _, d := range uuidInputs { - yield(fmt.Sprintf("BIN_TO_UUID(%s)", d), nil) + yield(fmt.Sprintf("BIN_TO_UUID(%s)", d), nil, false) } for _, d := range uuidInputs { for _, a := range args { - yield(fmt.Sprintf("BIN_TO_UUID(%s, %s)", d, a), nil) + yield(fmt.Sprintf("BIN_TO_UUID(%s, %s)", d, a), nil, false) } } } func FnIsUUID(yield Query) { for _, d := range uuidInputs { - yield(fmt.Sprintf("IS_UUID(%s)", d), nil) + yield(fmt.Sprintf("IS_UUID(%s)", d), nil, false) } } func FnUUID(yield Query) { - yield("LENGTH(UUID())", nil) - yield("COLLATION(UUID())", nil) - yield("IS_UUID(UUID())", nil) - yield("LENGTH(UUID_TO_BIN(UUID())", nil) + yield("LENGTH(UUID())", nil, false) + yield("COLLATION(UUID())", nil, false) + yield("IS_UUID(UUID())", nil, false) + yield("LENGTH(UUID_TO_BIN(UUID())", nil, false) } func FnUUIDToBin(yield Query) { @@ -2325,12 +2345,12 @@ func FnUUIDToBin(yield Query) { "'2'", } for _, d := range uuidInputs { - yield(fmt.Sprintf("UUID_TO_BIN(%s)", d), nil) + yield(fmt.Sprintf("UUID_TO_BIN(%s)", d), nil, false) } for _, d := range uuidInputs { for _, a := range args { - yield(fmt.Sprintf("UUID_TO_BIN(%s, %s)", d, a), nil) + yield(fmt.Sprintf("UUID_TO_BIN(%s, %s)", d, a), nil, false) } } } @@ -2371,15 +2391,15 @@ func DateMath(yield Query) { } for _, q := range mysqlDocSamples { - yield(q, nil) + yield(q, nil, false) } for _, d := range dates { for _, i := range inputIntervals { for _, v := range intervalValues { - yield(fmt.Sprintf("DATE_ADD(%s, INTERVAL %s %s)", d, v, i), nil) - yield(fmt.Sprintf("DATE_SUB(%s, INTERVAL %s %s)", d, v, i), nil) - yield(fmt.Sprintf("TIMESTAMPADD(%v, %s, %s)", i, v, d), nil) + yield(fmt.Sprintf("DATE_ADD(%s, INTERVAL %s %s)", d, v, i), nil, false) + yield(fmt.Sprintf("DATE_SUB(%s, INTERVAL %s %s)", d, v, i), nil, false) + yield(fmt.Sprintf("TIMESTAMPADD(%v, %s, %s)", i, v, d), nil, false) } } } @@ -2434,15 +2454,15 @@ func RegexpLike(yield Query) { } for _, q := range mysqlDocSamples { - yield(q, nil) + yield(q, nil, false) } for _, i := range regexInputs { for _, p := range regexInputs { - yield(fmt.Sprintf("%s REGEXP %s", i, p), nil) - yield(fmt.Sprintf("%s NOT REGEXP %s", i, p), nil) + yield(fmt.Sprintf("%s REGEXP %s", i, p), nil, false) + yield(fmt.Sprintf("%s NOT REGEXP %s", i, p), nil, false) for _, m := range regexMatchStrings { - yield(fmt.Sprintf("REGEXP_LIKE(%s, %s, %s)", i, p, m), nil) + yield(fmt.Sprintf("REGEXP_LIKE(%s, %s, %s)", i, p, m), nil, false) } } } @@ -2518,7 +2538,7 @@ func RegexpInstr(yield Query) { } for _, q := range mysqlDocSamples { - yield(q, nil) + yield(q, nil, false) } } @@ -2585,7 +2605,7 @@ func RegexpSubstr(yield Query) { } for _, q := range mysqlDocSamples { - yield(q, nil) + yield(q, nil, false) } } @@ -2665,6 +2685,6 @@ func RegexpReplace(yield Query) { } for _, q := range mysqlDocSamples { - yield(q, nil) + yield(q, nil, false) } } diff --git a/go/vt/vtgate/evalengine/testcases/helpers.go b/go/vt/vtgate/evalengine/testcases/helpers.go index a908b8196c8..1c6e92f767e 100644 --- a/go/vt/vtgate/evalengine/testcases/helpers.go +++ b/go/vt/vtgate/evalengine/testcases/helpers.go @@ -29,7 +29,7 @@ import ( querypb "vitess.io/vitess/go/vt/proto/query" ) -type Query func(query string, row []sqltypes.Value) +type Query func(query string, row []sqltypes.Value, skipCollationCheck bool) type Runner func(yield Query) type TestCase struct { Run Runner