diff --git a/internal/manager/task/migrate.go b/internal/manager/task/migrate.go index 609512b2f4b..8ce478ec80e 100644 --- a/internal/manager/task/migrate.go +++ b/internal/manager/task/migrate.go @@ -50,7 +50,7 @@ func (s *MigrateJob) Execute(ctx context.Context, progress *job.Progress) error // always backup so that we can roll back to the previous version if // migration fails backupPath := s.BackupPath - if backupPath == "" { + if backupPath == "" || s.Database.DatabaseType() == sqlite.PostgresBackend { backupPath = database.DatabaseBackupPath(s.Config.GetBackupDirectoryPath()) } else { // check if backup path is a filename or path diff --git a/pkg/sqlite/database.go b/pkg/sqlite/database.go index 9f846c28658..aec6e681130 100644 --- a/pkg/sqlite/database.go +++ b/pkg/sqlite/database.go @@ -6,6 +6,7 @@ import ( "embed" "errors" "fmt" + "path/filepath" "strconv" "time" @@ -358,6 +359,28 @@ func (db *Database) Version() uint { return db.schemaVersion } +func (db *Database) Reset() error { + if err := db.Remove(); err != nil { + return err + } + + if err := db.Open(); err != nil { + return fmt.Errorf("[reset DB] unable to initialize: %w", err) + } + + return nil +} + +func (db *Database) AnonymousDatabasePath(backupDirectoryPath string) string { + fn := fmt.Sprintf("%s.anonymous.%d.%s", filepath.Base(db.DatabasePath()), db.schemaVersion, time.Now().Format("20060102_150405")) + + if backupDirectoryPath != "" { + return filepath.Join(backupDirectoryPath, fn) + } + + return fn +} + func (db *Database) Optimise(ctx context.Context) error { logger.Info("Optimising database") diff --git a/pkg/sqlite/database_postgres.go b/pkg/sqlite/database_postgres.go index 0bd601b4e1b..2f667f451d6 100644 --- a/pkg/sqlite/database_postgres.go +++ b/pkg/sqlite/database_postgres.go @@ -2,12 +2,12 @@ package sqlite import ( "fmt" + "time" "github.com/doug-martin/goqu/v9" _ "github.com/doug-martin/goqu/v9/dialect/postgres" _ "github.com/jackc/pgx/v5/stdlib" "github.com/jmoiron/sqlx" - "github.com/stashapp/stash/pkg/logger" ) type PostgresDB struct { @@ -71,27 +71,51 @@ func (db *PostgresDB) open(disableForeignKeys bool, writable bool) (conn *sqlx.D return conn, nil } -func (db *PostgresDB) Remove() error { - logger.Warn("Postgres backend detected, ignoring Remove request") - return nil +func (db *PostgresDB) Remove() (err error) { + _, err = db.writeDB.Exec(` +DO $$ DECLARE + r RECORD; +BEGIN + -- Disable triggers to avoid foreign key constraint violations + EXECUTE 'SET session_replication_role = replica'; + + -- Drop all tables + FOR r IN (SELECT tablename FROM pg_tables WHERE schemaname = 'public') LOOP + EXECUTE 'DROP TABLE IF EXISTS ' || quote_ident(r.tablename) || ' CASCADE'; + END LOOP; + + -- Re-enable triggers + EXECUTE 'SET session_replication_role = DEFAULT'; +END $$; +`) + + return err } -func (db *PostgresDB) Reset() error { - logger.Warn("Postgres backend detected, ignoring Reset request") - return nil +// getDBCloneCommand returns the command to clone a database from a backup file +func getDBCloneCommand(backupPath string, dbname string) string { + return fmt.Sprintf(` +SELECT pg_terminate_backend(pg_stat_activity.pid) FROM pg_stat_activity +WHERE pg_stat_activity.datname = '%[2]s' AND pid <> pg_backend_pid(); +CREATE DATABASE %[1]s WITH TEMPLATE %[2]s; +`, backupPath, dbname) } +// Backup creates a backup of the database at the given path. func (db *PostgresDB) Backup(backupPath string) (err error) { - logger.Warn("Postgres backend detected, ignoring Backup request") - return nil + _, err = db.writeDB.Exec(getDBCloneCommand(backupPath, "stash")) + return err } -func (db *PostgresDB) RestoreFromBackup(backupPath string) error { - logger.Warn("Postgres backend detected, ignoring RestoreFromBackup request") - return nil +// RestoreFromBackup restores the database from a backup file at the given path. +func (db *PostgresDB) RestoreFromBackup(backupPath string) (err error) { + sqlcmd := "DROP DATABASE stash;\n" + getDBCloneCommand("stash", backupPath) + + _, err = db.writeDB.Exec(sqlcmd) + return err } +// DatabaseBackupPath returns the path to a database backup file for the given directory. func (db *PostgresDB) DatabaseBackupPath(backupDirectoryPath string) string { - logger.Warn("Postgres backend detected, ignoring DatabaseBackupPath request") - return "" + return fmt.Sprintf("stash_%d_%s", db.schemaVersion, time.Now().Format("20060102_150405")) } diff --git a/pkg/sqlite/database_sqlite.go b/pkg/sqlite/database_sqlite.go index 2a58d6c65ef..8d3451f7fb9 100644 --- a/pkg/sqlite/database_sqlite.go +++ b/pkg/sqlite/database_sqlite.go @@ -103,18 +103,6 @@ func (db *SQLiteDB) Remove() error { return nil } -func (db *SQLiteDB) Reset() error { - if err := db.Remove(); err != nil { - return err - } - - if err := db.Open(); err != nil { - return fmt.Errorf("[reset DB] unable to initialize: %w", err) - } - - return nil -} - // Backup the database. If db is nil, then uses the existing database // connection. func (db *SQLiteDB) Backup(backupPath string) (err error) { @@ -150,13 +138,3 @@ func (db *SQLiteDB) DatabaseBackupPath(backupDirectoryPath string) string { return fn } - -func (db *SQLiteDB) AnonymousDatabasePath(backupDirectoryPath string) string { - fn := fmt.Sprintf("%s.anonymous.%d.%s", filepath.Base(db.DatabasePath()), db.schemaVersion, time.Now().Format("20060102_150405")) - - if backupDirectoryPath != "" { - return filepath.Join(backupDirectoryPath, fn) - } - - return fn -}