From 24be6e30bea84b254ae602aed54f490794b0cd4f Mon Sep 17 00:00:00 2001 From: Shayon Mukherjee Date: Sat, 16 Nov 2024 12:06:43 -0500 Subject: [PATCH] Simple factory pattern (#51) --- cmd/root.go | 17 ++- internal/scripts/e2e_test_local.sh | 6 +- pkg/replicator/base_replicator.go | 11 +- pkg/replicator/factory.go | 111 ++++++++++++++++++ pkg/replicator/replicator.go | 72 ------------ pkg/replicator/tests/base_replicator_test.go | 37 ++++-- .../tests/copy_and_stream_replicator_test.go | 6 +- pkg/replicator/tests/ddl_replicator_test.go | 7 +- pkg/sinks/postgres.go | 32 ++--- pkg/utils/shared_types.go | 23 ++++ pkg/utils/zerolog_logger.go | 95 +++++++++++++++ pkg/worker/worker.go | 4 +- 12 files changed, 297 insertions(+), 124 deletions(-) create mode 100644 pkg/replicator/factory.go delete mode 100644 pkg/replicator/replicator.go create mode 100644 pkg/utils/zerolog_logger.go diff --git a/cmd/root.go b/cmd/root.go index 3e5ca86..285a877 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -280,7 +280,22 @@ func runReplicator(_ *cobra.Command, _ []string) { maxCopyWorkersPerTable := viper.GetInt("max-copy-workers-per-table") - rep, err := replicator.NewReplicator(config, natsClient, copyAndStream, copyMode, maxCopyWorkersPerTable) + var factory replicator.Factory + + if copyMode { + factory = &replicator.CopyAndStreamReplicatorFactory{ + MaxCopyWorkersPerTable: maxCopyWorkersPerTable, + CopyOnly: true, + } + } else if copyAndStream { + factory = &replicator.CopyAndStreamReplicatorFactory{ + MaxCopyWorkersPerTable: maxCopyWorkersPerTable, + } + } else { + factory = &replicator.StreamReplicatorFactory{} + } + + rep, err := factory.CreateReplicator(config, natsClient) if err != nil { log.Fatal().Err(err).Msg("Failed to create replicator") } diff --git a/internal/scripts/e2e_test_local.sh b/internal/scripts/e2e_test_local.sh index 8cd3bae..596e858 100755 --- a/internal/scripts/e2e_test_local.sh +++ b/internal/scripts/e2e_test_local.sh @@ -33,9 +33,9 @@ make build setup_docker -log "Running e2e transform filter tests..." -if CI=false ./internal/scripts/e2e_transform_filter.sh; then - success "e2e transform filter tests completed successfully" +log "Running e2e postgres tests..." +if CI=false ./internal/scripts/e2e_postgres.sh; then + success "e2e postgres tests completed successfully" else error "Original e2e tests failed" exit 1 diff --git a/pkg/replicator/base_replicator.go b/pkg/replicator/base_replicator.go index b38f737..6dade52 100644 --- a/pkg/replicator/base_replicator.go +++ b/pkg/replicator/base_replicator.go @@ -3,7 +3,6 @@ package replicator import ( "context" "fmt" - "os" "strings" "time" @@ -13,15 +12,9 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgproto3" "github.com/pgflo/pg_flo/pkg/utils" - "github.com/rs/zerolog" "github.com/rs/zerolog/log" ) -func init() { - log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: "15:04:05.000"}) - zerolog.TimeFieldFormat = "2006-01-02T15:04:05.000Z07:00" -} - // GeneratePublicationName generates a deterministic publication name based on the group name func GeneratePublicationName(group string) string { group = strings.ReplaceAll(group, "-", "_") @@ -34,7 +27,7 @@ type BaseReplicator struct { ReplicationConn ReplicationConnection StandardConn StandardConnection Relations map[uint32]*pglogrepl.RelationMessage - Logger zerolog.Logger + Logger utils.Logger TableDetails map[string][]string LastLSN pglogrepl.LSN NATSClient NATSClient @@ -47,7 +40,7 @@ func NewBaseReplicator(config Config, replicationConn ReplicationConnection, sta config.Schema = "public" } - logger := log.With().Str("component", "replicator").Logger() + logger := utils.NewZerologLogger(log.With().Str("component", "replicator").Logger()) br := &BaseReplicator{ Config: config, diff --git a/pkg/replicator/factory.go b/pkg/replicator/factory.go new file mode 100644 index 0000000..974a7b0 --- /dev/null +++ b/pkg/replicator/factory.go @@ -0,0 +1,111 @@ +package replicator + +import ( + "context" + "fmt" +) + +// ReplicatorFactory defines the interface for creating replicators +type Factory interface { + CreateReplicator(config Config, natsClient NATSClient) (Replicator, error) +} + +// BaseFactory provides common functionality for factories +type BaseFactory struct{} + +// CreateConnections creates replication and standard connections +func (f *BaseFactory) CreateConnections(config Config) (ReplicationConnection, StandardConnection, error) { + replicationConn := NewReplicationConnection(config) + if err := replicationConn.Connect(context.Background()); err != nil { + return nil, nil, fmt.Errorf("failed to connect for replication: %v", err) + } + + standardConn, err := NewStandardConnection(config) + if err != nil { + return nil, nil, fmt.Errorf("failed to create standard connection: %v", err) + } + + return replicationConn, standardConn, nil +} + +// StreamReplicatorFactory creates `StreamReplicator` instances +type StreamReplicatorFactory struct { + BaseFactory +} + +// CreateReplicator creates a new `StreamReplicator` +func (f *StreamReplicatorFactory) CreateReplicator(config Config, natsClient NATSClient) (Replicator, error) { + replicationConn, standardConn, err := f.CreateConnections(config) + if err != nil { + return nil, err + } + + baseReplicator := NewBaseReplicator(config, replicationConn, standardConn, natsClient) + + var ddlReplicator *DDLReplicator + if config.TrackDDL { + ddlConn, err := NewStandardConnection(config) + if err != nil { + return nil, fmt.Errorf("failed to create DDL connection: %v", err) + } + ddlReplicator, err = NewDDLReplicator(config, baseReplicator, ddlConn) + if err != nil { + return nil, fmt.Errorf("failed to create DDL replicator: %v", err) + } + } + + streamReplicator := &StreamReplicator{ + BaseReplicator: *baseReplicator, + } + + if ddlReplicator != nil { + streamReplicator.DDLReplicator = *ddlReplicator + } + + return streamReplicator, nil +} + +// CopyAndStreamReplicatorFactory creates `CopyAndStreamReplicator` instances +type CopyAndStreamReplicatorFactory struct { + BaseFactory + MaxCopyWorkersPerTable int + CopyOnly bool +} + +// CreateReplicator creates a new `CopyAndStreamReplicator` +func (f *CopyAndStreamReplicatorFactory) CreateReplicator(config Config, natsClient NATSClient) (Replicator, error) { + replicationConn, standardConn, err := f.CreateConnections(config) + if err != nil { + return nil, err + } + + baseReplicator := NewBaseReplicator(config, replicationConn, standardConn, natsClient) + + var ddlReplicator *DDLReplicator + if config.TrackDDL { + ddlConn, err := NewStandardConnection(config) + if err != nil { + return nil, fmt.Errorf("failed to create DDL connection: %v", err) + } + ddlReplicator, err = NewDDLReplicator(config, baseReplicator, ddlConn) + if err != nil { + return nil, fmt.Errorf("failed to create DDL replicator: %v", err) + } + } + + if f.MaxCopyWorkersPerTable <= 0 { + f.MaxCopyWorkersPerTable = 4 + } + + copyAndStreamReplicator := &CopyAndStreamReplicator{ + BaseReplicator: *baseReplicator, + MaxCopyWorkersPerTable: f.MaxCopyWorkersPerTable, + CopyOnly: f.CopyOnly, + } + + if ddlReplicator != nil { + copyAndStreamReplicator.DDLReplicator = *ddlReplicator + } + + return copyAndStreamReplicator, nil +} diff --git a/pkg/replicator/replicator.go b/pkg/replicator/replicator.go deleted file mode 100644 index b2a7abd..0000000 --- a/pkg/replicator/replicator.go +++ /dev/null @@ -1,72 +0,0 @@ -package replicator - -import ( - "context" - "fmt" - - "github.com/pgflo/pg_flo/pkg/pgflonats" -) - -// NewReplicator creates a new Replicator based on the configuration -func NewReplicator(config Config, natsClient *pgflonats.NATSClient, copyAndStream bool, copyOnly bool, maxCopyWorkersPerTable int) (Replicator, error) { - - replicationConn := NewReplicationConnection(config) - if err := replicationConn.Connect(context.Background()); err != nil { - return nil, fmt.Errorf("failed to connect to database for replication: %v", err) - } - - standardConn, err := NewStandardConnection(config) - if err != nil { - return nil, fmt.Errorf("failed to create standard connection: %v", err) - } - - baseReplicator := NewBaseReplicator(config, replicationConn, standardConn, natsClient) - - var ddlReplicator *DDLReplicator - if config.TrackDDL { - ddlConn, err := NewStandardConnection(config) - if err != nil { - return nil, fmt.Errorf("failed to create DDL connection: %v", err) - } - ddlReplicator, err = NewDDLReplicator(config, baseReplicator, ddlConn) - if err != nil { - return nil, fmt.Errorf("failed to create DDL replicator: %v", err) - } - } - - if copyOnly { - copyOnlyReplicator := &CopyAndStreamReplicator{ - BaseReplicator: *baseReplicator, - MaxCopyWorkersPerTable: maxCopyWorkersPerTable, - CopyOnly: true, - } - if ddlReplicator != nil { - copyOnlyReplicator.DDLReplicator = *ddlReplicator - } - return copyOnlyReplicator, nil - } - - if copyAndStream { - if maxCopyWorkersPerTable <= 0 { - maxCopyWorkersPerTable = 4 - } - copyAndStreamReplicator := &CopyAndStreamReplicator{ - BaseReplicator: *baseReplicator, - MaxCopyWorkersPerTable: maxCopyWorkersPerTable, - } - if ddlReplicator != nil { - copyAndStreamReplicator.DDLReplicator = *ddlReplicator - } - return copyAndStreamReplicator, nil - } - - streamReplicator := &StreamReplicator{ - BaseReplicator: *baseReplicator, - } - - if ddlReplicator != nil { - streamReplicator.DDLReplicator = *ddlReplicator - } - - return streamReplicator, nil -} diff --git a/pkg/replicator/tests/base_replicator_test.go b/pkg/replicator/tests/base_replicator_test.go index ceeb04b..bbe46a4 100644 --- a/pkg/replicator/tests/base_replicator_test.go +++ b/pkg/replicator/tests/base_replicator_test.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "io/ioutil" "strconv" "strings" "testing" @@ -90,7 +89,7 @@ func TestBaseReplicator(t *testing.T) { br := &replicator.BaseReplicator{ Config: replicator.Config{Group: "existing_pub"}, StandardConn: mockStandardConn, - Logger: zerolog.Nop(), + Logger: utils.NewZerologLogger(zerolog.New(nil)), } err := br.CreatePublication() @@ -121,7 +120,7 @@ func TestBaseReplicator(t *testing.T) { Tables: []string{"users", "orders"}, }, StandardConn: mockStandardConn, - Logger: zerolog.Nop(), + Logger: utils.NewZerologLogger(zerolog.New(nil)), } err := br.CreatePublication() @@ -141,7 +140,7 @@ func TestBaseReplicator(t *testing.T) { br := &replicator.BaseReplicator{ Config: replicator.Config{Group: "error_pub"}, StandardConn: mockStandardConn, - Logger: zerolog.Nop(), + Logger: utils.NewZerologLogger(zerolog.New(nil)), } err := br.CreatePublication() @@ -194,7 +193,7 @@ func TestBaseReplicator(t *testing.T) { Schema: "public", }, StandardConn: mockStandardConn, - Logger: zerolog.Nop(), + Logger: utils.NewZerologLogger(zerolog.New(nil)), } err := br.CreatePublication() @@ -255,7 +254,7 @@ func TestBaseReplicator(t *testing.T) { Schema: "public", }, StandardConn: mockStandardConn, - Logger: zerolog.Nop(), + Logger: utils.NewZerologLogger(zerolog.New(nil)), } err := br.CreatePublication() @@ -305,7 +304,7 @@ func TestBaseReplicator(t *testing.T) { StandardConn: mockStandardConn, NATSClient: mockNATSClient, Config: replicator.Config{Group: "test_pub"}, - Logger: zerolog.Nop(), + Logger: utils.NewZerologLogger(zerolog.New(nil)), } ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) @@ -329,7 +328,7 @@ func TestBaseReplicator(t *testing.T) { ReplicationConn: mockReplicationConn, StandardConn: mockStandardConn, Config: replicator.Config{Group: "test_pub"}, - Logger: zerolog.Nop(), + Logger: utils.NewZerologLogger(zerolog.New(nil)), } stopChan := make(chan struct{}) @@ -346,6 +345,7 @@ func TestBaseReplicator(t *testing.T) { mockNATSClient := new(MockNATSClient) br := &replicator.BaseReplicator{ NATSClient: mockNATSClient, + Logger: utils.NewZerologLogger(zerolog.New(nil)), Relations: map[uint32]*pglogrepl.RelationMessage{ 1: { RelationID: 1, @@ -397,6 +397,7 @@ func TestBaseReplicator(t *testing.T) { t.Run("Unknown relation ID", func(t *testing.T) { br := &replicator.BaseReplicator{ Relations: make(map[uint32]*pglogrepl.RelationMessage), + Logger: utils.NewZerologLogger(zerolog.New(nil)), } msg := &pglogrepl.InsertMessage{RelationID: 999} @@ -514,6 +515,7 @@ func TestBaseReplicator(t *testing.T) { br := &replicator.BaseReplicator{ NATSClient: mockNATSClient, + Logger: utils.NewZerologLogger(zerolog.New(nil)), Relations: map[uint32]*pglogrepl.RelationMessage{ 1: { RelationID: 1, @@ -606,6 +608,7 @@ func TestBaseReplicator(t *testing.T) { mockNATSClient := new(MockNATSClient) br := &replicator.BaseReplicator{ NATSClient: mockNATSClient, + Logger: utils.NewZerologLogger(zerolog.New(nil)), Relations: map[uint32]*pglogrepl.RelationMessage{ 1: { RelationID: 1, @@ -661,6 +664,7 @@ func TestBaseReplicator(t *testing.T) { t.Run("Unknown relation ID", func(t *testing.T) { br := &replicator.BaseReplicator{ + Logger: utils.NewZerologLogger(zerolog.New(nil)), Relations: make(map[uint32]*pglogrepl.RelationMessage), } @@ -676,6 +680,7 @@ func TestBaseReplicator(t *testing.T) { br := &replicator.BaseReplicator{ NATSClient: mockNATSClient, + Logger: utils.NewZerologLogger(zerolog.New(nil)), Relations: map[uint32]*pglogrepl.RelationMessage{ 1: { RelationID: 1, @@ -735,6 +740,7 @@ func TestBaseReplicator(t *testing.T) { mockNATSClient := new(MockNATSClient) br := &replicator.BaseReplicator{ NATSClient: mockNATSClient, + Logger: utils.NewZerologLogger(zerolog.New(nil)), Relations: map[uint32]*pglogrepl.RelationMessage{ 1: { RelationID: 1, @@ -801,7 +807,7 @@ func TestBaseReplicator(t *testing.T) { br := &replicator.BaseReplicator{ NATSClient: mockNATSClient, - Logger: zerolog.Nop(), + Logger: utils.NewZerologLogger(zerolog.New(nil)), } msg := &pglogrepl.CommitMessage{ @@ -826,6 +832,7 @@ func TestBaseReplicator(t *testing.T) { br := &replicator.BaseReplicator{ NATSClient: mockNATSClient, + Logger: utils.NewZerologLogger(zerolog.New(nil)), Config: replicator.Config{ Group: "test_group", }, @@ -884,6 +891,7 @@ func TestBaseReplicator(t *testing.T) { br := &replicator.BaseReplicator{ NATSClient: mockNATSClient, + Logger: utils.NewZerologLogger(zerolog.New(nil)), Config: replicator.Config{ Group: "test_group", }, @@ -972,6 +980,7 @@ func TestBaseReplicator(t *testing.T) { br := &replicator.BaseReplicator{ ReplicationConn: mockReplicationConn, + Logger: utils.NewZerologLogger(zerolog.New(nil)), } err := br.SendStandbyStatusUpdate(context.Background()) @@ -985,8 +994,8 @@ func TestBaseReplicator(t *testing.T) { br := &replicator.BaseReplicator{ ReplicationConn: mockReplicationConn, + Logger: utils.NewZerologLogger(zerolog.New(nil)), LastLSN: pglogrepl.LSN(100), - Logger: zerolog.New(ioutil.Discard), } mockReplicationConn.On("SendStandbyStatusUpdate", @@ -1017,6 +1026,7 @@ func TestBaseReplicator(t *testing.T) { br := &replicator.BaseReplicator{ StandardConn: mockStandardConn, ReplicationConn: mockReplicationConn, + Logger: utils.NewZerologLogger(zerolog.New(nil)), Config: replicator.Config{ Group: "test_group", }, @@ -1054,7 +1064,7 @@ func TestBaseReplicator(t *testing.T) { Config: replicator.Config{ Group: "test_group", }, - Logger: zerolog.Nop(), + Logger: utils.NewZerologLogger(zerolog.New(nil)), } err := br.CreateReplicationSlot(context.Background()) @@ -1085,7 +1095,7 @@ func TestBaseReplicator(t *testing.T) { Config: replicator.Config{ Group: "test_group", }, - Logger: zerolog.Nop(), + Logger: utils.NewZerologLogger(zerolog.New(nil)), } err := br.CreateReplicationSlot(context.Background()) @@ -1109,6 +1119,7 @@ func TestBaseReplicator(t *testing.T) { br := &replicator.BaseReplicator{ StandardConn: mockStandardConn, + Logger: utils.NewZerologLogger(zerolog.New(nil)), } exists, err := br.CheckReplicationSlotExists("test_slot") @@ -1129,6 +1140,7 @@ func TestBaseReplicator(t *testing.T) { br := &replicator.BaseReplicator{ StandardConn: mockStandardConn, + Logger: utils.NewZerologLogger(zerolog.New(nil)), } exists, err := br.CheckReplicationSlotExists("test_slot") @@ -1148,6 +1160,7 @@ func TestBaseReplicator(t *testing.T) { br := &replicator.BaseReplicator{ StandardConn: mockStandardConn, + Logger: utils.NewZerologLogger(zerolog.New(nil)), } _, err := br.CheckReplicationSlotExists("test_slot") diff --git a/pkg/replicator/tests/copy_and_stream_replicator_test.go b/pkg/replicator/tests/copy_and_stream_replicator_test.go index 4368279..9609ab0 100644 --- a/pkg/replicator/tests/copy_and_stream_replicator_test.go +++ b/pkg/replicator/tests/copy_and_stream_replicator_test.go @@ -94,7 +94,7 @@ func TestCopyAndStreamReplicator(t *testing.T) { BaseReplicator: replicator.BaseReplicator{ StandardConn: mockStandardConn, NATSClient: mockNATSClient, - Logger: zerolog.Nop(), + Logger: utils.NewZerologLogger(zerolog.New(nil)), Config: replicator.Config{ Tables: []string{"users"}, Schema: "public", @@ -165,7 +165,7 @@ func TestCopyAndStreamReplicator(t *testing.T) { BaseReplicator: replicator.BaseReplicator{ StandardConn: mockStandardConn, NATSClient: mockNATSClient, - Logger: zerolog.Nop(), + Logger: utils.NewZerologLogger(zerolog.New(nil)), Config: replicator.Config{ Tables: []string{"users"}, Schema: "public", @@ -339,7 +339,7 @@ func TestCopyAndStreamReplicator(t *testing.T) { BaseReplicator: replicator.BaseReplicator{ StandardConn: mockStandardConn, NATSClient: mockNATSClient, - Logger: zerolog.Nop(), + Logger: utils.NewZerologLogger(zerolog.New(nil)), Config: replicator.Config{ Tables: []string{"test_table"}, Schema: "public", diff --git a/pkg/replicator/tests/ddl_replicator_test.go b/pkg/replicator/tests/ddl_replicator_test.go index 2ec7fd7..c62a1d5 100644 --- a/pkg/replicator/tests/ddl_replicator_test.go +++ b/pkg/replicator/tests/ddl_replicator_test.go @@ -8,6 +8,7 @@ import ( "github.com/jackc/pgx/v5/pgconn" "github.com/pgflo/pg_flo/pkg/replicator" + "github.com/pgflo/pg_flo/pkg/utils" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -16,7 +17,7 @@ import ( func TestDDLReplicator(t *testing.T) { t.Run("NewDDLReplicator", func(t *testing.T) { mockBaseReplicator := &replicator.BaseReplicator{ - Logger: zerolog.Logger{}, + Logger: utils.NewZerologLogger(zerolog.New(nil)), } mockStandardConn := &MockStandardConnection{} config := replicator.Config{} @@ -32,7 +33,7 @@ func TestDDLReplicator(t *testing.T) { t.Run("SetupDDLTracking", func(t *testing.T) { mockStandardConn := &MockStandardConnection{} mockBaseRepl := &replicator.BaseReplicator{ - Logger: zerolog.New(zerolog.NewConsoleWriter()).With().Timestamp().Logger(), + Logger: utils.NewZerologLogger(zerolog.New(zerolog.NewConsoleWriter()).With().Timestamp().Logger()), StandardConn: mockStandardConn, Config: replicator.Config{ Schema: "public", @@ -65,7 +66,7 @@ func TestDDLReplicator(t *testing.T) { t.Run("StartDDLReplication", func(t *testing.T) { mockStandardConn := &MockStandardConnection{} mockBaseReplicator := &replicator.BaseReplicator{ - Logger: zerolog.New(zerolog.NewConsoleWriter()).With().Timestamp().Logger(), + Logger: utils.NewZerologLogger(zerolog.New(zerolog.NewConsoleWriter()).With().Timestamp().Logger()), } ddlReplicator := &replicator.DDLReplicator{ DDLConn: mockStandardConn, diff --git a/pkg/sinks/postgres.go b/pkg/sinks/postgres.go index 9fe799a..c6a8534 100644 --- a/pkg/sinks/postgres.go +++ b/pkg/sinks/postgres.go @@ -13,23 +13,16 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/pgflo/pg_flo/pkg/utils" - "github.com/rs/zerolog" "github.com/rs/zerolog/log" ) -func init() { - log.Logger = log.Output(zerolog.ConsoleWriter{ - Out: os.Stderr, - TimeFormat: "15:04:05.000", - }) -} - // PostgresSink represents a sink for PostgreSQL database type PostgresSink struct { conn *pgx.Conn disableForeignKeyChecks bool connConfig *pgx.ConnConfig retryConfig utils.RetryConfig + logger utils.Logger } // NewPostgresSink creates a new PostgresSink instance @@ -42,6 +35,7 @@ func NewPostgresSink(targetHost string, targetPort int, targetDBName, targetUser sink := &PostgresSink{ connConfig: connConfig, disableForeignKeyChecks: disableForeignKeyChecks, + logger: utils.NewZerologLogger(log.With().Str("component", "postgres_sink").Logger()), retryConfig: utils.RetryConfig{ MaxAttempts: 5, InitialWait: 1 * time.Second, @@ -69,7 +63,7 @@ func (s *PostgresSink) connect(ctx context.Context) error { return utils.WithRetry(ctx, s.retryConfig, func() error { conn, err := pgx.ConnectConfig(ctx, s.connConfig) if err != nil { - log.Error().Err(err).Msg("Failed to connect to database, will retry") + s.logger.Error().Err(err).Msg("Failed to connect to database, will retry") return err } connMutex.Lock() @@ -91,7 +85,7 @@ func (s *PostgresSink) syncSchema(sourceHost string, sourcePort int, sourceDBNam ) schemaDump, err := dumpCmd.Output() if err != nil { - log.Error().Err(err).Msg("Failed to dump schema from source database") + s.logger.Error().Err(err).Msg("Failed to dump schema from source database") return fmt.Errorf("failed to dump schema from source database: %v", err) } @@ -220,7 +214,7 @@ func (s *PostgresSink) handleDelete(tx pgx.Tx, message *utils.CDCMessage) error } if result.RowsAffected() == 0 { - log.Warn(). + s.logger.Warn(). Str("table", message.Table). Str("query", query). Interface("values", whereValues). @@ -262,7 +256,7 @@ func (s *PostgresSink) handleUpdate(tx pgx.Tx, message *utils.CDCMessage) error } if len(setClauses) == 0 { - log.Debug().Msg("No columns to update, skipping") + s.logger.Debug().Msg("No columns to update, skipping") return nil } @@ -287,7 +281,7 @@ func (s *PostgresSink) handleUpdate(tx pgx.Tx, message *utils.CDCMessage) error } if result.RowsAffected() == 0 { - log.Warn(). + s.logger.Warn(). Str("table", message.Table). Str("query", query). Interface("values", values). @@ -320,7 +314,7 @@ func (s *PostgresSink) handleDDL(tx pgx.Tx, message *utils.CDCMessage) (pgx.Tx, return tx, fmt.Errorf("DDL command is not a string") } - log.Debug().Msgf("Executing DDL: %s", ddlString) + s.logger.Debug().Msgf("Executing DDL: %s", ddlString) if strings.Contains(strings.ToUpper(ddlString), "CONCURRENTLY") { if err := tx.Commit(context.Background()); err != nil { @@ -329,7 +323,7 @@ func (s *PostgresSink) handleDDL(tx pgx.Tx, message *utils.CDCMessage) (pgx.Tx, if _, err := s.conn.Exec(context.Background(), ddlString); err != nil { if strings.Contains(err.Error(), "does not exist") { - log.Warn().Msgf("Ignoring DDL for non-existent object: %s", ddlString) + s.logger.Warn().Msgf("Ignoring DDL for non-existent object: %s", ddlString) return s.conn.Begin(context.Background()) } return nil, fmt.Errorf("failed to execute concurrent DDL: %v", err) @@ -341,7 +335,7 @@ func (s *PostgresSink) handleDDL(tx pgx.Tx, message *utils.CDCMessage) (pgx.Tx, _, err = tx.Exec(context.Background(), ddlString) if err != nil { if strings.Contains(err.Error(), "does not exist") { - log.Warn().Msgf("Ignoring DDL for non-existent object: %s", ddlString) + s.logger.Warn().Msgf("Ignoring DDL for non-existent object: %s", ddlString) return tx, nil } return tx, fmt.Errorf("failed to execute DDL: %v", err) @@ -385,7 +379,7 @@ func (s *PostgresSink) writeBatchInternal(ctx context.Context, messages []*utils defer func() { if tx != nil { if err := tx.Rollback(ctx); err != nil { - log.Error().Err(err).Msg("failed to rollback transaction") + s.logger.Error().Err(err).Msg("failed to rollback transaction") } } }() @@ -396,7 +390,7 @@ func (s *PostgresSink) writeBatchInternal(ctx context.Context, messages []*utils } defer func() { if err := s.enableForeignKeys(ctx); err != nil { - log.Error().Err(err).Msg("failed to re-enable foreign key checks") + s.logger.Error().Err(err).Msg("failed to re-enable foreign key checks") } }() } @@ -442,7 +436,7 @@ func (s *PostgresSink) writeBatchInternal(ctx context.Context, messages []*utils if err != nil || operationErr != nil { if tx != nil { if rollbackErr := tx.Rollback(ctx); rollbackErr != nil { - log.Error().Err(rollbackErr).Msg("failed to rollback transaction") + s.logger.Error().Err(rollbackErr).Msg("failed to rollback transaction") } } tx = nil diff --git a/pkg/utils/shared_types.go b/pkg/utils/shared_types.go index 71caba8..aee6dc0 100644 --- a/pkg/utils/shared_types.go +++ b/pkg/utils/shared_types.go @@ -24,3 +24,26 @@ type ReplicationKey struct { Type ReplicationKeyType Columns []string } + +type Logger interface { + Debug() LogEvent + Info() LogEvent + Warn() LogEvent + Error() LogEvent + Err(err error) LogEvent +} + +type LogEvent interface { + Str(key, val string) LogEvent + Int(key string, val int) LogEvent + Int64(key string, val int64) LogEvent + Uint8(key string, val uint8) LogEvent + Uint32(key string, val uint32) LogEvent + Interface(key string, val interface{}) LogEvent + Err(err error) LogEvent + Strs(key string, vals []string) LogEvent + Any(key string, val interface{}) LogEvent + Type(key string, val interface{}) LogEvent + Msg(msg string) + Msgf(format string, v ...interface{}) +} diff --git a/pkg/utils/zerolog_logger.go b/pkg/utils/zerolog_logger.go new file mode 100644 index 0000000..b5a843a --- /dev/null +++ b/pkg/utils/zerolog_logger.go @@ -0,0 +1,95 @@ +package utils + +import ( + "github.com/rs/zerolog" +) + +type ZerologLogger struct { + logger zerolog.Logger +} + +func NewZerologLogger(logger zerolog.Logger) Logger { + return &ZerologLogger{logger: logger} +} + +type ZerologLogEvent struct { + event *zerolog.Event +} + +func (z *ZerologLogger) Debug() LogEvent { + return &ZerologLogEvent{event: z.logger.Debug()} +} + +func (z *ZerologLogger) Info() LogEvent { + return &ZerologLogEvent{event: z.logger.Info()} +} + +func (z *ZerologLogger) Warn() LogEvent { + return &ZerologLogEvent{event: z.logger.Warn()} +} + +func (z *ZerologLogger) Error() LogEvent { + return &ZerologLogEvent{event: z.logger.Error()} +} + +func (z *ZerologLogger) Err(err error) LogEvent { + return &ZerologLogEvent{event: z.logger.Err(err)} +} + +func (e *ZerologLogEvent) Str(key, val string) LogEvent { + e.event = e.event.Str(key, val) + return e +} + +func (e *ZerologLogEvent) Int(key string, val int) LogEvent { + e.event = e.event.Int(key, val) + return e +} + +func (e *ZerologLogEvent) Int64(key string, val int64) LogEvent { + e.event = e.event.Int64(key, val) + return e +} + +func (e *ZerologLogEvent) Uint32(key string, val uint32) LogEvent { + e.event = e.event.Uint32(key, val) + return e +} + +func (e *ZerologLogEvent) Interface(key string, val interface{}) LogEvent { + e.event = e.event.Interface(key, val) + return e +} + +func (e *ZerologLogEvent) Err(err error) LogEvent { + e.event = e.event.Err(err) + return e +} + +func (e *ZerologLogEvent) Msg(msg string) { + e.event.Msg(msg) +} + +func (e *ZerologLogEvent) Msgf(format string, v ...interface{}) { + e.event.Msgf(format, v...) +} + +func (e *ZerologLogEvent) Strs(key string, vals []string) LogEvent { + e.event = e.event.Strs(key, vals) + return e +} + +func (e *ZerologLogEvent) Any(key string, val interface{}) LogEvent { + e.event = e.event.Interface(key, val) + return e +} + +func (e *ZerologLogEvent) Uint8(key string, val uint8) LogEvent { + e.event = e.event.Uint8(key, val) + return e +} + +func (e *ZerologLogEvent) Type(key string, val interface{}) LogEvent { + e.event = e.event.Type(key, val) + return e +} diff --git a/pkg/worker/worker.go b/pkg/worker/worker.go index 5c10d45..ad95e27 100644 --- a/pkg/worker/worker.go +++ b/pkg/worker/worker.go @@ -25,7 +25,7 @@ type Worker struct { router *routing.Router sink sinks.Sink group string - logger zerolog.Logger + logger utils.Logger batchSize int buffer []*utils.CDCMessage lastSavedState uint64 @@ -51,7 +51,7 @@ func init() { // NewWorker creates and returns a new Worker instance with the provided NATS client func NewWorker(natsClient *pgflonats.NATSClient, ruleEngine *rules.RuleEngine, router *routing.Router, sink sinks.Sink, group string, opts ...Option) *Worker { - logger := log.With().Str("component", "worker").Logger() + logger := utils.NewZerologLogger(log.With().Str("component", "worker").Logger()) w := &Worker{ natsClient: natsClient,