Skip to content

Commit

Permalink
Merge branch 'main' into usebingo
Browse files Browse the repository at this point in the history
  • Loading branch information
kishaningithub authored Jun 5, 2023
2 parents c9e89ac + e7647ad commit 32d2e95
Show file tree
Hide file tree
Showing 26 changed files with 452 additions and 294 deletions.
15 changes: 11 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
[![Go Reference](https://pkg.go.dev/badge/github.com/jeroenrinzema/psql-wire.svg)](https://pkg.go.dev/github.com/jeroenrinzema/psql-wire) [![Latest release](https://img.shields.io/github/release/jeroenrinzema/psql-wire.svg)](https://github.com/jeroenrinzema/psql-wire/releases) [![Go Report Card](https://goreportcard.com/badge/github.com/jeroenrinzema/psql-wire)](https://goreportcard.com/report/github.com/jeroenrinzema/psql-wire)

A pure Go [PostgreSQL](https://www.postgresql.org/) server wire protocol implementation.
Build your own PostgreSQL server with 15 lines of code.
Build your own PostgreSQL server within a few lines of code.
This project attempts to make it as straight forward as possible to set-up and configure your own PSQL server.
Feel free to check out the [examples](https://github.com/jeroenrinzema/psql-wire/tree/main/examples) directory for various ways on how to configure/set-up your own server.

Expand All @@ -21,10 +21,17 @@ import (
)

func main() {
wire.ListenAndServe("127.0.0.1:5432", func(ctx context.Context, query string, writer wire.DataWriter, parameters []string) error {
fmt.Println(query)
wire.ListenAndServe("127.0.0.1:5432", handler)
}

func handler(ctx context.Context, query string) (wire.PreparedStatementFn, []oid.Oid, wire.Columns, error) {
fmt.Println(query)

statement := func(ctx context.Context, writer wire.DataWriter, parameters []string) error {
return writer.Complete("OK")
})
}

return statement, nil, nil, nil
}
```

Expand Down
12 changes: 6 additions & 6 deletions auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@ func TestDefaultHandleAuth(t *testing.T) {
sink := bytes.NewBuffer([]byte{})

ctx := context.Background()
reader := buffer.NewReader(input, buffer.DefaultBufferSize)
writer := buffer.NewWriter(sink)
reader := buffer.NewReader(zap.NewNop(), input, buffer.DefaultBufferSize)
writer := buffer.NewWriter(zap.NewNop(), sink)

server := &Server{logger: zap.NewNop()}
err := server.handleAuth(ctx, reader, writer)
if err != nil {
t.Fatal(err)
}

result := buffer.NewReader(sink, buffer.DefaultBufferSize)
result := buffer.NewReader(zap.NewNop(), sink, buffer.DefaultBufferSize)
ty, ln, err := result.ReadTypedMsg()
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -54,7 +54,7 @@ func TestClearTextPassword(t *testing.T) {
expected := "password"

input := bytes.NewBuffer([]byte{})
incoming := buffer.NewWriter(input)
incoming := buffer.NewWriter(zap.NewNop(), input)

// NOTE: we could reuse the server buffered writer to write client messages
incoming.Start(types.ServerMessage(types.ClientPassword))
Expand All @@ -73,8 +73,8 @@ func TestClearTextPassword(t *testing.T) {
sink := bytes.NewBuffer([]byte{})

ctx := context.Background()
reader := buffer.NewReader(input, buffer.DefaultBufferSize)
writer := buffer.NewWriter(sink)
reader := buffer.NewReader(zap.NewNop(), input, buffer.DefaultBufferSize)
writer := buffer.NewWriter(zap.NewNop(), sink)

server := &Server{logger: zap.NewNop(), Auth: ClearTextPassword(validate)}
err := server.handleAuth(ctx, reader, writer)
Expand Down
59 changes: 47 additions & 12 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,42 +3,61 @@ package wire
import (
"context"
"sync"

"github.com/jeroenrinzema/psql-wire/internal/buffer"
"github.com/lib/pq/oid"
)

type Statement struct {
fn PreparedStatementFn
parameters []oid.Oid
columns Columns
}

type DefaultStatementCache struct {
statements map[string]PreparedStatementFn
statements map[string]*Statement
mu sync.RWMutex
}

// Set attempts to bind the given statement to the given name. Any
// previously defined statement is overridden.
func (cache *DefaultStatementCache) Set(ctx context.Context, name string, fn PreparedStatementFn) error {
func (cache *DefaultStatementCache) Set(ctx context.Context, name string, fn PreparedStatementFn, parameters []oid.Oid, columns Columns) error {
cache.mu.Lock()
defer cache.mu.Unlock()

if cache.statements == nil {
cache.statements = map[string]PreparedStatementFn{}
cache.statements = map[string]*Statement{}
}

cache.statements[name] = &Statement{
fn: fn,
parameters: parameters,
columns: columns,
}

cache.statements[name] = fn
return nil
}

// Get attempts to get the prepared statement for the given name. An error
// is returned when no statement has been found.
func (cache *DefaultStatementCache) Get(ctx context.Context, name string) (PreparedStatementFn, error) {
func (cache *DefaultStatementCache) Get(ctx context.Context, name string) (*Statement, error) {
cache.mu.RLock()
defer cache.mu.RUnlock()

if cache.statements == nil {
return nil, nil
}

return cache.statements[name], nil
stmt, has := cache.statements[name]
if !has {
return nil, nil
}

return stmt, nil
}

type portal struct {
statement PreparedStatementFn
statement *Statement
parameters []string
}

Expand All @@ -47,7 +66,7 @@ type DefaultPortalCache struct {
mu sync.RWMutex
}

func (cache *DefaultPortalCache) Bind(ctx context.Context, name string, fn PreparedStatementFn, parametes []string) error {
func (cache *DefaultPortalCache) Bind(ctx context.Context, name string, stmt *Statement, parameters []string) error {
cache.mu.Lock()
defer cache.mu.Unlock()

Expand All @@ -56,14 +75,30 @@ func (cache *DefaultPortalCache) Bind(ctx context.Context, name string, fn Prepa
}

cache.portals[name] = portal{
statement: fn,
parameters: parametes,
statement: stmt,
parameters: parameters,
}

return nil
}

func (cache *DefaultPortalCache) Execute(ctx context.Context, name string, writer DataWriter) error {
func (cache *DefaultPortalCache) Get(ctx context.Context, name string) (*Statement, error) {
cache.mu.Lock()
defer cache.mu.Unlock()

if cache.portals == nil {
return nil, nil
}

portal, has := cache.portals[name]
if !has {
return nil, nil
}

return portal.statement, nil
}

func (cache *DefaultPortalCache) Execute(ctx context.Context, name string, writer *buffer.Writer) error {
cache.mu.Lock()
defer cache.mu.Unlock()

Expand All @@ -72,5 +107,5 @@ func (cache *DefaultPortalCache) Execute(ctx context.Context, name string, write
return nil
}

return portal.statement(ctx, writer, portal.parameters)
return portal.statement.fn(ctx, NewDataWriter(ctx, portal.statement.columns, writer), portal.parameters)
}
Loading

0 comments on commit 32d2e95

Please sign in to comment.