Skip to content

Commit

Permalink
test(FIR-12467): improve unit test coverage (#148)
Browse files Browse the repository at this point in the history
  • Loading branch information
ptiurin authored Sep 13, 2024
1 parent 674e356 commit f434826
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 7 deletions.
1 change: 0 additions & 1 deletion dbt/adapters/firebolt/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,6 @@ def get_rows_different_sql(
columns: the number of rows that are different between the two
relations and the number of mismatched rows.
"""
# This method only really exists for test reasons.
names: List[str]
if column_names is None:
columns = self.get_columns_in_relation(relation_a)
Expand Down
162 changes: 156 additions & 6 deletions tests/unit/test_firebolt_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
from multiprocessing import get_context
from unittest.mock import MagicMock, patch

import agate
from dbt.adapters.contracts.connection import Connection
from dbt_common.exceptions import DbtRuntimeError
from firebolt.client.auth import ClientCredentials, UsernamePassword
from firebolt.db import ARRAY, DECIMAL
from firebolt.db.connection import Connection as SDKConnection
from firebolt.utils.exception import InterfaceError
from pytest import fixture, mark
from pytest import fixture, mark, raises

from dbt.adapters.firebolt import (
FireboltAdapter,
Expand All @@ -16,6 +18,7 @@
)
from dbt.adapters.firebolt.column import FireboltColumn
from dbt.adapters.firebolt.connections import _determine_auth
from dbt.adapters.firebolt.relation import FireboltRelation
from tests.functional.adapter.test_basic import AnySpecifiedType


Expand All @@ -32,6 +35,11 @@ def connection():
return connection


@fixture
def adapter():
return FireboltAdapter(MagicMock(), get_context('spawn'))


def test_open(connection):
successful_attempt = MagicMock(spec=SDKConnection)

Expand All @@ -54,11 +62,6 @@ def test_open(connection):
assert connection.handle == successful_attempt


@fixture(scope='module')
def adapter():
return FireboltAdapter(MagicMock(), get_context('spawn'))


@mark.parametrize(
'column,expected',
[
Expand Down Expand Up @@ -176,3 +179,150 @@ def test_determine_auth_with_id_and_secret():
assert isinstance(auth, ClientCredentials)
assert auth.client_id == 'your_user_id'
assert auth.client_secret == 'your_user_secret'


def test_make_field_partition_pairs_valid(adapter):
columns = [
{'name': 'col1', 'data_type': 'INT'},
{'name': 'col2', 'data_type': 'STRING'},
]
partitions = [
{'name': 'part1', 'data_type': 'DATE', 'regex': 'regex1'},
{'name': 'part2', 'data_type': 'DATE', 'regex': 'regex2'},
]
result = adapter.make_field_partition_pairs(columns, partitions)
expected = [
'"col1" INT',
'"col2" STRING',
'"part1" DATE PARTITION (\'regex1\')',
'"part2" DATE PARTITION (\'regex2\')',
]
assert result == expected


def test_make_field_partition_pairs_missing_column_data_type(adapter):
columns = [
{'name': 'col1', 'data_type': 'INT'},
{'name': 'col2'}, # Missing data_type
]
partitions = []
with raises(DbtRuntimeError, match='Data type is missing for column `col2`.'):
adapter.make_field_partition_pairs(columns, partitions)


def test_make_field_partition_pairs_missing_partition_data_type(adapter):
columns = [{'name': 'col1', 'data_type': 'INT'}]
partitions = [{'name': 'part1', 'regex': 'regex1'}] # Missing data_type
with raises(DbtRuntimeError, match='Data type is missing for partition `part1`.'):
adapter.make_field_partition_pairs(columns, partitions)


def test_make_field_partition_pairs_missing_partition_regex(adapter):
columns = [{'name': 'col1', 'data_type': 'INT'}]
partitions = [{'name': 'part1', 'data_type': 'DATE'}] # Missing regex
with raises(DbtRuntimeError, match='Regex is missing for partition `part1`.'):
adapter.make_field_partition_pairs(columns, partitions)


def test_make_field_partition_pairs_empty_partitions(adapter):
columns = [
{'name': 'col1', 'data_type': 'INT'},
{'name': 'col2', 'data_type': 'STRING'},
]
partitions = []
result = adapter.make_field_partition_pairs(columns, partitions)
expected = ['"col1" INT', '"col2" STRING']
assert result == expected


def test_stack_tables_valid(adapter):
table1 = agate.Table.from_object(
[{'col1': 1, 'col2': 'a'}, {'col1': 2, 'col2': 'b'}]
)
table2 = agate.Table.from_object(
[{'col1': 3, 'col2': 'c'}, {'col1': 4, 'col2': 'd'}]
)
result = adapter.stack_tables([table1, table2])
expected = agate.Table.from_object(
[
{'col1': 1, 'col2': 'a'},
{'col1': 2, 'col2': 'b'},
{'col1': 3, 'col2': 'c'},
{'col1': 4, 'col2': 'd'},
]
)
assert str(result) == str(expected)


def test_stack_tables_empty(adapter):
result = adapter.stack_tables([agate.Table.from_object([])])
expected = agate.Table.from_object([])
assert str(result) == str(expected)


@mark.parametrize(
'table1,table2',
[
({'col1': 1, 'col2': 'a'}, {'col3': 2, 'col4': 'b'}),
({'col1': 1, 'col2': 'a'}, {'col1': 2}),
],
)
def test_stack_tables_different_schemas(table1, table2, adapter):
table1 = agate.Table.from_object([table1])
table2 = agate.Table.from_object([table2])
with raises(ValueError, match='Not all tables have the same column types!'):
adapter.stack_tables([table1, table2])


def test_get_rows_different_sql_valid(adapter):
relation_a = FireboltRelation.create(
database='db', schema='public', identifier='table_a'
)
relation_b = FireboltRelation.create(
database='db', schema='public', identifier='table_b'
)
column_names = ['col1', 'col2']
sql = adapter.get_rows_different_sql(relation_a, relation_b, column_names)
assert 'SELECT "col1", "col2" FROM "table_a"' in sql
assert 'SELECT "col1", "col2" FROM "table_b"' in sql
assert 'WHERE "table_a"."col1" = "table_b"."col1"' in sql
assert 'AND "table_a"."col2" = "table_b"."col2"' in sql


def test_get_rows_different_sql_empty_columns(adapter):
def mock_column(name: str) -> MagicMock:
col = MagicMock()
col.name = name
return col

relation_a = FireboltRelation.create(
database='db', schema='public', identifier='table_a'
)
relation_b = FireboltRelation.create(
database='db', schema='public', identifier='table_b'
)
column_names = None
adapter.get_columns_in_relation = MagicMock(
return_value=[mock_column('col1'), mock_column('col2')]
)
sql = adapter.get_rows_different_sql(relation_a, relation_b, column_names)
print(sql)
assert 'SELECT "col1", "col2" FROM "table_a"' in sql
assert 'SELECT "col1", "col2" FROM "table_b"' in sql
assert 'WHERE "table_a"."col1" = "table_b"."col1"' in sql
assert 'AND "table_a"."col2" = "table_b"."col2"' in sql


def test_annotate_date_columns_for_partitions_valid(adapter):
vals = '1,2,3'
cols = ['col1', 'col2', 'col3']
col_types = [
FireboltColumn('col1', 'int'),
FireboltColumn('col2', 'date'),
FireboltColumn('col3', 'date'),
]
expected = '1,2,3'
# It is expected that there's no cast to DATE here, as the logic is incorrect
# I'm not sure we even need it anymore
result = adapter.annotate_date_columns_for_partitions(vals, cols, col_types)
assert result == expected

0 comments on commit f434826

Please sign in to comment.