Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add postgresql schema support. #507

Merged
merged 9 commits into from
Mar 6, 2024
Merged
37 changes: 31 additions & 6 deletions dbbackup/db/postgresql.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from typing import List, Optional
from urllib.parse import quote

from .base import BaseCommandDBConnector
Expand All @@ -21,7 +22,7 @@ def create_postgres_uri(self):
else:
host = "@" + host

port = ":{}".format(self.settings.get("PORT")) if self.settings.get("PORT") else ""
port = f":{self.settings.get('PORT')}" if self.settings.get("PORT") else ""
dbname = f"--dbname=postgresql://{user}{password}{host}{port}/{dbname}"
return dbname

Expand All @@ -37,16 +38,23 @@ class PgDumpConnector(BaseCommandDBConnector):
restore_cmd = "psql"
single_transaction = True
drop = True
schemas: Optional[List[str]] = []

def _create_dump(self):
cmd = f"{self.dump_cmd} "
cmd = cmd + create_postgres_uri(self)

for table in self.exclude:
cmd += f" --exclude-table-data={table}"

if self.drop:
cmd += " --clean"

if self.schemas:
# First schema is not prefixed with -n
# when using join function so add it manually.
cmd += " -n " + " -n ".join(self.schemas)

cmd = f"{self.dump_prefix} {cmd} {self.dump_suffix}"
stdout, stderr = self.run_command(cmd, env=self.dump_env)
return stdout
Expand All @@ -57,9 +65,14 @@ def _restore_dump(self, dump):

# without this, psql terminates with an exit value of 0 regardless of errors
cmd += " --set ON_ERROR_STOP=on"

if self.schemas:
cmd += " -n " + " -n ".join(self.schemas)

if self.single_transaction:
cmd += " --single-transaction"
cmd += " {}".format(self.settings["NAME"])

cmd += f" {self.settings['NAME']}"
Archmonger marked this conversation as resolved.
Show resolved Hide resolved
cmd = f"{self.restore_prefix} {cmd} {self.restore_suffix}"
stdout, stderr = self.run_command(cmd, stdin=dump, env=self.restore_env)
return stdout, stderr
Expand All @@ -75,12 +88,15 @@ class PgDumpGisConnector(PgDumpConnector):

def _enable_postgis(self):
cmd = f'{self.psql_cmd} -c "CREATE EXTENSION IF NOT EXISTS postgis;"'
cmd += " --username={}".format(self.settings["ADMIN_USER"])
cmd += f" --username={self.settings['ADMIN_USER']}"
cmd += " --no-password"

if self.settings.get("HOST"):
cmd += " --host={}".format(self.settings["HOST"])
cmd += f" --host={self.settings['HOST']}"

if self.settings.get("PORT"):
cmd += " --port={}".format(self.settings["PORT"])
cmd += f" --port={self.settings['PORT']}"

return self.run_command(cmd)

def _restore_dump(self, dump):
Expand Down Expand Up @@ -108,8 +124,12 @@ def _create_dump(self):
cmd += " --format=custom"
for table in self.exclude:
cmd += f" --exclude-table-data={table}"

if self.schemas:
cmd += " -n " + " -n ".join(self.schemas)

cmd = f"{self.dump_prefix} {cmd} {self.dump_suffix}"
stdout, stderr = self.run_command(cmd, env=self.dump_env)
stdout, _ = self.run_command(cmd, env=self.dump_env)
return stdout

def _restore_dump(self, dump):
Expand All @@ -118,8 +138,13 @@ def _restore_dump(self, dump):

if self.single_transaction:
cmd += " --single-transaction"

if self.drop:
cmd += " --clean"

if self.schemas:
cmd += " -n " + " -n ".join(self.schemas)

cmd = f"{self.restore_prefix} {cmd} {self.restore_suffix}"
stdout, stderr = self.run_command(cmd, stdin=dump, env=self.restore_env)
return stdout, stderr
36 changes: 30 additions & 6 deletions dbbackup/management/commands/dbbackup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


class Command(BaseDbBackupCommand):
help = "Backup a database, encrypt and/or compress and write to " "storage." ""
help = "Backup a database, encrypt and/or compress and write to 'storage.'"
Archmonger marked this conversation as resolved.
Show resolved Hide resolved
content_type = "db"

option_list = BaseDbBackupCommand.option_list + (
Expand Down Expand Up @@ -49,7 +49,10 @@ class Command(BaseDbBackupCommand):
help="Encrypt the backup files",
),
make_option(
"-o", "--output-filename", default=None, help="Specify filename on storage"
"-o",
"--output-filename",
default=None,
help="Specify filename on storage",
Archmonger marked this conversation as resolved.
Show resolved Hide resolved
),
make_option(
"-O",
Expand All @@ -58,7 +61,17 @@ class Command(BaseDbBackupCommand):
help="Specify where to store on local filesystem",
),
make_option(
"-x", "--exclude-tables", default=None, help="Exclude tables from backup"
"-x",
"--exclude-tables",
default=None,
help="Exclude tables from backup",
),
make_option(
"-n",
"--schema",
action="append",
default=[],
help="Specify schema(s) to backup. Can be used multiple times.",
Archmonger marked this conversation as resolved.
Show resolved Hide resolved
),
)

Expand All @@ -78,6 +91,7 @@ def handle(self, **options):
self.path = options.get("output_path")
self.exclude_tables = options.get("exclude_tables")
self.storage = get_storage()
self.schemas = options.get("schema")

self.database = options.get("database") or ""
database_keys = self.database.split(",") or settings.DATABASES
Expand All @@ -100,23 +114,33 @@ def _save_new_backup(self, database):
"""
Save a new backup file.
"""
self.logger.info("Backing Up Database: %s", database["NAME"])
# Get backup and name
self.logger.info(f"Backing Up Database: {database['NAME']}")
# Get backup, schema and name
filename = self.connector.generate_filename(self.servername)

if self.schemas:
self.connector.schemas = self.schemas

outputfile = self.connector.create_dump()

# Apply trans
if self.compress:
compressed_file, filename = utils.compress_file(outputfile, filename)
outputfile = compressed_file

if self.encrypt:
encrypted_file, filename = utils.encrypt_file(outputfile, filename)
outputfile = encrypted_file

# Set file name
filename = self.filename or filename
self.logger.debug("Backup size: %s", utils.handle_size(outputfile))
self.logger.info(f"Backup tempfile created: {utils.handle_size(outputfile)}")

# Store backup
outputfile.seek(0)

if self.path is None:
self.write_to_storage(outputfile, filename)

else:
self.write_local_file(outputfile, self.path)
42 changes: 33 additions & 9 deletions dbbackup/management/commands/dbrestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,24 @@


class Command(BaseDbBackupCommand):
help = """Restore a database backup from storage, encrypted and/or
compressed."""
help = "Restore a database backup from storage, encrypted and/or compressed."
content_type = "db"

option_list = BaseDbBackupCommand.option_list + (
make_option("-d", "--database", help="Database to restore"),
make_option("-i", "--input-filename", help="Specify filename to backup from"),
make_option(
"-I", "--input-path", help="Specify path on local filesystem to backup from"
"-d",
"--database",
help="Database to restore",
),
make_option(
"-i",
"--input-filename",
help="Specify filename to backup from",
),
make_option(
"-I",
"--input-path",
help="Specify path on local filesystem to backup from",
),
make_option(
"-s",
Expand All @@ -46,6 +55,13 @@ class Command(BaseDbBackupCommand):
default=False,
help="Uncompress gzip data before restoring",
),
make_option(
"-n",
"--schema",
action="append",
default=[],
help="Specify schema(s) to restore. Can be used multiple times.",
),
)

def handle(self, *args, **options):
Expand All @@ -68,6 +84,7 @@ def handle(self, *args, **options):
self.input_database_name
)
self.storage = get_storage()
self.schemas = options.get("schema")
self._restore_backup()
except StorageError as err:
raise CommandError(err) from err
Expand All @@ -91,11 +108,14 @@ def _restore_backup(self):
input_filename, input_file = self._get_backup_file(
database=self.input_database_name, servername=self.servername
)

self.logger.info(
"Restoring backup for database '%s' and server '%s'",
self.database_name,
self.servername,
f"Restoring backup for database '{self.database_name}' and server '{self.servername}'"
)

if self.schemas:
self.logger.info(f"Restoring schemas: {self.schemas}")

self.logger.info(f"Restoring: {input_filename}")

if self.decrypt:
Expand All @@ -111,10 +131,14 @@ def _restore_backup(self):
input_file.close()
input_file = uncompressed_file

self.logger.info("Restore tempfile created: %s", utils.handle_size(input_file))
self.logger.info(f"Restore tempfile created: {utils.handle_size(input_file)}")
if self.interactive:
self._ask_confirmation()

input_file.seek(0)
self.connector = get_connector(self.database_name)

if self.schemas:
self.connector.schemas = self.schemas

self.connector.restore_dump(input_file)
10 changes: 9 additions & 1 deletion dbbackup/tests/commands/test_dbbackup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
Tests for dbbackup command.
"""
import os
from unittest.mock import patch

from django.test import TestCase
from mock import patch

from dbbackup.db.base import get_connector
from dbbackup.management.commands.dbbackup import Command as DbbackupCommand
Expand All @@ -26,6 +26,7 @@ def setUp(self):
self.command.stdout = DEV_NULL
self.command.filename = None
self.command.path = None
self.command.schemas = []

def tearDown(self):
clean_gpg_keys()
Expand All @@ -49,6 +50,12 @@ def test_path(self):
# tearDown
os.remove(self.command.path)

def test_schema(self):
self.command.schemas = ["public"]
result = self.command._save_new_backup(TEST_DATABASE)

self.assertIsNone(result)


@patch("dbbackup.settings.GPG_RECIPIENT", "test@test")
@patch("sys.stdout", DEV_NULL)
Expand All @@ -65,6 +72,7 @@ def setUp(self):
self.command.filename = None
self.command.path = None
self.command.connector = get_connector("default")
self.command.schemas = []

def tearDown(self):
clean_gpg_keys()
Expand Down
6 changes: 5 additions & 1 deletion dbbackup/tests/commands/test_dbrestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
"""
from shutil import copyfileobj
from tempfile import mktemp
from unittest.mock import patch

from django.conf import settings
from django.core.files import File
from django.core.management.base import CommandError
from django.test import TestCase
from mock import patch

from dbbackup import utils
from dbbackup.db.base import get_connector
Expand Down Expand Up @@ -46,6 +46,8 @@ def setUp(self):
self.command.input_database_name = None
self.command.database_name = "default"
self.command.connector = get_connector("default")
self.command.schemas = []
self.command.no_owner = False
Archmonger marked this conversation as resolved.
Show resolved Hide resolved
HANDLED_FILES.clean()

def tearDown(self):
Expand Down Expand Up @@ -146,6 +148,8 @@ def setUp(self):
self.command.database_name = "mongo"
self.command.input_database_name = None
self.command.servername = HOSTNAME
self.command.schemas = []
self.command.no_owner = False
HANDLED_FILES.clean()
add_private_gpg()

Expand Down
Loading