From 8f922749f8776c0d5888ef184d1eb21f37408ac4 Mon Sep 17 00:00:00 2001 From: Jessie Chen Date: Fri, 7 Apr 2023 10:19:37 -0700 Subject: [PATCH 01/21] add support for datashared objects and external tables for redshift list_relation_without_caching --- dbt/include/redshift/macros/adapters.sql | 76 +++++++++++++++++++++++- 1 file changed, 75 insertions(+), 1 deletion(-) diff --git a/dbt/include/redshift/macros/adapters.sql b/dbt/include/redshift/macros/adapters.sql index ebf2e16a5..c0ac896e0 100644 --- a/dbt/include/redshift/macros/adapters.sql +++ b/dbt/include/redshift/macros/adapters.sql @@ -196,12 +196,57 @@ ), + data_share as ( + select + a.ordinal_position as columnnum, + a.schema_name as schemaname, + a.column_name as columnname, + case + when a.data_type ilike 'character varying%' then + 'character varying' + when a.data_type ilike 'numeric%' then 'numeric' + else a.data_type + end as col_type, + case + when a.data_type like 'character%' + then nullif(REGEXP_SUBSTR(a.character_maximum_length, '[0-9]+'), '')::int + else null + end as character_maximum_length, + case + when a.data_type like 'numeric%' or a.data_type ilike 'integer%' + then nullif( + SPLIT_PART(REGEXP_SUBSTR(a.numeric_precision, '[0-9,]+'), ',', 1), + '')::int + end as numeric_precision, + case + when a.data_type like 'numeric%' or a.data_type ilike 'integer%' + then nullif( + SPLIT_PART(REGEXP_SUBSTR(a.numeric_scale, '[0-9,]+'), ',', 2), + '')::int + else null + end as numeric_scale + from svv_all_columns a + inner join ( + select object_name + from svv_datashare_objects + where share_type = 'INBOUND' + and object_type in ('table', 'view', 'materialized view', 'late binding view') + and object_name = '{{ relation.schema }}' || '.' || '{{ relation.identifier }}' + ) b on a.schema_name || '.' || a.table_name = b.object_name + inner join ( + select consumer_database from SVV_DATASHARES + where share_type = 'INBOUND' + ) c on c.consumer_database = a.database_name + ), + unioned as ( select * from bound_views union all select * from unbound_views union all select * from external_views + union all + select * from data_share ) select @@ -223,7 +268,36 @@ {% macro redshift__list_relations_without_caching(schema_relation) %} - {{ return(postgres__list_relations_without_caching(schema_relation)) }} + {% call statement('list_relations_without_caching', fetch_result=True) -%} + select + database_name as database, + table_name as name, + schema_name as schema, + 'table' as type + from SVV_REDSHIFT_TABLES + where schema_name ilike '{{ schema_relation.schema }}' + and database_name ilike '{{ schema_relation.database }}' + and table_type = 'TABLE' + union all + select + database_name as database, + table_name as name, + schema_name as schema, + 'view' as type + from SVV_REDSHIFT_TABLES + where schema_name ilike '{{ schema_relation.schema }}' + and database_name ilike '{{ schema_relation.database }}' + and table_type = 'VIEW' + union all + select + '{{ schema_relation.database }}' as database, + tablename as name, + schemaname as schema, + 'external_tables' as type + from SVV_EXTERNAL_TABLES + where schemaname ilike '{{ schema_relation.schema }}' + {% endcall %} + {{ return(load_result('list_relations_without_caching').table) }} {% endmacro %} From 29e5a5c7199ec07fa96bb59bb43e2ffd5efc0930 Mon Sep 17 00:00:00 2001 From: Jessie Chen Date: Fri, 7 Apr 2023 10:47:48 -0700 Subject: [PATCH 02/21] changie --- .changes/unreleased/Features-20230407-104723.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 .changes/unreleased/Features-20230407-104723.yaml diff --git a/.changes/unreleased/Features-20230407-104723.yaml b/.changes/unreleased/Features-20230407-104723.yaml new file mode 100644 index 000000000..6f57a9f18 --- /dev/null +++ b/.changes/unreleased/Features-20230407-104723.yaml @@ -0,0 +1,5 @@ +kind: Features +time: 2023-04-07T10:47:23.105369-07:00 +custom: + Author: jiezhen-chen + Issue: 17 179 217 From dc98d126ddc837ed28297b72b07429072199683f Mon Sep 17 00:00:00 2001 From: Jessie Chen Date: Thu, 13 Apr 2023 12:12:57 -0700 Subject: [PATCH 03/21] refactor inner joins to cte --- dbt/include/redshift/macros/adapters.sql | 53 ++++++++++++------------ 1 file changed, 27 insertions(+), 26 deletions(-) diff --git a/dbt/include/redshift/macros/adapters.sql b/dbt/include/redshift/macros/adapters.sql index 82c356dc7..25c8cbcc6 100644 --- a/dbt/include/redshift/macros/adapters.sql +++ b/dbt/include/redshift/macros/adapters.sql @@ -198,47 +198,48 @@ ), - data_share as ( + inbound_datashare as ( + select object_name + from svv_datashare_objects + where share_type = 'INBOUND' + and object_type in ('table', 'view', 'materialized view', 'late binding view') + and object_name = '{{ relation.schema }}' || '.' || '{{ relation.identifier }}' + ), + + data_share_columns as ( select - a.ordinal_position as columnnum, - a.schema_name as schemaname, - a.column_name as columnname, + ordinal_position as columnnum, + schema_name as schemaname, + column_name as columnname, case - when a.data_type ilike 'character varying%' then + when data_type ilike 'character varying%' then 'character varying' - when a.data_type ilike 'numeric%' then 'numeric' - else a.data_type + when data_type ilike 'numeric%' then 'numeric' + else data_type end as col_type, case - when a.data_type like 'character%' - then nullif(REGEXP_SUBSTR(a.character_maximum_length, '[0-9]+'), '')::int + when data_type like 'character%' + then nullif(REGEXP_SUBSTR(character_maximum_length, '[0-9]+'), '')::int else null end as character_maximum_length, case - when a.data_type like 'numeric%' or a.data_type ilike 'integer%' + when data_type like 'numeric%' or data_type ilike 'integer%' then nullif( - SPLIT_PART(REGEXP_SUBSTR(a.numeric_precision, '[0-9,]+'), ',', 1), + SPLIT_PART(REGEXP_SUBSTR(numeric_precision, '[0-9,]+'), ',', 1), '')::int end as numeric_precision, case - when a.data_type like 'numeric%' or a.data_type ilike 'integer%' + when data_type like 'numeric%' or data_type ilike 'integer%' then nullif( - SPLIT_PART(REGEXP_SUBSTR(a.numeric_scale, '[0-9,]+'), ',', 2), + SPLIT_PART(REGEXP_SUBSTR(numeric_scale, '[0-9,]+'), ',', 2), '')::int else null end as numeric_scale - from svv_all_columns a - inner join ( - select object_name - from svv_datashare_objects - where share_type = 'INBOUND' - and object_type in ('table', 'view', 'materialized view', 'late binding view') - and object_name = '{{ relation.schema }}' || '.' || '{{ relation.identifier }}' - ) b on a.schema_name || '.' || a.table_name = b.object_name - inner join ( - select consumer_database from SVV_DATASHARES - where share_type = 'INBOUND' - ) c on c.consumer_database = a.database_name + from svv_all_columns + inner join inbound_datashare on + inbound_datashare.object_name = svv_all_columns.schema_name || '.' || svv_all_columns.table_name + where svv_all_columns.table_name = '{{ relation.identifier }}' + and svv_all_columns.schema_name = '{{ relation.schema }}' ), unioned as ( @@ -248,7 +249,7 @@ union all select * from external_views union all - select * from data_share + select * from data_share_columns ) select From 7bf1978bd20617e2d0ffcd0f982cc325610113e1 Mon Sep 17 00:00:00 2001 From: Jessie Chen Date: Thu, 13 Apr 2023 12:23:02 -0700 Subject: [PATCH 04/21] shorten list_relations_without_caching query --- dbt/include/redshift/macros/adapters.sql | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/dbt/include/redshift/macros/adapters.sql b/dbt/include/redshift/macros/adapters.sql index 25c8cbcc6..ae7555a10 100644 --- a/dbt/include/redshift/macros/adapters.sql +++ b/dbt/include/redshift/macros/adapters.sql @@ -276,21 +276,11 @@ database_name as database, table_name as name, schema_name as schema, - 'table' as type + LOWER(table_type) as type from SVV_REDSHIFT_TABLES where schema_name ilike '{{ schema_relation.schema }}' and database_name ilike '{{ schema_relation.database }}' - and table_type = 'TABLE' - union all - select - database_name as database, - table_name as name, - schema_name as schema, - 'view' as type - from SVV_REDSHIFT_TABLES - where schema_name ilike '{{ schema_relation.schema }}' - and database_name ilike '{{ schema_relation.database }}' - and table_type = 'VIEW' + and table_type IN ('TABLE','VIEW') union all select '{{ schema_relation.database }}' as database, From c69e651e23edc1b193336c66b4ef2ad3d1ed92fc Mon Sep 17 00:00:00 2001 From: Jessie Chen Date: Thu, 13 Apr 2023 14:35:37 -0700 Subject: [PATCH 05/21] add database filter to get_columns_in_relation to only query inbound datashared objects of a connected database --- dbt/include/redshift/macros/adapters.sql | 3 +++ 1 file changed, 3 insertions(+) diff --git a/dbt/include/redshift/macros/adapters.sql b/dbt/include/redshift/macros/adapters.sql index 616dc5669..20d13daf0 100644 --- a/dbt/include/redshift/macros/adapters.sql +++ b/dbt/include/redshift/macros/adapters.sql @@ -206,6 +206,8 @@ and object_name = '{{ relation.schema }}' || '.' || '{{ relation.identifier }}' ), + + data_share_columns as ( select ordinal_position as columnnum, @@ -240,6 +242,7 @@ inbound_datashare.object_name = svv_all_columns.schema_name || '.' || svv_all_columns.table_name where svv_all_columns.table_name = '{{ relation.identifier }}' and svv_all_columns.schema_name = '{{ relation.schema }}' + and svv_all_columns.database_name = '{{ relation.database }}' ), unioned as ( From bce3e889b04d886d6423ee2cee69a6e82ccf1cf1 Mon Sep 17 00:00:00 2001 From: Jessie Chen Date: Fri, 14 Apr 2023 11:49:05 -0700 Subject: [PATCH 06/21] leverage redshift_conn method for list_schemas --- dbt/include/redshift/macros/adapters.sql | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/dbt/include/redshift/macros/adapters.sql b/dbt/include/redshift/macros/adapters.sql index 20d13daf0..51261e83b 100644 --- a/dbt/include/redshift/macros/adapters.sql +++ b/dbt/include/redshift/macros/adapters.sql @@ -303,7 +303,12 @@ {% macro redshift__list_schemas(database) -%} - {{ return(postgres__list_schemas(database)) }} + {% call statement('list_schemas', fetch_result=True, auto_begin=False) %} + select database_name as database + from SVV_REDSHIFT_TABLES + + {% endcall %} + {{ return(load_result('list_schemas').table) }} {%- endmacro %} From 451c60d88330c062bfec28951ef58af5ee6555ce Mon Sep 17 00:00:00 2001 From: Jessie Chen Date: Fri, 14 Apr 2023 11:58:25 -0700 Subject: [PATCH 07/21] add list_schemas method --- dbt/adapters/redshift/__init__.py | 1 + dbt/adapters/redshift/impl.py | 14 ++++++++++++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/dbt/adapters/redshift/__init__.py b/dbt/adapters/redshift/__init__.py index 64ac384fe..09f816260 100644 --- a/dbt/adapters/redshift/__init__.py +++ b/dbt/adapters/redshift/__init__.py @@ -1,5 +1,6 @@ from dbt.adapters.redshift.connections import RedshiftConnectionManager # noqa from dbt.adapters.redshift.connections import RedshiftCredentials + from dbt.adapters.redshift.column import RedshiftColumn # noqa from dbt.adapters.redshift.relation import RedshiftRelation # noqa: F401 from dbt.adapters.redshift.impl import RedshiftAdapter diff --git a/dbt/adapters/redshift/impl.py b/dbt/adapters/redshift/impl.py index 54fdd7dcf..4cf17140d 100644 --- a/dbt/adapters/redshift/impl.py +++ b/dbt/adapters/redshift/impl.py @@ -1,7 +1,7 @@ from dataclasses import dataclass -from typing import Optional, Set, Any, Dict, Type +from typing import Optional, Set, Any, Dict, Type, List from collections import namedtuple - +import redshift_connector from dbt.adapters.base import PythonJobHelper from dbt.adapters.base.impl import AdapterConfig, ConstraintSupport from dbt.adapters.sql import SQLAdapter @@ -114,6 +114,16 @@ def valid_incremental_strategies(self): def timestamp_add_sql(self, add_to: str, number: int = 1, interval: str = "hour") -> str: return f"{add_to} + interval '{number} {interval}'" + @available + def list_schemas(self, database: str) -> List[str]: + # results = self.execute_macro("redshift__list_schemas", kwargs={"database": database}) + # return results + con = redshift_connector.Connection + # c = self.connections.get_thread_connection() + # conn = self.connections.open(connection=c) + results = con.cursor(self).get_schemas(catalog="dev") + return results + def _link_cached_database_relations(self, schemas: Set[str]): """ :param schemas: The set of schemas that should have links added. From 4353a38386aa1caf3beebc3f1fd2c1b3a4976d90 Mon Sep 17 00:00:00 2001 From: Jessie Chen Date: Fri, 28 Apr 2023 12:12:06 -0700 Subject: [PATCH 08/21] migrate from macros to redshift driver --- dbt/adapters/redshift/connections.py | 4 +- dbt/adapters/redshift/impl.py | 85 +++++++++++++++++++++--- dbt/include/redshift/macros/adapters.sql | 7 +- tests/unit/test_redshift_adapter.py | 67 +++++++++++++++++++ 4 files changed, 147 insertions(+), 16 deletions(-) diff --git a/dbt/adapters/redshift/connections.py b/dbt/adapters/redshift/connections.py index 9e830cb53..4a35debe4 100644 --- a/dbt/adapters/redshift/connections.py +++ b/dbt/adapters/redshift/connections.py @@ -55,11 +55,12 @@ class RedshiftCredentials(Credentials): iam_profile: Optional[str] = None autocreate: bool = False db_groups: List[str] = field(default_factory=list) - ra3_node: Optional[bool] = False connect_timeout: int = 30 + ra3_node: Optional[bool] = False role: Optional[str] = None sslmode: Optional[str] = None retries: int = 1 + current_db_only: Optional[bool] = False _ALIASES = {"dbname": "database", "pass": "password"} @@ -101,6 +102,7 @@ def get_connect_method(self): "db_groups": self.credentials.db_groups, "region": self.credentials.host.split(".")[2], "timeout": self.credentials.connect_timeout, + "database_metadata_current_db_only": self.credentials.current_db_only, } if self.credentials.sslmode: kwargs["sslmode"] = self.credentials.sslmode diff --git a/dbt/adapters/redshift/impl.py b/dbt/adapters/redshift/impl.py index 4cf17140d..6fe06b5a2 100644 --- a/dbt/adapters/redshift/impl.py +++ b/dbt/adapters/redshift/impl.py @@ -1,7 +1,6 @@ from dataclasses import dataclass from typing import Optional, Set, Any, Dict, Type, List from collections import namedtuple -import redshift_connector from dbt.adapters.base import PythonJobHelper from dbt.adapters.base.impl import AdapterConfig, ConstraintSupport from dbt.adapters.sql import SQLAdapter @@ -9,6 +8,8 @@ from dbt.contracts.connection import AdapterResponse from dbt.contracts.graph.nodes import ConstraintType from dbt.events import AdapterLogger +from dbt.adapters.base.relation import BaseRelation + import dbt.exceptions @@ -30,7 +31,7 @@ class RedshiftConfig(AdapterConfig): class RedshiftAdapter(SQLAdapter): - Relation = RedshiftRelation + Relation = RedshiftRelation # type: ignore ConnectionManager = RedshiftConnectionManager connections: RedshiftConnectionManager Column = RedshiftColumn # type: ignore @@ -114,16 +115,82 @@ def valid_incremental_strategies(self): def timestamp_add_sql(self, add_to: str, number: int = 1, interval: str = "hour") -> str: return f"{add_to} + interval '{number} {interval}'" + def _get_cursor(self): + return self.connections.get_thread_connection().handle.cursor() + + def list_schemas(self, database: str, schema: Optional[str] = None) -> List[str]: + cursor = self._get_cursor() + results = [] + for s in cursor.get_schemas(catalog=database, schema_pattern=schema): + results.append(s[0]) + return results + @available - def list_schemas(self, database: str) -> List[str]: - # results = self.execute_macro("redshift__list_schemas", kwargs={"database": database}) - # return results - con = redshift_connector.Connection - # c = self.connections.get_thread_connection() - # conn = self.connections.open(connection=c) - results = con.cursor(self).get_schemas(catalog="dev") + def check_schema_exists(self, database: str, schema: str) -> bool: + results = self.list_schemas(database=database, schema=schema) + return len(results) > 0 + + def get_columns_in_relation(self, relation): + # TODO handle cases where identifier not provided + cursor = self._get_cursor() + results = [] + if relation.identifier: + columns = cursor.get_columns( + catalog=relation.database, + schema_pattern=relation.schema, + tablename_pattern=relation.identifier, + ) + else: + columns = cursor.get_columns(catalog=relation.database, schema_pattern=relation.schema) + if columns is not None and len(columns) > 0: + for column in columns: + if column[4] == 1 or column[4] == 12: # if column type is character + results.append(RedshiftColumn(column[3], column[5], column[6], None, None)) + # elif column[4] == 5 or column[4] == 4 or column[4] == -5 or column[4] == 3 or column[4] == 7\ + # or column[4] == 8 or column[4] == 6 or column[4] == 2 or column[4] == 2003:#if column type is numeric + elif any( + column[4] == type_int for type_int in [5, 4, -5, 3, 7, 8, 6, 2, 2003] + ): # if column type is numeric + results.append( + RedshiftColumn(column[3], column[5], None, column[6], column[8]) + ) return results + def _get_tables(self, database: Optional[str], schema: Optional[str]): + cursor = self._get_cursor() + results = [] + for table in cursor.get_tables( + catalog=database, + schema_pattern=schema, + table_name_pattern=None, + types=["VIEW", "TABLE"], + ): + results.append([table[0], table[1], table[2], table[3]]) + # kinda feel like it might not be a good thing to do this by index + return results + + def list_relations_without_caching( # type: ignore + self, schema_relation: BaseRelation + ) -> List[RedshiftRelation]: + results = self._get_tables(schema_relation.database, schema_relation.schema) + relations = [] + quote_policy = {"database": True, "schema": True, "identifier": True} + for _database, _schema, name, _type in results: + try: + _type = self.Relation.get_relation_type(_type) + except ValueError: + _type = self.Relation.External + relations.append( + self.Relation.create( + database=_database, + schema=_schema, + identifier=name, + quote_policy=quote_policy, + type=_type, + ) + ) + return relations + def _link_cached_database_relations(self, schemas: Set[str]): """ :param schemas: The set of schemas that should have links added. diff --git a/dbt/include/redshift/macros/adapters.sql b/dbt/include/redshift/macros/adapters.sql index 51261e83b..20d13daf0 100644 --- a/dbt/include/redshift/macros/adapters.sql +++ b/dbt/include/redshift/macros/adapters.sql @@ -303,12 +303,7 @@ {% macro redshift__list_schemas(database) -%} - {% call statement('list_schemas', fetch_result=True, auto_begin=False) %} - select database_name as database - from SVV_REDSHIFT_TABLES - - {% endcall %} - {{ return(load_result('list_schemas').table) }} + {{ return(postgres__list_schemas(database)) }} {%- endmacro %} diff --git a/tests/unit/test_redshift_adapter.py b/tests/unit/test_redshift_adapter.py index 27bcd98f8..751e9752a 100644 --- a/tests/unit/test_redshift_adapter.py +++ b/tests/unit/test_redshift_adapter.py @@ -359,6 +359,73 @@ def test_add_query_success(self): "select * from test3", True, bindings=None, abridge_sql_log=False ) + @mock.patch.object( + dbt.adapters.redshift.connections.SQLConnectionManager, "get_thread_connection" + ) + def mock_cursor(self, mock_get_thread_conn): + conn = mock.MagicMock + mock_get_thread_conn.return_value = conn + mock_handle = mock.MagicMock + conn.return_value = mock_handle + mock_cursor = mock.MagicMock + mock_handle.return_value = mock_cursor + return mock_cursor + + @mock.patch("dbt.adapters.redshift.impl.RedshiftAdapter._get_cursor") + def test_get_tables(self, mock_cursor): + mock_cursor.return_value.get_tables.return_value = [ + ("apple", "banana", "cherry", "orange") + ] + results = self.adapter._get_tables(database="somedb", schema="someschema") + self.assertTrue(results == [["apple", "banana", "cherry", "orange"]]) + + @mock.patch("dbt.adapters.redshift.impl.RedshiftAdapter._get_tables") + def test_list_relations_without_caching(self, mock_get_tables): + mock_get_tables.return_value = [["somedb", "someschema", "sometb", "VIEW"]] + mock_schema = mock.MagicMock(database="somedb", schema="someschema") + results = self.adapter.list_relations_without_caching(mock_schema) + self.assertTrue(results[0].database == "somedb") + self.assertTrue(results[0].schema == "someschema") + self.assertTrue(results[0].identifier == "sometb") + + @mock.patch("dbt.adapters.redshift.impl.RedshiftAdapter._get_cursor") + def test_list_schemas(self, mock_cursor): + mock_cursor.return_value.get_schemas.return_value = ["schema1"], ["schema2"] + results = self.adapter.list_schemas(database="somedb", schema="someschema") + self.assertTrue(results == ["schema1", "schema2"]) + + @mock.patch("dbt.adapters.redshift.impl.RedshiftAdapter.list_schemas") + def test_check_schema_exists(self, mock_list_schemas): + mock_list_schemas.return_value = ["schema1", "schema2"] + results = self.adapter.check_schema_exists(database="somedb", schema="someschema") + self.assertTrue(results is True) + + @mock.patch("dbt.adapters.redshift.impl.RedshiftAdapter._get_cursor") + def test_get_columns_in_relation_char(self, mock_cursor): + mock_relation = mock.MagicMock(database="somedb", schema="someschema", identifier="iden") + mock_cursor.return_value.get_columns.return_value = [ + ("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "char_column", 12, "char", 10, 10, None) + ] + result = self.adapter.get_columns_in_relation(mock_relation) + self.assertTrue(result[0].column == "char_column") + self.assertTrue(result[0].dtype == "char") + self.assertTrue(result[0].char_size == 10) + self.assertTrue(result[0].numeric_scale is None) + self.assertTrue(result[0].numeric_precision is None) + + @mock.patch("dbt.adapters.redshift.impl.RedshiftAdapter._get_cursor") + def test_get_columns_in_relation_int(self, mock_cursor): + mock_relation = mock.MagicMock(database="somedb", schema="someschema", identifier="iden") + mock_cursor.return_value.get_columns.return_value = [ + ("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "int_column", 4, "integer", 10, 10, 20) + ] + result = self.adapter.get_columns_in_relation(mock_relation) + self.assertTrue(result[0].column == "int_column") + self.assertTrue(result[0].dtype == "integer") + self.assertTrue(result[0].char_size is None) + self.assertTrue(result[0].numeric_scale == 20) + self.assertTrue(result[0].numeric_precision == 10) + class TestRedshiftAdapterConversions(TestAdapterConversions): def test_convert_text_type(self): From d408c06a824a89ab0f8efc0c595724c27d210b53 Mon Sep 17 00:00:00 2001 From: Jessie Chen Date: Fri, 28 Apr 2023 18:29:13 -0700 Subject: [PATCH 09/21] add database_metadata_current_db_only to unit tests --- tests/unit/test_redshift_adapter.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/unit/test_redshift_adapter.py b/tests/unit/test_redshift_adapter.py index 04c19534c..d3411f4fa 100644 --- a/tests/unit/test_redshift_adapter.py +++ b/tests/unit/test_redshift_adapter.py @@ -74,6 +74,7 @@ def test_implicit_database_conn(self): db_groups=[], timeout=30, region="us-east-1", + database_metadata_current_db_only=False, ) @mock.patch("redshift_connector.connect", Mock()) @@ -92,6 +93,7 @@ def test_explicit_database_conn(self): db_groups=[], region="us-east-1", timeout=30, + database_metadata_current_db_only=False, ) @mock.patch("redshift_connector.connect", Mock()) @@ -117,6 +119,7 @@ def test_explicit_iam_conn_without_profile(self): profile=None, timeout=30, port=5439, + database_metadata_current_db_only=False, ) @mock.patch("redshift_connector.connect", Mock()) @@ -145,6 +148,7 @@ def test_explicit_iam_conn_with_profile(self): profile="test", timeout=30, port=5439, + database_metadata_current_db_only=False, ) @mock.patch("redshift_connector.connect", Mock()) @@ -171,6 +175,7 @@ def test_explicit_iam_serverless_with_profile(self): profile="test", timeout=30, port=5439, + database_metadata_current_db_only=False, ) @mock.patch("redshift_connector.connect", Mock()) @@ -199,6 +204,7 @@ def test_explicit_region(self): profile="test", timeout=30, port=5439, + database_metadata_current_db_only=False, ) @mock.patch("redshift_connector.connect", Mock()) @@ -228,6 +234,7 @@ def test_explicit_region_failure(self): profile="test", timeout=30, port=5439, + database_metadata_current_db_only=False, ) @mock.patch("redshift_connector.connect", Mock()) @@ -257,6 +264,7 @@ def test_explicit_invalid_region(self): profile="test", timeout=30, port=5439, + database_metadata_current_db_only=False, ) @mock.patch("redshift_connector.connect", Mock()) @@ -284,6 +292,7 @@ def test_serverless_iam_failure(self): profile="test", port=5439, timeout=30, + database_metadata_current_db_only=False, ) self.assertTrue("'host' must be provided" in context.exception.msg) From 98ed15170a25d8b99f4998d7a2fd6ce40584c77a Mon Sep 17 00:00:00 2001 From: Jessie Chen Date: Mon, 1 May 2023 15:09:02 -0700 Subject: [PATCH 10/21] override check_schema_exists --- dbt/adapters/redshift/impl.py | 63 ----------------------------- tests/unit/test_redshift_adapter.py | 43 -------------------- 2 files changed, 106 deletions(-) diff --git a/dbt/adapters/redshift/impl.py b/dbt/adapters/redshift/impl.py index 6fe06b5a2..0ac247686 100644 --- a/dbt/adapters/redshift/impl.py +++ b/dbt/adapters/redshift/impl.py @@ -8,8 +8,6 @@ from dbt.contracts.connection import AdapterResponse from dbt.contracts.graph.nodes import ConstraintType from dbt.events import AdapterLogger -from dbt.adapters.base.relation import BaseRelation - import dbt.exceptions @@ -130,67 +128,6 @@ def check_schema_exists(self, database: str, schema: str) -> bool: results = self.list_schemas(database=database, schema=schema) return len(results) > 0 - def get_columns_in_relation(self, relation): - # TODO handle cases where identifier not provided - cursor = self._get_cursor() - results = [] - if relation.identifier: - columns = cursor.get_columns( - catalog=relation.database, - schema_pattern=relation.schema, - tablename_pattern=relation.identifier, - ) - else: - columns = cursor.get_columns(catalog=relation.database, schema_pattern=relation.schema) - if columns is not None and len(columns) > 0: - for column in columns: - if column[4] == 1 or column[4] == 12: # if column type is character - results.append(RedshiftColumn(column[3], column[5], column[6], None, None)) - # elif column[4] == 5 or column[4] == 4 or column[4] == -5 or column[4] == 3 or column[4] == 7\ - # or column[4] == 8 or column[4] == 6 or column[4] == 2 or column[4] == 2003:#if column type is numeric - elif any( - column[4] == type_int for type_int in [5, 4, -5, 3, 7, 8, 6, 2, 2003] - ): # if column type is numeric - results.append( - RedshiftColumn(column[3], column[5], None, column[6], column[8]) - ) - return results - - def _get_tables(self, database: Optional[str], schema: Optional[str]): - cursor = self._get_cursor() - results = [] - for table in cursor.get_tables( - catalog=database, - schema_pattern=schema, - table_name_pattern=None, - types=["VIEW", "TABLE"], - ): - results.append([table[0], table[1], table[2], table[3]]) - # kinda feel like it might not be a good thing to do this by index - return results - - def list_relations_without_caching( # type: ignore - self, schema_relation: BaseRelation - ) -> List[RedshiftRelation]: - results = self._get_tables(schema_relation.database, schema_relation.schema) - relations = [] - quote_policy = {"database": True, "schema": True, "identifier": True} - for _database, _schema, name, _type in results: - try: - _type = self.Relation.get_relation_type(_type) - except ValueError: - _type = self.Relation.External - relations.append( - self.Relation.create( - database=_database, - schema=_schema, - identifier=name, - quote_policy=quote_policy, - type=_type, - ) - ) - return relations - def _link_cached_database_relations(self, schemas: Set[str]): """ :param schemas: The set of schemas that should have links added. diff --git a/tests/unit/test_redshift_adapter.py b/tests/unit/test_redshift_adapter.py index d3411f4fa..d1991ca80 100644 --- a/tests/unit/test_redshift_adapter.py +++ b/tests/unit/test_redshift_adapter.py @@ -471,23 +471,6 @@ def mock_cursor(self, mock_get_thread_conn): mock_handle.return_value = mock_cursor return mock_cursor - @mock.patch("dbt.adapters.redshift.impl.RedshiftAdapter._get_cursor") - def test_get_tables(self, mock_cursor): - mock_cursor.return_value.get_tables.return_value = [ - ("apple", "banana", "cherry", "orange") - ] - results = self.adapter._get_tables(database="somedb", schema="someschema") - self.assertTrue(results == [["apple", "banana", "cherry", "orange"]]) - - @mock.patch("dbt.adapters.redshift.impl.RedshiftAdapter._get_tables") - def test_list_relations_without_caching(self, mock_get_tables): - mock_get_tables.return_value = [["somedb", "someschema", "sometb", "VIEW"]] - mock_schema = mock.MagicMock(database="somedb", schema="someschema") - results = self.adapter.list_relations_without_caching(mock_schema) - self.assertTrue(results[0].database == "somedb") - self.assertTrue(results[0].schema == "someschema") - self.assertTrue(results[0].identifier == "sometb") - @mock.patch("dbt.adapters.redshift.impl.RedshiftAdapter._get_cursor") def test_list_schemas(self, mock_cursor): mock_cursor.return_value.get_schemas.return_value = ["schema1"], ["schema2"] @@ -500,32 +483,6 @@ def test_check_schema_exists(self, mock_list_schemas): results = self.adapter.check_schema_exists(database="somedb", schema="someschema") self.assertTrue(results is True) - @mock.patch("dbt.adapters.redshift.impl.RedshiftAdapter._get_cursor") - def test_get_columns_in_relation_char(self, mock_cursor): - mock_relation = mock.MagicMock(database="somedb", schema="someschema", identifier="iden") - mock_cursor.return_value.get_columns.return_value = [ - ("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "char_column", 12, "char", 10, 10, None) - ] - result = self.adapter.get_columns_in_relation(mock_relation) - self.assertTrue(result[0].column == "char_column") - self.assertTrue(result[0].dtype == "char") - self.assertTrue(result[0].char_size == 10) - self.assertTrue(result[0].numeric_scale is None) - self.assertTrue(result[0].numeric_precision is None) - - @mock.patch("dbt.adapters.redshift.impl.RedshiftAdapter._get_cursor") - def test_get_columns_in_relation_int(self, mock_cursor): - mock_relation = mock.MagicMock(database="somedb", schema="someschema", identifier="iden") - mock_cursor.return_value.get_columns.return_value = [ - ("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "int_column", 4, "integer", 10, 10, 20) - ] - result = self.adapter.get_columns_in_relation(mock_relation) - self.assertTrue(result[0].column == "int_column") - self.assertTrue(result[0].dtype == "integer") - self.assertTrue(result[0].char_size is None) - self.assertTrue(result[0].numeric_scale == 20) - self.assertTrue(result[0].numeric_precision == 10) - class TestRedshiftAdapterConversions(TestAdapterConversions): def test_convert_text_type(self): From 2a35c33442cdd41cb9a2fe51dacb829c54f39a98 Mon Sep 17 00:00:00 2001 From: Jessie Chen Date: Fri, 12 May 2023 17:42:40 -0700 Subject: [PATCH 11/21] override_get_schema_exists --- dbt/adapters/redshift/connections.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dbt/adapters/redshift/connections.py b/dbt/adapters/redshift/connections.py index 15e410f8c..79d33e0c1 100644 --- a/dbt/adapters/redshift/connections.py +++ b/dbt/adapters/redshift/connections.py @@ -75,7 +75,7 @@ class RedshiftCredentials(Credentials): autocreate: bool = False db_groups: List[str] = field(default_factory=list) ra3_node: Optional[bool] = False - connect_timeout: int = 30 + connect_timeout: Optional[int] = None role: Optional[str] = None sslmode: Optional[str] = None retries: int = 1 From c4e030b11a23639afb74fd7f3ee8ba3b34466c22 Mon Sep 17 00:00:00 2001 From: Jessie Chen Date: Tue, 16 May 2023 09:10:02 -0700 Subject: [PATCH 12/21] functional test check_schema_exists --- dbt/adapters/redshift/impl.py | 1 + .../adapter/test_adapter_methods.py | 87 +++++++++++++++++++ 2 files changed, 88 insertions(+) create mode 100644 tests/functional/adapter/test_adapter_methods.py diff --git a/dbt/adapters/redshift/impl.py b/dbt/adapters/redshift/impl.py index 0ac247686..6fe4b4726 100644 --- a/dbt/adapters/redshift/impl.py +++ b/dbt/adapters/redshift/impl.py @@ -116,6 +116,7 @@ def timestamp_add_sql(self, add_to: str, number: int = 1, interval: str = "hour" def _get_cursor(self): return self.connections.get_thread_connection().handle.cursor() + @available def list_schemas(self, database: str, schema: Optional[str] = None) -> List[str]: cursor = self._get_cursor() results = [] diff --git a/tests/functional/adapter/test_adapter_methods.py b/tests/functional/adapter/test_adapter_methods.py new file mode 100644 index 000000000..d35787132 --- /dev/null +++ b/tests/functional/adapter/test_adapter_methods.py @@ -0,0 +1,87 @@ +import pytest + +from dbt.tests.util import run_dbt, check_relations_equal +from dbt.tests.fixtures.project import write_project_files + +models__upstream_sql = """ +select 1 as id + +""" + +models__expected_sql = """ + +select 2 as id + +""" + +models__invalid_schema_model = """ + +{% set upstream = ref('upstream') %} + +{% set existing = adapter.check_schema_exists(upstream.database, "doesnotexist") %} +{% if existing == False %} +select 2 as id +{% else %} +select 1 as id +{% endif %} + +""" + +models__valid_schema_model = """ + +{% set upstream = ref('upstream') %} + +{% set existing = adapter.check_schema_exists(upstream.database, upstream.schema) %} +{% if existing == True %} +select 2 as id +{% else %} +select 1 as id +{% endif %} + +""" + + +class BaseAdapterMethod: + @pytest.fixture(scope="class") + def models(self): + return { + "upstream.sql": models__upstream_sql, + "expected.sql": models__expected_sql, + "invalid_schema.sql": models__invalid_schema_model, + "valid_schema.sql": models__valid_schema_model, + } + + @pytest.fixture(scope="class") + def project_files( + self, + project_root, + tests, + models, + ): + write_project_files(project_root, "models", models) + + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "name": "adapter_methods", + } + + # snowflake need all tables in CAP name + @pytest.fixture(scope="class") + def equal_tables(self): + return ["invalid_schema", "expected"] + + @pytest.fixture(scope="class") + def equal_tables2(self): + return ["valid_schema", "expected"] + + def test_adapter_methods(self, project, equal_tables): + run_dbt(["compile"]) # trigger any compile-time issues + result = run_dbt() + assert len(result) == 4 + check_relations_equal(project.adapter, equal_tables) + check_relations_equal(project.adapter, ["valid_schema", "expected"]) + + +class TestBaseCaching(BaseAdapterMethod): + pass From b43b24aded13991e1616195ed73caf19c7db9cb4 Mon Sep 17 00:00:00 2001 From: Jessie Chen Date: Wed, 17 May 2023 10:33:26 -0700 Subject: [PATCH 13/21] override check_schema_Exists --- dbt/adapters/redshift/connections.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dbt/adapters/redshift/connections.py b/dbt/adapters/redshift/connections.py index 79d33e0c1..d777fd77e 100644 --- a/dbt/adapters/redshift/connections.py +++ b/dbt/adapters/redshift/connections.py @@ -75,7 +75,7 @@ class RedshiftCredentials(Credentials): autocreate: bool = False db_groups: List[str] = field(default_factory=list) ra3_node: Optional[bool] = False - connect_timeout: Optional[int] = None + connect_timeout: Optional[int] = 30 role: Optional[str] = None sslmode: Optional[str] = None retries: int = 1 From 8e1e02836efd2521164209305a52442c500821b0 Mon Sep 17 00:00:00 2001 From: Jessie Chen Date: Wed, 17 May 2023 10:37:43 -0700 Subject: [PATCH 14/21] override check_schema_Exists --- dbt/adapters/redshift/connections.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dbt/adapters/redshift/connections.py b/dbt/adapters/redshift/connections.py index d777fd77e..15e410f8c 100644 --- a/dbt/adapters/redshift/connections.py +++ b/dbt/adapters/redshift/connections.py @@ -75,7 +75,7 @@ class RedshiftCredentials(Credentials): autocreate: bool = False db_groups: List[str] = field(default_factory=list) ra3_node: Optional[bool] = False - connect_timeout: Optional[int] = 30 + connect_timeout: int = 30 role: Optional[str] = None sslmode: Optional[str] = None retries: int = 1 From 99ef2f01825ea04f3d3fae2ac697a2fd2433e88e Mon Sep 17 00:00:00 2001 From: Jessie Chen Date: Wed, 17 May 2023 11:41:13 -0700 Subject: [PATCH 15/21] revert changes in adapters.sql --- dbt/include/redshift/macros/adapters.sql | 70 +----------------------- 1 file changed, 1 insertion(+), 69 deletions(-) diff --git a/dbt/include/redshift/macros/adapters.sql b/dbt/include/redshift/macros/adapters.sql index 20d13daf0..7adf3a077 100644 --- a/dbt/include/redshift/macros/adapters.sql +++ b/dbt/include/redshift/macros/adapters.sql @@ -198,61 +198,12 @@ ), - inbound_datashare as ( - select object_name - from svv_datashare_objects - where share_type = 'INBOUND' - and object_type in ('table', 'view', 'materialized view', 'late binding view') - and object_name = '{{ relation.schema }}' || '.' || '{{ relation.identifier }}' - ), - - - - data_share_columns as ( - select - ordinal_position as columnnum, - schema_name as schemaname, - column_name as columnname, - case - when data_type ilike 'character varying%' then - 'character varying' - when data_type ilike 'numeric%' then 'numeric' - else data_type - end as col_type, - case - when data_type like 'character%' - then nullif(REGEXP_SUBSTR(character_maximum_length, '[0-9]+'), '')::int - else null - end as character_maximum_length, - case - when data_type like 'numeric%' or data_type ilike 'integer%' - then nullif( - SPLIT_PART(REGEXP_SUBSTR(numeric_precision, '[0-9,]+'), ',', 1), - '')::int - end as numeric_precision, - case - when data_type like 'numeric%' or data_type ilike 'integer%' - then nullif( - SPLIT_PART(REGEXP_SUBSTR(numeric_scale, '[0-9,]+'), ',', 2), - '')::int - else null - end as numeric_scale - from svv_all_columns - inner join inbound_datashare on - inbound_datashare.object_name = svv_all_columns.schema_name || '.' || svv_all_columns.table_name - where svv_all_columns.table_name = '{{ relation.identifier }}' - and svv_all_columns.schema_name = '{{ relation.schema }}' - and svv_all_columns.database_name = '{{ relation.database }}' - ), - unioned as ( select * from bound_views union all select * from unbound_views union all select * from external_views - union all - select * from data_share_columns ) select @@ -274,26 +225,7 @@ {% macro redshift__list_relations_without_caching(schema_relation) %} - {% call statement('list_relations_without_caching', fetch_result=True) -%} - select - database_name as database, - table_name as name, - schema_name as schema, - LOWER(table_type) as type - from SVV_REDSHIFT_TABLES - where schema_name ilike '{{ schema_relation.schema }}' - and database_name ilike '{{ schema_relation.database }}' - and table_type IN ('TABLE','VIEW') - union all - select - '{{ schema_relation.database }}' as database, - tablename as name, - schemaname as schema, - 'external_tables' as type - from SVV_EXTERNAL_TABLES - where schemaname ilike '{{ schema_relation.schema }}' - {% endcall %} - {{ return(load_result('list_relations_without_caching').table) }} + {{ return(postgres__list_relations_without_caching(schema_relation)) }} {% endmacro %} From 5c16f0c6c0949b063117ac9ba69de293c747a2b9 Mon Sep 17 00:00:00 2001 From: Jessie Chen Date: Wed, 17 May 2023 11:53:27 -0700 Subject: [PATCH 16/21] remove available decorator for list_schemas --- dbt/adapters/redshift/impl.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dbt/adapters/redshift/impl.py b/dbt/adapters/redshift/impl.py index 6fe4b4726..0ac247686 100644 --- a/dbt/adapters/redshift/impl.py +++ b/dbt/adapters/redshift/impl.py @@ -116,7 +116,6 @@ def timestamp_add_sql(self, add_to: str, number: int = 1, interval: str = "hour" def _get_cursor(self): return self.connections.get_thread_connection().handle.cursor() - @available def list_schemas(self, database: str, schema: Optional[str] = None) -> List[str]: cursor = self._get_cursor() results = [] From 579c21a9c9dedbaada180f5dd62177c2decec8d7 Mon Sep 17 00:00:00 2001 From: Jessie Chen Date: Mon, 24 Jul 2023 12:19:13 -0700 Subject: [PATCH 17/21] merge with main manually --- dbt/adapters/redshift/__init__.py | 13 +- dbt/adapters/redshift/connections.py | 148 ++++++++++++------ .../adapter/test_adapter_methods.py | 111 +++++++++---- 3 files changed, 188 insertions(+), 84 deletions(-) diff --git a/dbt/adapters/redshift/__init__.py b/dbt/adapters/redshift/__init__.py index 09f816260..0fdfac524 100644 --- a/dbt/adapters/redshift/__init__.py +++ b/dbt/adapters/redshift/__init__.py @@ -1,14 +1,15 @@ -from dbt.adapters.redshift.connections import RedshiftConnectionManager # noqa -from dbt.adapters.redshift.connections import RedshiftCredentials - from dbt.adapters.redshift.column import RedshiftColumn # noqa +from dbt.adapters.base import AdapterPlugin + +from dbt.adapters.redshift.connections import ( # noqa: F401 + RedshiftConnectionManager, + RedshiftCredentials, +) from dbt.adapters.redshift.relation import RedshiftRelation # noqa: F401 from dbt.adapters.redshift.impl import RedshiftAdapter +from dbt.include import redshift -from dbt.adapters.base import AdapterPlugin # type: ignore -from dbt.include import redshift # type: ignore - Plugin: AdapterPlugin = AdapterPlugin( adapter=RedshiftAdapter, # type: ignore credentials=RedshiftCredentials, diff --git a/dbt/adapters/redshift/connections.py b/dbt/adapters/redshift/connections.py index 15e410f8c..7bbcffe5f 100644 --- a/dbt/adapters/redshift/connections.py +++ b/dbt/adapters/redshift/connections.py @@ -7,19 +7,29 @@ import agate import sqlparse import redshift_connector -import urllib.request -import json from redshift_connector.utils.oids import get_datatype_name from dbt.adapters.sql import SQLConnectionManager from dbt.contracts.connection import AdapterResponse, Connection, Credentials +from dbt.contracts.util import Replaceable +from dbt.dataclass_schema import FieldEncoder, dbtClassMixin, StrEnum, ValidationError from dbt.events import AdapterLogger -import dbt.exceptions +from dbt.exceptions import DbtRuntimeError, CompilationError import dbt.flags -from dbt.dataclass_schema import FieldEncoder, dbtClassMixin, StrEnum from dbt.helper_types import Port +class SSLConfigError(CompilationError): + def __init__(self, exc: ValidationError): + self.exc = exc + super().__init__(msg=self.get_message()) + + def get_message(self) -> str: + validator_msg = self.validator_error_message(self.exc) + msg = f"Could not parse SSL config: {validator_msg}" + return msg + + logger = AdapterLogger("Redshift") @@ -38,26 +48,69 @@ def json_schema(self): dbtClassMixin.register_field_encoders({IAMDuration: IAMDurationEncoder()}) -def _get_aws_regions(): - # Extract the prefixes from the AWS IP ranges JSON to determine the available regions - url = "https://ip-ranges.amazonaws.com/ip-ranges.json" - response = urllib.request.urlopen(url) - data = json.loads(response.read().decode()) - regions = set() +class RedshiftConnectionMethod(StrEnum): + DATABASE = "database" + IAM = "iam" - for prefix in data["prefixes"]: - if prefix["service"] == "AMAZON": - regions.add(prefix["region"]) - return regions +class UserSSLMode(StrEnum): + disable = "disable" + allow = "allow" + prefer = "prefer" + require = "require" + verify_ca = "verify-ca" + verify_full = "verify-full" + @classmethod + def default(cls) -> "UserSSLMode": + # default for `psycopg2`, which aligns with dbt-redshift 1.4 and provides backwards compatibility + return cls.prefer -_AVAILABLE_AWS_REGIONS = _get_aws_regions() +class RedshiftSSLMode(StrEnum): + verify_ca = "verify-ca" + verify_full = "verify-full" -class RedshiftConnectionMethod(StrEnum): - DATABASE = "database" - IAM = "iam" + +SSL_MODE_TRANSLATION = { + UserSSLMode.disable: None, + UserSSLMode.allow: RedshiftSSLMode.verify_ca, + UserSSLMode.prefer: RedshiftSSLMode.verify_ca, + UserSSLMode.require: RedshiftSSLMode.verify_ca, + UserSSLMode.verify_ca: RedshiftSSLMode.verify_ca, + UserSSLMode.verify_full: RedshiftSSLMode.verify_full, +} + + +@dataclass +class RedshiftSSLConfig(dbtClassMixin, Replaceable): # type: ignore + ssl: bool = True + sslmode: Optional[RedshiftSSLMode] = SSL_MODE_TRANSLATION[UserSSLMode.default()] + + @classmethod + def parse(cls, user_sslmode: UserSSLMode) -> "RedshiftSSLConfig": + try: + raw_redshift_ssl = { + "ssl": user_sslmode != UserSSLMode.disable, + "sslmode": SSL_MODE_TRANSLATION[user_sslmode], + } + cls.validate(raw_redshift_ssl) + except ValidationError as exc: + raise SSLConfigError(exc) + + redshift_ssl = cls.from_dict(raw_redshift_ssl) + + if redshift_ssl.ssl: + message = ( + f"Establishing connection using ssl with `sslmode` set to '{user_sslmode}'." + f"To connect without ssl, set `sslmode` to 'disable'." + ) + else: + message = "Establishing connection without ssl." + + logger.debug(message) + + return redshift_ssl @dataclass @@ -75,12 +128,14 @@ class RedshiftCredentials(Credentials): autocreate: bool = False db_groups: List[str] = field(default_factory=list) ra3_node: Optional[bool] = False - connect_timeout: int = 30 + connect_timeout: Optional[int] = None role: Optional[str] = None - sslmode: Optional[str] = None + sslmode: Optional[UserSSLMode] = field(default_factory=UserSSLMode.default) retries: int = 1 - region: Optional[str] = None # if not provided, will be determined from host current_db_only: Optional[bool] = False + region: Optional[str] = None + # opt-in by default per team deliberation on https://peps.python.org/pep-0249/#autocommit + autocommit: Optional[bool] = True _ALIASES = {"dbname": "database", "pass": "password"} @@ -91,15 +146,25 @@ def type(self): def _connection_keys(self): return ( "host", - "port", "user", + "port", "database", - "schema", "method", "cluster_id", "iam_profile", + "schema", + "sslmode", + "region", "sslmode", "region", + "iam_profile", + "autocreate", + "db_groups", + "ra3_node", + "connect_timeout", + "role", + "retries", + "autocommit", ) @property @@ -107,13 +172,6 @@ def unique_field(self) -> str: return self.host -def _is_valid_region(region): - if region is None or len(region) == 0: - logger.warning("Couldn't determine AWS regions. Skipping validation to avoid blocking.") - return True - return region in _AVAILABLE_AWS_REGIONS - - class RedshiftConnectMethodFactory: credentials: RedshiftCredentials @@ -125,33 +183,16 @@ def get_connect_method(self): kwargs = { "host": self.credentials.host, "database": self.credentials.database, - "port": self.credentials.port if self.credentials.port else 5439, + "port": int(self.credentials.port) if self.credentials.port else int(5439), "auto_create": self.credentials.autocreate, "db_groups": self.credentials.db_groups, "region": self.credentials.region, "timeout": self.credentials.connect_timeout, "database_metadata_current_db_only": self.credentials.current_db_only, } - if kwargs["region"] is None: - logger.debug("No region provided, attempting to determine from host.") - try: - region_value = self.credentials.host.split(".")[2] - except IndexError: - raise dbt.exceptions.FailedToConnectError( - "No region provided and unable to determine region from host: " - "{}".format(self.credentials.host) - ) - - kwargs["region"] = region_value - - # Validate the set region - if not _is_valid_region(kwargs["region"]): - raise dbt.exceptions.FailedToConnectError( - "Invalid region provided: {}".format(kwargs["region"]) - ) - if self.credentials.sslmode: - kwargs["sslmode"] = self.credentials.sslmode + redshift_ssl_config = RedshiftSSLConfig.parse(self.credentials.sslmode) + kwargs.update(redshift_ssl_config.to_dict()) # Support missing 'method' for backwards compatibility if method == RedshiftConnectionMethod.DATABASE or method is None: @@ -169,6 +210,8 @@ def connect(): password=self.credentials.password, **kwargs, ) + if self.credentials.autocommit: + c.autocommit = True if self.credentials.role: c.cursor().execute("set role {}".format(self.credentials.role)) return c @@ -191,6 +234,8 @@ def connect(): profile=self.credentials.iam_profile, **kwargs, ) + if self.credentials.autocommit: + c.autocommit = True if self.credentials.role: c.cursor().execute("set role {}".format(self.credentials.role)) return c @@ -310,6 +355,7 @@ def execute( fetch: bool = False, limit: Optional[int] = None, ) -> Tuple[AdapterResponse, agate.Table]: + sql = self._add_query_comment(sql) _, cursor = self.add_query(sql, auto_begin) response = self.get_response(cursor) if fetch: @@ -342,7 +388,7 @@ def add_query(self, sql, auto_begin=True, bindings=None, abridge_sql_log=False): if cursor is None: conn = self.get_thread_connection() conn_name = conn.name if conn and conn.name else "" - raise dbt.exceptions.DbtRuntimeError(f"Tried to run invalid SQL: {sql} on {conn_name}") + raise DbtRuntimeError(f"Tried to run invalid SQL: {sql} on {conn_name}") return connection, cursor diff --git a/tests/functional/adapter/test_adapter_methods.py b/tests/functional/adapter/test_adapter_methods.py index d35787132..043d55c21 100644 --- a/tests/functional/adapter/test_adapter_methods.py +++ b/tests/functional/adapter/test_adapter_methods.py @@ -3,18 +3,20 @@ from dbt.tests.util import run_dbt, check_relations_equal from dbt.tests.fixtures.project import write_project_files -models__upstream_sql = """ -select 1 as id - -""" -models__expected_sql = """ +tests__get_relation_invalid = """ +{% set upstream = ref('upstream') %} +{% set relations = adapter.get_relation(database=upstream.database, schema=upstream.schema, identifier="doesnotexist") %} +{% set limit_query = 0 %} +{% if relations.identifier %} + {% set limit_query = 1 %} +{% endif %} -select 2 as id +select 1 as id limit {{ limit_query }} """ -models__invalid_schema_model = """ +models__get_invalid_schema = """ {% set upstream = ref('upstream') %} @@ -24,10 +26,9 @@ {% else %} select 1 as id {% endif %} - """ -models__valid_schema_model = """ +models__get_valid_schema = """ {% set upstream = ref('upstream') %} @@ -37,27 +38,88 @@ {% else %} select 1 as id {% endif %} +""" + +models__upstream_sql = """ +select 1 as id """ +models__expected_sql = """ +select 1 as valid_relation + +""" + +models__model_sql = """ + +{% set upstream = ref('upstream') %} + +select * from {{ upstream }} + +""" + +models__call_get_relation = """ + +{% set model = ref('model') %} + +{% set relation = adapter.get_relation(database=model.database, schema=model.schema, identifier=model.identifier) %} +{% if relation.identifier == model.identifier %} + +select 1 as valid_relation + +{% else %} + +select 0 as valid_relation + +{% endif %} + +""" + +models__get_relation_type = """ + +{% set base_view = ref('base_view') %} + +{% set relation = adapter.get_relation(database=base_view.database, schema=base_view.schema, identifier=base_view.identifier) %} +{% if relation.type == 'view' %} + +select 1 as valid_type + +{% else %} + +select 0 as valid_type + +{% endif %} + +""" + + +class RedshiftAdapterMethod: + @pytest.fixture(scope="class") + def tests(self): + return {"get_relation_invalid.sql": tests__get_relation_invalid} -class BaseAdapterMethod: @pytest.fixture(scope="class") def models(self): return { "upstream.sql": models__upstream_sql, "expected.sql": models__expected_sql, - "invalid_schema.sql": models__invalid_schema_model, - "valid_schema.sql": models__valid_schema_model, + "model.sql": models__model_sql, + "call_get_relation.sql": models__call_get_relation, + "base_view.sql": "{{ config(bind=True) }} select * from {{ ref('model') }}", + "get_relation_type.sql": models__get_relation_type, + "expected_type.sql": "select 1 as valid_type", + "get_invalid_schema.sql": models__get_invalid_schema, + "get_valid_schema.sql": models__get_invalid_schema, + "get_schema_expected.sql": "select 2 as id", } - @pytest.fixture(scope="class") def project_files( self, project_root, tests, models, ): + write_project_files(project_root, "tests", tests) write_project_files(project_root, "models", models) @pytest.fixture(scope="class") @@ -66,22 +128,17 @@ def project_config_update(self): "name": "adapter_methods", } - # snowflake need all tables in CAP name - @pytest.fixture(scope="class") - def equal_tables(self): - return ["invalid_schema", "expected"] - - @pytest.fixture(scope="class") - def equal_tables2(self): - return ["valid_schema", "expected"] - - def test_adapter_methods(self, project, equal_tables): + def test_adapter_methods(self, project): run_dbt(["compile"]) # trigger any compile-time issues result = run_dbt() - assert len(result) == 4 - check_relations_equal(project.adapter, equal_tables) - check_relations_equal(project.adapter, ["valid_schema", "expected"]) + assert len(result) == 10 + + run_dbt(["test"]) + check_relations_equal(project.adapter, ["call_get_relation", "expected"]) + check_relations_equal(project.adapter, ["get_relation_type", "expected_type"]) + check_relations_equal(project.adapter, ["get_invalid_schema", "get_schema_expected"]) + check_relations_equal(project.adapter, ["get_valid_schema", "get_schema_expected"]) -class TestBaseCaching(BaseAdapterMethod): +class TestRedshiftAdapterMethod(RedshiftAdapterMethod): pass From b2de60afdc4431ec736f8baff1d20abedd904d07 Mon Sep 17 00:00:00 2001 From: Jessie Chen Date: Mon, 24 Jul 2023 14:52:45 -0700 Subject: [PATCH 18/21] adding functional tests for new check_schema_exists method --- dbt/adapters/redshift/__init__.py | 1 - dbt/adapters/redshift/connections.py | 4 +--- tests/unit/test_redshift_adapter.py | 6 ++++++ 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/dbt/adapters/redshift/__init__.py b/dbt/adapters/redshift/__init__.py index 0d1ff026d..70e77ef5e 100644 --- a/dbt/adapters/redshift/__init__.py +++ b/dbt/adapters/redshift/__init__.py @@ -1,4 +1,3 @@ -from dbt.adapters.redshift.column import RedshiftColumn # noqa from dbt.adapters.base import AdapterPlugin from dbt.adapters.redshift.connections import ( # noqa: F401 diff --git a/dbt/adapters/redshift/connections.py b/dbt/adapters/redshift/connections.py index 7bbcffe5f..d129f40c8 100644 --- a/dbt/adapters/redshift/connections.py +++ b/dbt/adapters/redshift/connections.py @@ -155,9 +155,6 @@ def _connection_keys(self): "schema", "sslmode", "region", - "sslmode", - "region", - "iam_profile", "autocreate", "db_groups", "ra3_node", @@ -165,6 +162,7 @@ def _connection_keys(self): "role", "retries", "autocommit", + "current_db_only", ) @property diff --git a/tests/unit/test_redshift_adapter.py b/tests/unit/test_redshift_adapter.py index b2f4f878b..c4c8c85fd 100644 --- a/tests/unit/test_redshift_adapter.py +++ b/tests/unit/test_redshift_adapter.py @@ -143,6 +143,7 @@ def test_conn_timeout_30(self): db_groups=[], region=None, timeout=30, + database_metadata_current_db_only=False, **DEFAULT_SSL_CONFIG, ) @@ -313,6 +314,7 @@ def test_sslmode_disable(self): timeout=None, ssl=False, sslmode=None, + database_metadata_current_db_only=False, ) @mock.patch("redshift_connector.connect", Mock()) @@ -332,6 +334,7 @@ def test_sslmode_allow(self): timeout=None, ssl=True, sslmode="verify-ca", + database_metadata_current_db_only=False, ) @mock.patch("redshift_connector.connect", Mock()) @@ -351,6 +354,7 @@ def test_sslmode_verify_full(self): timeout=None, ssl=True, sslmode="verify-full", + database_metadata_current_db_only=False, ) @mock.patch("redshift_connector.connect", Mock()) @@ -370,6 +374,7 @@ def test_sslmode_verify_ca(self): timeout=None, ssl=True, sslmode="verify-ca", + database_metadata_current_db_only=False, ) @mock.patch("redshift_connector.connect", Mock()) @@ -389,6 +394,7 @@ def test_sslmode_prefer(self): timeout=None, ssl=True, sslmode="verify-ca", + database_metadata_current_db_only=False, ) @mock.patch("redshift_connector.connect", Mock()) From fd7aaf6607e4fc38ae201e384678ca766feb3feb Mon Sep 17 00:00:00 2001 From: Jessie Chen Date: Tue, 25 Jul 2023 11:14:50 -0700 Subject: [PATCH 19/21] remove current_db_only param --- dbt/adapters/redshift/connections.py | 3 --- tests/unit/test_redshift_adapter.py | 15 --------------- 2 files changed, 18 deletions(-) diff --git a/dbt/adapters/redshift/connections.py b/dbt/adapters/redshift/connections.py index d129f40c8..22b69f75a 100644 --- a/dbt/adapters/redshift/connections.py +++ b/dbt/adapters/redshift/connections.py @@ -132,7 +132,6 @@ class RedshiftCredentials(Credentials): role: Optional[str] = None sslmode: Optional[UserSSLMode] = field(default_factory=UserSSLMode.default) retries: int = 1 - current_db_only: Optional[bool] = False region: Optional[str] = None # opt-in by default per team deliberation on https://peps.python.org/pep-0249/#autocommit autocommit: Optional[bool] = True @@ -162,7 +161,6 @@ def _connection_keys(self): "role", "retries", "autocommit", - "current_db_only", ) @property @@ -186,7 +184,6 @@ def get_connect_method(self): "db_groups": self.credentials.db_groups, "region": self.credentials.region, "timeout": self.credentials.connect_timeout, - "database_metadata_current_db_only": self.credentials.current_db_only, } redshift_ssl_config = RedshiftSSLConfig.parse(self.credentials.sslmode) diff --git a/tests/unit/test_redshift_adapter.py b/tests/unit/test_redshift_adapter.py index c4c8c85fd..f0ac40dd8 100644 --- a/tests/unit/test_redshift_adapter.py +++ b/tests/unit/test_redshift_adapter.py @@ -75,7 +75,6 @@ def test_implicit_database_conn(self): port=5439, auto_create=False, db_groups=[], - database_metadata_current_db_only=False, timeout=None, region=None, **DEFAULT_SSL_CONFIG, @@ -95,7 +94,6 @@ def test_explicit_region_with_database_conn(self): port=5439, auto_create=False, db_groups=[], - database_metadata_current_db_only=False, region=None, timeout=None, **DEFAULT_SSL_CONFIG, @@ -124,7 +122,6 @@ def test_explicit_iam_conn_without_profile(self): db_groups=[], profile=None, port=5439, - database_metadata_current_db_only=False, **DEFAULT_SSL_CONFIG, ) @@ -143,7 +140,6 @@ def test_conn_timeout_30(self): db_groups=[], region=None, timeout=30, - database_metadata_current_db_only=False, **DEFAULT_SSL_CONFIG, ) @@ -173,7 +169,6 @@ def test_explicit_iam_conn_with_profile(self): profile="test", timeout=None, port=5439, - database_metadata_current_db_only=False, **DEFAULT_SSL_CONFIG, ) @@ -201,7 +196,6 @@ def test_explicit_iam_serverless_with_profile(self): profile="test", timeout=None, port=5439, - database_metadata_current_db_only=False, **DEFAULT_SSL_CONFIG, ) @@ -231,7 +225,6 @@ def test_explicit_region(self): profile="test", timeout=None, port=5439, - database_metadata_current_db_only=False, **DEFAULT_SSL_CONFIG, ) @@ -262,7 +255,6 @@ def test_explicit_region_failure(self): profile="test", timeout=None, port=5439, - database_metadata_current_db_only=False, **DEFAULT_SSL_CONFIG, ) @@ -293,7 +285,6 @@ def test_explicit_invalid_region(self): profile="test", timeout=None, port=5439, - database_metadata_current_db_only=False, **DEFAULT_SSL_CONFIG, ) @@ -314,7 +305,6 @@ def test_sslmode_disable(self): timeout=None, ssl=False, sslmode=None, - database_metadata_current_db_only=False, ) @mock.patch("redshift_connector.connect", Mock()) @@ -334,7 +324,6 @@ def test_sslmode_allow(self): timeout=None, ssl=True, sslmode="verify-ca", - database_metadata_current_db_only=False, ) @mock.patch("redshift_connector.connect", Mock()) @@ -354,7 +343,6 @@ def test_sslmode_verify_full(self): timeout=None, ssl=True, sslmode="verify-full", - database_metadata_current_db_only=False, ) @mock.patch("redshift_connector.connect", Mock()) @@ -374,7 +362,6 @@ def test_sslmode_verify_ca(self): timeout=None, ssl=True, sslmode="verify-ca", - database_metadata_current_db_only=False, ) @mock.patch("redshift_connector.connect", Mock()) @@ -394,7 +381,6 @@ def test_sslmode_prefer(self): timeout=None, ssl=True, sslmode="verify-ca", - database_metadata_current_db_only=False, ) @mock.patch("redshift_connector.connect", Mock()) @@ -421,7 +407,6 @@ def test_serverless_iam_failure(self): user="", profile="test", port=5439, - database_metadata_current_db_only=False, timeout=None, **DEFAULT_SSL_CONFIG, ) From a9885551d321198538f6c93ecb6311daab78bf36 Mon Sep 17 00:00:00 2001 From: Jessie Chen <121250701+jiezhen-chen@users.noreply.github.com> Date: Tue, 25 Jul 2023 12:57:08 -0700 Subject: [PATCH 20/21] Update Features-20230407-104723.yaml --- .changes/unreleased/Features-20230407-104723.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.changes/unreleased/Features-20230407-104723.yaml b/.changes/unreleased/Features-20230407-104723.yaml index 6f57a9f18..41fa799da 100644 --- a/.changes/unreleased/Features-20230407-104723.yaml +++ b/.changes/unreleased/Features-20230407-104723.yaml @@ -2,4 +2,4 @@ kind: Features time: 2023-04-07T10:47:23.105369-07:00 custom: Author: jiezhen-chen - Issue: 17 179 217 + Issue: From 68f2c730c59245a1ea8b5a80bca83557f2b29144 Mon Sep 17 00:00:00 2001 From: Jessie Chen <121250701+jiezhen-chen@users.noreply.github.com> Date: Tue, 25 Jul 2023 14:17:24 -0700 Subject: [PATCH 21/21] Update Features-20230407-104723.yaml --- .changes/unreleased/Features-20230407-104723.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.changes/unreleased/Features-20230407-104723.yaml b/.changes/unreleased/Features-20230407-104723.yaml index 41fa799da..edbb2369a 100644 --- a/.changes/unreleased/Features-20230407-104723.yaml +++ b/.changes/unreleased/Features-20230407-104723.yaml @@ -2,4 +2,4 @@ kind: Features time: 2023-04-07T10:47:23.105369-07:00 custom: Author: jiezhen-chen - Issue: + Issue: 555