Skip to content

Commit

Permalink
pgxstore: Implement CtxStore interface
Browse files Browse the repository at this point in the history
- Followed the same patterns at the goredisstore module.
  • Loading branch information
pete-woods committed Apr 9, 2024
1 parent 7e11d57 commit b7e39e0
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 45 deletions.
47 changes: 34 additions & 13 deletions pgxstore/pgxstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package pgxstore

import (
"context"
"errors"
"log"
"time"

Expand Down Expand Up @@ -33,42 +34,42 @@ func NewWithCleanupInterval(pool *pgxpool.Pool, cleanupInterval time.Duration) *
return p
}

// Find returns the data for a given session token from the PostgresStore instance.
// FindCtx returns the data for a given session token from the PostgresStore instance.
// If the session token is not found or is expired, the returned exists flag will
// be set to false.
func (p *PostgresStore) Find(token string) (b []byte, exists bool, err error) {
row := p.pool.QueryRow(context.Background(), "SELECT data FROM sessions WHERE token = $1 AND current_timestamp < expiry", token)
func (p *PostgresStore) FindCtx(ctx context.Context, token string) (b []byte, found bool, err error) {
row := p.pool.QueryRow(ctx, "SELECT data FROM sessions WHERE token = $1 AND current_timestamp < expiry", token)
err = row.Scan(&b)
if err == pgx.ErrNoRows {
if errors.Is(err, pgx.ErrNoRows) {
return nil, false, nil
} else if err != nil {
return nil, false, err
}
return b, true, nil
}

// Commit adds a session token and data to the PostgresStore instance with the
// CommitCtx adds a session token and data to the PostgresStore instance with the
// given expiry time. If the session token already exists, then the data and expiry
// time are updated.
func (p *PostgresStore) Commit(token string, b []byte, expiry time.Time) error {
_, err := p.pool.Exec(context.Background(), "INSERT INTO sessions (token, data, expiry) VALUES ($1, $2, $3) ON CONFLICT (token) DO UPDATE SET data = EXCLUDED.data, expiry = EXCLUDED.expiry", token, b, expiry)
func (p *PostgresStore) CommitCtx(ctx context.Context, token string, b []byte, expiry time.Time) (err error) {
_, err = p.pool.Exec(ctx, "INSERT INTO sessions (token, data, expiry) VALUES ($1, $2, $3) ON CONFLICT (token) DO UPDATE SET data = EXCLUDED.data, expiry = EXCLUDED.expiry", token, b, expiry)
if err != nil {
return err
}
return nil
}

// Delete removes a session token and corresponding data from the PostgresStore
// DeleteCtx removes a session token and corresponding data from the PostgresStore
// instance.
func (p *PostgresStore) Delete(token string) error {
_, err := p.pool.Exec(context.Background(), "DELETE FROM sessions WHERE token = $1", token)
func (p *PostgresStore) DeleteCtx(ctx context.Context, token string) (err error) {
_, err = p.pool.Exec(ctx, "DELETE FROM sessions WHERE token = $1", token)
return err
}

// All returns a map containing the token and data for all active (i.e.
// AllCtx returns a map containing the token and data for all active (i.e.
// not expired) sessions in the PostgresStore instance.
func (p *PostgresStore) All() (map[string][]byte, error) {
rows, err := p.pool.Query(context.Background(), "SELECT token, data FROM sessions WHERE current_timestamp < expiry")
func (p *PostgresStore) AllCtx(ctx context.Context) (map[string][]byte, error) {
rows, err := p.pool.Query(ctx, "SELECT token, data FROM sessions WHERE current_timestamp < expiry")
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -135,3 +136,23 @@ func (p *PostgresStore) deleteExpired() error {
_, err := p.pool.Exec(context.Background(), "DELETE FROM sessions WHERE expiry < current_timestamp")
return err
}

// We have to add the plain Store methods here to be recognized a Store
// by the go compiler. Not using a separate type makes any errors caught
// only at runtime instead of compile time.

func (p *PostgresStore) Find(token string) (b []byte, exists bool, err error) {
panic("missing context arg")
}

func (p *PostgresStore) Commit(token string, b []byte, expiry time.Time) error {
panic("missing context arg")
}

func (p *PostgresStore) Delete(token string) error {
panic("missing context arg")
}

func (p *PostgresStore) All() (map[string][]byte, error) {
panic("missing context arg")
}
79 changes: 47 additions & 32 deletions pgxstore/pgxstore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,27 @@ import (
)

func TestFind(t *testing.T) {
ctx := context.Background()

dsn := os.Getenv("SCS_POSTGRES_TEST_DSN")
pool, err := pgxpool.New(context.Background(), dsn)
pool, err := pgxpool.New(ctx, dsn)
if err != nil {
t.Fatal(err)
}
defer pool.Close()

_, err = pool.Exec(context.Background(), "TRUNCATE TABLE sessions")
_, err = pool.Exec(ctx, "TRUNCATE TABLE sessions")
if err != nil {
t.Fatal(err)
}
_, err = pool.Exec(context.Background(), "INSERT INTO sessions VALUES('session_token', 'encoded_data', current_timestamp + interval '1 minute')")
_, err = pool.Exec(ctx, "INSERT INTO sessions VALUES('session_token', 'encoded_data', current_timestamp + interval '1 minute')")
if err != nil {
t.Fatal(err)
}

p := NewWithCleanupInterval(pool, 0)

b, found, err := p.Find("session_token")
b, found, err := p.FindCtx(ctx, "session_token")
if err != nil {
t.Fatal(err)
}
Expand All @@ -43,21 +45,23 @@ func TestFind(t *testing.T) {
}

func TestFindMissing(t *testing.T) {
ctx := context.Background()

dsn := os.Getenv("SCS_POSTGRES_TEST_DSN")
pool, err := pgxpool.New(context.Background(), dsn)
pool, err := pgxpool.New(ctx, dsn)
if err != nil {
t.Fatal(err)
}
defer pool.Close()

_, err = pool.Exec(context.Background(), "TRUNCATE TABLE sessions")
_, err = pool.Exec(ctx, "TRUNCATE TABLE sessions")
if err != nil {
t.Fatal(err)
}

p := NewWithCleanupInterval(pool, 0)

_, found, err := p.Find("missing_session_token")
_, found, err := p.FindCtx(ctx, "missing_session_token")
if err != nil {
t.Fatalf("got %v: expected %v", err, nil)
}
Expand All @@ -67,26 +71,28 @@ func TestFindMissing(t *testing.T) {
}

func TestSaveNew(t *testing.T) {
ctx := context.Background()

dsn := os.Getenv("SCS_POSTGRES_TEST_DSN")
pool, err := pgxpool.New(context.Background(), dsn)
pool, err := pgxpool.New(ctx, dsn)
if err != nil {
t.Fatal(err)
}
defer pool.Close()

_, err = pool.Exec(context.Background(), "TRUNCATE TABLE sessions")
_, err = pool.Exec(ctx, "TRUNCATE TABLE sessions")
if err != nil {
t.Fatal(err)
}

p := NewWithCleanupInterval(pool, 0)

err = p.Commit("session_token", []byte("encoded_data"), time.Now().Add(time.Minute))
err = p.CommitCtx(ctx, "session_token", []byte("encoded_data"), time.Now().Add(time.Minute))
if err != nil {
t.Fatal(err)
}

row := pool.QueryRow(context.Background(), "SELECT data FROM sessions WHERE token = 'session_token'")
row := pool.QueryRow(ctx, "SELECT data FROM sessions WHERE token = 'session_token'")
var data []byte
err = row.Scan(&data)
if err != nil {
Expand All @@ -98,30 +104,32 @@ func TestSaveNew(t *testing.T) {
}

func TestSaveUpdated(t *testing.T) {
ctx := context.Background()

dsn := os.Getenv("SCS_POSTGRES_TEST_DSN")
pool, err := pgxpool.New(context.Background(), dsn)
pool, err := pgxpool.New(ctx, dsn)
if err != nil {
t.Fatal(err)
}
defer pool.Close()

_, err = pool.Exec(context.Background(), "TRUNCATE TABLE sessions")
_, err = pool.Exec(ctx, "TRUNCATE TABLE sessions")
if err != nil {
t.Fatal(err)
}
_, err = pool.Exec(context.Background(), "INSERT INTO sessions VALUES('session_token', 'encoded_data', current_timestamp + interval '1 minute')")
_, err = pool.Exec(ctx, "INSERT INTO sessions VALUES('session_token', 'encoded_data', current_timestamp + interval '1 minute')")
if err != nil {
t.Fatal(err)
}

p := NewWithCleanupInterval(pool, 0)

err = p.Commit("session_token", []byte("new_encoded_data"), time.Now().Add(time.Minute))
err = p.CommitCtx(ctx, "session_token", []byte("new_encoded_data"), time.Now().Add(time.Minute))
if err != nil {
t.Fatal(err)
}

row := pool.QueryRow(context.Background(), "SELECT data FROM sessions WHERE token = 'session_token'")
row := pool.QueryRow(ctx, "SELECT data FROM sessions WHERE token = 'session_token'")
var data []byte
err = row.Scan(&data)
if err != nil {
Expand All @@ -133,62 +141,67 @@ func TestSaveUpdated(t *testing.T) {
}

func TestExpiry(t *testing.T) {
ctx := context.Background()

dsn := os.Getenv("SCS_POSTGRES_TEST_DSN")
pool, err := pgxpool.New(context.Background(), dsn)
pool, err := pgxpool.New(ctx, dsn)
if err != nil {
t.Fatal(err)
}
defer pool.Close()

_, err = pool.Exec(context.Background(), "TRUNCATE TABLE sessions")
_, err = pool.Exec(ctx, "TRUNCATE TABLE sessions")
if err != nil {
t.Fatal(err)
}

p := NewWithCleanupInterval(pool, 0)
p := NewWithCleanupInterval(pool, 10*time.Millisecond)
defer p.StopCleanup()

err = p.Commit("session_token", []byte("encoded_data"), time.Now().Add(100*time.Millisecond))
err = p.CommitCtx(ctx, "session_token", []byte("encoded_data"), time.Now().Add(100*time.Millisecond))
if err != nil {
t.Fatal(err)
}

_, found, _ := p.Find("session_token")
_, found, _ := p.FindCtx(ctx, "session_token")
if found != true {
t.Fatalf("got %v: expected %v", found, true)
}

time.Sleep(100 * time.Millisecond)
_, found, _ = p.Find("session_token")
_, found, _ = p.FindCtx(ctx, "session_token")
if found != false {
t.Fatalf("got %v: expected %v", found, false)
}
}

func TestDelete(t *testing.T) {
ctx := context.Background()

dsn := os.Getenv("SCS_POSTGRES_TEST_DSN")
pool, err := pgxpool.New(context.Background(), dsn)
pool, err := pgxpool.New(ctx, dsn)
if err != nil {
t.Fatal(err)
}
defer pool.Close()

_, err = pool.Exec(context.Background(), "TRUNCATE TABLE sessions")
_, err = pool.Exec(ctx, "TRUNCATE TABLE sessions")
if err != nil {
t.Fatal(err)
}
_, err = pool.Exec(context.Background(), "INSERT INTO sessions VALUES('session_token', 'encoded_data', current_timestamp + interval '1 minute')")
_, err = pool.Exec(ctx, "INSERT INTO sessions VALUES('session_token', 'encoded_data', current_timestamp + interval '1 minute')")
if err != nil {
t.Fatal(err)
}

p := NewWithCleanupInterval(pool, 0)

err = p.Delete("session_token")
err = p.DeleteCtx(ctx, "session_token")
if err != nil {
t.Fatal(err)
}

row := pool.QueryRow(context.Background(), "SELECT COUNT(*) FROM sessions WHERE token = 'session_token'")
row := pool.QueryRow(ctx, "SELECT COUNT(*) FROM sessions WHERE token = 'session_token'")
var count int
err = row.Scan(&count)
if err != nil {
Expand All @@ -200,27 +213,29 @@ func TestDelete(t *testing.T) {
}

func TestCleanup(t *testing.T) {
ctx := context.Background()

dsn := os.Getenv("SCS_POSTGRES_TEST_DSN")
pool, err := pgxpool.New(context.Background(), dsn)
pool, err := pgxpool.New(ctx, dsn)
if err != nil {
t.Fatal(err)
}
defer pool.Close()

_, err = pool.Exec(context.Background(), "TRUNCATE TABLE sessions")
_, err = pool.Exec(ctx, "TRUNCATE TABLE sessions")
if err != nil {
t.Fatal(err)
}

p := NewWithCleanupInterval(pool, 200*time.Millisecond)
defer p.StopCleanup()

err = p.Commit("session_token", []byte("encoded_data"), time.Now().Add(100*time.Millisecond))
err = p.CommitCtx(ctx, "session_token", []byte("encoded_data"), time.Now().Add(100*time.Millisecond))
if err != nil {
t.Fatal(err)
}

row := pool.QueryRow(context.Background(), "SELECT COUNT(*) FROM sessions WHERE token = 'session_token'")
row := pool.QueryRow(ctx, "SELECT COUNT(*) FROM sessions WHERE token = 'session_token'")
var count int
err = row.Scan(&count)
if err != nil {
Expand All @@ -231,7 +246,7 @@ func TestCleanup(t *testing.T) {
}

time.Sleep(300 * time.Millisecond)
row = pool.QueryRow(context.Background(), "SELECT COUNT(*) FROM sessions WHERE token = 'session_token'")
row = pool.QueryRow(ctx, "SELECT COUNT(*) FROM sessions WHERE token = 'session_token'")
err = row.Scan(&count)
if err != nil {
t.Fatal(err)
Expand Down

0 comments on commit b7e39e0

Please sign in to comment.