Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
systay committed Nov 8, 2024
1 parent 6313ead commit 8c5edf6
Show file tree
Hide file tree
Showing 13 changed files with 821 additions and 621 deletions.
5 changes: 4 additions & 1 deletion go/mysql/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -469,10 +469,13 @@ func (c *Conn) readPacketAsMemBuffer() (mem.Buffer, error) {
return mem.SliceBuffer(data), nil
}

const RawPacketsPos = 20

func updateProtoHeader(b []byte, v int) {
b[0] = byte(protowire.EncodeTag(1, protowire.BytesType))
b[0] = byte(protowire.EncodeTag(RawPacketsPos, protowire.BytesType))
switch {
case v < 1<<28:
// Proto packet data size is 4 bytes.
b[1] = byte((v>>0)&0x7f | 0x80)
b[2] = byte((v>>7)&0x7f | 0x80)
b[3] = byte((v>>14)&0x7f | 0x80)
Expand Down
73 changes: 52 additions & 21 deletions go/mysql/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -432,14 +432,15 @@ func (c *Conn) ReadQueryResultAsSliceBuffer(maxrows int) (*sqltypes.Result, bool
var packetOk PacketOK

// Get the result.
colNumber, err := c.readComQueryResponse(&packetOk)
first, colNumber, err := c.readComQueryResponseAsMemBuf(&packetOk)
if err != nil {
return nil, false, 0, err
}
more := packetOk.statusFlags&ServerMoreResultsExists != 0
warnings := packetOk.warnings
if colNumber == 0 {
// OK packet, means no results. Just use the numbers.
first.Free()
return &sqltypes.Result{
RowsAffected: packetOk.affectedRows,
InsertID: packetOk.lastInsertID,
Expand All @@ -449,7 +450,7 @@ func (c *Conn) ReadQueryResultAsSliceBuffer(maxrows int) (*sqltypes.Result, bool
}, more, warnings, nil
}

var rawPackets []mem.Buffer
rawPackets := []mem.Buffer{first}
var data mem.Buffer

defer func() {
Expand All @@ -465,8 +466,7 @@ func (c *Conn) ReadQueryResultAsSliceBuffer(maxrows int) (*sqltypes.Result, bool
for i := 0; i < colNumber; i++ {
data, err = c.readPacketAsMemBuffer()
if err != nil {
err = sqlerror.NewSQLError(sqlerror.CRMalformedPacket, "", "")
return nil, false, 0, err
return nil, false, 0, sqlerror.NewSQLError(sqlerror.CRMalformedPacket, "", "")
}
rawPackets = append(rawPackets, data)
}
Expand All @@ -475,18 +475,16 @@ func (c *Conn) ReadQueryResultAsSliceBuffer(maxrows int) (*sqltypes.Result, bool
// EOF is only present here if it's not deprecated.
data, err = c.readPacketAsMemBuffer()
if err != nil {
err = sqlerror.NewSQLError(sqlerror.CRServerLost, sqlerror.SSUnknownSQLState, err.Error())
return nil, false, 0, err
return nil, false, 0, sqlerror.NewSQLError(sqlerror.CRServerLost, sqlerror.SSUnknownSQLState, err.Error())
}
rawPackets = append(rawPackets, data)
defer data.Free()

if c.isEOFPacket(data.ReadOnlyData()) {
rawPackets = rawPackets[:len(rawPackets)-1]
// empty by design
} else if isErrorPacket(data.ReadOnlyData()) {
err = ParseErrorPacket(data.ReadOnlyData())
return nil, false, 0, err
return nil, false, 0, ParseErrorPacket(data.ReadOnlyData())
} else {
err = vterrors.Errorf(vtrpc.Code_INTERNAL, "unexpected packet after fields: %v", data)
return nil, false, 0, err
return nil, false, 0, vterrors.Errorf(vtrpc.Code_INTERNAL, "unexpected packet after fields: %v", data)
}
}

Expand All @@ -496,8 +494,7 @@ func (c *Conn) ReadQueryResultAsSliceBuffer(maxrows int) (*sqltypes.Result, bool
for {
data, err = c.readPacketAsMemBuffer()
if err != nil {
err = sqlerror.NewSQLError(sqlerror.CRServerLost, sqlerror.SSUnknownSQLState, err.Error())
return nil, false, 0, err
return nil, false, 0, sqlerror.NewSQLError(sqlerror.CRServerLost, sqlerror.SSUnknownSQLState, err.Error())
}
rawPackets = append(rawPackets, data)

Expand All @@ -514,8 +511,7 @@ func (c *Conn) ReadQueryResultAsSliceBuffer(maxrows int) (*sqltypes.Result, bool
}
more = (statusFlags & ServerMoreResultsExists) != 0
result.StatusFlags = statusFlags

rawPackets = rawPackets[:len(rawPackets)-1]
// rawPackets = rawPackets[:len(rawPackets)-1]
} else {
var packetEof PacketOK
if err = c.parseOKPacket(&packetEof, data.ReadOnlyData()); err != nil {
Expand All @@ -525,7 +521,7 @@ func (c *Conn) ReadQueryResultAsSliceBuffer(maxrows int) (*sqltypes.Result, bool
more = (packetEof.statusFlags & ServerMoreResultsExists) != 0
result.StatusFlags = packetEof.statusFlags

rawPackets = rawPackets[:len(rawPackets)-1]
// rawPackets = rawPackets[:len(rawPackets)-1]
result.SessionStateChanges = packetEof.sessionStateData
result.Info = packetEof.info
}
Expand All @@ -536,8 +532,7 @@ func (c *Conn) ReadQueryResultAsSliceBuffer(maxrows int) (*sqltypes.Result, bool

} else if isErrorPacket(data.ReadOnlyData()) {
// Error packet.
err = ParseErrorPacket(data.ReadOnlyData())
return nil, false, 0, err
return nil, false, 0, ParseErrorPacket(data.ReadOnlyData())
}

if maxrows == FETCH_NO_ROWS {
Expand All @@ -549,8 +544,7 @@ func (c *Conn) ReadQueryResultAsSliceBuffer(maxrows int) (*sqltypes.Result, bool
if err = c.drainResults(); err != nil {
return nil, false, 0, err
}
err = vterrors.Errorf(vtrpc.Code_ABORTED, "Row count exceeded %d", maxrows)
return nil, false, 0, err
return nil, false, 0, vterrors.Errorf(vtrpc.Code_ABORTED, "Row count exceeded %d", maxrows)
}

rowcount++
Expand Down Expand Up @@ -738,6 +732,43 @@ func (c *Conn) readComQueryResponse(packetOk *PacketOK) (int, error) {
return int(n), nil
}

func (c *Conn) readComQueryResponseAsMemBuf(packetOk *PacketOK) (buf mem.Buffer, res int, err error) {
defer func() {
if buf != nil && err != nil {
buf.Free()
buf = nil
}
}()
buf, err = c.readPacketAsMemBuffer()
if err != nil {
return buf, 0, sqlerror.NewSQLErrorf(sqlerror.CRServerLost, sqlerror.SSUnknownSQLState, "%v", err)
}
defer c.recycleReadPacket()
data := buf.ReadOnlyData()
if len(data) == 0 {
return buf, 0, sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "invalid empty COM_QUERY response packet")
}

switch data[0] {
case OKPacket:
return buf, 0, c.parseOKPacket(packetOk, data)
case ErrPacket:
// Error
return buf, 0, ParseErrorPacket(data)
case 0xfb:
// Local infile
return buf, 0, vterrors.Errorf(vtrpc.Code_UNIMPLEMENTED, "not implemented")
}
n, pos, ok := readLenEncInt(data, 0)
if !ok {
return buf, 0, sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "cannot get column number")
}
if pos != len(data) {
return buf, 0, sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "extra data in COM_QUERY response")
}
return buf, int(n), nil
}

//
// Server side methods.
//
Expand Down
17 changes: 7 additions & 10 deletions go/mysql/result.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,19 @@ limitations under the License.
package mysql

import (
"vitess.io/vitess/go/mysql/sqlerror"
"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
)

// ParseResult converts the raw packets in a QueryResult to a sqltypes.Result.
func ParseResult(qr *querypb.QueryResult, wantfields bool) (*sqltypes.Result, error) {
if qr.RawPackets == nil {
return sqltypes.Proto3ToResult(qr), nil
func ParseResultFoo(qr *querypb.ExecuteResponse, wantfields bool) (*sqltypes.Result, error) {
if len(qr.RawPackets) == 0 {
return sqltypes.Proto3ToResult(qr.Result), nil
}

var colcount int
for i, p := range qr.RawPackets {
if len(p) == 0 {
colcount = i
break
}
colcount, _, ok := readLenEncInt(qr.RawPackets[0], 0)
if !ok {
return nil, sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "cannot get column number")
}

var err error
Expand Down
3 changes: 0 additions & 3 deletions go/sqltypes/proto3.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,6 @@ func Proto3ToResult(qr *querypb.QueryResult) *Result {
if qr == nil {
return nil
}
if qr.RawPackets != nil {
panic("Proto3ToResult with raw mysql packets")
}
return &Result{
Fields: qr.Fields,
RowsAffected: qr.RowsAffected,
Expand Down
Loading

0 comments on commit 8c5edf6

Please sign in to comment.