From a85d11b78c26bc34c0161fda532f4172acfa4d46 Mon Sep 17 00:00:00 2001 From: Karan Kajla Date: Wed, 22 May 2024 14:10:55 -0700 Subject: [PATCH] Support passing in a connection string for postgres databases (#325) --- .github/workflows/mysql.yaml | 7 ++---- .github/workflows/postgres.yaml | 8 ++----- cmd/warrant/main.go | 2 +- go.mod | 2 +- pkg/config/config.go | 2 ++ pkg/database/mysql.go | 6 ++--- pkg/database/postgres.go | 42 +++++++++++++-------------------- pkg/database/sqlite.go | 6 ++--- 8 files changed, 30 insertions(+), 45 deletions(-) diff --git a/.github/workflows/mysql.yaml b/.github/workflows/mysql.yaml index 5fcb4b1a..d7a45858 100644 --- a/.github/workflows/mysql.yaml +++ b/.github/workflows/mysql.yaml @@ -47,15 +47,12 @@ jobs: WARRANT_CHECK_MAXCONCURRENCY: 1000 WARRANT_CHECK_TIMEOUT: 1m WARRANT_DATASTORE: mysql - WARRANT_DATASTORE_MYSQL_USERNAME: root - WARRANT_DATASTORE_MYSQL_PASSWORD: root - WARRANT_DATASTORE_MYSQL_HOSTNAME: 127.0.0.1 - WARRANT_DATASTORE_MYSQL_DATABASE: warrant + WARRANT_DATASTORE_MYSQL_DSN: root:root@tcp(127.0.0.1:3306)/warrant?parseTime=true WARRANT_DATASTORE_MYSQL_MAXIDLECONNECTIONS: 5 WARRANT_DATASTORE_MYSQL_MAXOPENCONNECTIONS: 5 WARRANT_DATASTORE_MYSQL_CONNMAXIDLETIME: 4h WARRANT_DATASTORE_MYSQL_CONNMAXLIFETIME: 6h - WARRANT_DATASTORE_MYSQL_READERHOSTNAME: 127.0.0.1 + WARRANT_DATASTORE_MYSQL_READERDSN: root:root@tcp(127.0.0.1:3306)/warrant?parseTime=true WARRANT_DATASTORE_MYSQL_READERMAXIDLECONNECTIONS: 5 WARRANT_DATASTORE_MYSQL_READERMAXOPENCONNECTIONS: 5 - name: Run apirunner tests diff --git a/.github/workflows/postgres.yaml b/.github/workflows/postgres.yaml index 9d88f050..cfa64262 100644 --- a/.github/workflows/postgres.yaml +++ b/.github/workflows/postgres.yaml @@ -52,16 +52,12 @@ jobs: WARRANT_CHECK_MAXCONCURRENCY: 1000 WARRANT_CHECK_TIMEOUT: 1m WARRANT_DATASTORE: postgres - WARRANT_DATASTORE_POSTGRES_USERNAME: warrant_user - WARRANT_DATASTORE_POSTGRES_PASSWORD: db_password - WARRANT_DATASTORE_POSTGRES_HOSTNAME: localhost - WARRANT_DATASTORE_POSTGRES_DATABASE: warrant - WARRANT_DATASTORE_POSTGRES_SSLMODE: disable + WARRANT_DATASTORE_POSTGRES_DSN: postgresql://warrant_user:db_password@localhost:5432/warrant?sslmode=disable WARRANT_DATASTORE_POSTGRES_MAXIDLECONNECTIONS: 5 WARRANT_DATASTORE_POSTGRES_MAXOPENCONNECTIONS: 5 WARRANT_DATASTORE_POSTGRES_CONNMAXIDLETIME: 4h WARRANT_DATASTORE_POSTGRES_CONNMAXLIFETIME: 6h - WARRANT_DATASTORE_POSTGRES_READERHOSTNAME: localhost + WARRANT_DATASTORE_POSTGRES_READERDSN: postgresql://warrant_user:db_password@localhost:5432/warrant?sslmode=disable WARRANT_DATASTORE_POSTGRES_READERMAXIDLECONNECTIONS: 5 WARRANT_DATASTORE_POSTGRES_READERMAXOPENCONNECTIONS: 5 - name: Run apirunner tests diff --git a/cmd/warrant/main.go b/cmd/warrant/main.go index 5a7e09b8..bc6d45c6 100644 --- a/cmd/warrant/main.go +++ b/cmd/warrant/main.go @@ -74,7 +74,7 @@ func (env *ServiceEnv) InitDB(cfg config.Config) error { return nil } - if cfg.GetDatastore().GetPostgres().Hostname != "" { + if cfg.GetDatastore().GetPostgres().Hostname != "" || cfg.GetDatastore().GetPostgres().DSN != "" { db := database.NewPostgres(*cfg.GetDatastore().GetPostgres()) err := db.Connect(ctx) if err != nil { diff --git a/go.mod b/go.mod index a477588e..adadab9b 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,6 @@ require ( github.com/google/uuid v1.6.0 github.com/gorilla/mux v1.8.1 github.com/jmoiron/sqlx v1.4.0 - github.com/lib/pq v1.10.9 github.com/mattn/go-sqlite3 v1.14.22 github.com/pkg/errors v0.9.1 github.com/rs/zerolog v1.32.0 @@ -33,6 +32,7 @@ require ( github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect + github.com/lib/pq v1.10.9 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.19 // indirect diff --git a/pkg/config/config.go b/pkg/config/config.go index 1f281c4d..39d87cab 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -137,6 +137,8 @@ type PostgresConfig struct { ReaderHostname string `mapstructure:"readerHostname"` ReaderMaxIdleConnections int `mapstructure:"readerMaxIdleConnections"` ReaderMaxOpenConnections int `mapstructure:"readerMaxOpenConnections"` + DSN string `mapstructure:"dsn"` + ReaderDSN string `mapstructure:"readerDsn"` } type SQLiteConfig struct { diff --git a/pkg/database/mysql.go b/pkg/database/mysql.go index a95bfd0f..5012b4ee 100644 --- a/pkg/database/mysql.go +++ b/pkg/database/mysql.go @@ -42,7 +42,7 @@ func NewMySQL(config config.MySQLConfig) *MySQL { } } -func (ds MySQL) Type() string { +func (ds *MySQL) Type() string { return TypeMySQL } @@ -122,7 +122,7 @@ func (ds *MySQL) Connect(ctx context.Context) error { return nil } -func (ds MySQL) Migrate(ctx context.Context, toVersion uint) error { +func (ds *MySQL) Migrate(ctx context.Context, toVersion uint) error { log.Info().Msgf("init: migrating mysql database %s", ds.Config.Database) // migrate database to latest schema mig, err := migrate.New( @@ -159,7 +159,7 @@ func (ds MySQL) Migrate(ctx context.Context, toVersion uint) error { return nil } -func (ds MySQL) Ping(ctx context.Context) error { +func (ds *MySQL) Ping(ctx context.Context) error { err := ds.Writer.PingContext(ctx) if err != nil { return errors.Wrap(err, "Error while attempting to ping mysql database") diff --git a/pkg/database/postgres.go b/pkg/database/postgres.go index b8a16aac..d4c26041 100644 --- a/pkg/database/postgres.go +++ b/pkg/database/postgres.go @@ -28,7 +28,6 @@ import ( _ "github.com/golang-migrate/migrate/v4/database/postgres" _ "github.com/golang-migrate/migrate/v4/source/file" _ "github.com/golang-migrate/migrate/v4/source/github" - "github.com/lib/pq" "github.com/warrant-dev/warrant/pkg/config" ) @@ -44,7 +43,7 @@ func NewPostgres(config config.PostgresConfig) *Postgres { } } -func (ds Postgres) Type() string { +func (ds *Postgres) Type() string { return TypePostgres } @@ -52,26 +51,12 @@ func (ds *Postgres) Connect(ctx context.Context) error { var db *sqlx.DB var err error - // open new database connection without specifying the database name - usernamePassword := url.UserPassword(ds.Config.Username, ds.Config.Password).String() - db, err = sqlx.Open("postgres", fmt.Sprintf("postgres://%s@%s/?sslmode=%s", usernamePassword, ds.Config.Hostname, ds.Config.SSLMode)) - if err != nil { - return errors.Wrap(err, fmt.Sprintf("Unable to establish connection to postgres database %s. Shutting down server.", ds.Config.Database)) - } - - // create database if it does not already exist - _, err = db.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", ds.Config.Database)) - if err != nil { - pgErr, ok := err.(*pq.Error) - if ok && pgErr.Code.Name() != "duplicate_database" { - return errors.Wrap(err, fmt.Sprintf("Unable to create postgres database %s", ds.Config.Database)) - } + if ds.Config.DSN != "" { + db, err = sqlx.Open("postgres", ds.Config.DSN) + } else { + usernamePassword := url.UserPassword(ds.Config.Username, ds.Config.Password).String() + db, err = sqlx.Open("postgres", fmt.Sprintf("postgres://%s@%s/%s?sslmode=%s", usernamePassword, ds.Config.Hostname, ds.Config.Database, ds.Config.SSLMode)) } - - db.Close() - - // open new database connection, this time specifying the database name - db, err = sqlx.Open("postgres", fmt.Sprintf("postgres://%s@%s/%s?sslmode=%s", usernamePassword, ds.Config.Hostname, ds.Config.Database, ds.Config.SSLMode)) if err != nil { return errors.Wrap(err, fmt.Sprintf("Unable to establish connection to postgres database %s. Shutting down server.", ds.Config.Database)) } @@ -101,8 +86,14 @@ func (ds *Postgres) Connect(ctx context.Context) error { ds.Config.Database, ds.Config.MaxIdleConnections, ds.Config.ConnMaxIdleTime, ds.Config.MaxOpenConnections, ds.Config.ConnMaxLifetime) // connect to reader if provided - if ds.Config.ReaderHostname != "" { - reader, err := sqlx.Open("postgres", fmt.Sprintf("postgres://%s@%s/%s?sslmode=%s", usernamePassword, ds.Config.ReaderHostname, ds.Config.Database, ds.Config.SSLMode)) + if ds.Config.ReaderHostname != "" || ds.Config.ReaderDSN != "" { + var reader *sqlx.DB + if ds.Config.ReaderDSN != "" { + reader, err = sqlx.Open("postgres", ds.Config.ReaderDSN) + } else { + usernamePassword := url.UserPassword(ds.Config.Username, ds.Config.Password).String() + reader, err = sqlx.Open("postgres", fmt.Sprintf("postgres://%s@%s/%s?sslmode=%s", usernamePassword, ds.Config.ReaderHostname, ds.Config.Database, ds.Config.SSLMode)) + } if err != nil { return errors.Wrap(err, fmt.Sprintf("Unable to establish connection to postgres reader %s. Shutting down server.", ds.Config.Database)) } @@ -126,7 +117,6 @@ func (ds *Postgres) Connect(ctx context.Context) error { // map struct attributes to db column names reader.Mapper = reflectx.NewMapperFunc("postgres", func(s string) string { return s }) - ds.Reader = reader log.Info().Msgf("init: connected to postgres reader database %s [maxIdleConns: %d, connMaxIdleTime: %s, maxOpenConns: %d, connMaxLifetime: %s]", ds.Config.Database, ds.Config.ReaderMaxIdleConnections, ds.Config.ConnMaxIdleTime, ds.Config.ReaderMaxOpenConnections, ds.Config.ConnMaxLifetime) @@ -135,7 +125,7 @@ func (ds *Postgres) Connect(ctx context.Context) error { return nil } -func (ds Postgres) Migrate(ctx context.Context, toVersion uint) error { +func (ds *Postgres) Migrate(ctx context.Context, toVersion uint) error { log.Info().Msgf("init: migrating postgres database %s", ds.Config.Database) // migrate database to latest schema usernamePassword := url.UserPassword(ds.Config.Username, ds.Config.Password).String() @@ -173,7 +163,7 @@ func (ds Postgres) Migrate(ctx context.Context, toVersion uint) error { return nil } -func (ds Postgres) Ping(ctx context.Context) error { +func (ds *Postgres) Ping(ctx context.Context) error { err := ds.Writer.PingContext(ctx) if err != nil { return errors.Wrap(err, "Error while attempting to ping postgres database") diff --git a/pkg/database/sqlite.go b/pkg/database/sqlite.go index f5cc1324..cf3fcdca 100644 --- a/pkg/database/sqlite.go +++ b/pkg/database/sqlite.go @@ -46,7 +46,7 @@ func NewSQLite(config config.SQLiteConfig) *SQLite { } } -func (ds SQLite) Type() string { +func (ds *SQLite) Type() string { return TypeSQLite } @@ -94,7 +94,7 @@ func (ds *SQLite) Connect(ctx context.Context) error { return nil } -func (ds SQLite) Migrate(ctx context.Context, toVersion uint) error { +func (ds *SQLite) Migrate(ctx context.Context, toVersion uint) error { log.Info().Msgf("init: migrating sqlite database %s", ds.Config.Database) // migrate database to latest schema instance, err := sqlite3.WithInstance(ds.Writer.DB, &sqlite3.Config{}) @@ -136,6 +136,6 @@ func (ds SQLite) Migrate(ctx context.Context, toVersion uint) error { return nil } -func (ds SQLite) Ping(ctx context.Context) error { +func (ds *SQLite) Ping(ctx context.Context) error { return ds.Writer.PingContext(ctx) }