From dc11843d5157d0269996bed5bc4b2e5014410abe Mon Sep 17 00:00:00 2001 From: colin-rogers-dbt <111200756+colin-rogers-dbt@users.noreply.github.com> Date: Tue, 23 Jan 2024 09:32:39 -0800 Subject: [PATCH 1/3] update dbt-adapters to include unit test feature updates (#43) --- dbt/adapters/base/relation.py | 2 +- .../macros/materializations/tests/helpers.sql | 30 ++++++++ .../macros/materializations/tests/unit.sql | 29 +++++++ .../macros/unit_test_sql/get_fixture_sql.sql | 76 +++++++++++++++++++ 4 files changed, 136 insertions(+), 1 deletion(-) create mode 100644 dbt/include/global_project/macros/materializations/tests/unit.sql create mode 100644 dbt/include/global_project/macros/unit_test_sql/get_fixture_sql.sql diff --git a/dbt/adapters/base/relation.py b/dbt/adapters/base/relation.py index 13f48977..ea03b067 100644 --- a/dbt/adapters/base/relation.py +++ b/dbt/adapters/base/relation.py @@ -228,7 +228,7 @@ def add_ephemeral_prefix(name: str): def create_ephemeral_from( cls: Type[Self], relation_config: RelationConfig, - limit: Optional[int], + limit: Optional[int] = None, ) -> Self: # Note that ephemeral models are based on the name. identifier = cls.add_ephemeral_prefix(relation_config.name) diff --git a/dbt/include/global_project/macros/materializations/tests/helpers.sql b/dbt/include/global_project/macros/materializations/tests/helpers.sql index efc55288..13e640c2 100644 --- a/dbt/include/global_project/macros/materializations/tests/helpers.sql +++ b/dbt/include/global_project/macros/materializations/tests/helpers.sql @@ -12,3 +12,33 @@ {{ "limit " ~ limit if limit != none }} ) dbt_internal_test {%- endmacro %} + + + + +{% macro get_unit_test_sql(main_sql, expected_fixture_sql, expected_column_names) -%} + {{ adapter.dispatch('get_unit_test_sql', 'dbt')(main_sql, expected_fixture_sql, expected_column_names) }} +{%- endmacro %} + +{% macro default__get_unit_test_sql(main_sql, expected_fixture_sql, expected_column_names) -%} +-- Build actual result given inputs +with dbt_internal_unit_test_actual AS ( + select + {% for expected_column_name in expected_column_names %}{{expected_column_name}}{% if not loop.last -%},{% endif %}{%- endfor -%}, {{ dbt.string_literal("actual") }} as actual_or_expected + from ( + {{ main_sql }} + ) _dbt_internal_unit_test_actual +), +-- Build expected result +dbt_internal_unit_test_expected AS ( + select + {% for expected_column_name in expected_column_names %}{{expected_column_name}}{% if not loop.last -%}, {% endif %}{%- endfor -%}, {{ dbt.string_literal("expected") }} as actual_or_expected + from ( + {{ expected_fixture_sql }} + ) _dbt_internal_unit_test_expected +) +-- Union actual and expected results +select * from dbt_internal_unit_test_actual +union all +select * from dbt_internal_unit_test_expected +{%- endmacro %} \ No newline at end of file diff --git a/dbt/include/global_project/macros/materializations/tests/unit.sql b/dbt/include/global_project/macros/materializations/tests/unit.sql new file mode 100644 index 00000000..79d5631b --- /dev/null +++ b/dbt/include/global_project/macros/materializations/tests/unit.sql @@ -0,0 +1,29 @@ +{%- materialization unit, default -%} + + {% set relations = [] %} + + {% set expected_rows = config.get('expected_rows') %} + {% set tested_expected_column_names = expected_rows[0].keys() if (expected_rows | length ) > 0 else get_columns_in_query(sql) %} %} + + {%- set target_relation = this.incorporate(type='table') -%} + {%- set temp_relation = make_temp_relation(target_relation)-%} + {% do run_query(get_create_table_as_sql(True, temp_relation, get_empty_subquery_sql(sql))) %} + {%- set columns_in_relation = adapter.get_columns_in_relation(temp_relation) -%} + {%- set column_name_to_data_types = {} -%} + {%- for column in columns_in_relation -%} + {%- do column_name_to_data_types.update({column.name: column.dtype}) -%} + {%- endfor -%} + + {% set unit_test_sql = get_unit_test_sql(sql, get_expected_sql(expected_rows, column_name_to_data_types), tested_expected_column_names) %} + + {% call statement('main', fetch_result=True) -%} + + {{ unit_test_sql }} + + {%- endcall %} + + {% do adapter.drop_relation(temp_relation) %} + + {{ return({'relations': relations}) }} + +{%- endmaterialization -%} diff --git a/dbt/include/global_project/macros/unit_test_sql/get_fixture_sql.sql b/dbt/include/global_project/macros/unit_test_sql/get_fixture_sql.sql new file mode 100644 index 00000000..2f90a561 --- /dev/null +++ b/dbt/include/global_project/macros/unit_test_sql/get_fixture_sql.sql @@ -0,0 +1,76 @@ +{% macro get_fixture_sql(rows, column_name_to_data_types) %} +-- Fixture for {{ model.name }} +{% set default_row = {} %} + +{%- if not column_name_to_data_types -%} +{%- set columns_in_relation = adapter.get_columns_in_relation(this) -%} +{%- set column_name_to_data_types = {} -%} +{%- for column in columns_in_relation -%} +{%- do column_name_to_data_types.update({column.name: column.dtype}) -%} +{%- endfor -%} +{%- endif -%} + +{%- if not column_name_to_data_types -%} + {{ exceptions.raise_compiler_error("Not able to get columns for unit test '" ~ model.name ~ "' from relation " ~ this) }} +{%- endif -%} + +{%- for column_name, column_type in column_name_to_data_types.items() -%} + {%- do default_row.update({column_name: (safe_cast("null", column_type) | trim )}) -%} +{%- endfor -%} + +{%- for row in rows -%} +{%- do format_row(row, column_name_to_data_types) -%} +{%- set default_row_copy = default_row.copy() -%} +{%- do default_row_copy.update(row) -%} +select +{%- for column_name, column_value in default_row_copy.items() %} {{ column_value }} AS {{ column_name }}{% if not loop.last -%}, {%- endif %} +{%- endfor %} +{%- if not loop.last %} +union all +{% endif %} +{%- endfor -%} + +{%- if (rows | length) == 0 -%} + select + {%- for column_name, column_value in default_row.items() %} {{ column_value }} AS {{ column_name }}{% if not loop.last -%},{%- endif %} + {%- endfor %} + limit 0 +{%- endif -%} +{% endmacro %} + + +{% macro get_expected_sql(rows, column_name_to_data_types) %} + +{%- if (rows | length) == 0 -%} + select * FROM dbt_internal_unit_test_actual + limit 0 +{%- else -%} +{%- for row in rows -%} +{%- do format_row(row, column_name_to_data_types) -%} +select +{%- for column_name, column_value in row.items() %} {{ column_value }} AS {{ column_name }}{% if not loop.last -%}, {%- endif %} +{%- endfor %} +{%- if not loop.last %} +union all +{% endif %} +{%- endfor -%} +{%- endif -%} + +{% endmacro %} + +{%- macro format_row(row, column_name_to_data_types) -%} + +{#-- wrap yaml strings in quotes, apply cast --#} +{%- for column_name, column_value in row.items() -%} +{% set row_update = {column_name: column_value} %} +{%- if column_value is string -%} +{%- set row_update = {column_name: safe_cast(dbt.string_literal(column_value), column_name_to_data_types[column_name]) } -%} +{%- elif column_value is none -%} +{%- set row_update = {column_name: safe_cast('null', column_name_to_data_types[column_name]) } -%} +{%- else -%} +{%- set row_update = {column_name: safe_cast(column_value, column_name_to_data_types[column_name]) } -%} +{%- endif -%} +{%- do row.update(row_update) -%} +{%- endfor -%} + +{%- endmacro -%} From 88065dc499bbfc99c60a57fd18a85e8b915c3c74 Mon Sep 17 00:00:00 2001 From: Mike Alfare <13974384+mikealfare@users.noreply.github.com> Date: Tue, 23 Jan 2024 16:54:03 -0500 Subject: [PATCH 2/3] Allow adapter version to be specified in `__about__.py` for hatch support (#44) --- .../Under the Hood-20240123-121220.yaml | 6 ++++++ dbt/adapters/factory.py | 17 ++++++++++++----- 2 files changed, 18 insertions(+), 5 deletions(-) create mode 100644 .changes/unreleased/Under the Hood-20240123-121220.yaml diff --git a/.changes/unreleased/Under the Hood-20240123-121220.yaml b/.changes/unreleased/Under the Hood-20240123-121220.yaml new file mode 100644 index 00000000..8d01f256 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20240123-121220.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Allow version to be specified in either __version__.py or __about__.py +time: 2024-01-23T12:12:20.529147-05:00 +custom: + Author: mikealfare + Issue: "44" diff --git a/dbt/adapters/factory.py b/dbt/adapters/factory.py index dfe2bc01..e5c7be78 100644 --- a/dbt/adapters/factory.py +++ b/dbt/adapters/factory.py @@ -99,12 +99,9 @@ def load_plugin(self, name: str) -> Type[Credentials]: def register_adapter(self, config: AdapterRequiredConfig, mp_context: SpawnContext) -> None: adapter_name = config.credentials.type adapter_type = self.get_adapter_class_by_name(adapter_name) - adapter_version = import_module(f".{adapter_name}.__version__", "dbt.adapters").version - adapter_version_specifier = VersionSpecifier.from_version_string( - adapter_version - ).to_version_string() + adapter_version = self._adapter_version(adapter_name) fire_event( - AdapterRegistered(adapter_name=adapter_name, adapter_version=adapter_version_specifier) + AdapterRegistered(adapter_name=adapter_name, adapter_version=adapter_version) ) with self.lock: if adapter_name in self.adapters: @@ -114,6 +111,16 @@ def register_adapter(self, config: AdapterRequiredConfig, mp_context: SpawnConte adapter: Adapter = adapter_type(config, mp_context) # type: ignore self.adapters[adapter_name] = adapter + def _adapter_version(self, adapter_name: str) -> str: + try: + raw_version = import_module(f".{adapter_name}.__about__", "dbt.adapters").version + except ModuleNotFoundError: + raw_version = import_module(f".{adapter_name}.__version__", "dbt.adapters").version + return self._validate_version(raw_version) + + def _validate_version(self, raw_version: str) -> str: + return VersionSpecifier.from_version_string(raw_version).to_version_string() + def lookup_adapter(self, adapter_name: str) -> Adapter: return self.adapters[adapter_name] From 9a5bd537ab20c8924da8e11542b0bdec5bd32c55 Mon Sep 17 00:00:00 2001 From: Mike Alfare <13974384+mikealfare@users.noreply.github.com> Date: Wed, 24 Jan 2024 12:37:36 -0500 Subject: [PATCH 3/3] Remove core test fixtures from dbt-tests-adapter (#45) --- dbt/tests/fixtures/__init__.py | 0 dbt/tests/fixtures/project.py | 572 ----------------------------- dbt/tests/util.py | 640 --------------------------------- 3 files changed, 1212 deletions(-) delete mode 100644 dbt/tests/fixtures/__init__.py delete mode 100644 dbt/tests/fixtures/project.py delete mode 100644 dbt/tests/util.py diff --git a/dbt/tests/fixtures/__init__.py b/dbt/tests/fixtures/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/dbt/tests/fixtures/project.py b/dbt/tests/fixtures/project.py deleted file mode 100644 index cf99587b..00000000 --- a/dbt/tests/fixtures/project.py +++ /dev/null @@ -1,572 +0,0 @@ -# These are the fixtures that are used in dbt core functional tests -# -# The main functional test fixture is the 'project' fixture, which combines -# other fixtures, writes out a dbt project in a temporary directory, creates a temp -# schema in the testing database, and returns a `TestProjInfo` object that -# contains information from the other fixtures for convenience. -# -# The models, macros, seeds, snapshots, tests, and analyses fixtures all -# represent directories in a dbt project, and are all dictionaries with -# file name keys and file contents values. -# -# The other commonly used fixture is 'project_config_update'. Other -# occasionally used fixtures are 'profiles_config_update', 'packages', -# and 'selectors'. -# -# Most test cases have fairly small files which are best included in -# the test case file itself as string variables, to make it easy to -# understand what is happening in the test. Files which are used -# in multiple test case files can be included in a common file, such as -# files.py or fixtures.py. Large files, such as seed files, which would -# just clutter the test file can be pulled in from 'data' subdirectories -# in the test directory. -# -# Test logs are written in the 'logs' directory in the root of the repo. -# Every test case writes to a log directory with the same 'prefix' as the -# test's unique schema. -# -# These fixture have "class" scope. Class scope fixtures can be used both -# in classes and in single test functions (which act as classes for this -# purpose). Pytest will collect all classes starting with 'Test', so if -# you have a class that you want to be subclassed, it's generally best to -# not start the class name with 'Test'. All standalone functions starting with -# 'test_' and methods in classes starting with 'test_' (in classes starting -# with 'Test') will be collected. -# -# Please see the pytest docs for further information: -# https://docs.pytest.org -from argparse import Namespace -from datetime import datetime -import os -from pathlib import Path -import random -import warnings - -# TODO: replace this runner to avoid a dependency on dbt-core -from dbt.parser.manifest import ManifestLoader -from dbt.context.providers import generate_runtime_macro_context -from dbt.config.runtime import RuntimeConfig -from dbt.events.logging import setup_event_logger -from dbt.mp_context import get_mp_context -from dbt_common.events.event_manager_client import cleanup_event_logger -from dbt_common.exceptions import CompilationError, DbtDatabaseError -import pytest -import yaml - -from dbt.adapters.factory import ( - get_adapter, - get_adapter_by_type, - register_adapter, - reset_adapters, -) - -from dbt.tests.util import ( - TestProcessingException, - get_connection, - run_sql_with_adapter, - write_file, -) - - -# Used in constructing the unique_schema and logs_dir -@pytest.fixture(scope="class") -def prefix(): - # create a directory name that will be unique per test session - _randint = random.randint(0, 9999) - _runtime_timedelta = datetime.utcnow() - datetime(1970, 1, 1, 0, 0, 0) - _runtime = (int(_runtime_timedelta.total_seconds() * 1e6)) + _runtime_timedelta.microseconds - prefix = f"test{_runtime}{_randint:04}" - return prefix - - -# Every test has a unique schema -@pytest.fixture(scope="class") -def unique_schema(request, prefix) -> str: - test_file = request.module.__name__ - # We only want the last part of the name - test_file = test_file.split(".")[-1] - unique_schema = f"{prefix}_{test_file}" - return unique_schema - - -# Create a directory for the profile using tmpdir fixture -@pytest.fixture(scope="class") -def profiles_root(tmpdir_factory): - return tmpdir_factory.mktemp("profile") - - -# Create a directory for the project using tmpdir fixture -@pytest.fixture(scope="class") -def project_root(tmpdir_factory): - # tmpdir docs - https://docs.pytest.org/en/6.2.x/tmpdir.html - project_root = tmpdir_factory.mktemp("project") - print(f"\n=== Test project_root: {project_root}") - return project_root - - -# This is for data used by multiple tests, in the 'tests/data' directory -@pytest.fixture(scope="session") -def shared_data_dir(request): - return os.path.join(request.config.rootdir, "tests", "data") - - -# This is for data for a specific test directory, i.e. tests/basic/data -@pytest.fixture(scope="module") -def test_data_dir(request): - return os.path.join(request.fspath.dirname, "data") - - -# This contains the profile target information, for simplicity in setting -# up different profiles, particularly in the adapter repos. -# Note: because we load the profile to create the adapter, this -# fixture can't be used to test vars and env_vars or errors. The -# profile must be written out after the test starts. -@pytest.fixture(scope="class") -def dbt_profile_target(): - return { - "type": "postgres", - "threads": 4, - "host": "localhost", - "port": int(os.getenv("POSTGRES_TEST_PORT", 5432)), - "user": os.getenv("POSTGRES_TEST_USER", "root"), - "pass": os.getenv("POSTGRES_TEST_PASS", "password"), - "dbname": os.getenv("POSTGRES_TEST_DATABASE", "dbt"), - } - - -@pytest.fixture(scope="class") -def profile_user(dbt_profile_target): - return dbt_profile_target["user"] - - -# This fixture can be overridden in a project. The data provided in this -# fixture will be merged into the default project dictionary via a python 'update'. -@pytest.fixture(scope="class") -def profiles_config_update(): - return {} - - -# The profile dictionary, used to write out profiles.yml. It will pull in updates -# from two separate sources, the 'profile_target' and 'profiles_config_update'. -# The second one is useful when using alternative targets, etc. -@pytest.fixture(scope="class") -def dbt_profile_data(unique_schema, dbt_profile_target, profiles_config_update): - profile = { - "test": { - "outputs": { - "default": {}, - }, - "target": "default", - }, - } - target = dbt_profile_target - target["schema"] = unique_schema - profile["test"]["outputs"]["default"] = target - - if profiles_config_update: - profile.update(profiles_config_update) - return profile - - -# Write out the profile data as a yaml file -@pytest.fixture(scope="class") -def profiles_yml(profiles_root, dbt_profile_data): - os.environ["DBT_PROFILES_DIR"] = str(profiles_root) - write_file(yaml.safe_dump(dbt_profile_data), profiles_root, "profiles.yml") - yield dbt_profile_data - del os.environ["DBT_PROFILES_DIR"] - - -# Data used to update the dbt_project config data. -@pytest.fixture(scope="class") -def project_config_update(): - return {} - - -# Combines the project_config_update dictionary with project_config defaults to -# produce a project_yml config and write it out as dbt_project.yml -@pytest.fixture(scope="class") -def dbt_project_yml(project_root, project_config_update): - project_config = { - "name": "test", - "profile": "test", - "flags": {"send_anonymous_usage_stats": False}, - } - if project_config_update: - if isinstance(project_config_update, dict): - project_config.update(project_config_update) - elif isinstance(project_config_update, str): - updates = yaml.safe_load(project_config_update) - project_config.update(updates) - write_file(yaml.safe_dump(project_config), project_root, "dbt_project.yml") - return project_config - - -# Fixture to provide dependencies -@pytest.fixture(scope="class") -def dependencies(): - return {} - - -# Write out the dependencies.yml file -# Write out the packages.yml file -@pytest.fixture(scope="class") -def dependencies_yml(project_root, dependencies): - if dependencies: - if isinstance(dependencies, str): - data = dependencies - else: - data = yaml.safe_dump(dependencies) - write_file(data, project_root, "dependencies.yml") - - -# Fixture to provide packages as either yaml or dictionary -@pytest.fixture(scope="class") -def packages(): - return {} - - -# Write out the packages.yml file -@pytest.fixture(scope="class") -def packages_yml(project_root, packages): - if packages: - if isinstance(packages, str): - data = packages - else: - data = yaml.safe_dump(packages) - write_file(data, project_root, "packages.yml") - - -# Fixture to provide selectors as either yaml or dictionary -@pytest.fixture(scope="class") -def selectors(): - return {} - - -# Write out the selectors.yml file -@pytest.fixture(scope="class") -def selectors_yml(project_root, selectors): - if selectors: - if isinstance(selectors, str): - data = selectors - else: - data = yaml.safe_dump(selectors) - write_file(data, project_root, "selectors.yml") - - -# This fixture ensures that the logging infrastructure does not accidentally -# reuse streams configured on previous test runs, which might now be closed. -# It should be run before (and so included as a parameter by) any other fixture -# which runs dbt-core functions that might fire events. -@pytest.fixture(scope="class") -def clean_up_logging(): - cleanup_event_logger() - - -# This creates an adapter that is used for running test setup, such as creating -# the test schema, and sql commands that are run in tests prior to the first -# dbt command. After a dbt command is run, the project.adapter property will -# return the current adapter (for this adapter type) from the adapter factory. -# The adapter produced by this fixture will contain the "base" macros (not including -# macros from dependencies). -# -# Anything used here must be actually working (dbt_project, profile, project and internal macros), -# otherwise this will fail. So to test errors in those areas, you need to copy the files -# into the project in the tests instead of putting them in the fixtures. -@pytest.fixture(scope="class") -def adapter( - logs_dir, - unique_schema, - project_root, - profiles_root, - profiles_yml, - dbt_project_yml, - clean_up_logging, -): - # The profiles.yml and dbt_project.yml should already be written out - args = Namespace( - profiles_dir=str(profiles_root), - project_dir=str(project_root), - target=None, - profile=None, - threads=None, - ) - runtime_config = RuntimeConfig.from_args(args) - register_adapter(runtime_config, get_mp_context()) - adapter = get_adapter(runtime_config) - # We only need the base macros, not macros from dependencies, and don't want - # to run 'dbt deps' here. - manifest = ManifestLoader.load_macros( - runtime_config, - adapter.connections.set_query_header, - base_macros_only=True, - ) - - adapter.set_macro_resolver(manifest) - adapter.set_macro_context_generator(generate_runtime_macro_context) - yield adapter - adapter.cleanup_connections() - reset_adapters() - - -# Start at directory level. -def write_project_files(project_root, dir_name, file_dict): - path = project_root.mkdir(dir_name) - if file_dict: - write_project_files_recursively(path, file_dict) - - -# Write files out from file_dict. Can be nested directories... -def write_project_files_recursively(path, file_dict): - if type(file_dict) is not dict: - raise TestProcessingException(f"File dict is not a dict: '{file_dict}' for path '{path}'") - suffix_list = [".sql", ".csv", ".md", ".txt", ".py"] - for name, value in file_dict.items(): - if name.endswith(".yml") or name.endswith(".yaml"): - if isinstance(value, str): - data = value - else: - data = yaml.safe_dump(value) - write_file(data, path, name) - elif name.endswith(tuple(suffix_list)): - write_file(value, path, name) - else: - write_project_files_recursively(path.mkdir(name), value) - - -# models, macros, seeds, snapshots, tests, analyses -# Provide a dictionary of file names to contents. Nested directories -# are handle by nested dictionaries. - - -# models directory -@pytest.fixture(scope="class") -def models(): - return {} - - -# macros directory -@pytest.fixture(scope="class") -def macros(): - return {} - - -# properties directory -@pytest.fixture(scope="class") -def properties(): - return {} - - -# seeds directory -@pytest.fixture(scope="class") -def seeds(): - return {} - - -# snapshots directory -@pytest.fixture(scope="class") -def snapshots(): - return {} - - -# tests directory -@pytest.fixture(scope="class") -def tests(): - return {} - - -# analyses directory -@pytest.fixture(scope="class") -def analyses(): - return {} - - -# Write out the files provided by models, macros, properties, snapshots, seeds, tests, analyses -@pytest.fixture(scope="class") -def project_files(project_root, models, macros, snapshots, properties, seeds, tests, analyses): - write_project_files(project_root, "models", {**models, **properties}) - write_project_files(project_root, "macros", macros) - write_project_files(project_root, "snapshots", snapshots) - write_project_files(project_root, "seeds", seeds) - write_project_files(project_root, "tests", tests) - write_project_files(project_root, "analyses", analyses) - - -# We have a separate logs dir for every test -@pytest.fixture(scope="class") -def logs_dir(request, prefix): - dbt_log_dir = os.path.join(request.config.rootdir, "logs", prefix) - os.environ["DBT_LOG_PATH"] = str(dbt_log_dir) - yield str(Path(dbt_log_dir)) - del os.environ["DBT_LOG_PATH"] - - -# This fixture is for customizing tests that need overrides in adapter -# repos. Example in dbt.tests.adapter.basic.test_base. -@pytest.fixture(scope="class") -def test_config(): - return {} - - -# This class is returned from the 'project' fixture, and contains information -# from the pytest fixtures that may be needed in the test functions, including -# a 'run_sql' method. -class TestProjInfo: - def __init__( - self, - project_root, - profiles_dir, - adapter_type, - test_dir, - shared_data_dir, - test_data_dir, - test_schema, - database, - test_config, - ): - self.project_root = project_root - self.profiles_dir = profiles_dir - self.adapter_type = adapter_type - self.test_dir = test_dir - self.shared_data_dir = shared_data_dir - self.test_data_dir = test_data_dir - self.test_schema = test_schema - self.database = database - self.test_config = test_config - self.created_schemas = [] - - @property - def adapter(self): - # This returns the last created "adapter" from the adapter factory. Each - # dbt command will create a new one. This allows us to avoid patching the - # providers 'get_adapter' function. - return get_adapter_by_type(self.adapter_type) - - # Run sql from a path - def run_sql_file(self, sql_path, fetch=None): - with open(sql_path, "r") as f: - statements = f.read().split(";") - for statement in statements: - self.run_sql(statement, fetch) - - # Run sql from a string, using adapter saved at test startup - def run_sql(self, sql, fetch=None): - return run_sql_with_adapter(self.adapter, sql, fetch=fetch) - - # Create the unique test schema. Used in test setup, so that we're - # ready for initial sql prior to a run_dbt command. - def create_test_schema(self, schema_name=None): - if schema_name is None: - schema_name = self.test_schema - with get_connection(self.adapter): - relation = self.adapter.Relation.create(database=self.database, schema=schema_name) - self.adapter.create_schema(relation) - self.created_schemas.append(schema_name) - - # Drop the unique test schema, usually called in test cleanup - def drop_test_schema(self): - if self.adapter.get_macro_resolver() is None: - manifest = ManifestLoader.load_macros( - self.adapter.config, - self.adapter.connections.set_query_header, - base_macros_only=True, - ) - self.adapter.set_macro_resolver(manifest) - - with get_connection(self.adapter): - for schema_name in self.created_schemas: - relation = self.adapter.Relation.create(database=self.database, schema=schema_name) - self.adapter.drop_schema(relation) - self.created_schemas = [] - - # This return a dictionary of table names to 'view' or 'table' values. - def get_tables_in_schema(self): - sql = """ - select table_name, - case when table_type = 'BASE TABLE' then 'table' - when table_type = 'VIEW' then 'view' - else table_type - end as materialization - from information_schema.tables - where {} - order by table_name - """ - sql = sql.format("{} ilike '{}'".format("table_schema", self.test_schema)) - result = self.run_sql(sql, fetch="all") - return {model_name: materialization for (model_name, materialization) in result} - - -# This is the main fixture that is used in all functional tests. It pulls in the other -# fixtures that are necessary to set up a dbt project, and saves some of the information -# in a TestProjInfo class, which it returns, so that individual test cases do not have -# to pull in the other fixtures individually to access their information. -@pytest.fixture(scope="class") -def project( - clean_up_logging, - project_root, - profiles_root, - request, - unique_schema, - profiles_yml, - dbt_project_yml, - packages_yml, - dependencies_yml, - selectors_yml, - adapter, - project_files, - shared_data_dir, - test_data_dir, - logs_dir, - test_config, -): - # Logbook warnings are ignored so we don't have to fork logbook to support python 3.10. - # This _only_ works for tests in `tests/` that use the project fixture. - warnings.filterwarnings("ignore", category=DeprecationWarning, module="logbook") - log_flags = Namespace( - LOG_PATH=logs_dir, - LOG_FORMAT="json", - LOG_FORMAT_FILE="json", - USE_COLORS=False, - USE_COLORS_FILE=False, - LOG_LEVEL="info", - LOG_LEVEL_FILE="debug", - DEBUG=False, - LOG_CACHE_EVENTS=False, - QUIET=False, - LOG_FILE_MAX_BYTES=1000000, - ) - setup_event_logger(log_flags) - orig_cwd = os.getcwd() - os.chdir(project_root) - # Return whatever is needed later in tests but can only come from fixtures, so we can keep - # the signatures in the test signature to a minimum. - project = TestProjInfo( - project_root=project_root, - profiles_dir=profiles_root, - adapter_type=adapter.type(), - test_dir=request.fspath.dirname, - shared_data_dir=shared_data_dir, - test_data_dir=test_data_dir, - test_schema=unique_schema, - database=adapter.config.credentials.database, - test_config=test_config, - ) - project.drop_test_schema() - project.create_test_schema() - - yield project - - # deps, debug and clean commands will not have an installed adapter when running and will raise - # a KeyError here. Just pass for now. - # See https://github.com/dbt-labs/dbt-core/issues/5041 - # The debug command also results in an AttributeError since `Profile` doesn't have - # a `load_dependencies` method. - # Macros gets executed as part of drop_scheme in core/dbt/adapters/sql/impl.py. When - # the macros have errors (which is what we're actually testing for...) they end up - # throwing CompilationErrorss or DatabaseErrors - try: - project.drop_test_schema() - except (KeyError, AttributeError, CompilationError, DbtDatabaseError): - pass - os.chdir(orig_cwd) - cleanup_event_logger() diff --git a/dbt/tests/util.py b/dbt/tests/util.py deleted file mode 100644 index f04305de..00000000 --- a/dbt/tests/util.py +++ /dev/null @@ -1,640 +0,0 @@ -# ============================================================================= -# Test utilities -# run_dbt -# run_dbt_and_capture -# get_manifest -# copy_file -# rm_file -# write_file -# read_file -# mkdir -# rm_dir -# get_artifact -# update_config_file -# write_config_file -# get_unique_ids_in_results -# check_result_nodes_by_name -# check_result_nodes_by_unique_id - -# SQL related utilities that use the adapter -# run_sql_with_adapter -# relation_from_name -# check_relation_types (table/view) -# check_relations_equal -# check_relation_has_expected_schema -# check_relations_equal_with_relations -# check_table_does_exist -# check_table_does_not_exist -# get_relation_columns -# update_rows -# generate_update_clause -# -# Classes for comparing fields in dictionaries -# AnyFloat -# AnyInteger -# AnyString -# AnyStringWith -# ============================================================================= -from contextlib import contextmanager -from datetime import datetime -from io import StringIO -import json -import os -import shutil -from typing import Any, Dict, List, Optional -import warnings - -# TODO: replace this runner to avoid a dependency on dbt-core -from dbt.contracts.graph.manifest import Manifest -from dbt.cli.main import dbtRunner -from dbt.logger import log_manager -from dbt_common.events.functions import ( - capture_stdout_logs, - fire_event, - reset_metadata_vars, - stop_capture_stdout_logs, -) -from dbt_common.events.base_types import EventLevel -from dbt_common.events.types import Note -import yaml - -from dbt.adapters.base.relation import BaseRelation -from dbt.adapters.factory import Adapter - - -def run_dbt( - args: Optional[List[str]] = None, - expect_pass: bool = True, -): - """ - 'run_dbt' is used in pytest tests to run dbt commands. - It will return different objects depending on the command that is executed. - For a run command (and most other commands) it will return a list of results. - For the 'docs generate' command it returns a CatalogArtifact. - - The first parameter is a list of dbt command line arguments, such as - run_dbt(["run", "--vars", "seed_name: base"]) - - If the command is expected to fail, pass in "expect_pass=False"): - run_dbt(["test"], expect_pass=False) - """ - - # Ignore logbook warnings - warnings.filterwarnings("ignore", category=DeprecationWarning, module="logbook") - - # reset global vars - reset_metadata_vars() - - # The logger will complain about already being initialized if - # we don't do this. - log_manager.reset_handlers() - if args is None: - args = ["run"] - - print("\n\nInvoking dbt with {}".format(args)) - from dbt.flags import get_flags - - flags = get_flags() - project_dir = getattr(flags, "PROJECT_DIR", None) - profiles_dir = getattr(flags, "PROFILES_DIR", None) - if project_dir and "--project-dir" not in args: - args.extend(["--project-dir", project_dir]) - if profiles_dir and "--profiles-dir" not in args: - args.extend(["--profiles-dir", profiles_dir]) - - dbt = dbtRunner() - res = dbt.invoke(args) - - # the exception is immediately raised to be caught in tests - # using a pattern like `with pytest.raises(SomeException):` - if res.exception is not None: - raise res.exception - - if expect_pass is not None: - assert res.success == expect_pass, "dbt exit state did not match expected" - - return res.result - - -# Use this if you need to capture the command logs in a test. -# If you want the logs that are normally written to a file, you must -# start with the "--debug" flag. The structured schema log CI test -# will turn the logs into json, so you have to be prepared for that. -def run_dbt_and_capture( - args: Optional[List[str]] = None, - expect_pass: bool = True, -): - try: - stringbuf = StringIO() - capture_stdout_logs(stringbuf) - res = run_dbt(args, expect_pass=expect_pass) - stdout = stringbuf.getvalue() - - finally: - stop_capture_stdout_logs() - - return res, stdout - - -def get_logging_events(log_output, event_name): - logging_events = [] - for log_line in log_output.split("\n"): - # skip empty lines - if len(log_line) == 0: - continue - # The adapter logging also shows up, so skip non-json lines - if not log_line.startswith("{"): - continue - if event_name in log_line: - log_dct = json.loads(log_line) - if log_dct["info"]["name"] == event_name: - logging_events.append(log_dct) - return logging_events - - -# Used in test cases to get the manifest from the partial parsing file -# Note: this uses an internal version of the manifest, and in the future -# parts of it will not be supported for external use. -def get_manifest(project_root) -> Optional[Manifest]: - path = os.path.join(project_root, "target", "partial_parse.msgpack") - if os.path.exists(path): - with open(path, "rb") as fp: - manifest_mp = fp.read() - manifest: Manifest = Manifest.from_msgpack(manifest_mp) - return manifest - else: - return None - - -# Used in test cases to get the run_results.json file. -def get_run_results(project_root) -> Any: - path = os.path.join(project_root, "target", "run_results.json") - if os.path.exists(path): - with open(path) as run_result_text: - return json.load(run_result_text) - else: - return None - - -# Used in tests to copy a file, usually from a data directory to the project directory -def copy_file(src_path, src, dest_path, dest) -> None: - # dest is a list, so that we can provide nested directories, like 'models' etc. - # copy files from the data_dir to appropriate project directory - shutil.copyfile( - os.path.join(src_path, src), - os.path.join(dest_path, *dest), - ) - - -# Used in tests when you want to remove a file from the project directory -def rm_file(*paths) -> None: - # remove files from proj_path - os.remove(os.path.join(*paths)) - - -# Used in tests to write out the string contents of a file to a -# file in the project directory. -# We need to explicitly use encoding="utf-8" because otherwise on -# Windows we'll get codepage 1252 and things might break -def write_file(contents, *paths): - with open(os.path.join(*paths), "w", encoding="utf-8") as fp: - fp.write(contents) - - -def file_exists(*paths): - """Check if file exists at path""" - return os.path.exists(os.path.join(*paths)) - - -# Used in test utilities -def read_file(*paths): - contents = "" - with open(os.path.join(*paths), "r") as fp: - contents = fp.read() - return contents - - -# To create a directory -def mkdir(directory_path): - try: - os.makedirs(directory_path) - except FileExistsError: - raise FileExistsError(f"{directory_path} already exists.") - - -# To remove a directory -def rm_dir(directory_path): - try: - shutil.rmtree(directory_path) - except FileNotFoundError: - raise FileNotFoundError(f"{directory_path} does not exist.") - - -def rename_dir(src_directory_path, dest_directory_path): - os.rename(src_directory_path, dest_directory_path) - - -# Get an artifact (usually from the target directory) such as -# manifest.json or catalog.json to use in a test -def get_artifact(*paths): - contents = read_file(*paths) - dct = json.loads(contents) - return dct - - -def write_artifact(dct, *paths): - json_output = json.dumps(dct) - write_file(json_output, *paths) - - -# For updating yaml config files -def update_config_file(updates, *paths): - current_yaml = read_file(*paths) - config = yaml.safe_load(current_yaml) - config.update(updates) - new_yaml = yaml.safe_dump(config) - write_file(new_yaml, *paths) - - -# Write new config file -def write_config_file(data, *paths): - if type(data) is dict: - data = yaml.safe_dump(data) - write_file(data, *paths) - - -# Get the unique_ids in dbt command results -def get_unique_ids_in_results(results): - unique_ids = [] - for result in results: - unique_ids.append(result.node.unique_id) - return unique_ids - - -# Check the nodes in the results returned by a dbt run command -def check_result_nodes_by_name(results, names): - result_names = [] - for result in results: - result_names.append(result.node.name) - assert set(names) == set(result_names) - - -# Check the nodes in the results returned by a dbt run command -def check_result_nodes_by_unique_id(results, unique_ids): - result_unique_ids = [] - for result in results: - result_unique_ids.append(result.node.unique_id) - assert set(unique_ids) == set(result_unique_ids) - - -# Check datetime is between start and end/now -def check_datetime_between(timestr, start, end=None): - datefmt = "%Y-%m-%dT%H:%M:%S.%fZ" - if end is None: - end = datetime.utcnow() - parsed = datetime.strptime(timestr, datefmt) - assert start <= parsed - assert end >= parsed - - -class TestProcessingException(Exception): - pass - - -# Testing utilities that use adapter code - - -# Uses: -# adapter.config.credentials -# adapter.quote -# adapter.run_sql_for_tests -def run_sql_with_adapter(adapter, sql, fetch=None): - if sql.strip() == "": - return - - # substitute schema and database in sql - kwargs = { - "schema": adapter.config.credentials.schema, - "database": adapter.quote(adapter.config.credentials.database), - } - sql = sql.format(**kwargs) - - msg = f'test connection "__test" executing: {sql}' - fire_event(Note(msg=msg), level=EventLevel.DEBUG) - with get_connection(adapter) as conn: - return adapter.run_sql_for_tests(sql, fetch, conn) - - -# Get a Relation object from the identifier (name of table/view). -# Uses the default database and schema. If you need a relation -# with a different schema, it should be constructed in the test. -# Uses: -# adapter.Relation -# adapter.config.credentials -# Relation.get_default_quote_policy -# Relation.get_default_include_policy -def relation_from_name(adapter, name: str): - """reverse-engineer a relation from a given name and - the adapter. The relation name is split by the '.' character. - """ - - # Different adapters have different Relation classes - cls = adapter.Relation - credentials = adapter.config.credentials - quote_policy = cls.get_default_quote_policy().to_dict() - include_policy = cls.get_default_include_policy().to_dict() - - # Make sure we have database/schema/identifier parts, even if - # only identifier was supplied. - relation_parts = name.split(".") - if len(relation_parts) == 1: - relation_parts.insert(0, credentials.schema) - if len(relation_parts) == 2: - relation_parts.insert(0, credentials.database) - kwargs = { - "database": relation_parts[0], - "schema": relation_parts[1], - "identifier": relation_parts[2], - } - - relation = cls.create( - include_policy=include_policy, - quote_policy=quote_policy, - **kwargs, - ) - return relation - - -# Ensure that models with different materialiations have the -# current table/view. -# Uses: -# adapter.list_relations_without_caching -def check_relation_types(adapter, relation_to_type): - """ - Relation name to table/view - { - "base": "table", - "other": "view", - } - """ - - expected_relation_values = {} - found_relations = [] - schemas = set() - - for key, value in relation_to_type.items(): - relation = relation_from_name(adapter, key) - expected_relation_values[relation] = value - schemas.add(relation.without_identifier()) - - with get_connection(adapter): - for schema in schemas: - found_relations.extend(adapter.list_relations_without_caching(schema)) - - for key, value in relation_to_type.items(): - for relation in found_relations: - # this might be too broad - if relation.identifier == key: - assert relation.type == value, ( - f"Got an unexpected relation type of {relation.type} " - f"for relation {key}, expected {value}" - ) - - -# Replaces assertTablesEqual. assertManyTablesEqual can be replaced -# by doing a separate call for each set of tables/relations. -# Wraps check_relations_equal_with_relations by creating relations -# from the list of names passed in. -def check_relations_equal(adapter, relation_names: List, compare_snapshot_cols=False): - if len(relation_names) < 2: - raise TestProcessingException( - "Not enough relations to compare", - ) - relations = [relation_from_name(adapter, name) for name in relation_names] - return check_relations_equal_with_relations( - adapter, relations, compare_snapshot_cols=compare_snapshot_cols - ) - - -# Used to check that a particular relation has an expected schema -# expected_schema should look like {"column_name": "expected datatype"} -def check_relation_has_expected_schema(adapter, relation_name, expected_schema: Dict): - relation = relation_from_name(adapter, relation_name) - with get_connection(adapter): - actual_columns = {c.name: c.data_type for c in adapter.get_columns_in_relation(relation)} - assert ( - actual_columns == expected_schema - ), f"Actual schema did not match expected, actual: {json.dumps(actual_columns)}" - - -# This can be used when checking relations in different schemas, by supplying -# a list of relations. Called by 'check_relations_equal'. -# Uses: -# adapter.get_columns_in_relation -# adapter.get_rows_different_sql -# adapter.execute -def check_relations_equal_with_relations( - adapter: Adapter, relations: List, compare_snapshot_cols=False -): - with get_connection(adapter): - basis, compares = relations[0], relations[1:] - # Skip columns starting with "dbt_" because we don't want to - # compare those, since they are time sensitive - # (unless comparing "dbt_" snapshot columns is explicitly enabled) - column_names = [ - c.name - for c in adapter.get_columns_in_relation(basis) # type: ignore - if not c.name.lower().startswith("dbt_") or compare_snapshot_cols - ] - - for relation in compares: - sql = adapter.get_rows_different_sql(basis, relation, column_names=column_names) # type: ignore - _, tbl = adapter.execute(sql, fetch=True) - num_rows = len(tbl) - assert ( - num_rows == 1 - ), f"Invalid sql query from get_rows_different_sql: incorrect number of rows ({num_rows})" - num_cols = len(tbl[0]) - assert ( - num_cols == 2 - ), f"Invalid sql query from get_rows_different_sql: incorrect number of cols ({num_cols})" - row_count_difference = tbl[0][0] - assert ( - row_count_difference == 0 - ), f"Got {row_count_difference} difference in row count betwen {basis} and {relation}" - rows_mismatched = tbl[0][1] - assert ( - rows_mismatched == 0 - ), f"Got {rows_mismatched} different rows between {basis} and {relation}" - - -# Uses: -# adapter.update_column_sql -# adapter.execute -# adapter.commit_if_has_connection -def update_rows(adapter, update_rows_config): - """ - { - "name": "base", - "dst_col": "some_date" - "clause": { - "type": "add_timestamp", - "src_col": "some_date", - "where" "id > 10" - } - """ - for key in ["name", "dst_col", "clause"]: - if key not in update_rows_config: - raise TestProcessingException(f"Invalid update_rows: no {key}") - - clause = update_rows_config["clause"] - clause = generate_update_clause(adapter, clause) - - where = None - if "where" in update_rows_config: - where = update_rows_config["where"] - - name = update_rows_config["name"] - dst_col = update_rows_config["dst_col"] - relation = relation_from_name(adapter, name) - - with get_connection(adapter): - sql = adapter.update_column_sql( - dst_name=str(relation), - dst_column=dst_col, - clause=clause, - where_clause=where, - ) - adapter.execute(sql, auto_begin=True) - adapter.commit_if_has_connection() - - -# This is called by the 'update_rows' function. -# Uses: -# adapter.timestamp_add_sql -# adapter.string_add_sql -def generate_update_clause(adapter, clause) -> str: - """ - Called by update_rows function. Expects the "clause" dictionary - documented in 'update_rows. - """ - - if "type" not in clause or clause["type"] not in ["add_timestamp", "add_string"]: - raise TestProcessingException("invalid update_rows clause: type missing or incorrect") - clause_type = clause["type"] - - if clause_type == "add_timestamp": - if "src_col" not in clause: - raise TestProcessingException("Invalid update_rows clause: no src_col") - add_to = clause["src_col"] - kwargs = {k: v for k, v in clause.items() if k in ("interval", "number")} - with get_connection(adapter): - return adapter.timestamp_add_sql(add_to=add_to, **kwargs) - elif clause_type == "add_string": - for key in ["src_col", "value"]: - if key not in clause: - raise TestProcessingException(f"Invalid update_rows clause: no {key}") - src_col = clause["src_col"] - value = clause["value"] - location = clause.get("location", "append") - with get_connection(adapter): - return adapter.string_add_sql(src_col, value, location) - return "" - - -@contextmanager -def get_connection(adapter, name="_test"): - with adapter.connection_named(name): - conn = adapter.connections.get_thread_connection() - yield conn - - -# Uses: -# adapter.get_columns_in_relation -def get_relation_columns(adapter, name): - relation = relation_from_name(adapter, name) - with get_connection(adapter): - columns = adapter.get_columns_in_relation(relation) - return sorted(((c.name, c.dtype, c.char_size) for c in columns), key=lambda x: x[0]) - - -def check_table_does_not_exist(adapter, name): - columns = get_relation_columns(adapter, name) - assert len(columns) == 0 - - -def check_table_does_exist(adapter, name): - columns = get_relation_columns(adapter, name) - assert len(columns) > 0 - - -# Utility classes for enabling comparison of dictionaries - - -class AnyFloat: - """Any float. Use this in assert calls""" - - def __eq__(self, other): - return isinstance(other, float) - - -class AnyInteger: - """Any Integer. Use this in assert calls""" - - def __eq__(self, other): - return isinstance(other, int) - - -class AnyString: - """Any string. Use this in assert calls""" - - def __eq__(self, other): - return isinstance(other, str) - - -class AnyStringWith: - """AnyStringWith("AUTO")""" - - def __init__(self, contains=None): - self.contains = contains - - def __eq__(self, other): - if not isinstance(other, str): - return False - - if self.contains is None: - return True - - return self.contains in other - - def __repr__(self): - return "AnyStringWith<{!r}>".format(self.contains) - - -def assert_message_in_logs(message: str, logs: str, expected_pass: bool = True): - # if the logs are json strings, then 'jsonify' the message because of things like escape quotes - if os.environ.get("DBT_LOG_FORMAT", "") == "json": - message = message.replace(r'"', r"\"") - - if expected_pass: - assert message in logs - else: - assert message not in logs - - -def get_project_config(project): - file_yaml = read_file(project.project_root, "dbt_project.yml") - return yaml.safe_load(file_yaml) - - -def set_project_config(project, config): - config_yaml = yaml.safe_dump(config) - write_file(config_yaml, project.project_root, "dbt_project.yml") - - -def get_model_file(project, relation: BaseRelation) -> str: - return read_file(project.project_root, "models", f"{relation.name}.sql") - - -def set_model_file(project, relation: BaseRelation, model_sql: str): - write_file(model_sql, project.project_root, "models", f"{relation.name}.sql")