diff --git a/server/db.go b/server/db.go index af1fbb739..dba7980d0 100644 --- a/server/db.go +++ b/server/db.go @@ -33,7 +33,10 @@ import ( "go.uber.org/zap" ) -const dbErrorDatabaseDoesNotExist = pgerrcode.InvalidCatalogName +const ( + dbErrorDatabaseInvalidPassword = pgerrcode.InvalidPassword + dbErrorDatabaseDoesNotExist = pgerrcode.InvalidCatalogName +) var ErrDatabaseDriverMismatch = errors.New("database driver mismatch") @@ -50,7 +53,7 @@ func DbConnect(ctx context.Context, logger *zap.Logger, config Config, create bo } query := parsedURL.Query() var queryUpdated bool - if len(query.Get("sslmode")) == 0 { + if query.Get("sslmode") == "" { query.Set("sslmode", "prefer") queryUpdated = true } @@ -77,16 +80,14 @@ func DbConnect(ctx context.Context, logger *zap.Logger, config Config, create bo logger.Fatal("Failed to open database", zap.Error(err)) } + dbPing(ctx, logger, db, dbName) + if create { var nakamaDBExists bool if err = db.QueryRow("SELECT EXISTS (SELECT 1 from pg_database WHERE datname = $1)", dbName).Scan(&nakamaDBExists); err != nil { - var pgErr *pgconn.PgError - if errors.As(err, &pgErr) && pgErr.Code == dbErrorDatabaseDoesNotExist { - nakamaDBExists = false - } else { - db.Close() - logger.Fatal("Failed to check if db exists", zap.String("db", dbName), zap.Error(err)) - } + nakamaDBExists = false + db.Close() + logger.Fatal("Failed to check if db exists", zap.String("db", dbName), zap.Error(err)) } if !nakamaDBExists { @@ -104,45 +105,32 @@ func DbConnect(ctx context.Context, logger *zap.Logger, config Config, create bo logger.Fatal("Failed to create database", zap.Error(err)) } db.Close() + parsedURL.Path = fmt.Sprintf("/%s", dbName) - db, err = sql.Open("pgx", parsedURL.String()) - if err != nil { - db.Close() - logger.Fatal("Failed to open database", zap.Error(err)) - } } } logger.Debug("Complete database connection URL", zap.String("raw_url", parsedURL.String())) db, err = sql.Open("pgx", parsedURL.String()) if err != nil { - logger.Fatal("Error connecting to database", zap.Error(err)) - } - // Limit max time allowed across database ping and version fetch to 15 seconds total. - pingCtx, pingCtxCancelFn := context.WithTimeout(ctx, 15*time.Second) - defer pingCtxCancelFn() - if err = db.PingContext(pingCtx); err != nil { - if strings.HasSuffix(err.Error(), "does not exist (SQLSTATE 3D000)") { - logger.Fatal("Database schema not found, run `nakama migrate up`", zap.Error(err)) - } - logger.Fatal("Error pinging database", zap.Error(err)) + logger.Fatal("Failed to open database", zap.Error(err)) } + dbPing(ctx, logger, db, dbName) + db.SetConnMaxLifetime(time.Millisecond * time.Duration(config.GetDatabase().ConnMaxLifetimeMs)) db.SetMaxOpenConns(config.GetDatabase().MaxOpenConns) db.SetMaxIdleConns(config.GetDatabase().MaxIdleConns) var dbVersion string - if err = db.QueryRowContext(pingCtx, "SELECT version()").Scan(&dbVersion); err != nil { - logger.Fatal("Error querying database version", zap.Error(err)) + versionCtx, versionCtxCancelFn := context.WithTimeout(ctx, 5*time.Second) + defer versionCtxCancelFn() + if err = db.QueryRowContext(versionCtx, "SELECT version()").Scan(&dbVersion); err != nil { + logger.Fatal("Failed to query database version", zap.Error(err)) } logger.Info("Database information", zap.String("version", dbVersion)) - if strings.Split(dbVersion, " ")[0] == "CockroachDB" { - isCockroach = true - } else { - isCockroach = false - } + isCockroach = strings.HasPrefix(dbVersion, "CockroachDB ") // Periodically check database hostname for underlying address changes. go func() { @@ -237,6 +225,30 @@ func DbConnect(ctx context.Context, logger *zap.Logger, config Config, create bo return db } +func dbPing(ctx context.Context, logger *zap.Logger, db *sql.DB, dbName string) { + pingCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + err := db.PingContext(pingCtx) + if err != nil { + db.Close() + + errLogger := logger.With(zap.String("db", dbName), zap.Error(err)) + + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) { + switch pgErr.Code { + case dbErrorDatabaseInvalidPassword: + errLogger.Fatal("Invalid credentials") + case dbErrorDatabaseDoesNotExist: + errLogger.Fatal("Database schema not found, run `nakama migrate up`") + } + } else { + errLogger.Fatal("Failed to ping database") + } + } +} + func dbResolveAddress(ctx context.Context, logger *zap.Logger, host string) ([]string, map[string]struct{}) { resolveCtx, resolveCtxCancelFn := context.WithTimeout(ctx, 15*time.Second) defer resolveCtxCancelFn()