Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for COPY IN protocol #72

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package wire
import (
"context"
"fmt"
"io"
"sync"

"github.com/jeroenrinzema/psql-wire/pkg/buffer"
Expand Down Expand Up @@ -102,6 +103,10 @@ func (cache *DefaultPortalCache) Get(ctx context.Context, name string) (*Portal,
}

func (cache *DefaultPortalCache) Execute(ctx context.Context, name string, writer *buffer.Writer) (err error) {
return cache.ExecuteCopyIn(ctx, name, writer, nil)
}

func (cache *DefaultPortalCache) ExecuteCopyIn(ctx context.Context, name string, writer *buffer.Writer, copyData io.Reader) (err error) {
defer func() {
r := recover()
if r != nil {
Expand All @@ -121,5 +126,5 @@ func (cache *DefaultPortalCache) Execute(ctx context.Context, name string, write
return nil
}

return portal.statement.fn(ctx, NewDataWriter(ctx, portal.statement.columns, portal.formats, writer), portal.parameters)
return portal.statement.fn(ctx, NewDataWriter(ctx, portal.statement.columns, portal.formats, writer, copyData), portal.parameters)
}
332 changes: 207 additions & 125 deletions command.go

Large diffs are not rendered by default.

11 changes: 11 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"io"
"log/slog"
"regexp"
"strconv"
Expand Down Expand Up @@ -90,12 +91,22 @@ type StatementCache interface {

// PortalCache represents a cache which could be used to bind and execute
// prepared statements with parameters.
//
// Deprecated: Use [PortalCacheCopyIn] instead.
type PortalCache interface {
Bind(ctx context.Context, name string, statement *Statement, parameters []Parameter, columns []FormatCode) error
Get(ctx context.Context, name string) (*Portal, error)
Execute(ctx context.Context, name string, writer *buffer.Writer) error
}

// PortalCacheCopyIn extends [PortalCache] to support the COPY IN protocol.
type PortalCacheCopyIn interface {
PortalCache
// ExecuteCopyIn executes the named cached statement. copyData is ignored and
// may be nil for queries that do not use the COPY IN protocol.
ExecuteCopyIn(ctx context.Context, name string, writer *buffer.Writer, copyData io.Reader) error
}

type CloseFn func(ctx context.Context) error

// OptionFn options pattern used to define and set options for the given
Expand Down
2 changes: 1 addition & 1 deletion pkg/buffer/reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func (reader *Reader) Slurp(size int) error {
}

// ReadUntypedMsg reads a length-prefixed message. It is only used directly
// during the authentication phase of the protocol; ReadTypedMsg is used at all
// during the authentication phase of the protocol; [ReadTypedMsg] is used at all
// other times. This returns the number of bytes read and an error, if there
// was one. The number of bytes returned can be non-zero even with an error
// (e.g. if data was read but didn't validate) so that we can more accurately
Expand Down
1 change: 0 additions & 1 deletion pkg/buffer/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"bytes"
"encoding/binary"
"io"

"log/slog"

"github.com/jeroenrinzema/psql-wire/pkg/types"
Expand Down
17 changes: 17 additions & 0 deletions pkg/types/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ type ClientMessage byte
// ServerMessage represents a server pgwire message.
type ServerMessage byte

// DescribeMessage represents a client describe message type.
type DescribeMessage byte

// http://www.postgresql.org/docs/9.4/static/protocol-message-formats.html
const (
ClientBind ClientMessage = 'B'
Expand Down Expand Up @@ -38,6 +41,9 @@ const (
ServerPortalSuspended ServerMessage = 's'
ServerReady ServerMessage = 'Z'
ServerRowDescription ServerMessage = 'T'

DescribePortal DescribeMessage = 'P'
DescribeStatement DescribeMessage = 'S'
)

func (m ClientMessage) String() string {
Expand Down Expand Up @@ -111,3 +117,14 @@ func (m ServerMessage) String() string {
return "Unknown"
}
}

func (m DescribeMessage) String() string {
switch m {
case DescribePortal:
return "Portal"
case DescribeStatement:
return "Statement"
default:
return "Unknown"
}
}
3 changes: 1 addition & 2 deletions wire.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@ import (
"crypto/x509"
"errors"
"fmt"
"log/slog"
"net"
"sync"
"sync/atomic"

"log/slog"

"github.com/jackc/pgx/v5/pgtype"
"github.com/jeroenrinzema/psql-wire/pkg/buffer"
"github.com/jeroenrinzema/psql-wire/pkg/types"
Expand Down
134 changes: 133 additions & 1 deletion wire_test.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
package wire

import (
"bytes"
"context"
"database/sql"
"fmt"
"io"
"net"
"testing"

"github.com/jackc/pgx/v5"
_ "github.com/jackc/pgx/v5/stdlib"
"github.com/jeroenrinzema/psql-wire/pkg/mock"
_ "github.com/lib/pq"
"github.com/lib/pq"
"github.com/lib/pq/oid"
"github.com/neilotoole/slogt"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -494,3 +496,133 @@ func TestServerNULLValues(t *testing.T) {
}
})
}

type stubStatementCache struct {
SetFn func(context.Context, string, *PreparedStatement) error
GetFn func(context.Context, string) (*Statement, error)
}

func (c *stubStatementCache) Set(ctx context.Context, name string, fn *PreparedStatement) error {
return c.SetFn(ctx, name, fn)
}

func (c *stubStatementCache) Get(ctx context.Context, name string) (*Statement, error) {
return c.GetFn(ctx, name)
}

func TestServerCopyIn(t *testing.T) {
t.Parallel()

handler := func(ctx context.Context, query string) (PreparedStatements, error) {
t.Log("preparing QUERY", query)
handle := func(ctx context.Context, writer DataWriter, parameters []Parameter) error {
t.Log("executing QUERY", query)
switch query {
case "BEGIN READ WRITE":
return writer.Complete("BEGIN")
}
r, err := writer.CopyIn(BinaryFormat, []FormatCode{BinaryFormat, BinaryFormat, BinaryFormat})
if err != nil {
return err
}
buf := &bytes.Buffer{}
if _, err := io.Copy(buf, r); err != nil {
return err
}
return writer.Complete("COPY 2")
}

return Prepared(NewStatement(handle)), nil
}

defStatement := DefaultStatementCache{}

opts := []OptionFn{
Logger(slogt.New(t)),
Statements(
&stubStatementCache{
SetFn: func(ctx context.Context, name string, fn *PreparedStatement) error {
fn.columns = Columns{
{}, {}, {},
}
return defStatement.Set(ctx, name, fn)
},
GetFn: defStatement.Get,
},
),
}

server, err := NewServer(handler, opts...)
if err != nil {
t.Fatal(err)
}

address := TListenAndServe(t, server)

rows := [][]any{
{196, "My Posse In Effect", nil},
{181, "Almost KISS", nil},
}

t.Run("lib/pq", func(t *testing.T) {
t.Skip()
connstr := fmt.Sprintf("host=%s port=%d sslmode=disable", address.IP, address.Port)
conn, err := sql.Open("postgres", connstr)
if err != nil {
t.Fatal(err)
}

txn, err := conn.Begin()
if err != nil {
t.Fatal(err)
}

stmt, err := txn.Prepare(pq.CopyIn("id", "name", "spotify_id"))
if err != nil {
t.Fatal(err)
}

for _, row := range rows {
_, err := stmt.Exec(row...)
if err != nil {
t.Fatal(err)
}
}
if err != nil {
t.Fatal(err)
}

if err := stmt.Close(); err != nil {
t.Fatal(err)
}
if err := txn.Commit(); err != nil {
t.Fatal(err)
}

if err := conn.Close(); err != nil {
t.Fatal(err)
}
})

t.Run("jackc/pgx", func(t *testing.T) {
ctx := context.Background()
connstr := fmt.Sprintf("postgres://%s:%d", address.IP, address.Port)
conn, err := pgx.Connect(ctx, connstr)
if err != nil {
t.Fatal(err)
}

n, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"id", "name", "spotify_id"}, pgx.CopyFromRows(rows))
if err != nil {
t.Fatal(err)
}
if n != 2 {
t.Fatalf("unexpected number of rows copied: %d", n)
}

err = conn.Close(ctx)
if err != nil {
t.Fatal(err)
}
})
}
66 changes: 55 additions & 11 deletions writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package wire
import (
"context"
"errors"
"io"

"github.com/jeroenrinzema/psql-wire/pkg/buffer"
"github.com/jeroenrinzema/psql-wire/pkg/types"
Expand Down Expand Up @@ -32,6 +33,14 @@ type DataWriter interface {
//
// [CommandComplete]: https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-COMMANDCOMPLETE
Complete(description string) error

// CopyIn sends a [CopyInResponse] to the client, to initiate a CopyIn
// operation. All format values must be either [TextFormat] or [BinaryFormat].
// When overallFormat is [TextFormat], all columnFormats must be [TextFormat]. When
// overallFormat is BinaryFormat, columnFormats may be either [TextFormat] or
// [BinaryFormat]. You must provide one columnFormat value for each column
// expected by the CopyIn operation.
CopyIn(overallFormat FormatCode, columnFormats []FormatCode) (io.Reader, error)
}

// ErrDataWritten is returned when an empty result is attempted to be sent to the
Expand All @@ -45,23 +54,25 @@ var ErrClosedWriter = errors.New("closed writer")
// buffer. The returned writer should be handled with caution as it is not safe
// for concurrent use. Concurrent access to the same data without proper
// synchronization can result in unexpected behavior and data corruption.
func NewDataWriter(ctx context.Context, columns Columns, formats []FormatCode, writer *buffer.Writer) DataWriter {
func NewDataWriter(ctx context.Context, columns Columns, formats []FormatCode, writer *buffer.Writer, copyData io.Reader) DataWriter {
return &dataWriter{
ctx: ctx,
columns: columns,
formats: formats,
client: writer,
ctx: ctx,
columns: columns,
formats: formats,
client: writer,
copyData: copyData,
}
}

// dataWriter is a implementation of the DataWriter interface.
type dataWriter struct {
ctx context.Context
columns Columns
formats []FormatCode
client *buffer.Writer
closed bool
written uint64
ctx context.Context
columns Columns
formats []FormatCode
client *buffer.Writer
closed bool
written uint64
copyData io.Reader
}

func (writer *dataWriter) Define(columns Columns) error {
Expand All @@ -83,6 +94,39 @@ func (writer *dataWriter) Row(values []any) error {
return writer.columns.Write(writer.ctx, writer.formats, writer.client, values)
}

func (writer *dataWriter) CopyIn(overallFormat FormatCode, columnFormats []FormatCode) (io.Reader, error) {
if writer.closed {
return nil, ErrClosedWriter
}
if writer.copyData == nil {
return nil, errors.New("DataCopyFn is nil; use PortalCacheCopy to execute CopyIn")
}
if len(columnFormats) == 0 {
return nil, errors.New("CopyIn must have at least one column")
}

if err := writer.sendCopyInResponse(overallFormat, columnFormats); err != nil {
return nil, err
}

return writer.copyData, nil
}

// sendCopyInResponse sends a [CopyInResponse] to the client, to initiate a
// CopyIn operation. format must be either [TextFormat] or [BinaryFormat], and
// columnCount must be >= 1.
//
// [CopyInResponse]: https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-COPYINRESPONSE
func (writer *dataWriter) sendCopyInResponse(format FormatCode, columnFormats []FormatCode) error {
writer.client.Start(types.ServerCopyInResponse)
writer.client.AddByte(byte(format))
writer.client.AddInt16(int16(len(columnFormats)))
for _, columnFormat := range columnFormats {
writer.client.AddInt16(int16(columnFormat))
}
return writer.client.End()
}

func (writer *dataWriter) Empty() error {
if writer.closed {
return ErrClosedWriter
Expand Down
Loading