Skip to content

Commit

Permalink
feat: improved the handling of column format codes
Browse files Browse the repository at this point in the history
  • Loading branch information
jeroenrinzema committed Nov 22, 2023
1 parent 133195e commit 963c984
Show file tree
Hide file tree
Showing 9 changed files with 138 additions and 116 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ lint: | $(GOLANGCI_LINT) ; $(info $(M) running golint…) @ ## Run the project l

.PHONY: test
test: ## Run all tests
$Q $(GO) test ./...
$Q $(GO) test ./... -timeout 20s

.PHONY: fmt
fmt: ; $(info $(M) running gofmt…) @ ## Run gofmt on all source files
Expand Down
18 changes: 10 additions & 8 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,33 +56,35 @@ func (cache *DefaultStatementCache) Get(ctx context.Context, name string) (*Stat
return stmt, nil
}

type portal struct {
type Portal struct {
statement *Statement
parameters []Parameter
formats []FormatCode
}

type DefaultPortalCache struct {
portals map[string]portal
portals map[string]*Portal
mu sync.RWMutex
}

func (cache *DefaultPortalCache) Bind(ctx context.Context, name string, stmt *Statement, parameters []Parameter) error {
func (cache *DefaultPortalCache) Bind(ctx context.Context, name string, stmt *Statement, parameters []Parameter, formats []FormatCode) error {
cache.mu.Lock()
defer cache.mu.Unlock()

if cache.portals == nil {
cache.portals = map[string]portal{}
cache.portals = map[string]*Portal{}
}

cache.portals[name] = portal{
cache.portals[name] = &Portal{
statement: stmt,
parameters: parameters,
formats: formats,
}

return nil
}

func (cache *DefaultPortalCache) Get(ctx context.Context, name string) (*Statement, error) {
func (cache *DefaultPortalCache) Get(ctx context.Context, name string) (*Portal, error) {
cache.mu.Lock()
defer cache.mu.Unlock()

Expand All @@ -95,7 +97,7 @@ func (cache *DefaultPortalCache) Get(ctx context.Context, name string) (*Stateme
return nil, nil
}

return portal.statement, nil
return portal, nil
}

func (cache *DefaultPortalCache) Execute(ctx context.Context, name string, writer *buffer.Writer) error {
Expand All @@ -111,5 +113,5 @@ func (cache *DefaultPortalCache) Execute(ctx context.Context, name string, write
return nil
}

return portal.statement.fn(ctx, NewDataWriter(ctx, portal.statement.columns, writer), portal.parameters)
return portal.statement.fn(ctx, NewDataWriter(ctx, portal.statement.columns, portal.formats, writer), portal.parameters)
}
67 changes: 37 additions & 30 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,6 @@ func (srv *Server) handleCommand(ctx context.Context, conn net.Conn, t types.Cli
case types.ClientParse:
return srv.handleParse(ctx, reader, writer)
case types.ClientDescribe:
// TODO: Server should return the column types that will be
// returned for the given portal or statement.
//
// The Describe message (portal variant) specifies the name of an
// existing portal (or an empty string for the unnamed portal). The
// response is a RowDescription message describing the rows that will be
Expand Down Expand Up @@ -250,12 +247,12 @@ func (srv *Server) handleSimpleQuery(ctx context.Context, reader *buffer.Reader,
}

// NOTE: we have to define the column definitions before executing a simple query
err = statement.columns.Define(ctx, writer)
err = statement.columns.Define(ctx, writer, nil)
if err != nil {
return ErrorCode(writer, err)
}

err = statement.fn(ctx, NewDataWriter(ctx, statement.columns, writer), nil)
err = statement.fn(ctx, NewDataWriter(ctx, statement.columns, nil, writer), nil)
if err != nil {
return ErrorCode(writer, err)
}
Expand Down Expand Up @@ -324,11 +321,9 @@ func (srv *Server) handleDescribe(ctx context.Context, reader *buffer.Reader, wr
return err
}

var statement *Statement

switch d[0] {
case 'S':
statement, err = srv.Statements.Get(ctx, name)
statement, err := srv.Statements.Get(ctx, name)
if err != nil {
return err
}
Expand All @@ -341,18 +336,28 @@ func (srv *Server) handleDescribe(ctx context.Context, reader *buffer.Reader, wr
if err != nil {
return err
}

// NOTE: the format codes are not yet known at this point in time.
return srv.writeColumnDescription(ctx, writer, nil, statement.columns)
case 'P':
statement, err = srv.Portals.Get(ctx, name)
portal, err := srv.Portals.Get(ctx, name)
if err != nil {
return err
}
}

if statement == nil {
return ErrorCode(writer, errors.New("unknown statement"))
if portal == nil {
return ErrorCode(writer, errors.New("unknown portal"))
}

err = srv.writeParameterDescription(writer, portal.statement.parameters)
if err != nil {
return err
}

return srv.writeColumnDescription(ctx, writer, portal.formats, portal.statement.columns)
}

return srv.writeColumnDescription(ctx, writer, statement.columns)
return ErrorCode(writer, fmt.Errorf("unknown describe command: %s", string(d[0])))
}

// https://www.postgresql.org/docs/15/protocol-message-formats.html
Expand All @@ -371,13 +376,13 @@ func (srv *Server) writeParameterDescription(writer *buffer.Writer, parameters [
// back to the writer buffer. Information about the returned columns is written
// to the client.
// https://www.postgresql.org/docs/15/protocol-message-formats.html
func (srv *Server) writeColumnDescription(ctx context.Context, writer *buffer.Writer, columns Columns) error {
func (srv *Server) writeColumnDescription(ctx context.Context, writer *buffer.Writer, formats []FormatCode, columns Columns) error {
if len(columns) == 0 {
writer.Start(types.ServerNoData)
return writer.End()
}

return columns.Define(ctx, writer)
return columns.Define(ctx, writer, formats)
}

func (srv *Server) handleBind(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error {
Expand All @@ -396,6 +401,11 @@ func (srv *Server) handleBind(ctx context.Context, reader *buffer.Reader, writer
return err
}

formats, err := srv.readColumnTypes(ctx, reader)
if err != nil {
return err
}

stmt, err := srv.Statements.Get(ctx, statement)
if err != nil {
return err
Expand All @@ -405,7 +415,7 @@ func (srv *Server) handleBind(ctx context.Context, reader *buffer.Reader, writer
return NewErrUnkownStatement(statement)
}

err = srv.Portals.Bind(ctx, name, stmt, parameters)
err = srv.Portals.Bind(ctx, name, stmt, parameters, formats)
if err != nil {
return err
}
Expand All @@ -418,7 +428,6 @@ func (srv *Server) handleBind(ctx context.Context, reader *buffer.Reader, writer
// reader. The parameters are parsed and returned.
// https://www.postgresql.org/docs/14/protocol-message-formats.html
func (srv *Server) readParameters(ctx context.Context, reader *buffer.Reader) ([]Parameter, error) {
// [0 1, 0 0 0 1 0 0 0 3 98 111 120 0 2 0 1 0 0]
// NOTE: read the total amount of parameter format length that will be send
// by the client. This can be zero to indicate that there are no parameters
// or that the parameters all use the default format (text); or one, in
Expand Down Expand Up @@ -480,30 +489,28 @@ func (srv *Server) readParameters(ctx context.Context, reader *buffer.Reader) ([
parameters[i] = NewParameter(TypeMap(ctx), format, value)
}

// NOTE: Read the total amount of result-column format that will be
// send by the client.
length, err = reader.GetUint16()
return parameters, nil
}

func (srv *Server) readColumnTypes(ctx context.Context, reader *buffer.Reader) ([]FormatCode, error) {
length, err := reader.GetUint16()
if err != nil {
return nil, err
}

srv.logger.Debug("reading result-column format codes", slog.Uint64("length", uint64(length)))
srv.logger.Debug("reading column format codes", slog.Uint64("length", uint64(length)))

columns := make([]FormatCode, length)
for i := uint16(0); i < length; i++ {
// TODO: Handle incoming result-column format codes
//
// Incoming format codes are currently ignored and should be handled in
// the future. The result-column format codes. Each must presently be
// zero (text) or one (binary). These format codes should be returned
// and handled by the parent function to return the proper column formats.
// https://www.postgresql.org/docs/current/protocol-message-formats.html
_, err := reader.GetUint16()
format, err := reader.GetUint16()
if err != nil {
return nil, err
}

columns[i] = FormatCode(format)
}

return parameters, nil
return columns, nil
}

func (srv *Server) handleExecute(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error {
Expand Down
18 changes: 8 additions & 10 deletions command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,16 @@ func TestBindMessageParameters(t *testing.T) {

columns := Columns{
{
Table: 0,
Name: "full_name",
Oid: oid.T_text,
Width: 256,
Format: TextFormat,
Table: 0,
Name: "full_name",
Oid: oid.T_text,
Width: 256,
},
{
Table: 0,
Name: "answer_to_life_the_universe_and_everything",
Oid: oid.T_text,
Width: 256,
Format: TextFormat,
Table: 0,
Name: "answer_to_life_the_universe_and_everything",
Oid: oid.T_text,
Width: 256,
},
}

Expand Down
27 changes: 12 additions & 15 deletions examples/simple/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,22 @@ func main() {

var table = wire.Columns{
{
Table: 0,
Name: "name",
Oid: oid.T_text,
Width: 256,
Format: wire.TextFormat,
Table: 0,
Name: "name",
Oid: oid.T_text,
Width: 256,
},
{
Table: 0,
Name: "member",
Oid: oid.T_bool,
Width: 1,
Format: wire.TextFormat,
Table: 0,
Name: "member",
Oid: oid.T_bool,
Width: 1,
},
{
Table: 0,
Name: "age",
Oid: oid.T_int4,
Width: 1,
Format: wire.TextFormat,
Table: 0,
Name: "age",
Oid: oid.T_int4,
Width: 1,
},
}

Expand Down
4 changes: 2 additions & 2 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ type StatementCache interface {
// PortalCache represents a cache which could be used to bind and execute
// prepared statements with parameters.
type PortalCache interface {
Bind(ctx context.Context, name string, statement *Statement, parameters []Parameter) error
Get(ctx context.Context, name string) (*Statement, error)
Bind(ctx context.Context, name string, statement *Statement, parameters []Parameter, columns []FormatCode) error
Get(ctx context.Context, name string) (*Portal, error)
Execute(ctx context.Context, name string, writer *buffer.Writer) error
}

Expand Down
Loading

0 comments on commit 963c984

Please sign in to comment.