diff --git a/server/conn.go b/server/conn.go index e71d9d42c..03521ecb2 100644 --- a/server/conn.go +++ b/server/conn.go @@ -28,6 +28,7 @@ type Conn struct { credentialProvider CredentialProvider user string password string + db string cachingSha2FullAuth bool h Handler @@ -71,8 +72,8 @@ func NewConn(conn net.Conn, user string, password string, h Handler) (*Conn, err return c, nil } -// NewCustomizedConn: create connection with customized server settings -func NewCustomizedConn(conn net.Conn, serverConf *Server, p CredentialProvider, h Handler) (*Conn, error) { +// MakeConn creates a new server side connection without performing the handshake. +func MakeConn(conn net.Conn, serverConf *Server, p CredentialProvider, h Handler) *Conn { var packetConn *packet.Conn if serverConf.tlsConfig != nil { packetConn = packet.NewTLSConn(conn) @@ -91,6 +92,13 @@ func NewCustomizedConn(conn net.Conn, serverConf *Server, p CredentialProvider, } c.closed.Set(false) + return c +} + +// NewCustomizedConn: create connection with customized server settings +func NewCustomizedConn(conn net.Conn, serverConf *Server, p CredentialProvider, h Handler) (*Conn, error) { + c := MakeConn(conn, serverConf, p, h) + if err := c.handshake(); err != nil { c.Close() return nil, err diff --git a/server/handshake_resp.go b/server/handshake_resp.go index 762843865..bfc7b3664 100644 --- a/server/handshake_resp.go +++ b/server/handshake_resp.go @@ -128,6 +128,8 @@ func (c *Conn) readDb(data []byte, pos int) (int, error) { if err := c.h.UseDB(db); err != nil { return 0, err } + + c.db = db } return pos, nil } diff --git a/server/teleport.go b/server/teleport.go new file mode 100644 index 000000000..b8bf539aa --- /dev/null +++ b/server/teleport.go @@ -0,0 +1,25 @@ +package server + +import ( + . "github.com/go-mysql-org/go-mysql/mysql" +) + +func (c *Conn) WriteInitialHandshake() error { + return c.writeInitialHandshake() +} + +func (c *Conn) ReadHandshakeResponse() error { + return c.readHandshakeResponse() +} + +func (c *Conn) GetDatabase() string { + return c.db +} + +func (c *Conn) WriteOK(r *Result) error { + return c.writeOK(r) +} + +func (c *Conn) WriteError(e error) error { + return c.writeError(e) +}