Skip to content

Commit

Permalink
add boundary test to better understand get_columns method for cross d…
Browse files Browse the repository at this point in the history
…atabase refs
  • Loading branch information
mikealfare committed Sep 9, 2024
1 parent 8bbdbf2 commit a9af12d
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 0 deletions.
1 change: 1 addition & 0 deletions test.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
REDSHIFT_TEST_HOST=
REDSHIFT_TEST_PORT=
REDSHIFT_TEST_DBNAME=
REDSHIFT_TEST_DBNAME_ALT=
REDSHIFT_TEST_USER=
REDSHIFT_TEST_PASS=
REDSHIFT_TEST_REGION=
Expand Down
41 changes: 41 additions & 0 deletions tests/boundary/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from datetime import datetime
import os
import random
from typing import Any, Dict

import pytest
import redshift_connector


@pytest.fixture
def connection(connection_config) -> redshift_connector.Connection:
return redshift_connector.connect(**connection_config)


@pytest.fixture
def connection_alt(connection_config) -> redshift_connector.Connection:
config = connection_config.copy()
config.update(database=os.getenv("REDSHIFT_TEST_DBNAME_ALT"))
return redshift_connector.connect(**config)


@pytest.fixture
def connection_config() -> Dict[str, Any]:
return {
"user": os.getenv("REDSHIFT_TEST_USER"),
"password": os.getenv("REDSHIFT_TEST_PASS"),
"host": os.getenv("REDSHIFT_TEST_HOST"),
"port": int(os.getenv("REDSHIFT_TEST_PORT")),
"database": os.getenv("REDSHIFT_TEST_DBNAME"),
"region": os.getenv("REDSHIFT_TEST_REGION"),
}


@pytest.fixture
def schema_name(request) -> str:
runtime = datetime.utcnow() - datetime(1970, 1, 1, 0, 0, 0)
runtime_s = int(runtime.total_seconds())
runtime_ms = runtime.microseconds
random_int = random.randint(0, 9999)
file_name = request.module.__name__.split(".")[-1]
return f"test_{runtime_s}{runtime_ms}{random_int:04}_{file_name}"
61 changes: 61 additions & 0 deletions tests/boundary/test_redshift_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import os

import pytest


@pytest.fixture(autouse=True)
def setup(connection, connection_alt, schema_name) -> str:
# create the same table in two different databases
with connection.cursor() as cursor:
cursor.execute(f"CREATE SCHEMA IF NOT EXISTS {schema_name}")
cursor.execute(f"CREATE TABLE {schema_name}.cross_db as select 3.14 as id")
with connection_alt.cursor() as cursor:
cursor.execute(f"CREATE SCHEMA IF NOT EXISTS {schema_name}")
cursor.execute(f"CREATE TABLE {schema_name}.cross_db as select 3.14 as id")

yield schema_name

# drop both test schemas
with connection_alt.cursor() as cursor:
cursor.execute(f"DROP SCHEMA IF EXISTS {schema_name} CASCADE")
with connection.cursor() as cursor:
cursor.execute(f"DROP SCHEMA IF EXISTS {schema_name} CASCADE")


def test_columns_in_relation(connection, schema_name):
# we're specifically running this query from the default database
# we're expecting to get both tables, the one in the default database and the one in the alt database
with connection.cursor() as cursor:
columns = cursor.get_columns(schema_pattern=schema_name, tablename_pattern="cross_db")

# we should have the same table in both databases
assert len(columns) == 2

databases = set()
for column in columns:
(
database,
schema,
table,
name,
type_code,
type_name,
precision,
_,
scale,
*_,
) = column
databases.add(database)
assert schema_name == schema_name
assert table == "cross_db"
assert name == "id"
assert type_code == 2
assert type_name == "numeric"
assert precision == 3
assert scale == 2

# only the databases are different
assert databases == {
os.getenv("REDSHIFT_TEST_DBNAME"),
os.getenv("REDSHIFT_TEST_DBNAME_ALT"),
}
41 changes: 41 additions & 0 deletions tests/functional/test_columns_in_relation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from dbt.tests.util import get_connection, run_dbt
import pytest


MY_CROSS_DB_SOURCES = """
version: 2
sources:
- name: ci
schema: adapter
tables:
- name: cross_db
- name: ci_alt
database: ci_alt
schema: adapter
tables:
- name: cross_db
"""


class TestCrossDatabase:
"""
This addresses https://github.com/dbt-labs/dbt-redshift/issues/736
"""

@pytest.fixture(scope="class")
def models(self):
my_model = """
select '{{ adapter.get_columns_in_relation(source('ci', 'cross_db')) }}' as columns
union all
select '{{ adapter.get_columns_in_relation(source('ci_alt', 'cross_db')) }}' as columns
"""
return {
"sources.yml": MY_CROSS_DB_SOURCES,
"my_model.sql": my_model,
}

def test_columns_in_relation(self, project):
run_dbt(["run"])
with get_connection(project.adapter, "_test"):
records = project.run_sql(f"select * from {project.test_schema}.my_model", fetch=True)
assert len(records) == 2

0 comments on commit a9af12d

Please sign in to comment.