From f434826b1cccd7e25b9cfad88895d586ae4dbc6d Mon Sep 17 00:00:00 2001 From: Petro Tiurin <93913847+ptiurin@users.noreply.github.com> Date: Fri, 13 Sep 2024 15:10:13 +0100 Subject: [PATCH] test(FIR-12467): improve unit test coverage (#148) --- dbt/adapters/firebolt/impl.py | 1 - tests/unit/test_firebolt_adapter.py | 162 ++++++++++++++++++++++++++-- 2 files changed, 156 insertions(+), 7 deletions(-) diff --git a/dbt/adapters/firebolt/impl.py b/dbt/adapters/firebolt/impl.py index c2b61d8c..67fa58e3 100644 --- a/dbt/adapters/firebolt/impl.py +++ b/dbt/adapters/firebolt/impl.py @@ -274,7 +274,6 @@ def get_rows_different_sql( columns: the number of rows that are different between the two relations and the number of mismatched rows. """ - # This method only really exists for test reasons. names: List[str] if column_names is None: columns = self.get_columns_in_relation(relation_a) diff --git a/tests/unit/test_firebolt_adapter.py b/tests/unit/test_firebolt_adapter.py index 54164f68..38931987 100644 --- a/tests/unit/test_firebolt_adapter.py +++ b/tests/unit/test_firebolt_adapter.py @@ -2,12 +2,14 @@ from multiprocessing import get_context from unittest.mock import MagicMock, patch +import agate from dbt.adapters.contracts.connection import Connection +from dbt_common.exceptions import DbtRuntimeError from firebolt.client.auth import ClientCredentials, UsernamePassword from firebolt.db import ARRAY, DECIMAL from firebolt.db.connection import Connection as SDKConnection from firebolt.utils.exception import InterfaceError -from pytest import fixture, mark +from pytest import fixture, mark, raises from dbt.adapters.firebolt import ( FireboltAdapter, @@ -16,6 +18,7 @@ ) from dbt.adapters.firebolt.column import FireboltColumn from dbt.adapters.firebolt.connections import _determine_auth +from dbt.adapters.firebolt.relation import FireboltRelation from tests.functional.adapter.test_basic import AnySpecifiedType @@ -32,6 +35,11 @@ def connection(): return connection +@fixture +def adapter(): + return FireboltAdapter(MagicMock(), get_context('spawn')) + + def test_open(connection): successful_attempt = MagicMock(spec=SDKConnection) @@ -54,11 +62,6 @@ def test_open(connection): assert connection.handle == successful_attempt -@fixture(scope='module') -def adapter(): - return FireboltAdapter(MagicMock(), get_context('spawn')) - - @mark.parametrize( 'column,expected', [ @@ -176,3 +179,150 @@ def test_determine_auth_with_id_and_secret(): assert isinstance(auth, ClientCredentials) assert auth.client_id == 'your_user_id' assert auth.client_secret == 'your_user_secret' + + +def test_make_field_partition_pairs_valid(adapter): + columns = [ + {'name': 'col1', 'data_type': 'INT'}, + {'name': 'col2', 'data_type': 'STRING'}, + ] + partitions = [ + {'name': 'part1', 'data_type': 'DATE', 'regex': 'regex1'}, + {'name': 'part2', 'data_type': 'DATE', 'regex': 'regex2'}, + ] + result = adapter.make_field_partition_pairs(columns, partitions) + expected = [ + '"col1" INT', + '"col2" STRING', + '"part1" DATE PARTITION (\'regex1\')', + '"part2" DATE PARTITION (\'regex2\')', + ] + assert result == expected + + +def test_make_field_partition_pairs_missing_column_data_type(adapter): + columns = [ + {'name': 'col1', 'data_type': 'INT'}, + {'name': 'col2'}, # Missing data_type + ] + partitions = [] + with raises(DbtRuntimeError, match='Data type is missing for column `col2`.'): + adapter.make_field_partition_pairs(columns, partitions) + + +def test_make_field_partition_pairs_missing_partition_data_type(adapter): + columns = [{'name': 'col1', 'data_type': 'INT'}] + partitions = [{'name': 'part1', 'regex': 'regex1'}] # Missing data_type + with raises(DbtRuntimeError, match='Data type is missing for partition `part1`.'): + adapter.make_field_partition_pairs(columns, partitions) + + +def test_make_field_partition_pairs_missing_partition_regex(adapter): + columns = [{'name': 'col1', 'data_type': 'INT'}] + partitions = [{'name': 'part1', 'data_type': 'DATE'}] # Missing regex + with raises(DbtRuntimeError, match='Regex is missing for partition `part1`.'): + adapter.make_field_partition_pairs(columns, partitions) + + +def test_make_field_partition_pairs_empty_partitions(adapter): + columns = [ + {'name': 'col1', 'data_type': 'INT'}, + {'name': 'col2', 'data_type': 'STRING'}, + ] + partitions = [] + result = adapter.make_field_partition_pairs(columns, partitions) + expected = ['"col1" INT', '"col2" STRING'] + assert result == expected + + +def test_stack_tables_valid(adapter): + table1 = agate.Table.from_object( + [{'col1': 1, 'col2': 'a'}, {'col1': 2, 'col2': 'b'}] + ) + table2 = agate.Table.from_object( + [{'col1': 3, 'col2': 'c'}, {'col1': 4, 'col2': 'd'}] + ) + result = adapter.stack_tables([table1, table2]) + expected = agate.Table.from_object( + [ + {'col1': 1, 'col2': 'a'}, + {'col1': 2, 'col2': 'b'}, + {'col1': 3, 'col2': 'c'}, + {'col1': 4, 'col2': 'd'}, + ] + ) + assert str(result) == str(expected) + + +def test_stack_tables_empty(adapter): + result = adapter.stack_tables([agate.Table.from_object([])]) + expected = agate.Table.from_object([]) + assert str(result) == str(expected) + + +@mark.parametrize( + 'table1,table2', + [ + ({'col1': 1, 'col2': 'a'}, {'col3': 2, 'col4': 'b'}), + ({'col1': 1, 'col2': 'a'}, {'col1': 2}), + ], +) +def test_stack_tables_different_schemas(table1, table2, adapter): + table1 = agate.Table.from_object([table1]) + table2 = agate.Table.from_object([table2]) + with raises(ValueError, match='Not all tables have the same column types!'): + adapter.stack_tables([table1, table2]) + + +def test_get_rows_different_sql_valid(adapter): + relation_a = FireboltRelation.create( + database='db', schema='public', identifier='table_a' + ) + relation_b = FireboltRelation.create( + database='db', schema='public', identifier='table_b' + ) + column_names = ['col1', 'col2'] + sql = adapter.get_rows_different_sql(relation_a, relation_b, column_names) + assert 'SELECT "col1", "col2" FROM "table_a"' in sql + assert 'SELECT "col1", "col2" FROM "table_b"' in sql + assert 'WHERE "table_a"."col1" = "table_b"."col1"' in sql + assert 'AND "table_a"."col2" = "table_b"."col2"' in sql + + +def test_get_rows_different_sql_empty_columns(adapter): + def mock_column(name: str) -> MagicMock: + col = MagicMock() + col.name = name + return col + + relation_a = FireboltRelation.create( + database='db', schema='public', identifier='table_a' + ) + relation_b = FireboltRelation.create( + database='db', schema='public', identifier='table_b' + ) + column_names = None + adapter.get_columns_in_relation = MagicMock( + return_value=[mock_column('col1'), mock_column('col2')] + ) + sql = adapter.get_rows_different_sql(relation_a, relation_b, column_names) + print(sql) + assert 'SELECT "col1", "col2" FROM "table_a"' in sql + assert 'SELECT "col1", "col2" FROM "table_b"' in sql + assert 'WHERE "table_a"."col1" = "table_b"."col1"' in sql + assert 'AND "table_a"."col2" = "table_b"."col2"' in sql + + +def test_annotate_date_columns_for_partitions_valid(adapter): + vals = '1,2,3' + cols = ['col1', 'col2', 'col3'] + col_types = [ + FireboltColumn('col1', 'int'), + FireboltColumn('col2', 'date'), + FireboltColumn('col3', 'date'), + ] + expected = '1,2,3' + # It is expected that there's no cast to DATE here, as the logic is incorrect + # I'm not sure we even need it anymore + result = adapter.annotate_date_columns_for_partitions(vals, cols, col_types) + assert result == expected