-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpgxtester.go
108 lines (96 loc) · 2.76 KB
/
pgxtester.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
//nolint:ireturn
package main
import (
"context"
"os"
"sync"
"testing"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
)
type DBTX interface {
Exec(context.Context, string, ...any) (pgconn.CommandTag, error)
Query(context.Context, string, ...any) (pgx.Rows, error)
QueryRow(context.Context, string, ...any) pgx.Row
CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error)
}
type Config struct {
// Connection URL to be used if Conn is not provided.
URL string
// The connection to wrap around. If nil, a new connection will be established using URL.
Conn *pgx.Conn
// The timeout for establishing connection, creating trnsaction, and rolling it back.
//
// Default is 2 seconds.
Timeout time.Duration
// The options for creating a transaction.
TxOptions pgx.TxOptions
}
func Connect(t *testing.T, c Config) DBTX {
if c.Timeout == 0 {
c.Timeout = 2 * time.Second
}
if c.Conn == nil {
if c.URL == "" {
c.URL = os.Getenv("POSTGRES_URL")
}
ctx, cancel := context.WithTimeout(context.Background(), c.Timeout)
defer cancel()
conn, err := pgx.Connect(ctx, c.URL)
if err != nil {
t.Fatalf("failed to connect to PostgreSQL: %v", err)
}
c.Conn = conn
}
ctx, cancel := context.WithTimeout(context.Background(), c.Timeout)
defer cancel()
tx, err := c.Conn.BeginTx(ctx, c.TxOptions)
if err != nil {
t.Fatalf("failed to begin tansaction: %v", err)
}
t.Cleanup(func() {
ctx, cancel := context.WithTimeout(context.Background(), c.Timeout)
defer cancel()
err := tx.Rollback(ctx)
if err != nil {
t.Fatalf("failed to rollback transaction: %v", err)
}
})
wrapper := blockingDB{tx: tx}
return &wrapper
}
// A wrapper around DB connection that is safe to be used concurrently.
//
// It is similar to pgxpool except it keeps only a single connection,
// and so it can be safely rolled back. It is slower than pgxpool
// so it must be used only in tests.
type blockingDB struct {
tx DBTX
mx sync.Mutex
}
func (db *blockingDB) Exec(ctx context.Context, q string, args ...any) (pgconn.CommandTag, error) {
db.mx.Lock()
defer db.mx.Unlock()
return db.tx.Exec(ctx, q, args...)
}
func (db *blockingDB) Query(ctx context.Context, q string, args ...any) (pgx.Rows, error) {
db.mx.Lock()
defer db.mx.Unlock()
return db.tx.Query(ctx, q, args...) //nolint:sqlclosecheck
}
func (db *blockingDB) QueryRow(ctx context.Context, q string, args ...any) pgx.Row {
db.mx.Lock()
defer db.mx.Unlock()
return db.tx.QueryRow(ctx, q, args...)
}
func (db *blockingDB) CopyFrom(
ctx context.Context,
tableName pgx.Identifier,
columnNames []string,
rowSrc pgx.CopyFromSource,
) (int64, error) {
db.mx.Lock()
defer db.mx.Unlock()
return db.tx.CopyFrom(ctx, tableName, columnNames, rowSrc)
}