Skip to content

Commit

Permalink
refactor: move scale and size calculation inside type aggregation
Browse files Browse the repository at this point in the history
Signed-off-by: Manan Gupta <[email protected]>
  • Loading branch information
GuptaManan100 committed Apr 24, 2024
1 parent e65487d commit 4786466
Show file tree
Hide file tree
Showing 11 changed files with 103 additions and 15 deletions.
8 changes: 8 additions & 0 deletions go/mysql/json/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,14 @@ type Value struct {
n NumberType
}

func (v *Value) Size() int32 {
return 0
}

func (v *Value) Scale() int32 {
return 0
}

func (v *Value) MarshalDate() string {
if d, ok := v.Date(); ok {
return d.ToStdTime(time.Local).Format("2006-01-02")
Expand Down
23 changes: 17 additions & 6 deletions go/vt/vtgate/evalengine/api_type_aggregation.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ type typeAggregation struct {
blob uint16
total uint16

nullable bool
nullable bool
scale, size int32
}

type TypeAggregator struct {
Expand All @@ -63,7 +64,7 @@ func (ta *TypeAggregator) Add(typ Type, env *collations.Environment) error {
return nil
}

ta.types.addNullable(typ.typ, typ.nullable)
ta.types.addNullable(typ.typ, typ.nullable, typ.size, typ.scale)
if err := ta.collations.add(typedCoercionCollation(typ.typ, typ.collation), env); err != nil {
return err
}
Expand Down Expand Up @@ -105,10 +106,10 @@ func (ta *typeAggregation) addEval(e eval) {
default:
t = e.SQLType()
}
ta.add(t, f)
ta.add(t, f, e.Size(), e.Scale())
}

func (ta *typeAggregation) addNullable(typ sqltypes.Type, nullable bool) {
func (ta *typeAggregation) addNullable(typ sqltypes.Type, nullable bool, size, scale int32) {
var flag typeFlag
if typ == sqltypes.HexVal || typ == sqltypes.HexNum {
typ = sqltypes.Binary
Expand All @@ -117,13 +118,15 @@ func (ta *typeAggregation) addNullable(typ sqltypes.Type, nullable bool) {
if nullable {
flag |= flagNullable
}
ta.add(typ, flag)
ta.add(typ, flag, size, scale)
}

func (ta *typeAggregation) add(tt sqltypes.Type, f typeFlag) {
func (ta *typeAggregation) add(tt sqltypes.Type, f typeFlag, size, scale int32) {
if f&flagNullable != 0 {
ta.nullable = true
}
ta.size = max(ta.size, size)
ta.scale = max(ta.scale, scale)
switch tt {
case sqltypes.Float32, sqltypes.Float64:
ta.double++
Expand Down Expand Up @@ -190,6 +193,14 @@ func nextSignedTypeForUnsigned(t sqltypes.Type) sqltypes.Type {
}
}

func (ta *typeAggregation) Size() int32 {
return ta.size
}

func (ta *typeAggregation) Scale() int32 {
return ta.scale
}

func (ta *typeAggregation) result() sqltypes.Type {
/*
If all types are numeric, the aggregated type is also numeric:
Expand Down
2 changes: 2 additions & 0 deletions go/vt/vtgate/evalengine/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ func (f typeFlag) Nullable() bool {
type eval interface {
ToRawBytes() []byte
SQLType() sqltypes.Type
Size() int32
Scale() int32
}

type hashable interface {
Expand Down
8 changes: 8 additions & 0 deletions go/vt/vtgate/evalengine/eval_bytes.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,14 @@ func (e *evalBytes) SQLType() sqltypes.Type {
return sqltypes.Type(e.tt)
}

func (e *evalBytes) Size() int32 {
return 0
}

func (e *evalBytes) Scale() int32 {
return 0
}

func (e *evalBytes) ToRawBytes() []byte {
return e.bytes
}
Expand Down
8 changes: 8 additions & 0 deletions go/vt/vtgate/evalengine/eval_enum.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ func (e *evalEnum) SQLType() sqltypes.Type {
return sqltypes.Enum
}

func (e *evalEnum) Size() int32 {
return 0
}

func (e *evalEnum) Scale() int32 {
return 0
}

func valueIdx(values *EnumSetValues, value string) int {
if values == nil {
return -1
Expand Down
32 changes: 32 additions & 0 deletions go/vt/vtgate/evalengine/eval_numeric.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,14 @@ func (e *evalInt64) SQLType() sqltypes.Type {
return sqltypes.Int64
}

func (e *evalInt64) Size() int32 {
return 0
}

func (e *evalInt64) Scale() int32 {
return 0
}

func (e *evalInt64) ToRawBytes() []byte {
return strconv.AppendInt(nil, e.i, 10)
}
Expand Down Expand Up @@ -409,6 +417,14 @@ func (e *evalUint64) SQLType() sqltypes.Type {
return sqltypes.Uint64
}

func (e *evalUint64) Size() int32 {
return 0
}

func (e *evalUint64) Scale() int32 {
return 0
}

func (e *evalUint64) ToRawBytes() []byte {
return strconv.AppendUint(nil, e.u, 10)
}
Expand Down Expand Up @@ -452,6 +468,14 @@ func (e *evalFloat) SQLType() sqltypes.Type {
return sqltypes.Float64
}

func (e *evalFloat) Size() int32 {
return 0
}

func (e *evalFloat) Scale() int32 {
return 0
}

func (e *evalFloat) ToRawBytes() []byte {
return format.FormatFloat(e.f)
}
Expand Down Expand Up @@ -528,6 +552,14 @@ func (e *evalDecimal) SQLType() sqltypes.Type {
return sqltypes.Decimal
}

func (e *evalDecimal) Size() int32 {
return e.length
}

func (e *evalDecimal) Scale() int32 {
return -e.dec.Exponent()
}

func (e *evalDecimal) ToRawBytes() []byte {
return e.dec.FormatMySQL(e.length)
}
Expand Down
8 changes: 8 additions & 0 deletions go/vt/vtgate/evalengine/eval_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ func (e *evalSet) SQLType() sqltypes.Type {
return sqltypes.Set
}

func (e *evalSet) Size() int32 {
return 0
}

func (e *evalSet) Scale() int32 {
return 0
}

func evalSetBits(values *EnumSetValues, value string) uint64 {
if values != nil && len(*values) > 64 {
// This never would happen as MySQL limits SET
Expand Down
8 changes: 8 additions & 0 deletions go/vt/vtgate/evalengine/eval_temporal.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ func (e *evalTemporal) SQLType() sqltypes.Type {
return e.t
}

func (e *evalTemporal) Size() int32 {
return 0
}

func (e *evalTemporal) Scale() int32 {
return 0
}

func (e *evalTemporal) toInt64() int64 {
switch e.SQLType() {
case sqltypes.Date:
Expand Down
8 changes: 8 additions & 0 deletions go/vt/vtgate/evalengine/eval_tuple.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,11 @@ func (e *evalTuple) ToRawBytes() []byte {
func (e *evalTuple) SQLType() sqltypes.Type {
return sqltypes.Tuple
}

func (e *evalTuple) Size() int32 {
return 0
}

func (e *evalTuple) Scale() int32 {
return 0
}
11 changes: 3 additions & 8 deletions go/vt/vtgate/evalengine/expr_logical.go
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,6 @@ func (c *CaseExpr) simplify(env *ExpressionEnv) error {
func (cs *CaseExpr) compile(c *compiler) (ctype, error) {
var ca collationAggregation
var ta typeAggregation
var scale, size int32

for _, wt := range cs.cases {
when, err := wt.when.compile(c)
Expand All @@ -691,9 +690,7 @@ func (cs *CaseExpr) compile(c *compiler) (ctype, error) {
return ctype{}, err
}

ta.add(then.Type, then.Flag)
scale = max(scale, then.Scale)
size = max(size, then.Size)
ta.add(then.Type, then.Flag, then.Size, then.Scale)
if err := ca.add(then.Col, c.env.CollationEnv()); err != nil {
return ctype{}, err
}
Expand All @@ -705,9 +702,7 @@ func (cs *CaseExpr) compile(c *compiler) (ctype, error) {
return ctype{}, err
}

ta.add(els.Type, els.Flag)
scale = max(scale, els.Scale)
size = max(size, els.Size)
ta.add(els.Type, els.Flag, els.Size, els.Scale)
if err := ca.add(els.Col, c.env.CollationEnv()); err != nil {
return ctype{}, err
}
Expand All @@ -717,7 +712,7 @@ func (cs *CaseExpr) compile(c *compiler) (ctype, error) {
if ta.nullable {
f |= flagNullable
}
ct := ctype{Type: ta.result(), Flag: f, Col: ca.result(), Scale: scale, Size: size}
ct := ctype{Type: ta.result(), Flag: f, Col: ca.result(), Scale: ta.Scale(), Size: ta.Size()}
c.asm.CmpCase(len(cs.cases), cs.Else != nil, ct.Type, ct.Col, c.sqlmode.AllowZeroDate())
return ct, nil
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/evalengine/fn_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func (b *builtinCoalesce) compile(c *compiler) (ctype, error) {
if !tt.nullable() {
f = 0
}
ta.add(tt.Type, tt.Flag)
ta.add(tt.Type, tt.Flag, tt.Size, tt.Scale)
if err := ca.add(tt.Col, c.env.CollationEnv()); err != nil {
return ctype{}, err
}
Expand Down

0 comments on commit 4786466

Please sign in to comment.