Skip to content

Commit

Permalink
case-insensitive comparisons in unit testing, base unit testing test
Browse files Browse the repository at this point in the history
  • Loading branch information
MichelleArk committed Jan 29, 2024
1 parent 60005a0 commit f16cd66
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 11 deletions.
2 changes: 1 addition & 1 deletion dbt/include/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from pkgutil import extend_path

__path__ = extend_path(__path__, __name__)
__path__ = extend_path(__path__, __name__)
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,17 @@

{% macro default__get_unit_test_sql(main_sql, expected_fixture_sql, expected_column_names) -%}
-- Build actual result given inputs
with dbt_internal_unit_test_actual AS (
with dbt_internal_unit_test_actual as (
select
{% for expected_column_name in expected_column_names %}{{expected_column_name}}{% if not loop.last -%},{% endif %}{%- endfor -%}, {{ dbt.string_literal("actual") }} as actual_or_expected
{% for expected_column_name in expected_column_names %}{{expected_column_name}}{% if not loop.last -%},{% endif %}{%- endfor -%}, {{ dbt.string_literal("actual") }} as {{ adapter.quote("actual_or_expected") }}
from (
{{ main_sql }}
) _dbt_internal_unit_test_actual
),
-- Build expected result
dbt_internal_unit_test_expected AS (
dbt_internal_unit_test_expected as (
select
{% for expected_column_name in expected_column_names %}{{expected_column_name}}{% if not loop.last -%}, {% endif %}{%- endfor -%}, {{ dbt.string_literal("expected") }} as actual_or_expected
{% for expected_column_name in expected_column_names %}{{expected_column_name}}{% if not loop.last -%}, {% endif %}{%- endfor -%}, {{ dbt.string_literal("expected") }} as {{ adapter.quote("actual_or_expected") }}
from (
{{ expected_fixture_sql }}
) _dbt_internal_unit_test_expected
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
{%- set columns_in_relation = adapter.get_columns_in_relation(temp_relation) -%}
{%- set column_name_to_data_types = {} -%}
{%- for column in columns_in_relation -%}
{%- do column_name_to_data_types.update({column.name: column.dtype}) -%}
{%- do column_name_to_data_types.update({column.name|lower: column.data_type}) -%}
{%- endfor -%}

{% set unit_test_sql = get_unit_test_sql(sql, get_expected_sql(expected_rows, column_name_to_data_types), tested_expected_column_names) %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
{%- set columns_in_relation = adapter.get_columns_in_relation(this) -%}
{%- set column_name_to_data_types = {} -%}
{%- for column in columns_in_relation -%}
{%- do column_name_to_data_types.update({column.name: column.dtype}) -%}
{#-- This needs to be a case-insensitive comparison --#}
{%- do column_name_to_data_types.update({column.name|lower: column.data_type}) -%}
{%- endfor -%}
{%- endif -%}

Expand All @@ -23,7 +24,7 @@
{%- set default_row_copy = default_row.copy() -%}
{%- do default_row_copy.update(row) -%}
select
{%- for column_name, column_value in default_row_copy.items() %} {{ column_value }} AS {{ column_name }}{% if not loop.last -%}, {%- endif %}
{%- for column_name, column_value in default_row_copy.items() %} {{ column_value }} as {{ column_name }}{% if not loop.last -%}, {%- endif %}
{%- endfor %}
{%- if not loop.last %}
union all
Expand All @@ -32,7 +33,7 @@ union all

{%- if (rows | length) == 0 -%}
select
{%- for column_name, column_value in default_row.items() %} {{ column_value }} AS {{ column_name }}{% if not loop.last -%},{%- endif %}
{%- for column_name, column_value in default_row.items() %} {{ column_value }} as {{ column_name }}{% if not loop.last -%},{%- endif %}
{%- endfor %}
limit 0
{%- endif -%}
Expand All @@ -48,7 +49,7 @@ union all
{%- for row in rows -%}
{%- do format_row(row, column_name_to_data_types) -%}
select
{%- for column_name, column_value in row.items() %} {{ column_value }} AS {{ column_name }}{% if not loop.last -%}, {%- endif %}
{%- for column_name, column_value in row.items() %} {{ column_value }} as {{ column_name }}{% if not loop.last -%}, {%- endif %}
{%- endfor %}
{%- if not loop.last %}
union all
Expand All @@ -64,7 +65,7 @@ union all
{%- for column_name, column_value in row.items() -%}
{% set row_update = {column_name: column_value} %}
{%- if column_value is string -%}
{%- set row_update = {column_name: safe_cast(dbt.string_literal(column_value), column_name_to_data_types[column_name]) } -%}
{%- set row_update = {column_name: safe_cast(dbt.string_literal(dbt.escape_single_quotes(column_value)), column_name_to_data_types[column_name]) } -%}
{%- elif column_value is none -%}
{%- set row_update = {column_name: safe_cast('null', column_name_to_data_types[column_name]) } -%}
{%- else -%}
Expand Down
84 changes: 84 additions & 0 deletions dbt/tests/adapter/unit_testing/test_unit_testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import pytest

from dbt.tests.util import write_file, run_dbt


my_model_sql = """
select
tested_column from {{ ref('my_upstream_model')}}
"""

my_upstream_model_sql = """
select
{sql_value} as tested_column
"""

test_my_model_yml = """
unit_tests:
- name: test_my_model
model: my_model
given:
- input: ref('my_upstream_model')
rows:
- {{ tested_column: {yaml_value} }}
expect:
rows:
- {{ tested_column: {yaml_value} }}
"""


class BaseUnitTestingTypes:
@pytest.fixture
def data_types(self):
# sql_value, yaml_value
return [
["1", "1"],
["'1'", "1"],
["true", "true"],
["DATE '2020-01-02'", "2020-01-02"],
["TIMESTAMP '2013-11-03 00:00:00-0'", "2013-11-03 00:00:00-0"],
["TIMESTAMPTZ '2013-11-03 00:00:00-0'", "2013-11-03 00:00:00-0"],
["ARRAY['a','b','c']", """'{"a", "b", "c"}'"""],
["ARRAY[1,2,3]", """'{1, 2, 3}'"""],
["'1'::numeric", "1"],
[
"""'{"bar": "baz", "balance": 7.77, "active": false}'::json""",
"""'{"bar": "baz", "balance": 7.77, "active": false}'""",
],
]

@pytest.fixture(scope="class")
def models(self):
return {
"my_model.sql": my_model_sql,
"my_upstream_model.sql": my_upstream_model_sql,
"schema.yml": test_my_model_yml,
}

def test_unit_test_data_type(self, project, data_types):
for sql_value, yaml_value in data_types:
# Write parametrized type value to sql files
write_file(
my_upstream_model_sql.format(sql_value=sql_value),
"models",
"my_upstream_model.sql",
)

# Write parametrized type value to unit test yaml definition
write_file(
test_my_model_yml.format(yaml_value=yaml_value),
"models",
"schema.yml",
)

results = run_dbt(["run", "--select", "my_upstream_model"])
assert len(results) == 1

try:
run_dbt(["unit-test", "--select", "my_model"])
except Exception:
raise AssertionError(f"unit test failed when testing model with {sql_value}")


class TestUnitTestingTypes(BaseUnitTestingTypes):
pass

0 comments on commit f16cd66

Please sign in to comment.