From 2dcd76beecccf9b66a46974216632ea708610476 Mon Sep 17 00:00:00 2001 From: Joe Eschen <126913098+squiishyy@users.noreply.github.com> Date: Wed, 24 Apr 2024 15:27:45 -0700 Subject: [PATCH] Add read replica host config and connection - Add a new field to the postgres db config struct, `readReplicaHost`. - Add a new endpoint in the `database` package to enable establishing a connection with a db without creating it if it doesn't exist Signed-off-by: Andrew Dye --- flytestdlib/database/config.go | 22 ++++++----- flytestdlib/database/db.go | 25 +++++++++++++ flytestdlib/database/dbconfig_flags.go | 1 + flytestdlib/database/dbconfig_flags_test.go | 14 +++++++ flytestdlib/database/postgres.go | 18 +++++++++ flytestdlib/database/postgres_test.go | 41 +++++++++++++++++++++ 6 files changed, 111 insertions(+), 10 deletions(-) diff --git a/flytestdlib/database/config.go b/flytestdlib/database/config.go index 16ca0e57086..f55eecda8f2 100644 --- a/flytestdlib/database/config.go +++ b/flytestdlib/database/config.go @@ -19,12 +19,13 @@ var defaultConfig = &DbConfig{ ConnMaxLifeTime: config.Duration{Duration: time.Hour}, Postgres: PostgresConfig{ // These values are suitable for local sandbox development - Host: "localhost", - Port: 30001, - DbName: postgresStr, - User: postgresStr, - Password: postgresStr, - ExtraOptions: "sslmode=disable", + Host: "localhost", + ReadReplicaHost: "localhost", + Port: 30001, + DbName: postgresStr, + User: postgresStr, + Password: postgresStr, + ExtraOptions: "sslmode=disable", }, } var configSection = config.MustRegisterSection(database, defaultConfig) @@ -64,10 +65,11 @@ type SQLiteConfig struct { // PostgresConfig includes specific config options for opening a connection to a postgres database. type PostgresConfig struct { - Host string `json:"host" pflag:",The host name of the database server"` - Port int `json:"port" pflag:",The port name of the database server"` - DbName string `json:"dbname" pflag:",The database name"` - User string `json:"username" pflag:",The database user who is connecting to the server."` + Host string `json:"host" pflag:",The host name of the database server"` + ReadReplicaHost string `json:"readReplicaHost" pflag:",The host name of the read replica database server"` + Port int `json:"port" pflag:",The port name of the database server"` + DbName string `json:"dbname" pflag:",The database name"` + User string `json:"username" pflag:",The database user who is connecting to the server."` // Either Password or PasswordPath must be set. Password string `json:"password" pflag:",The database password."` PasswordPath string `json:"passwordPath" pflag:",Points to the file containing the database password."` diff --git a/flytestdlib/database/db.go b/flytestdlib/database/db.go index 8046c7ead48..a9645185673 100644 --- a/flytestdlib/database/db.go +++ b/flytestdlib/database/db.go @@ -65,6 +65,31 @@ func GetDB(ctx context.Context, dbConfig *DbConfig, logConfig *logger.Config) ( return gormDb, setupDbConnectionPool(ctx, gormDb, dbConfig) } +// GetReadOnlyDB uses the dbConfig to create gorm DB object for the read replica passed via the config +func GetReadOnlyDB(ctx context.Context, dbConfig *DbConfig, logConfig *logger.Config) (*gorm.DB, error) { + if dbConfig == nil { + panic("Cannot initialize database repository from empty db config") + } + + if dbConfig.Postgres.IsEmpty() || dbConfig.Postgres.ReadReplicaHost == "" { + return nil, fmt.Errorf("read replica host not provided in db config") + } + + gormConfig := &gorm.Config{ + Logger: GetGormLogger(ctx, logConfig), + DisableForeignKeyConstraintWhenMigrating: false, + } + + var gormDb *gorm.DB + var err error + gormDb, err = CreatePostgresReadOnlyDbConnection(ctx, gormConfig, dbConfig.Postgres) + if err != nil { + return nil, err + } + + return gormDb, nil +} + func setupDbConnectionPool(ctx context.Context, gormDb *gorm.DB, dbConfig *DbConfig) error { genericDb, err := gormDb.DB() if err != nil { diff --git a/flytestdlib/database/dbconfig_flags.go b/flytestdlib/database/dbconfig_flags.go index c925094cc25..9fa96f9fa83 100755 --- a/flytestdlib/database/dbconfig_flags.go +++ b/flytestdlib/database/dbconfig_flags.go @@ -55,6 +55,7 @@ func (cfg DbConfig) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "maxOpenConnections"), defaultConfig.MaxOpenConnections, "maxOpenConnections sets the maximum number of open connections to the database.") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "connMaxLifeTime"), defaultConfig.ConnMaxLifeTime.String(), "sets the maximum amount of time a connection may be reused") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "postgres.host"), defaultConfig.Postgres.Host, "The host name of the database server") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "postgres.readReplicaHost"), defaultConfig.Postgres.ReadReplicaHost, "The host name of the read replica database server") cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "postgres.port"), defaultConfig.Postgres.Port, "The port name of the database server") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "postgres.dbname"), defaultConfig.Postgres.DbName, "The database name") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "postgres.username"), defaultConfig.Postgres.User, "The database user who is connecting to the server.") diff --git a/flytestdlib/database/dbconfig_flags_test.go b/flytestdlib/database/dbconfig_flags_test.go index 2f0a5d53ebc..fd49e69fd89 100755 --- a/flytestdlib/database/dbconfig_flags_test.go +++ b/flytestdlib/database/dbconfig_flags_test.go @@ -169,6 +169,20 @@ func TestDbConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_postgres.readReplicaHost", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("postgres.readReplicaHost", testValue) + if vString, err := cmdFlags.GetString("postgres.readReplicaHost"); err == nil { + testDecodeJson_DbConfig(t, fmt.Sprintf("%v", vString), &actual.Postgres.ReadReplicaHost) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) t.Run("Test_postgres.port", func(t *testing.T) { t.Run("Override", func(t *testing.T) { diff --git a/flytestdlib/database/postgres.go b/flytestdlib/database/postgres.go index ec43330af13..7e73226465c 100644 --- a/flytestdlib/database/postgres.go +++ b/flytestdlib/database/postgres.go @@ -54,6 +54,18 @@ func getPostgresDsn(ctx context.Context, pgConfig PostgresConfig) string { pgConfig.Host, pgConfig.Port, pgConfig.DbName, pgConfig.User, password, pgConfig.ExtraOptions) } +// Produces the DSN (data source name) for the read replica for opening a postgres db connection. +func getPostgresReadDsn(ctx context.Context, pgConfig PostgresConfig) string { + password := resolvePassword(ctx, pgConfig.Password, pgConfig.PasswordPath) + if len(password) == 0 { + // The password-less case is included for development environments. + return fmt.Sprintf("host=%s port=%d dbname=%s user=%s sslmode=disable", + pgConfig.ReadReplicaHost, pgConfig.Port, pgConfig.DbName, pgConfig.User) + } + return fmt.Sprintf("host=%s port=%d dbname=%s user=%s password=%s %s", + pgConfig.ReadReplicaHost, pgConfig.Port, pgConfig.DbName, pgConfig.User, password, pgConfig.ExtraOptions) +} + // CreatePostgresDbIfNotExists creates DB if it doesn't exist for the passed in config func CreatePostgresDbIfNotExists(ctx context.Context, gormConfig *gorm.Config, pgConfig PostgresConfig) (*gorm.DB, error) { dialector := postgres.Open(getPostgresDsn(ctx, pgConfig)) @@ -94,6 +106,12 @@ func CreatePostgresDbIfNotExists(ctx context.Context, gormConfig *gorm.Config, p return gorm.Open(dialector, gormConfig) } +// CreatePostgresDbConnection creates DB connection and returns the gorm.DB object and error +func CreatePostgresReadOnlyDbConnection(ctx context.Context, gormConfig *gorm.Config, pgConfig PostgresConfig) (*gorm.DB, error) { + dialector := postgres.Open(getPostgresReadDsn(ctx, pgConfig)) + return gorm.Open(dialector, gormConfig) +} + func IsPgErrorWithCode(err error, code string) bool { // Newer versions of the gorm postgres driver seem to use // "github.com/jackc/pgx/v5/pgconn" diff --git a/flytestdlib/database/postgres_test.go b/flytestdlib/database/postgres_test.go index 311b05c351d..eb22c6b3aa2 100644 --- a/flytestdlib/database/postgres_test.go +++ b/flytestdlib/database/postgres_test.go @@ -66,6 +66,47 @@ func TestGetPostgresDsn(t *testing.T) { }) } +func TestGetPostgresReadDsn(t *testing.T) { + pgConfig := PostgresConfig{ + Host: "localhost", + ReadReplicaHost: "readReplicaHost", + Port: 5432, + DbName: "postgres", + User: "postgres", + ExtraOptions: "sslmode=disable", + } + t.Run("no password", func(t *testing.T) { + dsn := getPostgresReadDsn(context.TODO(), pgConfig) + assert.Equal(t, "host=readReplicaHost port=5432 dbname=postgres user=postgres sslmode=disable", dsn) + }) + t.Run("with password", func(t *testing.T) { + pgConfig.Password = "passw" + dsn := getPostgresReadDsn(context.TODO(), pgConfig) + assert.Equal(t, "host=readReplicaHost port=5432 dbname=postgres user=postgres password=passw sslmode=disable", dsn) + + }) + t.Run("with password, no extra", func(t *testing.T) { + pgConfig.Password = "passwo" + pgConfig.ExtraOptions = "" + dsn := getPostgresReadDsn(context.TODO(), pgConfig) + assert.Equal(t, "host=readReplicaHost port=5432 dbname=postgres user=postgres password=passwo ", dsn) + }) + t.Run("with password path", func(t *testing.T) { + password := "1234abc" + tmpFile, err := ioutil.TempFile("", "prefix") + if err != nil { + t.Errorf("Couldn't open temp file: %v", err) + } + defer tmpFile.Close() + if _, err = tmpFile.WriteString(password); err != nil { + t.Errorf("Couldn't write to temp file: %v", err) + } + pgConfig.PasswordPath = tmpFile.Name() + dsn := getPostgresReadDsn(context.TODO(), pgConfig) + assert.Equal(t, "host=readReplicaHost port=5432 dbname=postgres user=postgres password=1234abc ", dsn) + }) +} + type wrappedError struct { err error }