diff --git a/tests/unit/context/test_context.py b/tests/unit/context/test_context.py index 10e591093ee..0c4de211a09 100644 --- a/tests/unit/context/test_context.py +++ b/tests/unit/context/test_context.py @@ -1,14 +1,20 @@ +import importlib import os -from typing import Any, Dict, Set +import re +from argparse import Namespace +from copy import deepcopy +from typing import Any, Dict, Mapping, Optional, Set from unittest import mock import pytest +import pytz import dbt_common.exceptions from dbt.adapters import factory, postgres from dbt.clients.jinja import MacroStack from dbt.config.project import VarProvider from dbt.context import base, docs, macros, providers, query_header +from dbt.context.base import Var from dbt.contracts.files import FileHash from dbt.contracts.graph.nodes import ( DependsOn, @@ -20,6 +26,7 @@ ) from dbt.node_types import NodeType from dbt_common.events.functions import reset_metadata_vars +from dbt_common.helper_types import WarnErrorOptions from tests.unit.mock_adapter import adapter_factory from tests.unit.utils import clear_plugin, config_from_parts_or_dicts, inject_adapter @@ -277,6 +284,352 @@ def assert_has_keys(required_keys: Set[str], maybe_keys: Set[str], ctx: Dict[str } +def clean_value(value): + if isinstance(value, set): + return set(value) + elif isinstance(value, Namespace): + return value.__dict__ + elif isinstance(value, Var): + return {k: v for k, v in value._merged.items()} + elif isinstance(value, bool): + return value + elif value is None: + return None + elif isinstance(value, int): + return value + else: + value_str = str(value) + value_str = re.sub(r" at 0x[0-9a-fA-F]+>", ">", value_str) + value_str = re.sub(r" id='[0-9]+'>", ">", value_str) + return value_str + + +def walk_dict(dictionary): + skip_paths = [ + ["invocation_id"], + ["builtins", "invocation_id"], + ["dbt_version"], + ["builtins", "dbt_version"], + ["invocation_args_dict", "invocation_command"], + ["run_started_at"], + ["builtins", "run_started_at"], + ["selected_resources"], + ["builtins", "selected_resources"], + ] + + stack = [(dictionary, [])] + visited = set() # Set to keep track of visited dictionary objects + + while stack: + current_dict, path = stack.pop(0) + + if id(current_dict) in visited: + continue + + visited.add(id(current_dict)) + + for key, value in current_dict.items(): + current_path = path + [key] + + if isinstance(value, Mapping): + stack.append((value, current_path)) + else: + if current_path not in skip_paths: + cv = clean_value(value) + if current_path == ["flags"]: + del cv["INVOCATION_COMMAND"] + + yield (tuple(current_path), cv) + + +def add_prefix(path_dict, prefix): + return {prefix + k: v for k, v in path_dict.items()} + + +def get_module_exports(module_name: str, filter_set: Optional[Set[str]] = None): + module = importlib.import_module(module_name) + export_names = filter_set or module.__all__ + + return { + ("modules", module_name, export): clean_value(getattr(module, export)) + for export in export_names + } + + +PYTZ_COUNTRY_TIMEZONES = { + ("modules", "pytz", "country_timezones", country_code): str(timezones) + for country_code, timezones in pytz.country_timezones.items() +} + +PYTZ_COUNTRY_NAMES = { + ("modules", "pytz", "country_names", country_code): country_name + for country_code, country_name in pytz.country_names.items() +} + +COMMON_FLAGS_INVOCATION_ARGS = { + "CACHE_SELECTED_ONLY": False, + "LOG_FORMAT": "default", + "LOG_PATH": "logs", + "SEND_ANONYMOUS_USAGE_STATS": True, + "INDIRECT_SELECTION": "eager", + "INTROSPECT": True, + "PARTIAL_PARSE": True, + "PRINTER_WIDTH": 80, + "QUIET": False, + "STATIC_PARSER": True, + "USE_COLORS": True, + "VERSION_CHECK": True, + "WRITE_JSON": True, +} + +COMMON_FLAGS = { + **COMMON_FLAGS_INVOCATION_ARGS, + "LOG_CACHE_EVENTS": False, + "FAIL_FAST": False, + "DEBUG": False, + "WARN_ERROR": None, + "WARN_ERROR_OPTIONS": WarnErrorOptions(include=[], exclude=[]), + "USE_EXPERIMENTAL_PARSER": False, + "NO_PRINT": None, + "PROFILES_DIR": None, + "TARGET_PATH": None, + "EMPTY": None, + "FULL_REFRESH": False, + "STORE_FAILURES": False, + "WHICH": "run", +} + + +COMMON_BUILTINS = { + ("diff_of_two_dicts",): "", + ("flags",): COMMON_FLAGS, + ("fromjson",): "", + ("fromyaml",): "", + ("local_md5",): "", + ("log",): "", + ("print",): "", + ("project_name",): "root", + ("return",): "", + ("set",): "", + ("set_strict",): "", + ("thread_id",): "MainThread", + ("tojson",): "", + ("toyaml",): "", + ("var",): {}, + ("zip",): "", + ("zip_strict",): "", +} + +COMMON_RUNTIME_CONTEXT = { + **COMMON_BUILTINS, + **add_prefix(COMMON_BUILTINS, ("builtins",)), + ("target", "host"): "localhost", + ("target", "port"): 1, + ("target", "user"): "test", + ("target", "database"): "test", + ("target", "schema"): "analytics", + ("target", "connect_timeout"): 10, + ("target", "role"): None, + ("target", "search_path"): None, + ("target", "keepalives_idle"): 0, + ("target", "sslmode"): None, + ("target", "sslcert"): None, + ("target", "sslkey"): None, + ("target", "sslrootcert"): None, + ("target", "application_name"): "dbt", + ("target", "retries"): 1, + ("target", "dbname"): "test", + ("target", "type"): "postgres", + ("target", "threads"): 1, + ("target", "name"): "test", + ("target", "target_name"): "test", + ("target", "profile_name"): "test", + **get_module_exports("datetime", {"date", "datetime", "time", "timedelta", "tzinfo"}), + **get_module_exports("re"), + **get_module_exports( + "itertools", + { + "count", + "cycle", + "repeat", + "accumulate", + "chain", + "compress", + "islice", + "starmap", + "tee", + "zip_longest", + "product", + "permutations", + "combinations", + "combinations_with_replacement", + }, + ), + ("modules", "pytz", "timezone"): "", + ("modules", "pytz", "utc"): "UTC", + ("modules", "pytz", "AmbiguousTimeError"): "", + ("modules", "pytz", "InvalidTimeError"): "", + ("modules", "pytz", "NonExistentTimeError"): "", + ("modules", "pytz", "UnknownTimeZoneError"): "", + ("modules", "pytz", "all_timezones"): str(pytz.all_timezones), + ("modules", "pytz", "all_timezones_set"): set(pytz.all_timezones_set), + ("modules", "pytz", "common_timezones"): str(pytz.common_timezones), + ("modules", "pytz", "common_timezones_set"): set(), + ("modules", "pytz", "BaseTzInfo"): "", + ("modules", "pytz", "FixedOffset"): "", + **PYTZ_COUNTRY_TIMEZONES, + **PYTZ_COUNTRY_NAMES, +} + +MODEL_BUILTINS = { + ("adapter",): "", + ( + "adapter_macro", + ): ">", + ("column",): "", + ("compiled_code",): "", + ("config",): "", + ("context_macro_stack",): "", + ("database",): "dbt", + ("defer_relation",): "", + ( + "env_var", + ): ">", + ("execute",): True, + ("graph",): "", + ( + "load_agate_table", + ): ">", + ( + "load_result", + ): ">", + ("metric",): "", + ("model",): "", + ("post_hooks",): "[]", + ("pre_hooks",): "[]", + ("ref",): "", + ( + "render", + ): ">", + ("schema",): "analytics", + ("source",): "", + ("sql",): "", + ("sql_now",): "", + ( + "store_raw_result", + ): ">", + ( + "store_result", + ): ">", + ( + "submit_python_job", + ): ">", + ("this",): "", + ( + "try_or_compiler_error", + ): ">", + ( + "write", + ): ">", +} + +MODEL_RUNTIME_BUILTINS = { + **MODEL_BUILTINS, +} + +MODEL_EXCEPTIONS = { + ("exceptions", "warn"): "", + ("exceptions", "missing_config"): "", + ("exceptions", "missing_materialization"): "", + ("exceptions", "missing_relation"): "", + ("exceptions", "raise_ambiguous_alias"): "", + ("exceptions", "raise_ambiguous_catalog_match"): "", + ("exceptions", "raise_cache_inconsistent"): "", + ("exceptions", "raise_dataclass_not_dict"): "", + ("exceptions", "raise_compiler_error"): "", + ("exceptions", "raise_database_error"): "", + ("exceptions", "raise_dep_not_found"): "", + ("exceptions", "raise_dependency_error"): "", + ("exceptions", "raise_duplicate_patch_name"): "", + ("exceptions", "raise_duplicate_resource_name"): "", + ( + "exceptions", + "raise_invalid_property_yml_version", + ): "", + ("exceptions", "raise_not_implemented"): "", + ("exceptions", "relation_wrong_type"): "", + ("exceptions", "raise_contract_error"): "", + ("exceptions", "column_type_missing"): "", + ("exceptions", "raise_fail_fast_error"): "", + ( + "exceptions", + "warn_snapshot_timestamp_data_types", + ): "", +} + +MODEL_MACROS = { + ("macro_a",): "", + ("macro_b",): "", +} + +EXPECTED_MODEL_RUNTIME_CONTEXT = deepcopy( + { + **COMMON_RUNTIME_CONTEXT, + **MODEL_RUNTIME_BUILTINS, + **add_prefix(MODEL_RUNTIME_BUILTINS, ("builtins",)), + **MODEL_MACROS, + **add_prefix(MODEL_MACROS, ("root",)), + **add_prefix( + {(k.lower(),): v for k, v in COMMON_FLAGS_INVOCATION_ARGS.items()}, + ("invocation_args_dict",), + ), + ("invocation_args_dict", "profile_dir"): "/dev/null", + ("invocation_args_dict", "warn_error_options", "include"): "[]", + ("invocation_args_dict", "warn_error_options", "exclude"): "[]", + **MODEL_EXCEPTIONS, + ("api", "Column"): "", + ("api", "Relation"): "", + ("validation", "any"): ".validate_any>", + } +) + +EXPECTED_MODEL_RUNTIME_CONTEXT = deepcopy( + { + **COMMON_RUNTIME_CONTEXT, + **MODEL_RUNTIME_BUILTINS, + **add_prefix(MODEL_RUNTIME_BUILTINS, ("builtins",)), + **MODEL_MACROS, + **add_prefix(MODEL_MACROS, ("root",)), + **add_prefix( + {(k.lower(),): v for k, v in COMMON_FLAGS_INVOCATION_ARGS.items()}, + ("invocation_args_dict",), + ), + ("invocation_args_dict", "profile_dir"): "/dev/null", + ("invocation_args_dict", "warn_error_options", "include"): "[]", + ("invocation_args_dict", "warn_error_options", "exclude"): "[]", + **MODEL_EXCEPTIONS, + ("api", "Column"): "", + ("api", "Relation"): "", + ("validation", "any"): ".validate_any>", + } +) + +DOCS_BUILTINS = { + ("doc",): ">", + ("env_var",): ">", +} + +EXPECTED_DOCS_RUNTIME_CONTEXT = deepcopy( + { + **COMMON_RUNTIME_CONTEXT, + **DOCS_BUILTINS, + **add_prefix(DOCS_BUILTINS, ("builtins",)), + } +) + + def model(): return ModelNode( alias="model_one", @@ -475,7 +828,8 @@ def test_model_parse_context(config_postgres, manifest_fx, get_adapter, get_incl manifest=manifest_fx, context_config=mock.MagicMock(), ) - assert_has_keys(REQUIRED_MODEL_KEYS, MAYBE_KEYS, ctx) + actual_model_context = {k: v for (k, v) in walk_dict(ctx)} + assert actual_model_context == EXPECTED_MODEL_RUNTIME_CONTEXT def test_model_runtime_context(config_postgres, manifest_fx, get_adapter, get_include_paths): @@ -484,12 +838,14 @@ def test_model_runtime_context(config_postgres, manifest_fx, get_adapter, get_in config=config_postgres, manifest=manifest_fx, ) - assert_has_keys(REQUIRED_MODEL_KEYS, MAYBE_KEYS, ctx) + actual_model_context = {k: v for (k, v) in walk_dict(ctx)} + assert actual_model_context == EXPECTED_MODEL_RUNTIME_CONTEXT def test_docs_runtime_context(config_postgres): ctx = docs.generate_runtime_docs_context(config_postgres, mock_model(), [], "root") - assert_has_keys(REQUIRED_DOCS_KEYS, MAYBE_KEYS, ctx) + actual_docs_runtime_context = {k: v for (k, v) in walk_dict(ctx)} + assert actual_docs_runtime_context == EXPECTED_DOCS_RUNTIME_CONTEXT def test_macro_namespace_duplicates(config_postgres, manifest_fx):