diff --git a/README.md b/README.md index 1c5af84..ac0dc15 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ [![Go Reference](https://pkg.go.dev/badge/github.com/jeroenrinzema/psql-wire.svg)](https://pkg.go.dev/github.com/jeroenrinzema/psql-wire) [![Latest release](https://img.shields.io/github/release/jeroenrinzema/psql-wire.svg)](https://github.com/jeroenrinzema/psql-wire/releases) [![Go Report Card](https://goreportcard.com/badge/github.com/jeroenrinzema/psql-wire)](https://goreportcard.com/report/github.com/jeroenrinzema/psql-wire) A pure Go [PostgreSQL](https://www.postgresql.org/) server wire protocol implementation. -Build your own PostgreSQL server with 15 lines of code. +Build your own PostgreSQL server within a few lines of code. This project attempts to make it as straight forward as possible to set-up and configure your own PSQL server. Feel free to check out the [examples](https://github.com/jeroenrinzema/psql-wire/tree/main/examples) directory for various ways on how to configure/set-up your own server. @@ -21,10 +21,17 @@ import ( ) func main() { - wire.ListenAndServe("127.0.0.1:5432", func(ctx context.Context, query string, writer wire.DataWriter, parameters []string) error { - fmt.Println(query) + wire.ListenAndServe("127.0.0.1:5432", handler) +} + +func handler(ctx context.Context, query string) (wire.PreparedStatementFn, []oid.Oid, wire.Columns, error) { + fmt.Println(query) + + statement := func(ctx context.Context, writer wire.DataWriter, parameters []string) error { return writer.Complete("OK") - }) + } + + return statement, nil, nil, nil } ``` diff --git a/auth_test.go b/auth_test.go index 2fac01d..3cc487d 100644 --- a/auth_test.go +++ b/auth_test.go @@ -17,8 +17,8 @@ func TestDefaultHandleAuth(t *testing.T) { sink := bytes.NewBuffer([]byte{}) ctx := context.Background() - reader := buffer.NewReader(input, buffer.DefaultBufferSize) - writer := buffer.NewWriter(sink) + reader := buffer.NewReader(zap.NewNop(), input, buffer.DefaultBufferSize) + writer := buffer.NewWriter(zap.NewNop(), sink) server := &Server{logger: zap.NewNop()} err := server.handleAuth(ctx, reader, writer) @@ -26,7 +26,7 @@ func TestDefaultHandleAuth(t *testing.T) { t.Fatal(err) } - result := buffer.NewReader(sink, buffer.DefaultBufferSize) + result := buffer.NewReader(zap.NewNop(), sink, buffer.DefaultBufferSize) ty, ln, err := result.ReadTypedMsg() if err != nil { t.Fatal(err) @@ -54,7 +54,7 @@ func TestClearTextPassword(t *testing.T) { expected := "password" input := bytes.NewBuffer([]byte{}) - incoming := buffer.NewWriter(input) + incoming := buffer.NewWriter(zap.NewNop(), input) // NOTE: we could reuse the server buffered writer to write client messages incoming.Start(types.ServerMessage(types.ClientPassword)) @@ -73,8 +73,8 @@ func TestClearTextPassword(t *testing.T) { sink := bytes.NewBuffer([]byte{}) ctx := context.Background() - reader := buffer.NewReader(input, buffer.DefaultBufferSize) - writer := buffer.NewWriter(sink) + reader := buffer.NewReader(zap.NewNop(), input, buffer.DefaultBufferSize) + writer := buffer.NewWriter(zap.NewNop(), sink) server := &Server{logger: zap.NewNop(), Auth: ClearTextPassword(validate)} err := server.handleAuth(ctx, reader, writer) diff --git a/cache.go b/cache.go index 1d58abf..44b4527 100644 --- a/cache.go +++ b/cache.go @@ -3,30 +3,44 @@ package wire import ( "context" "sync" + + "github.com/jeroenrinzema/psql-wire/internal/buffer" + "github.com/lib/pq/oid" ) +type Statement struct { + fn PreparedStatementFn + parameters []oid.Oid + columns Columns +} + type DefaultStatementCache struct { - statements map[string]PreparedStatementFn + statements map[string]*Statement mu sync.RWMutex } // Set attempts to bind the given statement to the given name. Any // previously defined statement is overridden. -func (cache *DefaultStatementCache) Set(ctx context.Context, name string, fn PreparedStatementFn) error { +func (cache *DefaultStatementCache) Set(ctx context.Context, name string, fn PreparedStatementFn, parameters []oid.Oid, columns Columns) error { cache.mu.Lock() defer cache.mu.Unlock() if cache.statements == nil { - cache.statements = map[string]PreparedStatementFn{} + cache.statements = map[string]*Statement{} + } + + cache.statements[name] = &Statement{ + fn: fn, + parameters: parameters, + columns: columns, } - cache.statements[name] = fn return nil } // Get attempts to get the prepared statement for the given name. An error // is returned when no statement has been found. -func (cache *DefaultStatementCache) Get(ctx context.Context, name string) (PreparedStatementFn, error) { +func (cache *DefaultStatementCache) Get(ctx context.Context, name string) (*Statement, error) { cache.mu.RLock() defer cache.mu.RUnlock() @@ -34,11 +48,16 @@ func (cache *DefaultStatementCache) Get(ctx context.Context, name string) (Prepa return nil, nil } - return cache.statements[name], nil + stmt, has := cache.statements[name] + if !has { + return nil, nil + } + + return stmt, nil } type portal struct { - statement PreparedStatementFn + statement *Statement parameters []string } @@ -47,7 +66,7 @@ type DefaultPortalCache struct { mu sync.RWMutex } -func (cache *DefaultPortalCache) Bind(ctx context.Context, name string, fn PreparedStatementFn, parametes []string) error { +func (cache *DefaultPortalCache) Bind(ctx context.Context, name string, stmt *Statement, parameters []string) error { cache.mu.Lock() defer cache.mu.Unlock() @@ -56,14 +75,30 @@ func (cache *DefaultPortalCache) Bind(ctx context.Context, name string, fn Prepa } cache.portals[name] = portal{ - statement: fn, - parameters: parametes, + statement: stmt, + parameters: parameters, } return nil } -func (cache *DefaultPortalCache) Execute(ctx context.Context, name string, writer DataWriter) error { +func (cache *DefaultPortalCache) Get(ctx context.Context, name string) (*Statement, error) { + cache.mu.Lock() + defer cache.mu.Unlock() + + if cache.portals == nil { + return nil, nil + } + + portal, has := cache.portals[name] + if !has { + return nil, nil + } + + return portal.statement, nil +} + +func (cache *DefaultPortalCache) Execute(ctx context.Context, name string, writer *buffer.Writer) error { cache.mu.Lock() defer cache.mu.Unlock() @@ -72,5 +107,5 @@ func (cache *DefaultPortalCache) Execute(ctx context.Context, name string, write return nil } - return portal.statement(ctx, writer, portal.parameters) + return portal.statement.fn(ctx, NewDataWriter(ctx, portal.statement.columns, writer), portal.parameters) } diff --git a/command.go b/command.go index 2cdff5b..6664eca 100644 --- a/command.go +++ b/command.go @@ -143,6 +143,7 @@ func (srv *Server) handleCommand(ctx context.Context, conn net.Conn, t types.Cli // to the backend; the format code fields in the RowDescription message // will be zeroes in this case. // https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY + return srv.handleDescribe(ctx, reader, writer) case types.ClientSync: // TODO: Include the ability to catch sync messages in order to // close the current transaction. @@ -185,12 +186,10 @@ func (srv *Server) handleCommand(ctx context.Context, conn net.Conn, t types.Cli // https://github.com/postgres/postgres/blob/6e1dd2773eb60a6ab87b27b8d9391b756e904ac3/src/backend/tcop/postgres.c#L4295 return readyForQuery(writer, types.ServerIdle) case types.ClientClose: - err = srv.handleConnClose(ctx) - if err != nil { - return err - } - - return conn.Close() + // TODO: close the statement or portal + writer.Start(types.ServerCloseComplete) //nolint:errcheck + writer.End() //nolint:errcheck + return readyForQuery(writer, types.ServerIdle) case types.ClientTerminate: err = srv.handleConnTerminate(ctx) if err != nil { @@ -206,12 +205,10 @@ func (srv *Server) handleCommand(ctx context.Context, conn net.Conn, t types.Cli default: return ErrorCode(writer, NewErrUnimplementedMessageType(t)) } - - return nil } func (srv *Server) handleSimpleQuery(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error { - if srv.Parse == nil { + if srv.parse == nil { return ErrorCode(writer, NewErrUnimplementedMessageType(types.ClientSimpleQuery)) } @@ -236,16 +233,22 @@ func (srv *Server) handleSimpleQuery(ctx context.Context, reader *buffer.Reader, return readyForQuery(writer, types.ServerIdle) } - statement, _, err := srv.Parse(ctx, query) + statement, _, columns, err := srv.parse(ctx, query) if err != nil { - return err + return ErrorCode(writer, err) } if err != nil { return ErrorCode(writer, err) } - err = statement(ctx, NewDataWriter(ctx, writer), nil) + // NOTE: we have to define the column definitions before executing a simple query + err = columns.Define(ctx, writer) + if err != nil { + return ErrorCode(writer, err) + } + + err = statement(ctx, NewDataWriter(ctx, columns, writer), nil) if err != nil { return ErrorCode(writer, err) } @@ -254,7 +257,7 @@ func (srv *Server) handleSimpleQuery(ctx context.Context, reader *buffer.Reader, } func (srv *Server) handleParse(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error { - if srv.Parse == nil || srv.Statements == nil { + if srv.parse == nil || srv.Statements == nil { return ErrorCode(writer, NewErrUnimplementedMessageType(types.ClientParse)) } @@ -285,19 +288,14 @@ func (srv *Server) handleParse(ctx context.Context, reader *buffer.Reader, write // `reader.GetUint32()` } - statement, descriptions, err := srv.Parse(ctx, query) + statement, params, columns, err := srv.parse(ctx, query) if err != nil { return ErrorCode(writer, err) } - srv.logger.Debug("incoming extended query", zap.String("query", query), zap.String("name", name), zap.Int("parameters", len(descriptions))) - - err = srv.writeParameterDescriptions(writer, descriptions) - if err != nil { - return err - } + srv.logger.Debug("incoming extended query", zap.String("query", query), zap.String("name", name), zap.Int("parameters", len(params))) - err = srv.Statements.Set(ctx, name, statement) + err = srv.Statements.Set(ctx, name, statement, params, columns) if err != nil { return ErrorCode(writer, err) } @@ -306,11 +304,46 @@ func (srv *Server) handleParse(ctx context.Context, reader *buffer.Reader, write return writer.End() } -func (srv *Server) writeParameterDescriptions(writer *buffer.Writer, parameters []oid.Oid) error { - if len(parameters) == 0 { - return nil +func (srv *Server) handleDescribe(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error { + d, err := reader.GetBytes(1) + if err != nil { + return err + } + + name, err := reader.GetString() + if err != nil { + return err } + var statement *Statement + + switch d[0] { + case 'S': + statement, err = srv.Statements.Get(ctx, name) + if err != nil { + return err + } + case 'P': + statement, err = srv.Portals.Get(ctx, name) + if err != nil { + return err + } + } + + if statement == nil { + return ErrorCode(writer, errors.New("unknown statement")) + } + + err = srv.writeParameterDescription(writer, statement.parameters) + if err != nil { + return err + } + + return srv.writeColumnDescription(writer, statement.columns) +} + +// https://www.postgresql.org/docs/15/protocol-message-formats.html +func (srv *Server) writeParameterDescription(writer *buffer.Writer, parameters []oid.Oid) error { writer.Start(types.ServerParameterDescription) writer.AddInt16(int16(len(parameters))) @@ -321,6 +354,33 @@ func (srv *Server) writeParameterDescriptions(writer *buffer.Writer, parameters return writer.End() } +// writeColumnDescription attempts to write the statement column descriptions +// back to the writer buffer. Information about the returned columns is written +// to the client. +// https://www.postgresql.org/docs/15/protocol-message-formats.html +func (srv *Server) writeColumnDescription(writer *buffer.Writer, columns Columns) error { + if len(columns) == 0 { + writer.Start(types.ServerNoData) + return writer.End() + } + + writer.Start(types.ServerRowDescription) + writer.AddInt16(int16(len(columns))) + + for _, column := range columns { + writer.AddString(column.Name) + writer.AddNullTerminate() + writer.AddInt32(column.ID) + writer.AddInt16(column.Attr) + writer.AddInt32(int32(column.Oid)) + writer.AddInt16(column.Width) + writer.AddInt32(column.TypeModifier) + writer.AddInt16(0) // NOTE: the format code is not known yet and will always be zero + } + + return writer.End() +} + func (srv *Server) handleBind(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error { name, err := reader.GetString() if err != nil { @@ -337,12 +397,12 @@ func (srv *Server) handleBind(ctx context.Context, reader *buffer.Reader, writer return err } - fn, err := srv.Statements.Get(ctx, statement) + stmt, err := srv.Statements.Get(ctx, statement) if err != nil { return err } - err = srv.Portals.Bind(ctx, name, fn, parameters) + err = srv.Portals.Bind(ctx, name, stmt, parameters) if err != nil { return err } @@ -454,7 +514,7 @@ func (srv *Server) handleExecute(ctx context.Context, reader *buffer.Reader, wri } srv.logger.Debug("executing", zap.String("name", name), zap.Uint32("limit", limit)) - err = srv.Portals.Execute(ctx, name, NewDataWriter(ctx, writer)) + err = srv.Portals.Execute(ctx, name, writer) if err != nil { return ErrorCode(writer, err) } @@ -462,14 +522,6 @@ func (srv *Server) handleExecute(ctx context.Context, reader *buffer.Reader, wri return nil } -func (srv *Server) handleConnClose(ctx context.Context) error { - if srv.CloseConn == nil { - return nil - } - - return srv.CloseConn(ctx) -} - func (srv *Server) handleConnTerminate(ctx context.Context) error { if srv.TerminateConn == nil { return nil diff --git a/command_test.go b/command_test.go index cc504c8..6c185f2 100644 --- a/command_test.go +++ b/command_test.go @@ -12,11 +12,11 @@ import ( "github.com/jeroenrinzema/psql-wire/internal/types" "github.com/lib/pq/oid" "github.com/stretchr/testify/assert" - "go.uber.org/zap" + "go.uber.org/zap/zaptest" ) func TestMessageSizeExceeded(t *testing.T) { - server, err := NewServer() + server, err := NewServer(nil, Logger(zaptest.NewLogger(t))) if err != nil { t.Fatal(err) } @@ -50,36 +50,39 @@ func TestMessageSizeExceeded(t *testing.T) { func TestBindMessageParameters(t *testing.T) { t.Parallel() - handler := func(ctx context.Context, query string, writer DataWriter, parameters []string) error { - t.Log("serving query") - - writer.Define(Columns{ //nolint:errcheck - { - Table: 0, - Name: "full_name", - Oid: oid.T_text, - Width: 256, - Format: TextFormat, - }, - { - Table: 0, - Name: "answer_to_life_the_universe_and_everything", - Oid: oid.T_text, - Width: 256, - Format: TextFormat, - }, - }) - - if len(parameters) != 2 { - return fmt.Errorf("unexpected amount of parameters %d, expected 2", len(parameters)) + columns := Columns{ + { + Table: 0, + Name: "full_name", + Oid: oid.T_text, + Width: 256, + Format: TextFormat, + }, + { + Table: 0, + Name: "answer_to_life_the_universe_and_everything", + Oid: oid.T_text, + Width: 256, + Format: TextFormat, + }, + } + + handler := func(ctx context.Context, query string) (PreparedStatementFn, []oid.Oid, Columns, error) { + statement := func(ctx context.Context, writer DataWriter, parameters []string) error { + t.Log("serving query") + + if len(parameters) != 2 { + return fmt.Errorf("unexpected amount of parameters %d, expected 2", len(parameters)) + } + + writer.Row([]any{parameters[0], parameters[1]}) //nolint:errcheck + return writer.Complete("SELECT 1") } - writer.Row([]any{parameters[0], parameters[1]}) //nolint:errcheck - return writer.Complete("OK") + return statement, ParseParameters(query), columns, nil } - d, _ := zap.NewDevelopment() - server, err := NewServer(SimpleQuery(handler), Logger(d)) + server, err := NewServer(handler, Logger(zaptest.NewLogger(t))) if err != nil { t.Fatal(err) } diff --git a/error_test.go b/error_test.go index 4bbb80b..a42ecd6 100644 --- a/error_test.go +++ b/error_test.go @@ -10,15 +10,20 @@ import ( "github.com/jackc/pgx/v5" "github.com/jeroenrinzema/psql-wire/codes" psqlerr "github.com/jeroenrinzema/psql-wire/errors" + "github.com/lib/pq/oid" "github.com/stretchr/testify/assert" + "go.uber.org/zap/zaptest" ) func TestErrorCode(t *testing.T) { - handler := func(ctx context.Context, query string, writer DataWriter, parameters []string) error { - return psqlerr.WithSeverity(psqlerr.WithCode(errors.New("unimplemented feature"), codes.FeatureNotSupported), psqlerr.LevelFatal) + handler := func(ctx context.Context, query string) (PreparedStatementFn, []oid.Oid, Columns, error) { + statement := func(ctx context.Context, writer DataWriter, parameters []string) error { + return psqlerr.WithSeverity(psqlerr.WithCode(errors.New("unimplemented feature"), codes.FeatureNotSupported), psqlerr.LevelFatal) + } + return statement, nil, nil, nil } - server, err := NewServer(SimpleQuery(handler)) + server, err := NewServer(handler, Logger(zaptest.NewLogger(t))) assert.NoError(t, err) address := TListenAndServe(t, server) diff --git a/examples/error/main.go b/examples/error/main.go index 97cf396..c8335db 100644 --- a/examples/error/main.go +++ b/examples/error/main.go @@ -8,18 +8,19 @@ import ( wire "github.com/jeroenrinzema/psql-wire" "github.com/jeroenrinzema/psql-wire/codes" psqlerr "github.com/jeroenrinzema/psql-wire/errors" + "github.com/lib/pq/oid" ) func main() { log.Println("PostgreSQL server is up and running at [127.0.0.1:5432]") - wire.ListenAndServe("127.0.0.1:5432", handle) + wire.ListenAndServe("127.0.0.1:5432", handler) } -func handle(ctx context.Context, query string, writer wire.DataWriter, parameters []string) error { +func handler(ctx context.Context, query string) (wire.PreparedStatementFn, []oid.Oid, wire.Columns, error) { log.Println("incoming SQL query:", query) err := errors.New("unimplemented feature") err = psqlerr.WithCode(err, codes.FeatureNotSupported) err = psqlerr.WithSeverity(err, psqlerr.LevelFatal) - return err + return nil, nil, nil, err } diff --git a/examples/numeric/main.go b/examples/numeric/main.go index ad01b32..33dd344 100644 --- a/examples/numeric/main.go +++ b/examples/numeric/main.go @@ -20,7 +20,7 @@ func main() { }) }) - srv, err := wire.NewServer(types, wire.SimpleQuery(handle)) + srv, err := wire.NewServer(handler, types) if err != nil { panic(err) } @@ -39,15 +39,18 @@ var table = wire.Columns{ }, } -func handle(ctx context.Context, query string, writer wire.DataWriter, parameters []string) error { +func handler(ctx context.Context, query string) (wire.PreparedStatementFn, []oid.Oid, wire.Columns, error) { log.Println("incoming SQL query:", query) - balance, err := decimal.NewFromString("256.23") - if err != nil { - return err + statement := func(ctx context.Context, writer wire.DataWriter, parameters []string) error { + balance, err := decimal.NewFromString("256.23") + if err != nil { + return err + } + + writer.Row([]any{balance}) + return writer.Complete("SELECT 1") } - writer.Define(table) - writer.Row([]any{balance}) - return writer.Complete("OK") + return statement, wire.ParseParameters(query), table, nil } diff --git a/examples/session/main.go b/examples/session/main.go index 11102c8..71f2fd3 100644 --- a/examples/session/main.go +++ b/examples/session/main.go @@ -7,10 +7,11 @@ import ( "sync" wire "github.com/jeroenrinzema/psql-wire" + "github.com/lib/pq/oid" ) func main() { - srv, err := wire.NewServer(wire.Session(session), wire.SimpleQuery(handle)) + srv, err := wire.NewServer(handler, wire.Session(session)) if err != nil { panic(err) } @@ -32,7 +33,13 @@ func session(ctx context.Context) (context.Context, error) { return context.WithValue(ctx, id, counter), nil } -func handle(ctx context.Context, query string, writer wire.DataWriter, parameters []string) error { - session := ctx.Value(id).(int) - return writer.Complete(fmt.Sprintf("OK, session: %d", session)) +func handler(ctx context.Context, query string) (wire.PreparedStatementFn, []oid.Oid, wire.Columns, error) { + log.Println("incoming SQL query:", query) + + statement := func(ctx context.Context, writer wire.DataWriter, parameters []string) error { + session := ctx.Value(id).(int) + return writer.Complete(fmt.Sprintf("OK, session: %d", session)) + } + + return statement, wire.ParseParameters(query), nil, nil } diff --git a/examples/simple/main.go b/examples/simple/main.go index ea8626a..a45a395 100644 --- a/examples/simple/main.go +++ b/examples/simple/main.go @@ -10,7 +10,7 @@ import ( func main() { log.Println("PostgreSQL server is up and running at [127.0.0.1:5432]") - wire.ListenAndServe("127.0.0.1:5432", handle) + wire.ListenAndServe("127.0.0.1:5432", handler) } var table = wire.Columns{ @@ -37,11 +37,14 @@ var table = wire.Columns{ }, } -func handle(ctx context.Context, query string, writer wire.DataWriter, parameters []string) error { +func handler(ctx context.Context, query string) (wire.PreparedStatementFn, []oid.Oid, wire.Columns, error) { log.Println("incoming SQL query:", query) - writer.Define(table) - writer.Row([]any{"John", true, 29}) - writer.Row([]any{"Marry", false, 21}) - return writer.Complete("OK") + statement := func(ctx context.Context, writer wire.DataWriter, parameters []string) error { + writer.Row([]any{"John", true, 29}) + writer.Row([]any{"Marry", false, 21}) + return writer.Complete("SELECT 2") + } + + return statement, wire.ParseParameters(query), table, nil } diff --git a/examples/tls/main.go b/examples/tls/main.go index 74058f4..52479ba 100644 --- a/examples/tls/main.go +++ b/examples/tls/main.go @@ -3,8 +3,10 @@ package main import ( "context" "crypto/tls" + "log" wire "github.com/jeroenrinzema/psql-wire" + "github.com/lib/pq/oid" "go.uber.org/zap" ) @@ -27,7 +29,7 @@ func run() error { } certs := []tls.Certificate{cert} - server, err := wire.NewServer(wire.SimpleQuery(handle), wire.Certificates(certs), wire.Logger(logger), wire.MessageBufferSize(100)) + server, err := wire.NewServer(handler, wire.Certificates(certs), wire.Logger(logger), wire.MessageBufferSize(100)) if err != nil { return err } @@ -36,6 +38,12 @@ func run() error { return server.ListenAndServe("127.0.0.1:5432") } -func handle(ctx context.Context, query string, writer wire.DataWriter, parameters []string) error { - return writer.Complete("OK") +func handler(ctx context.Context, query string) (wire.PreparedStatementFn, []oid.Oid, wire.Columns, error) { + log.Println("incoming SQL query:", query) + + statement := func(ctx context.Context, writer wire.DataWriter, parameters []string) error { + return writer.Complete("OK") + } + + return statement, wire.ParseParameters(query), nil, nil } diff --git a/handshake.go b/handshake.go index 05d4636..20cc687 100644 --- a/handshake.go +++ b/handshake.go @@ -14,7 +14,7 @@ import ( // Handshake performs the connection handshake and returns the connection // version and a buffered reader to read incoming messages send by the client. func (srv *Server) Handshake(conn net.Conn) (_ net.Conn, version types.Version, reader *buffer.Reader, err error) { - reader = buffer.NewReader(conn, srv.BufferedMsgSize) + reader = buffer.NewReader(srv.logger, conn, srv.BufferedMsgSize) version, err = srv.readVersion(reader) if err != nil { return conn, version, reader, err @@ -162,7 +162,7 @@ func (srv *Server) potentialConnUpgrade(conn net.Conn, reader *buffer.Reader, ve // NOTE: initialize the TLS connection and construct a new buffered // reader for the constructed TLS connection. conn = tls.Server(conn, &tlsConfig) - reader = buffer.NewReader(conn, srv.BufferedMsgSize) + reader = buffer.NewReader(srv.logger, conn, srv.BufferedMsgSize) version, err = srv.readVersion(reader) if err != nil { diff --git a/internal/buffer/error_test.go b/internal/buffer/error_test.go index d806593..426c34c 100644 --- a/internal/buffer/error_test.go +++ b/internal/buffer/error_test.go @@ -5,7 +5,7 @@ import ( "testing" ) -func TestMessageSizeExceeded(t *testing.T) { +func TestErrMessageSizeExceeded(t *testing.T) { max := DefaultBufferSize size := max + 1024 diff --git a/internal/buffer/reader.go b/internal/buffer/reader.go index 218f3ab..c8dc134 100644 --- a/internal/buffer/reader.go +++ b/internal/buffer/reader.go @@ -8,6 +8,7 @@ import ( "unsafe" "github.com/jeroenrinzema/psql-wire/internal/types" + "go.uber.org/zap" ) // DefaultBufferSize represents the default buffer size whenever the buffer size @@ -23,6 +24,7 @@ type BufferedReader interface { // Reader provides a convenient way to read pgwire protocol messages type Reader struct { + logger *zap.Logger Buffer BufferedReader Msg []byte MaxMessageSize int @@ -30,7 +32,7 @@ type Reader struct { } // NewReader constructs a new Postgres wire buffer for the given io.Reader -func NewReader(reader io.Reader, bufferSize int) *Reader { +func NewReader(logger *zap.Logger, reader io.Reader, bufferSize int) *Reader { if reader == nil { return nil } @@ -40,6 +42,7 @@ func NewReader(reader io.Reader, bufferSize int) *Reader { } return &Reader{ + logger: logger, Buffer: bufio.NewReaderSize(reader, bufferSize), MaxMessageSize: bufferSize, } @@ -78,6 +81,7 @@ func (reader *Reader) ReadTypedMsg() (types.ClientMessage, int, error) { return 0, 0, err } + reader.logger.Debug("reading typed message", zap.String("type", string(b))) return types.ClientMessage(b), n, nil } diff --git a/internal/buffer/reader_test.go b/internal/buffer/reader_test.go index dc0c589..296a8d3 100644 --- a/internal/buffer/reader_test.go +++ b/internal/buffer/reader_test.go @@ -9,10 +9,11 @@ import ( "testing" "github.com/jeroenrinzema/psql-wire/internal/types" + "go.uber.org/zap" ) func TestNewReaderNil(t *testing.T) { - reader := NewReader(nil, 0) + reader := NewReader(zap.NewNop(), nil, 0) if reader != nil { t.Fatalf("unexpected result, expected reader to be nil %+v", reader) } @@ -31,7 +32,7 @@ func TestReadTypedMsg(t *testing.T) { buffer.Write(size) buffer.Write(_text) - reader := NewReader(buffer, DefaultBufferSize) + reader := NewReader(zap.NewNop(), buffer, DefaultBufferSize) ty, ln, err := reader.ReadTypedMsg() if err != nil { @@ -57,7 +58,7 @@ func TestReadUntypedMsg(t *testing.T) { buffer.Write(size) buffer.Write(_text) - reader := NewReader(buffer, DefaultBufferSize) + reader := NewReader(zap.NewNop(), buffer, DefaultBufferSize) ln, err := reader.ReadUntypedMsg() if err != nil { @@ -89,7 +90,7 @@ func TestReadUntypedMsgParameters(t *testing.T) { buffer := msg.Bytes() binary.BigEndian.PutUint32(buffer, uint32(msg.Len())) - reader := NewReader(bytes.NewReader(buffer), DefaultBufferSize) + reader := NewReader(zap.NewNop(), bytes.NewReader(buffer), DefaultBufferSize) ln, err := reader.ReadUntypedMsg() if err != nil { t.Fatal(err) diff --git a/internal/buffer/writer.go b/internal/buffer/writer.go index 025693e..d5057cb 100644 --- a/internal/buffer/writer.go +++ b/internal/buffer/writer.go @@ -6,19 +6,22 @@ import ( "io" "github.com/jeroenrinzema/psql-wire/internal/types" + "go.uber.org/zap" ) // Writer provides a convenient way to write pgwire protocol messages type Writer struct { io.Writer + logger *zap.Logger frame bytes.Buffer putbuf [64]byte // buffer used to construct messages which could be written to the writer frame buffer err error } // NewWriter constructs a new Postgres buffered message writer for the given io.Writer -func NewWriter(writer io.Writer) *Writer { +func NewWriter(logger *zap.Logger, writer io.Writer) *Writer { return &Writer{ + logger: logger, Writer: writer, } } @@ -131,6 +134,8 @@ func (writer *Writer) End() error { length := uint32(writer.frame.Len() - 1) // total message length minus the message type byte binary.BigEndian.PutUint32(bytes[1:5], length) _, err := writer.Writer.Write(bytes) + + writer.logger.Debug("writing message", zap.String("type", string(bytes[0]))) return err } diff --git a/internal/buffer/writer_test.go b/internal/buffer/writer_test.go index 73ac7b8..3adca9f 100644 --- a/internal/buffer/writer_test.go +++ b/internal/buffer/writer_test.go @@ -7,15 +7,16 @@ import ( "testing" "github.com/jeroenrinzema/psql-wire/internal/types" + "go.uber.org/zap" ) func TestNewWriterNil(t *testing.T) { - NewWriter(nil) + NewWriter(zap.NewNop(), nil) } func TestWriteMsg(t *testing.T) { buffer := bytes.NewBuffer([]byte{}) - writer := NewWriter(buffer) + writer := NewWriter(zap.NewNop(), buffer) writer.Start(types.ServerDataRow) writer.AddString("John Doe") @@ -38,7 +39,7 @@ func TestWriteMsgErr(t *testing.T) { expected := errors.New("unexpected error") buffer := bytes.NewBuffer([]byte{}) - writer := NewWriter(buffer) + writer := NewWriter(zap.NewNop(), buffer) writer.Start(types.ServerDataRow) writer.err = expected @@ -61,7 +62,7 @@ func TestWriteMsgErr(t *testing.T) { func TestWriteTypes(t *testing.T) { buffer := bytes.NewBuffer([]byte{}) - writer := NewWriter(buffer) + writer := NewWriter(zap.NewNop(), buffer) t.Run("byte", func(t *testing.T) { writer.AddByte(byte(types.ServerAuth)) @@ -104,7 +105,7 @@ func TestWriteTypesErr(t *testing.T) { expected := errors.New("unexpected error") buffer := bytes.NewBuffer([]byte{}) - writer := NewWriter(buffer) + writer := NewWriter(zap.NewNop(), buffer) writer.err = expected t.Run("byte", func(t *testing.T) { diff --git a/internal/mock/buffer.go b/internal/mock/buffer.go index f4b4f98..7572572 100644 --- a/internal/mock/buffer.go +++ b/internal/mock/buffer.go @@ -5,11 +5,12 @@ import ( "github.com/jeroenrinzema/psql-wire/internal/buffer" "github.com/jeroenrinzema/psql-wire/internal/types" + "go.uber.org/zap" ) // NewWriter constructs a new PostgreSQL wire protocol writer. func NewWriter(writer io.Writer) *Writer { - return &Writer{buffer.NewWriter(writer)} + return &Writer{buffer.NewWriter(zap.NewNop(), writer)} } // Writer represents a low level PostgreSQL client writer allowing a user to @@ -29,7 +30,7 @@ func (buffer *Writer) Start(t types.ClientMessage) { // NewReader constructs a new PostgreSQL wire protocol reader using the default // buffer size. func NewReader(reader io.Reader) *Reader { - return &Reader{buffer.NewReader(reader, buffer.DefaultBufferSize)} + return &Reader{buffer.NewReader(zap.NewNop(), reader, buffer.DefaultBufferSize)} } // Reader represents a low level PostgreSQL client reader allowing a user to diff --git a/internal/mock/client.go b/internal/mock/client.go index 7e73c0b..525a53d 100644 --- a/internal/mock/client.go +++ b/internal/mock/client.go @@ -134,7 +134,7 @@ func (client *Client) Close(t *testing.T) { t.Log("closing the client!") defer t.Log("client closed") - client.Start(types.ClientClose) + client.Start(types.ClientTerminate) err := client.End() if err != nil { t.Fatal(err) diff --git a/internal/types/types.go b/internal/types/types.go index b5bedaf..c1b17ba 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -1,9 +1,9 @@ package types -//ClientMessage represents a client pgwire message. +// ClientMessage represents a client pgwire message. type ClientMessage byte -//ServerMessage represents a server pgwire message. +// ServerMessage represents a server pgwire message. type ServerMessage byte // http://www.postgresql.org/docs/9.4/static/protocol-message-formats.html diff --git a/options.go b/options.go index 2eecea4..9dba241 100644 --- a/options.go +++ b/options.go @@ -4,30 +4,18 @@ import ( "context" "crypto/tls" "crypto/x509" - "errors" "regexp" "strconv" "github.com/jackc/pgtype" + "github.com/jeroenrinzema/psql-wire/internal/buffer" "github.com/lib/pq/oid" "go.uber.org/zap" ) -// QueryParameters represents a regex which could be used to identify and lookup -// parameters defined inside a given query. Parameters could be defined as -// positional parameters and un-positional parameters. -// https://www.postgresql.org/docs/8.1/sql-syntax.html#:~:text=A%20dollar%20sign%20(%24)%20followed,a%20dollar%2Dquoted%20string%20constant. -var QueryParameters = regexp.MustCompile(`\$(\d+)|\?`) - -// SimpleQueryFn represents a callback function called whenever an incoming -// query is executed. The passed writer should be handled with caution as it is -// not safe for concurrent use. Concurrent access to the same data without -// proper synchronization can result in unexpected behavior and data corruption. -type SimpleQueryFn func(ctx context.Context, query string, writer DataWriter, parameters []string) error - // ParseFn parses the given query and returns a prepared statement which could // be used to execute at a later point in time. -type ParseFn func(ctx context.Context, query string) (PreparedStatementFn, []oid.Oid, error) +type ParseFn func(ctx context.Context, query string) (PreparedStatementFn, []oid.Oid, Columns, error) // PreparedStatementFn represents a query of which a statement has been // prepared. The statement could be executed at any point in time with the given @@ -44,17 +32,18 @@ type SessionHandler func(ctx context.Context) (context.Context, error) type StatementCache interface { // Set attempts to bind the given statement to the given name. Any // previously defined statement is overridden. - Set(ctx context.Context, name string, fn PreparedStatementFn) error + Set(ctx context.Context, name string, fn PreparedStatementFn, params []oid.Oid, columns Columns) error // Get attempts to get the prepared statement for the given name. An error // is returned when no statement has been found. - Get(ctx context.Context, name string) (PreparedStatementFn, error) + Get(ctx context.Context, name string) (*Statement, error) } // PortalCache represents a cache which could be used to bind and execute // prepared statements with parameters. type PortalCache interface { - Bind(ctx context.Context, name string, statement PreparedStatementFn, parameters []string) error - Execute(ctx context.Context, name string, writer DataWriter) error + Bind(ctx context.Context, name string, statement *Statement, parameters []string) error + Get(ctx context.Context, name string) (*Statement, error) + Execute(ctx context.Context, name string, writer *buffer.Writer) error } type CloseFn func(ctx context.Context) error @@ -63,60 +52,6 @@ type CloseFn func(ctx context.Context) error // PostgreSQL server. type OptionFn func(*Server) error -// SimpleQuery sets the simple query handle inside the given server instance. -func SimpleQuery(fn SimpleQueryFn) OptionFn { - return func(srv *Server) error { - if srv.Parse != nil { - return errors.New("simple query handler could not set if a query parser is set") - } - - srv.Parse = func(ctx context.Context, query string) (PreparedStatementFn, []oid.Oid, error) { - statement := func(ctx context.Context, writer DataWriter, parameters []string) error { - return fn(ctx, query, writer, parameters) - } - - // NOTE: we have to lookup all parameters within the given query. - // Parameters could represent positional parameters or anonymous - // parameters. We return a zero parameter oid for each parameter - // indicating that the given parameters could contain any type. We - // could safely ignore the err check while converting given - // parameters since ony matches are returned by the positional - // parameter regex. - matches := QueryParameters.FindAllStringSubmatch(query, -1) - parameters := make([]oid.Oid, 0, len(matches)) - for _, match := range matches { - // NOTE: we have to check whether the returned match is a - // positional parameter or an un-positional parameter. - // SELECT * FROM users WHERE id = ? - if match[1] == "" { - parameters = append(parameters, 0) - } - - position, _ := strconv.Atoi(match[1]) //nolint:errcheck - if position > len(parameters) { - parameters = parameters[:position] - } - } - - return statement, parameters, nil - } - - return nil - } -} - -// Parse sets the given parse function used to parse queries into prepared statements. -func Parse(fn ParseFn) OptionFn { - return func(srv *Server) error { - if srv.Parse != nil { - return errors.New("parser could not set if a simple query handler is set") - } - - srv.Parse = fn - return nil - } -} - // Statements sets the statement cache used to cache statements for later use. By // default is the DefaultStatementCache used to cache prepared statements. func Statements(cache StatementCache) OptionFn { @@ -258,3 +193,39 @@ func Session(fn SessionHandler) OptionFn { return nil } } + +// QueryParameters represents a regex which could be used to identify and lookup +// parameters defined inside a given query. Parameters could be defined as +// positional parameters and un-positional parameters. +// https://www.postgresql.org/docs/15/sql-expressions.html#SQL-EXPRESSIONS-PARAMETERS-POSITIONAL +var QueryParameters = regexp.MustCompile(`\$(\d+)|\?`) + +// ParseParameters attempts ot parse the parameters in the given string and +// returns the expected parameters. This is necessary for the query protocol +// where the parameter types are expected to be defined in the extended query protocol. +func ParseParameters(query string) []oid.Oid { + // NOTE: we have to lookup all parameters within the given query. + // Parameters could represent positional parameters or anonymous + // parameters. We return a zero parameter oid for each parameter + // indicating that the given parameters could contain any type. We + // could safely ignore the err check while converting given + // parameters since ony matches are returned by the positional + // parameter regex. + matches := QueryParameters.FindAllStringSubmatch(query, -1) + parameters := make([]oid.Oid, 0, len(matches)) + for _, match := range matches { + // NOTE: we have to check whether the returned match is a + // positional parameter or an un-positional parameter. + // SELECT * FROM users WHERE id = ? + if match[1] == "" { + parameters = append(parameters, 0) + } + + position, _ := strconv.Atoi(match[1]) //nolint:errcheck + if position > len(parameters) { + parameters = parameters[:position] + } + } + + return parameters +} diff --git a/options_test.go b/options_test.go index 311276e..128e228 100644 --- a/options_test.go +++ b/options_test.go @@ -2,37 +2,14 @@ package wire import ( "context" - "strconv" "testing" "github.com/lib/pq/oid" "github.com/stretchr/testify/assert" + "go.uber.org/zap/zaptest" ) -func TestInvalidOptions(t *testing.T) { - tests := [][]OptionFn{ - { - Parse(func(context.Context, string) (PreparedStatementFn, []oid.Oid, error) { return nil, nil, nil }), - SimpleQuery(func(context.Context, string, DataWriter, []string) error { return nil }), - }, - } - - for index, test := range tests { - t.Run(strconv.Itoa(index), func(t *testing.T) { - srv := &Server{} - for _, option := range test { - err := option(srv) - if err != nil { - return - } - } - - t.Error("unexpected pass") - }) - } -} - -func TestSimpleQueryParameters(t *testing.T) { +func TestParseParameters(t *testing.T) { type test struct { query string parameters []oid.Oid @@ -51,22 +28,14 @@ func TestSimpleQueryParameters(t *testing.T) { for name, test := range tests { t.Run(name, func(t *testing.T) { - option := SimpleQuery(nil) - - srv := &Server{} - err := option(srv) - assert.NoError(t, err) - - statement, parameters, err := srv.Parse(context.Background(), test.query) - assert.NoError(t, err) - assert.NotNil(t, statement) + parameters := ParseParameters(test.query) assert.Equal(t, test.parameters, parameters) }) } } func TestNilSessionHandler(t *testing.T) { - srv, err := NewServer() + srv, err := NewServer(nil, Logger(zaptest.NewLogger(t))) assert.NoError(t, err) assert.NotNil(t, srv) @@ -103,7 +72,8 @@ func TestSessionHandler(t *testing.T) { for name, test := range tests { t.Run(name, func(t *testing.T) { - srv, err := NewServer(test...) + test = append(test, Logger(zaptest.NewLogger(t))) + srv, err := NewServer(nil, test...) assert.NoError(t, err) assert.NotNil(t, srv) diff --git a/row.go b/row.go index 5a60841..f43217d 100644 --- a/row.go +++ b/row.go @@ -56,6 +56,8 @@ func (columns Columns) Write(ctx context.Context, writer *buffer.Writer, srcs [] // https://www.postgresql.org/docs/8.3/catalog-pg-attribute.html type Column struct { Table int32 // table id + ID int32 // column identifier + Attr int16 // column attribute number Name string // column name AttrNo int16 // column attribute no (optional) Oid oid.Oid diff --git a/wire.go b/wire.go index 44367d2..678797e 100644 --- a/wire.go +++ b/wire.go @@ -18,8 +18,8 @@ import ( // default configurations. The given handler function is used to handle simple // queries. This method should be used to construct a simple Postgres server for // testing purposes or simple use cases. -func ListenAndServe(address string, handler SimpleQueryFn) error { - server, err := NewServer(SimpleQuery(handler)) +func ListenAndServe(address string, handler ParseFn) error { + server, err := NewServer(handler) if err != nil { return err } @@ -28,8 +28,9 @@ func ListenAndServe(address string, handler SimpleQueryFn) error { } // NewServer constructs a new Postgres server using the given address and server options. -func NewServer(options ...OptionFn) (*Server, error) { +func NewServer(parse ParseFn, options ...OptionFn) (*Server, error) { srv := &Server{ + parse: parse, logger: zap.NewNop(), closer: make(chan struct{}), types: pgtype.NewConnInfo(), @@ -59,7 +60,7 @@ type Server struct { Certificates []tls.Certificate ClientCAs *x509.CertPool ClientAuth tls.ClientAuthType - Parse ParseFn + parse ParseFn Session SessionHandler Statements StatementCache Portals PortalCache @@ -138,7 +139,7 @@ func (srv *Server) serve(ctx context.Context, conn net.Conn) error { srv.logger.Debug("handshake successfull, validating authentication") - writer := buffer.NewWriter(conn) + writer := buffer.NewWriter(srv.logger, conn) ctx, err = srv.readClientParameters(ctx, reader) if err != nil { return err diff --git a/wire_test.go b/wire_test.go index 69523fd..c19b6a8 100644 --- a/wire_test.go +++ b/wire_test.go @@ -4,16 +4,16 @@ import ( "context" "database/sql" "fmt" + "net" + "testing" + "github.com/jackc/pgx/v5" _ "github.com/jackc/pgx/v5/stdlib" "github.com/jeroenrinzema/psql-wire/internal/mock" _ "github.com/lib/pq" "github.com/lib/pq/oid" "github.com/stretchr/testify/require" - "go.uber.org/zap" "go.uber.org/zap/zaptest" - "net" - "testing" ) // TListenAndServe will open a new TCP listener on a unallocated port inside @@ -40,11 +40,16 @@ func TListenAndServe(t *testing.T, server *Server) *net.TCPAddr { func TestClientConnect(t *testing.T) { t.Parallel() - pong := func(ctx context.Context, query string, writer DataWriter, parameters []string) error { - return writer.Complete("OK") + handler := func(ctx context.Context, query string) (PreparedStatementFn, []oid.Oid, Columns, error) { + statement := func(ctx context.Context, writer DataWriter, parameters []string) error { + t.Log("serving query") + return writer.Complete("OK") + } + + return statement, nil, nil, nil } - server, err := NewServer(SimpleQuery(pong)) + server, err := NewServer(handler, Logger(zaptest.NewLogger(t))) if err != nil { t.Fatal(err) } @@ -102,13 +107,94 @@ func TestClientConnect(t *testing.T) { }) } +func TestClientParameters(t *testing.T) { + t.Parallel() + + handler := func(ctx context.Context, query string) (PreparedStatementFn, []oid.Oid, Columns, error) { + statement := func(ctx context.Context, writer DataWriter, parameters []string) error { + writer.Row([]any{"John Doe"}) //nolint:errcheck + return writer.Complete("SELECT 1") + } + + parameters := ParseParameters(query) + columns := Columns{ + { + Table: 0, + Name: "full_name", + Oid: oid.T_text, + Width: 256, + Format: TextFormat, + }, + } + + return statement, parameters, columns, nil + } + + server, err := NewServer(handler, Logger(zaptest.NewLogger(t))) + if err != nil { + t.Fatal(err) + } + + address := TListenAndServe(t, server) + + t.Run("lib/pq", func(t *testing.T) { + connstr := fmt.Sprintf("host=%s port=%d sslmode=disable", address.IP, address.Port) + conn, err := sql.Open("postgres", connstr) + if err != nil { + t.Fatal(err) + } + + rows, err := conn.Query("SELECT * FROM users WHERE age > ?", 50) + if err != nil { + t.Fatal(err) + } + + err = rows.Close() + if err != nil { + t.Fatal(err) + } + + err = conn.Close() + if err != nil { + t.Fatal(err) + } + }) + + t.Run("jackc/pgx", func(t *testing.T) { + ctx := context.Background() + connstr := fmt.Sprintf("postgres://%s:%d", address.IP, address.Port) + conn, err := pgx.Connect(ctx, connstr) + if err != nil { + t.Fatal(err) + } + + rows, err := conn.Query(ctx, "SELECT * FROM users WHERE age > ?", 50) + if err != nil { + t.Fatal(err) + } + + rows.Close() + + err = conn.Close(ctx) + if err != nil { + t.Fatal(err) + } + }) +} + func TestServerWritingResult(t *testing.T) { t.Parallel() - handler := func(ctx context.Context, query string, writer DataWriter, parameters []string) error { - t.Log("serving query") + handler := func(ctx context.Context, query string) (PreparedStatementFn, []oid.Oid, Columns, error) { + statement := func(ctx context.Context, writer DataWriter, parameters []string) error { + t.Log("serving query") + writer.Row([]any{"John", true, 28}) //nolint:errcheck + writer.Row([]any{"Marry", false, 21}) //nolint:errcheck + return writer.Complete("SELECT 2") + } - writer.Define(Columns{ //nolint:errcheck + parameters := ParseParameters(query) + columns := Columns{ //nolint:errcheck { Table: 0, Name: "name", @@ -130,15 +216,12 @@ func TestServerWritingResult(t *testing.T) { Width: 1, Format: TextFormat, }, - }) + } - writer.Row([]any{"John", true, 28}) //nolint:errcheck - writer.Row([]any{"Marry", false, 21}) //nolint:errcheck - return writer.Complete("OK") + return statement, parameters, columns, nil } - d, _ := zap.NewDevelopment() - server, err := NewServer(SimpleQuery(handler), Logger(d)) + server, err := NewServer(handler, Logger(zaptest.NewLogger(t))) if err != nil { t.Fatal(err) } @@ -256,9 +339,15 @@ func TestServerHandlingMultipleConnections(t *testing.T) { func TOpenMockServer(t *testing.T) *net.TCPAddr { t.Helper() - handler := func(ctx context.Context, query string, writer DataWriter, parameters []string) error { - t.Log("serving query") - writer.Define(Columns{ //nolint:errcheck + handler := func(ctx context.Context, query string) (PreparedStatementFn, []oid.Oid, Columns, error) { + statement := func(ctx context.Context, writer DataWriter, parameters []string) error { + t.Log("serving query") + writer.Row([]any{20}) //nolint:errcheck + return writer.Complete("SELECT 1") + } + + parameters := ParseParameters(query) + columns := Columns{ { Table: 0, Name: "age", @@ -266,11 +355,12 @@ func TOpenMockServer(t *testing.T) *net.TCPAddr { Width: 1, Format: TextFormat, }, - }) - writer.Row([]any{20}) //nolint:errcheck - return writer.Complete("OK") + } + + return statement, parameters, columns, nil } - server, err := NewServer(SimpleQuery(handler), Logger(zaptest.NewLogger(t))) + + server, err := NewServer(handler, Logger(zaptest.NewLogger(t))) require.NoError(t, err) address := TListenAndServe(t, server) return address @@ -285,10 +375,16 @@ func TestServerNULLValues(t *testing.T) { nil, } - handler := func(ctx context.Context, query string, writer DataWriter, parameters []string) error { - t.Log("serving query") + handler := func(ctx context.Context, query string) (PreparedStatementFn, []oid.Oid, Columns, error) { + statement := func(ctx context.Context, writer DataWriter, parameters []string) error { + t.Log("serving query") + writer.Row([]any{"John"}) //nolint:errcheck + writer.Row([]any{nil}) //nolint:errcheck + return writer.Complete("SELECT 2") + } - writer.Define(Columns{ //nolint:errcheck + parameters := ParseParameters(query) + columns := Columns{ { Table: 0, Name: "name", @@ -296,14 +392,12 @@ func TestServerNULLValues(t *testing.T) { Width: 256, Format: TextFormat, }, - }) + } - writer.Row([]any{"John"}) //nolint:errcheck - writer.Row([]any{nil}) //nolint:errcheck - return writer.Complete("OK") + return statement, parameters, columns, nil } - server, err := NewServer(SimpleQuery(handler)) + server, err := NewServer(handler, Logger(zaptest.NewLogger(t))) if err != nil { t.Fatal(err) } diff --git a/writer.go b/writer.go index f6e2ef2..a5084e4 100644 --- a/writer.go +++ b/writer.go @@ -11,11 +11,6 @@ import ( // DataWriter represents a writer interface for writing columns and data rows // using the Postgres wire to the connected client. type DataWriter interface { - // Define writes the column headers containing their type definitions, width - // type oid, etc. to the underlaying Postgres client. The column headers - // could only be written once. An error will be returned whenever this - // method is called twice. - Define(Columns) error // Row writes a single data row containing the values inside the given slice to // the underlaying Postgres client. The column headers have to be written before // sending rows. Each item inside the slice represents a single column value. @@ -35,10 +30,6 @@ type DataWriter interface { Complete(description string) error } -// ErrUndefinedColumns is thrown when the columns inside the data writer have not -// yet been defined. -var ErrUndefinedColumns = errors.New("columns have not been defined") - // ErrDataWritten is thrown when an empty result is attempted to be send to the // client while data has already been written. var ErrDataWritten = errors.New("data has already been written") @@ -50,10 +41,11 @@ var ErrClosedWriter = errors.New("closed writer") // buffer. The returned writer should be handled with caution as it is not safe // for concurrent use. Concurrent access to the same data without proper // synchronization can result in unexpected behavior and data corruption. -func NewDataWriter(ctx context.Context, writer *buffer.Writer) DataWriter { +func NewDataWriter(ctx context.Context, columns Columns, writer *buffer.Writer) DataWriter { return &dataWriter{ - ctx: ctx, - client: writer, + ctx: ctx, + columns: columns, + client: writer, } } @@ -80,10 +72,6 @@ func (writer *dataWriter) Row(values []any) error { return ErrClosedWriter } - if writer.columns == nil { - return ErrUndefinedColumns - } - writer.written++ return writer.columns.Write(writer.ctx, writer.client, values) @@ -94,10 +82,6 @@ func (writer *dataWriter) Empty() error { return ErrClosedWriter } - if writer.columns == nil { - return ErrUndefinedColumns - } - if writer.written != 0 { return ErrDataWritten }