Skip to content

Commit

Permalink
Add WAL message heartbeat (#4)
Browse files Browse the repository at this point in the history
* Add WAL message heartbeat

This PR introduces WAL message heartbeats for PG >= 14.  The idea is,
every ~minute we push a wal message via the query connection.  This does
nothing but ensure that the wal subscriber receives messages and the DB
has activity.

We then force report the WAL position on heartbeats.

* Fix lints

* Don't heartbeat forever you muppet
  • Loading branch information
tonyhb authored Oct 12, 2024
1 parent cd89c86 commit ef02a7d
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 21 deletions.
5 changes: 5 additions & 0 deletions pkg/changeset/changeset.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ const (
OperationUpdate Operation = "UPDATE"
OperationDelete Operation = "DELETE"
OperationTruncate Operation = "TRUNCATE"

// OperationHeartbeat represents the changeset generated for heartbeats when we
// send messages to increase the WAL LSN. This is used for updating watermarks only,
// and should not process events.
OperationHeartbeat Operation = "HEARTBEAT"
)

// WatermarkCommitter is an interface that commits a given watermark to backing datastores.
Expand Down
2 changes: 2 additions & 0 deletions pkg/consts/pgconsts/pgconsts.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@ const (
Username = "inngest"
SlotName = "inngest_cdc"
PublicationName = "inngest"

MessagesVersion = 14
)
20 changes: 17 additions & 3 deletions pkg/decoder/pg_logical_v1.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,24 @@ import (
"github.com/jackc/pgx/v5/pgtype"
)

func NewV1LogicalDecoder(s *schema.PGXSchemaLoader, log *slog.Logger) Decoder {
func NewV1LogicalDecoder(s *schema.PGXSchemaLoader, log *slog.Logger, messages bool) Decoder {
return v1LogicalDecoder{
log: log,
schema: s,
messages: messages,
relations: make(map[uint32]*pglogrepl.RelationMessage),
}
}

type v1LogicalDecoder struct {
log *slog.Logger

messages bool
schema *schema.PGXSchemaLoader
relations map[uint32]*pglogrepl.RelationMessage
}

func (v1LogicalDecoder) ReplicationPluginArgs() []string {
func (v v1LogicalDecoder) ReplicationPluginArgs() []string {
// https://www.postgresql.org/docs/current/protocol-logical-replication.html#PROTOCOL-LOGICAL-REPLICATION-PARAMS
//
// "Proto_version '2'" with "streaming 'true' streams transactions as they're progressing.
Expand All @@ -37,10 +39,17 @@ func (v1LogicalDecoder) ReplicationPluginArgs() []string {
//
// Version 1 only sends DML entries when the transaction commits, ensuring that any event
// generated by Inngest is for a committed transaction.
if v.messages {
return []string{
"proto_version '1'",
fmt.Sprintf("publication_names '%s'", pgconsts.PublicationName),
"messages 'true'", // Doesn't work for <= v13
}
}

return []string{
"proto_version '1'",
fmt.Sprintf("publication_names '%s'", pgconsts.PublicationName),
// "messages 'true'", // Doesn't work for v12 and v13.
}
}

Expand All @@ -49,6 +58,11 @@ func (v v1LogicalDecoder) Decode(in []byte, cs *changeset.Changeset) (bool, erro
msgType := pglogrepl.MessageType(in[0])

switch msgType {
case pglogrepl.MessageTypeMessage:
// This is a heartbeat (or another WAL message). Do nothing but record
// the heartbeat and updated watermark.
cs.Operation = changeset.OperationHeartbeat
return true, nil
case pglogrepl.MessageTypeRelation:
// MessageTypeRelation describes the OIDs for any relation before DML messages are sent. From the docs:
//
Expand Down
103 changes: 85 additions & 18 deletions pkg/replicator/pgreplicator/pg.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"log/slog"
"os"
"strings"
"sync"
"sync/atomic"
"time"

Expand All @@ -24,8 +25,9 @@ import (
)

var (
ReadTimeout = time.Second * 5
CommitInterval = time.Second * 5
ReadTimeout = time.Second * 5
CommitInterval = time.Second * 5
DefaultHeartbeatTime = time.Minute
)

// PostgresReplicator is a Replicator with added postgres functionality.
Expand Down Expand Up @@ -61,6 +63,12 @@ type Opts struct {

// New returns a new postgres replicator for a single postgres database.
func New(ctx context.Context, opts Opts) (PostgresReplicator, error) {
if opts.Log == nil {
opts.Log = slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{
Level: slog.LevelInfo,
}))
}

cfg := opts.Config

// Ensure that we add "replication": "database" as a to the replication
Expand All @@ -84,24 +92,28 @@ func New(ctx context.Context, opts Opts) (PostgresReplicator, error) {
return nil, fmt.Errorf("error connecting to postgres host for schemas: %w", err)
}

// Query for current postgres version.
var version int
row := pgxc.QueryRow(ctx, "SELECT current_setting('server_version_num')::int / 10000;")
if err := row.Scan(&version); err != nil {
opts.Log.Warn("error querying for postgres version", "error", err)
}

sl := schema.NewPGXSchemaLoader(pgxc)
// Refresh all schemas to begin with
if err := sl.Refresh(); err != nil {
return nil, err
}

if opts.Log == nil {
opts.Log = slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{
Level: slog.LevelInfo,
}))
}

return &pg{
opts: opts,
conn: replConn,
queryConn: pgxc,
decoder: decoder.NewV1LogicalDecoder(sl, opts.Log),
log: opts.Log,
opts: opts,
conn: replConn,
queryConn: pgxc,
queryLock: &sync.Mutex{},
decoder: decoder.NewV1LogicalDecoder(sl, opts.Log, version >= pgconsts.MessagesVersion),
log: opts.Log,
version: version,
heartbeatTime: DefaultHeartbeatTime,
}, nil
}

Expand All @@ -111,8 +123,14 @@ type pg struct {
// conn is the WAL streaming connection. Once replication starts, this
// conn cannot be used for any queries.
conn *pgx.Conn

// queryCon is a conn for querying data.
queryConn *pgx.Conn

// queryLock is used to lock pgx.Conn, as it's a single connection which cannot be used
// in parallel.
queryLock *sync.Mutex

// decoder decodes the binary WAL log
decoder decoder.Decoder
// nextReportTime records the time in which we must next report the current
Expand All @@ -125,6 +143,9 @@ type pg struct {
// log is a stdlib logger for reporting debug and warn logs.
log *slog.Logger

version int
heartbeatTime time.Duration

stopped int32
}

Expand All @@ -140,14 +161,20 @@ func (p *pg) Close(ctx context.Context) error {
}

func (p *pg) ReplicationSlot(ctx context.Context) (ReplicationSlot, error) {

mode, err := p.walMode(ctx)
if err != nil {
return ReplicationSlot{}, err
}

if mode != "logical" {
return ReplicationSlot{}, ErrLogicalReplicationNotSetUp
}

// Lock when querying repl slot data.
p.queryLock.Lock()
defer p.queryLock.Unlock()

return ReplicationSlotData(ctx, p.queryConn)
}

Expand Down Expand Up @@ -218,6 +245,29 @@ func (p *pg) Pull(ctx context.Context, cc chan *changeset.Changeset) error {
// the DML.
unwrapper := &txnUnwrapper{cc: cc}

go func() {
if p.version < pgconsts.MessagesVersion {
// doesn't support wal messages; ignore.
return
}

t := time.NewTicker(p.heartbeatTime)
for range t.C {
if ctx.Err() != nil {
return
}

// Send a hearbeat every minute
p.queryLock.Lock()
_, err := p.queryConn.Exec(ctx, "SELECT pg_logical_emit_message(false, 'heartbeat', now()::varchar);")
p.queryLock.Unlock()

if err != nil {
p.log.Warn("unable to emit heartbeat", "error", err, "host", p.opts.Config.Host)
}
}
}()

for {
if ctx.Err() != nil || atomic.LoadInt32(&p.stopped) == 1 {
// Always call Close automatically.
Expand All @@ -233,6 +283,14 @@ func (p *pg) Pull(ctx context.Context, cc chan *changeset.Changeset) error {
continue
}

if changes.Operation == changeset.OperationHeartbeat {
p.Commit(changes.Watermark)
if err := p.forceNextReport(ctx); err != nil {
p.log.Warn("unable to report lsn on heartbeat", "error", err, "host", p.opts.Config.Host)
}
continue
}

unwrapper.Process(changes)
}
}
Expand All @@ -259,7 +317,9 @@ func (p *pg) fetch(ctx context.Context) (*changeset.Changeset, error) {

if err != nil {
if pgconn.Timeout(err) {
p.forceNextReport()
if err := p.forceNextReport(ctx); err != nil {
p.log.Warn("unable to report lsn on timeout", "error", err, "host", p.opts.Config.Host)
}
// We return nil as we want to keep iterating.
return nil, nil
}
Expand Down Expand Up @@ -291,7 +351,9 @@ func (p *pg) fetch(ctx context.Context) (*changeset.Changeset, error) {
return nil, fmt.Errorf("error parsing replication keepalive: %w", err)
}
if pkm.ReplyRequested {
p.forceNextReport()
if err := p.forceNextReport(ctx); err != nil {
p.log.Warn("unable to report lsn on request", "error", err, "host", p.opts.Config.Host)
}
}
return nil, nil
case pglogrepl.XLogDataByteID:
Expand All @@ -316,6 +378,7 @@ func (p *pg) fetch(ctx context.Context) (*changeset.Changeset, error) {
if err != nil {
return nil, fmt.Errorf("error decoding xlog data: %w", err)
}

if !ok {
return nil, nil
}
Expand Down Expand Up @@ -348,10 +411,11 @@ func (p *pg) committedWatermark() (wm changeset.Watermark) {
}
}

func (p *pg) forceNextReport() {
func (p *pg) forceNextReport(ctx context.Context) error {
// Updating the next report time to a zero time always reports the LSN,
// as time.Now() is always after the empty time.
p.nextReportTime = time.Time{}
return p.report(ctx, true)
}

// report reports the current replication slot's LSN progress to the server. We can optionally
Expand Down Expand Up @@ -384,6 +448,9 @@ func (p *pg) LSN() (lsn pglogrepl.LSN) {
}

func (p *pg) walMode(ctx context.Context) (string, error) {
p.queryLock.Lock()
defer p.queryLock.Unlock()

var mode string
row := p.queryConn.QueryRow(ctx, "SHOW wal_level")
err := row.Scan(&mode)
Expand All @@ -408,8 +475,8 @@ func ReplicationSlotData(ctx context.Context, conn *pgx.Conn) (ReplicationSlot,
row := conn.QueryRow(
ctx,
fmt.Sprintf(`SELECT
active, restart_lsn, confirmed_flush_lsn
FROM pg_replication_slots WHERE slot_name = '%s';`,
active, restart_lsn, confirmed_flush_lsn
FROM pg_replication_slots WHERE slot_name = '%s';`,
pgconsts.SlotName,
),
)
Expand Down
47 changes: 47 additions & 0 deletions pkg/replicator/pgreplicator/pg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,53 @@ func TestInsert(t *testing.T) {
}
}

func TestLogicalEmitHeartbeat(t *testing.T) {
t.Parallel()
versions := []int{14, 15, 16}

for _, v1 := range versions {
v := v1 // loop capture
t.Run(fmt.Sprintf("EmitHeartbeat - Postgres %d", v), func(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithCancel(context.Background())

c, conn := test.StartPG(t, ctx, test.StartPGOpts{Version: v})
opts := Opts{Config: conn}
repl, err := New(ctx, opts)

// heartbeat fast in tests.
r := repl.(*pg)
r.heartbeatTime = 250 * time.Millisecond
require.NoError(t, err)

cb := eventwriter.NewCallbackWriter(ctx, 1, time.Millisecond, func(batch []*changeset.Changeset) error {
return nil
})
csChan := cb.Listen(ctx, r)

go func() {
err := r.Pull(ctx, csChan)
require.NoError(t, err)
}()

slotA, err := r.ReplicationSlot(ctx)
require.NoError(t, err)

<-time.After(1100 * time.Millisecond)

slotB, err := r.ReplicationSlot(ctx)
require.NoError(t, err)

require.NotEqual(t, slotA.ConfirmedFlushLSN, slotB.ConfirmedFlushLSN)
require.True(t, int(slotB.ConfirmedFlushLSN) > int(slotA.ConfirmedFlushLSN))

cancel()
_ = c.Stop(ctx, nil)
})
}
}

func TestUpdateMany_ReplicaIdentityFull(t *testing.T) {
t.Parallel()
versions := []int{12, 13, 14, 15, 16}
Expand Down
6 changes: 6 additions & 0 deletions pkg/replicator/pgreplicator/txn_unwrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ func (t *txnUnwrapper) Process(cs *changeset.Changeset) {
}

switch cs.Operation {
case changeset.OperationHeartbeat:
// The unwrapper should never receive heartbeats as the replicator should
// handle them and short circuit. However, always transmit them immediately
// for safety in code in case someone changes something in the future.
t.cc <- cs
return
case changeset.OperationBegin:
t.begin = cs
case changeset.OperationCommit:
Expand Down

0 comments on commit ef02a7d

Please sign in to comment.