diff --git a/awx/__init__.py b/awx/__init__.py index 6b2f809c3027..30b672533222 100644 --- a/awx/__init__.py +++ b/awx/__init__.py @@ -52,14 +52,6 @@ def version_file(): MODE = 'production' -try: - import django # noqa: F401 -except ImportError: - pass -else: - from django.db import connection - - def prepare_env(): # Update the default settings environment variable based on current mode. os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'awx.settings.%s' % MODE) @@ -78,14 +70,6 @@ def manage(): from django.conf import settings from django.core.management import execute_from_command_line - # enforce the postgres version is a minimum of 12 (we need this for partitioning); if not, then terminate program with exit code of 1 - # In the future if we require a feature of a version of postgres > 12 this should be updated to reflect that. - # The return of connection.pg_version is something like 12013 - if not os.getenv('SKIP_PG_VERSION_CHECK', False) and not MODE == 'development': - if (connection.pg_version // 10000) < 12: - sys.stderr.write("At a minimum, postgres version 12 is required\n") - sys.exit(1) - if len(sys.argv) >= 2 and sys.argv[1] in ('version', '--version'): # pragma: no cover sys.stdout.write('%s\n' % __version__) # If running as a user without permission to read settings, display an diff --git a/awx/main/apps.py b/awx/main/apps.py index 3d9896374743..0e89034bd364 100644 --- a/awx/main/apps.py +++ b/awx/main/apps.py @@ -2,9 +2,13 @@ from django.apps import AppConfig from django.utils.translation import gettext_lazy as _ +from django.core.management.base import CommandError +from django.db.models.signals import pre_migrate + from awx.main.utils.common import bypass_in_test, load_all_entry_points_for from awx.main.utils.migration import is_database_synchronized from awx.main.utils.named_url_graph import _customize_graph, generate_graph +from awx.main.utils.db import db_requirement_violations from awx.conf import register, fields from awx_plugins.interfaces._temporary_private_licensing_api import detect_server_product_name @@ -14,6 +18,11 @@ class MainConfig(AppConfig): name = 'awx.main' verbose_name = _('Main') + def check_db_requirement(self, *args, **kwargs): + violations = db_requirement_violations() + if violations: + raise CommandError(violations) + def load_named_url_feature(self): models = [m for m in self.get_models() if hasattr(m, 'get_absolute_url')] generate_graph(models) @@ -85,3 +94,4 @@ def ready(self): self.load_credential_types_feature() self.load_named_url_feature() self.load_inventory_plugins() + pre_migrate.connect(self.check_db_requirement, sender=self) diff --git a/awx/main/management/commands/check_db.py b/awx/main/management/commands/check_db.py index e490e7a0e1ea..0d34340f3d20 100644 --- a/awx/main/management/commands/check_db.py +++ b/awx/main/management/commands/check_db.py @@ -1,9 +1,11 @@ # Copyright (c) 2015 Ansible, Inc. # All Rights Reserved -from django.core.management.base import BaseCommand +from django.core.management.base import BaseCommand, CommandError from django.db import connection +from awx.main.utils.db import db_requirement_violations + class Command(BaseCommand): """Checks connection to the database, and prints out connection info if not connected""" @@ -13,4 +15,8 @@ def handle(self, *args, **options): cursor.execute("SELECT version()") version = str(cursor.fetchone()[0]) + violations = db_requirement_violations() + if violations: + raise CommandError(violations) + return "Database Version: {}".format(version) diff --git a/awx/main/utils/db.py b/awx/main/utils/db.py index 8cc6aacce9f2..8f549f80c229 100644 --- a/awx/main/utils/db.py +++ b/awx/main/utils/db.py @@ -1,10 +1,34 @@ # Copyright (c) 2017 Ansible by Red Hat # All Rights Reserved. +from typing import Optional from awx.settings.application_name import set_application_name +from awx import MODE + from django.conf import settings +from django.db import connection def set_connection_name(function): set_application_name(settings.DATABASES, settings.CLUSTER_HOST_ID, function=function) + + +MIN_PG_VERSION = 12 + + +def db_requirement_violations() -> Optional[str]: + if connection.vendor == 'postgresql': + + # enforce the postgres version is a minimum of 12 (we need this for partitioning); if not, then terminate program with exit code of 1 + # In the future if we require a feature of a version of postgres > 12 this should be updated to reflect that. + # The return of connection.pg_version is something like 12013 + major_version = connection.pg_version // 10000 + if major_version < MIN_PG_VERSION: + return f"At a minimum, postgres version {MIN_PG_VERSION} is required, found {major_version}\n" + + return None + else: + if MODE == 'production': + return f"Running server with '{connection.vendor}' type database is not supported\n" + return None