diff --git a/yb-voyager/src/srcdb/mysql.go b/yb-voyager/src/srcdb/mysql.go index 3090682d18..3a5e63b2b4 100644 --- a/yb-voyager/src/srcdb/mysql.go +++ b/yb-voyager/src/srcdb/mysql.go @@ -46,6 +46,18 @@ func newMySQL(s *Source) *MySQL { } func (ms *MySQL) Connect() error { + if ms.db != nil { + err := ms.db.Ping() + if err == nil { + log.Infof("Already connected to the source database") + return nil + } else { + log.Infof("Failed to ping the source database: %s", err) + ms.Disconnect() + } + log.Info("Reconnecting to the source database") + } + db, err := sql.Open("mysql", ms.getConnectionUri()) db.SetMaxOpenConns(ms.source.NumConnections) db.SetConnMaxIdleTime(5 * time.Minute) diff --git a/yb-voyager/src/srcdb/mysql_test.go b/yb-voyager/src/srcdb/mysql_test.go index db74724162..d1dce03f74 100644 --- a/yb-voyager/src/srcdb/mysql_test.go +++ b/yb-voyager/src/srcdb/mysql_test.go @@ -46,6 +46,7 @@ func TestMysqlGetAllTableNames(t *testing.T) { testMySQLSource.Source.DBName = "test" // used in query of GetAllTableNames() // Test GetAllTableNames + _ = testMySQLSource.DB().Connect() actualTables := testMySQLSource.DB().GetAllTableNames() expectedTables := []*sqlname.SourceName{ sqlname.NewSourceName("test", "foo"), diff --git a/yb-voyager/src/srcdb/oracle.go b/yb-voyager/src/srcdb/oracle.go index 32de02bfa7..0515b9e095 100644 --- a/yb-voyager/src/srcdb/oracle.go +++ b/yb-voyager/src/srcdb/oracle.go @@ -46,6 +46,18 @@ func newOracle(s *Source) *Oracle { } func (ora *Oracle) Connect() error { + if ora.db != nil { + err := ora.db.Ping() + if err == nil { + log.Infof("Already connected to the source database") + return nil + } else { + log.Infof("Failed to ping the source database: %s", err) + ora.Disconnect() + } + log.Info("Reconnecting to the source database") + } + db, err := sql.Open("godror", ora.getConnectionUri()) db.SetMaxOpenConns(ora.source.NumConnections) db.SetConnMaxIdleTime(5 * time.Minute) diff --git a/yb-voyager/src/srcdb/oracle_test.go b/yb-voyager/src/srcdb/oracle_test.go index 9369b82575..35d97fb6d8 100644 --- a/yb-voyager/src/srcdb/oracle_test.go +++ b/yb-voyager/src/srcdb/oracle_test.go @@ -51,6 +51,7 @@ func TestOracleGetTableToUniqueKeyColumnsMap(t *testing.T) { tableList := []sqlname.NameTuple{ {CurrentName: objectName}, } + _ = testOracleSource.DB().Connect() uniqueKeys, err := testOracleSource.DB().GetTableToUniqueKeyColumnsMap(tableList) if err != nil { t.Fatalf("Error retrieving unique keys: %v", err) @@ -72,6 +73,7 @@ func TestOracleGetTableToUniqueKeyColumnsMap(t *testing.T) { } func TestOracleGetNonPKTables(t *testing.T) { + _ = testOracleSource.DB().Connect() actualTables, err := testOracleSource.DB().GetNonPKTables() assert.NilError(t, err, "Expected nil but non nil error: %v", err) diff --git a/yb-voyager/src/srcdb/postgres.go b/yb-voyager/src/srcdb/postgres.go index 48406e6abb..917a64f34e 100644 --- a/yb-voyager/src/srcdb/postgres.go +++ b/yb-voyager/src/srcdb/postgres.go @@ -109,6 +109,18 @@ func newPostgreSQL(s *Source) *PostgreSQL { } func (pg *PostgreSQL) Connect() error { + if pg.db != nil { + err := pg.db.Ping() + if err == nil { + log.Infof("Already connected to the source database") + log.Infof("Already connected to the source database") + return nil + } else { + log.Infof("Failed to ping the source database: %s", err) + pg.Disconnect() + } + log.Info("Reconnecting to the source database") + } db, err := sql.Open("pgx", pg.getConnectionUri()) db.SetMaxOpenConns(pg.source.NumConnections) db.SetConnMaxIdleTime(5 * time.Minute) diff --git a/yb-voyager/src/srcdb/postgres_test.go b/yb-voyager/src/srcdb/postgres_test.go index 41ac55d5ac..5d6b63fbff 100644 --- a/yb-voyager/src/srcdb/postgres_test.go +++ b/yb-voyager/src/srcdb/postgres_test.go @@ -48,6 +48,7 @@ func TestPostgresGetAllTableNames(t *testing.T) { testPostgresSource.Source.Schema = "test_schema" // used in query of GetAllTableNames() // Test GetAllTableNames + _ = testPostgresSource.DB().Connect() actualTables := testPostgresSource.DB().GetAllTableNames() expectedTables := []*sqlname.SourceName{ sqlname.NewSourceName("test_schema", "foo"), @@ -86,6 +87,8 @@ func TestPostgresGetTableToUniqueKeyColumnsMap(t *testing.T) { {CurrentName: sqlname.NewObjectName("postgresql", "test_schema", "test_schema", "another_unique_table")}, } + // Test GetTableToUniqueKeyColumnsMap + _ = testPostgresSource.DB().Connect() actualUniqKeys, err := testPostgresSource.DB().GetTableToUniqueKeyColumnsMap(uniqueTablesList) if err != nil { t.Fatalf("Error retrieving unique keys: %v", err) @@ -128,6 +131,8 @@ func TestPostgresGetNonPKTables(t *testing.T) { );`) defer testPostgresSource.TestContainer.ExecuteSqls(`DROP SCHEMA test_schema CASCADE;`) + // Test GetNonPKTables + _ = testPostgresSource.DB().Connect() actualTables, err := testPostgresSource.DB().GetNonPKTables() assert.NilError(t, err, "Expected nil but non nil error: %v", err) diff --git a/yb-voyager/src/srcdb/yugabytedb.go b/yb-voyager/src/srcdb/yugabytedb.go index 3af954ae91..a9bd0c3fa6 100644 --- a/yb-voyager/src/srcdb/yugabytedb.go +++ b/yb-voyager/src/srcdb/yugabytedb.go @@ -51,6 +51,17 @@ func newYugabyteDB(s *Source) *YugabyteDB { } func (yb *YugabyteDB) Connect() error { + if yb.db != nil { + err := yb.db.Ping() + if err == nil { + log.Infof("Already connected to the source database") + return nil + } else { + log.Infof("Failed to ping the source database: %s", err) + yb.Disconnect() + } + log.Info("Reconnecting to the source database") + } db, err := sql.Open("pgx", yb.getConnectionUri()) db.SetMaxOpenConns(yb.source.NumConnections) db.SetConnMaxIdleTime(5 * time.Minute) diff --git a/yb-voyager/src/srcdb/yugbaytedb_test.go b/yb-voyager/src/srcdb/yugbaytedb_test.go index 4a7cf8e6f9..fa8e073f00 100644 --- a/yb-voyager/src/srcdb/yugbaytedb_test.go +++ b/yb-voyager/src/srcdb/yugbaytedb_test.go @@ -48,6 +48,7 @@ func TestYugabyteGetAllTableNames(t *testing.T) { testYugabyteDBSource.Source.Schema = "test_schema" // Test GetAllTableNames + _ = testYugabyteDBSource.DB().Connect() actualTables := testYugabyteDBSource.DB().GetAllTableNames() expectedTables := []*sqlname.SourceName{ sqlname.NewSourceName("test_schema", "foo"), @@ -86,6 +87,8 @@ func TestYugabyteGetTableToUniqueKeyColumnsMap(t *testing.T) { {CurrentName: sqlname.NewObjectName("postgresql", "test_schema", "test_schema", "another_unique_table")}, } + // Test GetTableToUniqueKeyColumnsMap + _ = testYugabyteDBSource.DB().Connect() actualUniqKeys, err := testYugabyteDBSource.DB().GetTableToUniqueKeyColumnsMap(uniqueTablesList) if err != nil { t.Fatalf("Error retrieving unique keys: %v", err) @@ -128,6 +131,8 @@ func TestYugabyteGetNonPKTables(t *testing.T) { );`) defer testYugabyteDBSource.TestContainer.ExecuteSqls(`DROP SCHEMA test_schema CASCADE;`) + // Test GetNonPKTables + _ = testYugabyteDBSource.DB().Connect() actualTables, err := testYugabyteDBSource.DB().GetNonPKTables() assert.NilError(t, err, "Expected nil but non nil error: %v", err) diff --git a/yb-voyager/test/containers/helpers.go b/yb-voyager/test/containers/helpers.go index cd3afee88d..b2fd66a33a 100644 --- a/yb-voyager/test/containers/helpers.go +++ b/yb-voyager/test/containers/helpers.go @@ -2,9 +2,11 @@ package testcontainers import ( "context" + "database/sql" _ "embed" "fmt" "io" + "time" log "github.com/sirupsen/logrus" @@ -58,3 +60,32 @@ func printContainerLogs(container testcontainers.Container) { fmt.Printf("=== Logs for container %s ===\n%s\n=== End of Logs for container %s ===\n", containerID, string(logData), containerID) } + +// pingDatabase tries to connect to the database using the driver and connection string. +// It retries for a few times with a delay before giving up. +func pingDatabase(driver string, connStr string) error { + var err error + maxRetries := 3 + retryDelay := 5 * time.Second + + for i := 0; i < maxRetries; i++ { + db, openErr := sql.Open(driver, connStr) + if openErr != nil { + err = openErr + } else { + pingErr := db.Ping() + closeErr := db.Close() + if pingErr == nil && closeErr == nil { + return nil // success + } + + if pingErr != nil { + err = pingErr + } else { + err = closeErr + } + } + time.Sleep(retryDelay) + } + return fmt.Errorf("pingDatabase failed even after '%d' retries: %w", maxRetries, err) +} diff --git a/yb-voyager/test/containers/mysql_container.go b/yb-voyager/test/containers/mysql_container.go index c2d8930cff..9fb7583aca 100644 --- a/yb-voyager/test/containers/mysql_container.go +++ b/yb-voyager/test/containers/mysql_container.go @@ -17,11 +17,10 @@ import ( type MysqlContainer struct { ContainerConfig container testcontainers.Container - db *sql.DB } func (ms *MysqlContainer) Start(ctx context.Context) (err error) { - if ms.container != nil { + if ms.container != nil && ms.container.IsRunning() { utils.PrintAndLog("Mysql-%s container already running", ms.DBVersion) return nil } @@ -62,24 +61,15 @@ func (ms *MysqlContainer) Start(ctx context.Context) (err error) { Started: true, }) printContainerLogs(ms.container) - if err != nil { - return err - } - dsn := ms.GetConnectionString() - db, err := sql.Open("mysql", dsn) if err != nil { - return fmt.Errorf("failed to open mysql connection: %w", err) + return fmt.Errorf("failed to start mysql container: %w", err) } - if err = db.Ping(); err != nil { - db.Close() - return fmt.Errorf("failed to ping mysql after connection: %w", err) + err = pingDatabase("mysql", ms.GetConnectionString()) + if err != nil { + return fmt.Errorf("failed to ping mysql container: %w", err) } - - // Store the DB connection for reuse - ms.db = db - return nil } @@ -88,13 +78,6 @@ func (ms *MysqlContainer) Terminate(ctx context.Context) { return } - // Close the DB connection if it exists - if ms.db != nil { - if err := ms.db.Close(); err != nil { - log.Errorf("failed to close mysql db connection: %v", err) - } - } - err := ms.container.Terminate(ctx) if err != nil { log.Errorf("failed to terminate mysql container: %v", err) @@ -136,12 +119,17 @@ func (ms *MysqlContainer) GetConnectionString() string { } func (ms *MysqlContainer) ExecuteSqls(sqls ...string) { - if ms.db == nil { - utils.ErrExit("db connection not initialized for mysql container") + if ms == nil { + utils.ErrExit("mysql container is not started: nil") + } + + db, err := sql.Open("mysql", ms.GetConnectionString()) + if err != nil { + utils.ErrExit("failed to connect to mysql for executing sqls: %w", err) } for _, sqlStmt := range sqls { - _, err := ms.db.Exec(sqlStmt) + _, err := db.Exec(sqlStmt) if err != nil { utils.ErrExit("failed to execute sql '%s': %w", sqlStmt, err) } diff --git a/yb-voyager/test/containers/oracle_container.go b/yb-voyager/test/containers/oracle_container.go index 8fb8218de2..c02f2774d6 100644 --- a/yb-voyager/test/containers/oracle_container.go +++ b/yb-voyager/test/containers/oracle_container.go @@ -19,7 +19,7 @@ type OracleContainer struct { } func (ora *OracleContainer) Start(ctx context.Context) (err error) { - if ora.container != nil { + if ora.container != nil && ora.container.IsRunning() { utils.PrintAndLog("Oracle-%s container already running", ora.DBVersion) return nil } @@ -61,7 +61,15 @@ func (ora *OracleContainer) Start(ctx context.Context) (err error) { Started: true, }) printContainerLogs(ora.container) - return err + if err != nil { + return fmt.Errorf("failed to start oracle container: %w", err) + } + + err = pingDatabase("godror", ora.GetConnectionString()) + if err != nil { + return fmt.Errorf("failed to ping oracle container: %w", err) + } + return nil } func (ora *OracleContainer) Terminate(ctx context.Context) { @@ -99,7 +107,15 @@ func (ora *OracleContainer) GetConfig() ContainerConfig { } func (ora *OracleContainer) GetConnectionString() string { - panic("GetConnectionString() not implemented yet for oracle") + config := ora.GetConfig() + host, port, err := ora.GetHostPort() + if err != nil { + utils.ErrExit("failed to get host port for oracle connection string: %v", err) + } + + connectString := fmt.Sprintf(`(DESCRIPTION = (ADDRESS = (PROTOCOL = TCP)(HOST = %s)(PORT = %d))(CONNECT_DATA = (SERVICE_NAME = %s)))`, + host, port, config.DBName) + return fmt.Sprintf(`user="%s" password="%s" connectString="%s"`, config.User, config.Password, connectString) } func (ora *OracleContainer) ExecuteSqls(sqls ...string) { diff --git a/yb-voyager/test/containers/postgres_container.go b/yb-voyager/test/containers/postgres_container.go index 539763f7de..204eb89782 100644 --- a/yb-voyager/test/containers/postgres_container.go +++ b/yb-voyager/test/containers/postgres_container.go @@ -2,12 +2,12 @@ package testcontainers import ( "context" - "database/sql" "fmt" "os" "time" "github.com/docker/go-connections/nat" + "github.com/jackc/pgx/v5" log "github.com/sirupsen/logrus" "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/wait" @@ -17,11 +17,10 @@ import ( type PostgresContainer struct { ContainerConfig container testcontainers.Container - db *sql.DB } func (pg *PostgresContainer) Start(ctx context.Context) (err error) { - if pg.container != nil { + if pg.container != nil && pg.container.IsRunning() { utils.PrintAndLog("Postgres-%s container already running", pg.DBVersion) return nil } @@ -66,23 +65,13 @@ func (pg *PostgresContainer) Start(ctx context.Context) (err error) { printContainerLogs(pg.container) if err != nil { - return err + return fmt.Errorf("failed to start postgres container: %w", err) } - dsn := pg.GetConnectionString() - db, err := sql.Open("pgx", dsn) + err = pingDatabase("pgx", pg.GetConnectionString()) if err != nil { - return fmt.Errorf("failed to open postgres connection: %w", err) + return fmt.Errorf("failed to ping postgres container: %w", err) } - - if err := db.Ping(); err != nil { - db.Close() - pg.container.Terminate(ctx) - return fmt.Errorf("failed to ping postgres after connection: %w", err) - } - - // Store the DB connection for reuse - pg.db = db return nil } @@ -91,13 +80,6 @@ func (pg *PostgresContainer) Terminate(ctx context.Context) { return } - // Close the DB connection if it exists - if pg.db != nil { - if err := pg.db.Close(); err != nil { - log.Errorf("failed to close postgres db connection: %v", err) - } - } - err := pg.container.Terminate(ctx) if err != nil { log.Errorf("failed to terminate postgres container: %v", err) @@ -134,16 +116,23 @@ func (pg *PostgresContainer) GetConnectionString() string { utils.ErrExit("failed to get host port for postgres connection string: %v", err) } - return fmt.Sprintf("postgresql://%s:%s@%s:%d/%s", config.User, config.Password, host, port, config.DBName) + return fmt.Sprintf("postgresql://%s:%s@%s:%d/%s?sslmode=disable", config.User, config.Password, host, port, config.DBName) } func (pg *PostgresContainer) ExecuteSqls(sqls ...string) { - if pg.db == nil { - utils.ErrExit("db connection not initialized for postgres container") + if pg == nil { + utils.ErrExit("postgres container is not started: nil") + } + + connStr := pg.GetConnectionString() + conn, err := pgx.Connect(context.Background(), connStr) + if err != nil { + utils.ErrExit("failed to connect to postgres for executing sqls: %w", err) } + defer conn.Close(context.Background()) for _, sqlStmt := range sqls { - _, err := pg.db.Exec(sqlStmt) + _, err := conn.Exec(context.Background(), sqlStmt) if err != nil { utils.ErrExit("failed to execute sql '%s': %w", sqlStmt, err) } diff --git a/yb-voyager/test/containers/yugabytedb_container.go b/yb-voyager/test/containers/yugabytedb_container.go index c2900d12b9..82e3677a64 100644 --- a/yb-voyager/test/containers/yugabytedb_container.go +++ b/yb-voyager/test/containers/yugabytedb_container.go @@ -22,7 +22,7 @@ type YugabyteDBContainer struct { } func (yb *YugabyteDBContainer) Start(ctx context.Context) (err error) { - if yb.container != nil { + if yb.container != nil && yb.container.IsRunning() { utils.PrintAndLog("YugabyteDB-%s container already running", yb.DBVersion) return nil } @@ -39,6 +39,7 @@ func (yb *YugabyteDBContainer) Start(ctx context.Context) (err error) { } // this will create a 1 Node RF-1 cluster + // TODO: Ideally we should test with 3 Node RF-3 cluster req := testcontainers.ContainerRequest{ Image: fmt.Sprintf("yugabytedb/yugabyte:%s", yb.DBVersion), ExposedPorts: []string{"5433/tcp", "15433/tcp", "7000/tcp", "9000/tcp", "9042/tcp"}, @@ -67,7 +68,15 @@ func (yb *YugabyteDBContainer) Start(ctx context.Context) (err error) { Started: true, }) printContainerLogs(yb.container) - return err + if err != nil { + return fmt.Errorf("failed to start yugabytedb container: %w", err) + } + + err = pingDatabase("pgx", yb.GetConnectionString()) + if err != nil { + return fmt.Errorf("failed to ping yugabytedb container: %w", err) + } + return nil } func (yb *YugabyteDBContainer) Terminate(ctx context.Context) { @@ -115,10 +124,14 @@ func (yb *YugabyteDBContainer) GetConnectionString() string { } func (yb *YugabyteDBContainer) ExecuteSqls(sqls ...string) { + if yb == nil { + utils.ErrExit("yugabytedb container is not started: nil") + } + connStr := yb.GetConnectionString() conn, err := pgx.Connect(context.Background(), connStr) if err != nil { - utils.ErrExit("failed to connect postgres for executing sqls: %w", err) + utils.ErrExit("failed to connect to yugabytedb for executing sqls: %w", err) } defer conn.Close(context.Background())