From 176df247da874269c1a4f3fb7b85a3ff42fabd95 Mon Sep 17 00:00:00 2001 From: Tony Holdstock-Brown Date: Sat, 12 Oct 2024 00:35:56 +0000 Subject: [PATCH] Add WAL message heartbeat This PR introduces WAL message heartbeats for PG >= 14. The idea is, every ~minute we push a wal message via the query connection. This does nothing but ensure that the wal subscriber receives messages and the DB has activity. We then force report the WAL position on heartbeats. --- pkg/changeset/changeset.go | 5 + pkg/consts/pgconsts/pgconsts.go | 2 + pkg/decoder/pg_logical_v1.go | 20 +++- pkg/replicator/pgreplicator/pg.go | 106 +++++++++++++++---- pkg/replicator/pgreplicator/pg_test.go | 47 ++++++++ pkg/replicator/pgreplicator/txn_unwrapper.go | 6 ++ 6 files changed, 163 insertions(+), 23 deletions(-) diff --git a/pkg/changeset/changeset.go b/pkg/changeset/changeset.go index 833f2f4..33caf1c 100644 --- a/pkg/changeset/changeset.go +++ b/pkg/changeset/changeset.go @@ -16,6 +16,11 @@ const ( OperationUpdate Operation = "UPDATE" OperationDelete Operation = "DELETE" OperationTruncate Operation = "TRUNCATE" + + // OperationHeartbeat represents the changeset generated for heartbeats when we + // send messages to increase the WAL LSN. This is used for updating watermarks only, + // and should not process events. + OperationHeartbeat Operation = "HEARTBEAT" ) // WatermarkCommitter is an interface that commits a given watermark to backing datastores. diff --git a/pkg/consts/pgconsts/pgconsts.go b/pkg/consts/pgconsts/pgconsts.go index 14e0ebd..7d0def3 100644 --- a/pkg/consts/pgconsts/pgconsts.go +++ b/pkg/consts/pgconsts/pgconsts.go @@ -4,4 +4,6 @@ const ( Username = "inngest" SlotName = "inngest_cdc" PublicationName = "inngest" + + MessagesVersion = 14 ) diff --git a/pkg/decoder/pg_logical_v1.go b/pkg/decoder/pg_logical_v1.go index b978776..7224b84 100644 --- a/pkg/decoder/pg_logical_v1.go +++ b/pkg/decoder/pg_logical_v1.go @@ -12,10 +12,11 @@ import ( "github.com/jackc/pgx/v5/pgtype" ) -func NewV1LogicalDecoder(s *schema.PGXSchemaLoader, log *slog.Logger) Decoder { +func NewV1LogicalDecoder(s *schema.PGXSchemaLoader, log *slog.Logger, messages bool) Decoder { return v1LogicalDecoder{ log: log, schema: s, + messages: messages, relations: make(map[uint32]*pglogrepl.RelationMessage), } } @@ -23,11 +24,12 @@ func NewV1LogicalDecoder(s *schema.PGXSchemaLoader, log *slog.Logger) Decoder { type v1LogicalDecoder struct { log *slog.Logger + messages bool schema *schema.PGXSchemaLoader relations map[uint32]*pglogrepl.RelationMessage } -func (v1LogicalDecoder) ReplicationPluginArgs() []string { +func (v v1LogicalDecoder) ReplicationPluginArgs() []string { // https://www.postgresql.org/docs/current/protocol-logical-replication.html#PROTOCOL-LOGICAL-REPLICATION-PARAMS // // "Proto_version '2'" with "streaming 'true' streams transactions as they're progressing. @@ -37,10 +39,17 @@ func (v1LogicalDecoder) ReplicationPluginArgs() []string { // // Version 1 only sends DML entries when the transaction commits, ensuring that any event // generated by Inngest is for a committed transaction. + if v.messages { + return []string{ + "proto_version '1'", + fmt.Sprintf("publication_names '%s'", pgconsts.PublicationName), + "messages 'true'", // Doesn't work for <= v13 + } + } + return []string{ "proto_version '1'", fmt.Sprintf("publication_names '%s'", pgconsts.PublicationName), - // "messages 'true'", // Doesn't work for v12 and v13. } } @@ -49,6 +58,11 @@ func (v v1LogicalDecoder) Decode(in []byte, cs *changeset.Changeset) (bool, erro msgType := pglogrepl.MessageType(in[0]) switch msgType { + case pglogrepl.MessageTypeMessage: + // This is a heartbeat (or another WAL message). Do nothing but record + // the heartbeat and updated watermark. + cs.Operation = changeset.OperationHeartbeat + return true, nil case pglogrepl.MessageTypeRelation: // MessageTypeRelation describes the OIDs for any relation before DML messages are sent. From the docs: // diff --git a/pkg/replicator/pgreplicator/pg.go b/pkg/replicator/pgreplicator/pg.go index 4735a27..3f47225 100644 --- a/pkg/replicator/pgreplicator/pg.go +++ b/pkg/replicator/pgreplicator/pg.go @@ -9,6 +9,7 @@ import ( "log/slog" "os" "strings" + "sync" "sync/atomic" "time" @@ -24,8 +25,9 @@ import ( ) var ( - ReadTimeout = time.Second * 5 - CommitInterval = time.Second * 5 + ReadTimeout = time.Second * 5 + CommitInterval = time.Second * 5 + DefaultHeartbeatTime = time.Minute ) // PostgresReplicator is a Replicator with added postgres functionality. @@ -61,6 +63,12 @@ type Opts struct { // New returns a new postgres replicator for a single postgres database. func New(ctx context.Context, opts Opts) (PostgresReplicator, error) { + if opts.Log == nil { + opts.Log = slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{ + Level: slog.LevelInfo, + })) + } + cfg := opts.Config // Ensure that we add "replication": "database" as a to the replication @@ -84,24 +92,28 @@ func New(ctx context.Context, opts Opts) (PostgresReplicator, error) { return nil, fmt.Errorf("error connecting to postgres host for schemas: %w", err) } + // Query for current postgres version. + var version int + row := pgxc.QueryRow(ctx, "SELECT current_setting('server_version_num')::int / 10000;") + if err := row.Scan(&version); err != nil { + opts.Log.Warn("error querying for postgres version", "error", err) + } + sl := schema.NewPGXSchemaLoader(pgxc) // Refresh all schemas to begin with if err := sl.Refresh(); err != nil { return nil, err } - if opts.Log == nil { - opts.Log = slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{ - Level: slog.LevelInfo, - })) - } - return &pg{ - opts: opts, - conn: replConn, - queryConn: pgxc, - decoder: decoder.NewV1LogicalDecoder(sl, opts.Log), - log: opts.Log, + opts: opts, + conn: replConn, + queryConn: pgxc, + queryLock: &sync.Mutex{}, + decoder: decoder.NewV1LogicalDecoder(sl, opts.Log, version >= pgconsts.MessagesVersion), + log: opts.Log, + version: version, + heartbeatTime: DefaultHeartbeatTime, }, nil } @@ -111,8 +123,14 @@ type pg struct { // conn is the WAL streaming connection. Once replication starts, this // conn cannot be used for any queries. conn *pgx.Conn + // queryCon is a conn for querying data. queryConn *pgx.Conn + + // queryLock is used to lock pgx.Conn, as it's a single connection which cannot be used + // in parallel. + queryLock *sync.Mutex + // decoder decodes the binary WAL log decoder decoder.Decoder // nextReportTime records the time in which we must next report the current @@ -125,6 +143,9 @@ type pg struct { // log is a stdlib logger for reporting debug and warn logs. log *slog.Logger + version int + heartbeatTime time.Duration + stopped int32 } @@ -140,14 +161,20 @@ func (p *pg) Close(ctx context.Context) error { } func (p *pg) ReplicationSlot(ctx context.Context) (ReplicationSlot, error) { + mode, err := p.walMode(ctx) if err != nil { return ReplicationSlot{}, err } + if mode != "logical" { return ReplicationSlot{}, ErrLogicalReplicationNotSetUp } + // Lock when querying repl slot data. + p.queryLock.Lock() + defer p.queryLock.Unlock() + return ReplicationSlotData(ctx, p.queryConn) } @@ -218,6 +245,25 @@ func (p *pg) Pull(ctx context.Context, cc chan *changeset.Changeset) error { // the DML. unwrapper := &txnUnwrapper{cc: cc} + go func() { + if p.version < pgconsts.MessagesVersion { + // doesn't support wal messages; ignore. + return + } + + t := time.NewTicker(p.heartbeatTime) + for range t.C { + // Send a hearbeat every minute + p.queryLock.Lock() + _, err := p.queryConn.Exec(ctx, "SELECT pg_logical_emit_message(false, 'heartbeat', now()::varchar);") + p.queryLock.Unlock() + + if err != nil { + p.log.Warn("unable to emit heartbeat", "error", err, "host", p.opts.Config.Host) + } + } + }() + for { if ctx.Err() != nil || atomic.LoadInt32(&p.stopped) == 1 { // Always call Close automatically. @@ -233,6 +279,12 @@ func (p *pg) Pull(ctx context.Context, cc chan *changeset.Changeset) error { continue } + if changes.Operation == changeset.OperationHeartbeat { + p.Commit(changes.Watermark) + p.forceNextReport(ctx) + continue + } + unwrapper.Process(changes) } } @@ -259,7 +311,7 @@ func (p *pg) fetch(ctx context.Context) (*changeset.Changeset, error) { if err != nil { if pgconn.Timeout(err) { - p.forceNextReport() + p.forceNextReport(ctx) // We return nil as we want to keep iterating. return nil, nil } @@ -291,7 +343,7 @@ func (p *pg) fetch(ctx context.Context) (*changeset.Changeset, error) { return nil, fmt.Errorf("error parsing replication keepalive: %w", err) } if pkm.ReplyRequested { - p.forceNextReport() + p.forceNextReport(ctx) } return nil, nil case pglogrepl.XLogDataByteID: @@ -316,6 +368,7 @@ func (p *pg) fetch(ctx context.Context) (*changeset.Changeset, error) { if err != nil { return nil, fmt.Errorf("error decoding xlog data: %w", err) } + if !ok { return nil, nil } @@ -348,10 +401,11 @@ func (p *pg) committedWatermark() (wm changeset.Watermark) { } } -func (p *pg) forceNextReport() { +func (p *pg) forceNextReport(ctx context.Context) { // Updating the next report time to a zero time always reports the LSN, // as time.Now() is always after the empty time. p.nextReportTime = time.Time{} + p.report(ctx, true) } // report reports the current replication slot's LSN progress to the server. We can optionally @@ -384,6 +438,9 @@ func (p *pg) LSN() (lsn pglogrepl.LSN) { } func (p *pg) walMode(ctx context.Context) (string, error) { + p.queryLock.Lock() + defer p.queryLock.Unlock() + var mode string row := p.queryConn.QueryRow(ctx, "SHOW wal_level") err := row.Scan(&mode) @@ -405,15 +462,24 @@ type ReplicationSlot struct { func ReplicationSlotData(ctx context.Context, conn *pgx.Conn) (ReplicationSlot, error) { ret := ReplicationSlot{} - row := conn.QueryRow( + rows, err := conn.Query( ctx, fmt.Sprintf(`SELECT - active, restart_lsn, confirmed_flush_lsn - FROM pg_replication_slots WHERE slot_name = '%s';`, + active, restart_lsn, confirmed_flush_lsn + FROM pg_replication_slots WHERE slot_name = '%s';`, pgconsts.SlotName, ), ) - err := row.Scan(&ret.Active, &ret.RestartLSN, &ret.ConfirmedFlushLSN) + defer rows.Close() + if err != nil { + return ReplicationSlot{}, err + } + + if !rows.Next() { + return ReplicationSlot{}, ErrReplicationSlotNotFound + } + + err = rows.Scan(&ret.Active, &ret.RestartLSN, &ret.ConfirmedFlushLSN) // pgx has its own ErrNoRows :( if errors.Is(err, sql.ErrNoRows) || errors.Is(err, pgx.ErrNoRows) { return ret, ErrReplicationSlotNotFound diff --git a/pkg/replicator/pgreplicator/pg_test.go b/pkg/replicator/pgreplicator/pg_test.go index 900d071..60004bc 100644 --- a/pkg/replicator/pgreplicator/pg_test.go +++ b/pkg/replicator/pgreplicator/pg_test.go @@ -214,6 +214,53 @@ func TestInsert(t *testing.T) { } } +func TestLogicalEmitHeartbeat(t *testing.T) { + t.Parallel() + versions := []int{14, 15, 16} + + for _, v1 := range versions { + v := v1 // loop capture + t.Run(fmt.Sprintf("EmitHeartbeat - Postgres %d", v), func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + + c, conn := test.StartPG(t, ctx, test.StartPGOpts{Version: v}) + opts := Opts{Config: conn} + repl, err := New(ctx, opts) + + // heartbeat fast in tests. + r := repl.(*pg) + r.heartbeatTime = 250 * time.Millisecond + require.NoError(t, err) + + cb := eventwriter.NewCallbackWriter(ctx, 1, time.Millisecond, func(batch []*changeset.Changeset) error { + return nil + }) + csChan := cb.Listen(ctx, r) + + go func() { + err := r.Pull(ctx, csChan) + require.NoError(t, err) + }() + + slotA, err := r.ReplicationSlot(ctx) + require.NoError(t, err) + + <-time.After(1100 * time.Millisecond) + + slotB, err := r.ReplicationSlot(ctx) + require.NoError(t, err) + + require.NotEqual(t, slotA.ConfirmedFlushLSN, slotB.ConfirmedFlushLSN) + require.True(t, int(slotB.ConfirmedFlushLSN) > int(slotA.ConfirmedFlushLSN)) + + cancel() + _ = c.Stop(ctx, nil) + }) + } +} + func TestUpdateMany_ReplicaIdentityFull(t *testing.T) { t.Parallel() versions := []int{12, 13, 14, 15, 16} diff --git a/pkg/replicator/pgreplicator/txn_unwrapper.go b/pkg/replicator/pgreplicator/txn_unwrapper.go index 60062f5..1906d2d 100644 --- a/pkg/replicator/pgreplicator/txn_unwrapper.go +++ b/pkg/replicator/pgreplicator/txn_unwrapper.go @@ -30,6 +30,12 @@ func (t *txnUnwrapper) Process(cs *changeset.Changeset) { } switch cs.Operation { + case changeset.OperationHeartbeat: + // The unwrapper should never receive heartbeats as the replicator should + // handle them and short circuit. However, always transmit them immediately + // for safety in code in case someone changes something in the future. + t.cc <- cs + return case changeset.OperationBegin: t.begin = cs case changeset.OperationCommit: