From 7f12fa80bac21f349247ac219259badf0417e451 Mon Sep 17 00:00:00 2001 From: nicor88 <6278547+nicor88@users.noreply.github.com> Date: Fri, 19 May 2023 19:57:17 +0200 Subject: [PATCH 01/10] chore: add svdimchenko as code owner (#318) --- .github/CODEOWNERS | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 8d8f2945..c6354a37 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1 +1 @@ -* @jessedobbelaere @Jrmyy @mattiamatrix @nicor88 @thenaturalist +* @jessedobbelaere @Jrmyy @mattiamatrix @nicor88 @svdimchenko @thenaturalist From efe16199af45b8ae1e8372dd0a476b1de39ada1f Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 22 May 2023 07:01:45 +0200 Subject: [PATCH 02/10] chore: Update moto requirement from ~=4.1.9 to ~=4.1.10 (#319) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- dev-requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index 39bc8838..07b55264 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -4,7 +4,7 @@ dbt-tests-adapter~=1.5.0 flake8~=5.0 Flake8-pyproject~=1.2 isort~=5.11 -moto~=4.1.9 +moto~=4.1.10 pre-commit~=2.21 pyparsing~=3.0.9 pytest~=7.3 From b085ad1d44d9cc248c16b5abc4ec01f36d7de1c6 Mon Sep 17 00:00:00 2001 From: nicor88 <6278547+nicor88@users.noreply.github.com> Date: Mon, 22 May 2023 09:38:29 +0200 Subject: [PATCH 03/10] chore: add reusable workflow for functional tests (#317) --- .../workflows/functional-tests-workflow.yml | 51 ++++++++++++ .github/workflows/functional-tests.yml | 80 +++++-------------- 2 files changed, 69 insertions(+), 62 deletions(-) create mode 100644 .github/workflows/functional-tests-workflow.yml diff --git a/.github/workflows/functional-tests-workflow.yml b/.github/workflows/functional-tests-workflow.yml new file mode 100644 index 00000000..d24e0535 --- /dev/null +++ b/.github/workflows/functional-tests-workflow.yml @@ -0,0 +1,51 @@ +# reusable workflow to be called from the main workflow +name: functional-tests-workflow + +on: + workflow_call: + inputs: + checkout-ref: + required: true + type: string + checkout-repository: + required: true + type: string + aws-region: + required: true + type: string + +env: + DBT_TEST_ATHENA_DATABASE: awsdatacatalog + DBT_TEST_ATHENA_SCHEMA: dbt-tests + DBT_TEST_ATHENA_WORK_GROUP: athena-dbt-tests + DBT_TEST_ATHENA_THREADS: 16 + DBT_TEST_ATHENA_POLL_INTERVAL: 0.5 + +jobs: + functional-tests: + name: Functional Tests + runs-on: ubuntu-latest + env: + DBT_TEST_ATHENA_S3_STAGING_DIR: s3://dbt-athena-query-results-${{ inputs.aws-region }} + DBT_TEST_ATHENA_REGION_NAME: ${{ inputs.aws-region }} + steps: + - name: Checkout + uses: actions/checkout@v3 + with: + ref: ${{ inputs.checkout-ref }} + repository: ${{ inputs.checkout-repository }} + - name: Set up Python 3.10 + uses: actions/setup-python@v4 + with: + python-version: '3.10' + - name: Install dependencies + run: | + make install_deps + - name: Configure AWS credentials from Test account + uses: aws-actions/configure-aws-credentials@v2 + with: + role-to-assume: arn:aws:iam::${{ secrets.AWS_ACCOUNT_ID }}:role/${{ secrets.ASSUMABLE_ROLE_NAME }} + aws-region: ${{ inputs.aws-region }} + - name: Functional Test + run: | + pytest -n 8 tests/functional diff --git a/.github/workflows/functional-tests.yml b/.github/workflows/functional-tests.yml index 732e29ef..6ce86ba9 100644 --- a/.github/workflows/functional-tests.yml +++ b/.github/workflows/functional-tests.yml @@ -2,86 +2,42 @@ name: functional-tests on: # we use pull_request_target to run the CI also for forks - pull_request_target: # Please read https://securitylab.github.com/research/github-actions-preventing-pwn-requests/ before using + pull_request_target: types: [opened, reopened, synchronize, labeled] push: branches: [main] -env: - DBT_TEST_ATHENA_DATABASE: awsdatacatalog - DBT_TEST_ATHENA_SCHEMA: dbt-tests - DBT_TEST_ATHENA_WORK_GROUP: athena-dbt-tests - DBT_TEST_ATHENA_THREADS: 16 - DBT_TEST_ATHENA_POLL_INTERVAL: 0.5 - jobs: + # workflow that is invoked when for PRs with labels 'enable-functional-tests' functional-tests-pr: - name: Functional Test - PR - # trigger on PRs with label 'enable-ci' + name: Functional Tests - PR if: contains(github.event.pull_request.labels.*.name, 'enable-functional-tests') - runs-on: ubuntu-latest + uses: ./.github/workflows/functional-tests-workflow.yml strategy: matrix: aws-region: ['us-east-1', 'eu-west-1', 'eu-west-2', 'eu-central-1'] permissions: id-token: write contents: read - env: - DBT_TEST_ATHENA_S3_STAGING_DIR: s3://dbt-athena-query-results-${{ matrix.aws-region }} - DBT_TEST_ATHENA_REGION_NAME: ${{ matrix.aws-region }} - steps: - - name: Checkout - uses: actions/checkout@v3 - with: - # this is needed to checkout the PR branch - ref: ${{ github.event.pull_request.head.ref }} - repository: ${{ github.event.pull_request.head.repo.full_name }} - - name: Set up Python 3.10 - uses: actions/setup-python@v4 - with: - python-version: '3.10' - - name: Install dependencies - run: | - make install_deps - - name: Configure AWS credentials from Test account - uses: aws-actions/configure-aws-credentials@v2 - with: - role-to-assume: arn:aws:iam::${{ secrets.AWS_ACCOUNT_ID }}:role/${{ secrets.ASSUMABLE_ROLE_NAME }} - aws-region: ${{ matrix.aws-region }} - - name: Functional Test - run: | - pytest -n 8 tests/functional + with: + # this allows to pick the branch from the PR + checkout-ref: ${{ github.event.pull_request.head.ref }} + # this allows to work on fork + checkout-repository: ${{ github.event.pull_request.head.repo.full_name }} + aws-region: ${{ matrix.aws-region }} - # TODO: this is a workaround for now, we should use the same job for PR and main branch + # workflow that is invoked when a push to main happens functional-tests-main: - name: Functional Test - main - # trigger push to main branch + name: Functional Tests - main if: github.event_name == 'push' && github.ref == 'refs/heads/main' - runs-on: ubuntu-latest + uses: ./.github/workflows/functional-tests-workflow.yml strategy: matrix: - aws-region: [ 'us-east-1', 'eu-west-1', 'eu-west-2', 'eu-central-1' ] + aws-region: ['us-east-1', 'eu-west-1', 'eu-west-2', 'eu-central-1'] permissions: id-token: write contents: read - env: - DBT_TEST_ATHENA_S3_STAGING_DIR: s3://dbt-athena-query-results-${{ matrix.aws-region }} - DBT_TEST_ATHENA_REGION_NAME: ${{ matrix.aws-region }} - steps: - - name: Checkout - uses: actions/checkout@v3 - - name: Set up Python 3.10 - uses: actions/setup-python@v4 - with: - python-version: '3.10' - - name: Install dependencies - run: | - make install_deps - - name: Configure AWS credentials from Test account - uses: aws-actions/configure-aws-credentials@v2 - with: - role-to-assume: arn:aws:iam::${{ secrets.AWS_ACCOUNT_ID }}:role/${{ secrets.ASSUMABLE_ROLE_NAME }} - aws-region: ${{ matrix.aws-region }} - - name: Functional Test - run: | - pytest -n 8 tests/functional + with: + checkout-ref: ${{ github.ref }} + checkout-repository: ${{ github.repository }} + aws-region: ${{ matrix.aws-region }} From 28ff0ad9cb7127e96884602cd7fef1e6d6b6c663 Mon Sep 17 00:00:00 2001 From: nicor88 <6278547+nicor88@users.noreply.github.com> Date: Mon, 22 May 2023 09:46:40 +0200 Subject: [PATCH 04/10] fix: functional tests add fetch secrets (#320) --- .github/workflows/functional-tests.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/functional-tests.yml b/.github/workflows/functional-tests.yml index 6ce86ba9..1f2fb8d2 100644 --- a/.github/workflows/functional-tests.yml +++ b/.github/workflows/functional-tests.yml @@ -13,6 +13,7 @@ jobs: name: Functional Tests - PR if: contains(github.event.pull_request.labels.*.name, 'enable-functional-tests') uses: ./.github/workflows/functional-tests-workflow.yml + secrets: inherit strategy: matrix: aws-region: ['us-east-1', 'eu-west-1', 'eu-west-2', 'eu-central-1'] @@ -31,6 +32,7 @@ jobs: name: Functional Tests - main if: github.event_name == 'push' && github.ref == 'refs/heads/main' uses: ./.github/workflows/functional-tests-workflow.yml + secrets: inherit strategy: matrix: aws-region: ['us-east-1', 'eu-west-1', 'eu-west-2', 'eu-central-1'] From ba1cffe44be2ff040660fee0eb1745a47b8d83f5 Mon Sep 17 00:00:00 2001 From: nicor88 <6278547+nicor88@users.noreply.github.com> Date: Mon, 22 May 2023 13:35:43 +0200 Subject: [PATCH 05/10] chore: trim regions for functional testing (#321) --- .github/workflows/functional-tests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/functional-tests.yml b/.github/workflows/functional-tests.yml index 1f2fb8d2..fe0b1194 100644 --- a/.github/workflows/functional-tests.yml +++ b/.github/workflows/functional-tests.yml @@ -16,7 +16,7 @@ jobs: secrets: inherit strategy: matrix: - aws-region: ['us-east-1', 'eu-west-1', 'eu-west-2', 'eu-central-1'] + aws-region: ['us-east-1', 'eu-central-1'] permissions: id-token: write contents: read @@ -35,7 +35,7 @@ jobs: secrets: inherit strategy: matrix: - aws-region: ['us-east-1', 'eu-west-1', 'eu-west-2', 'eu-central-1'] + aws-region: ['us-east-1', 'eu-central-1'] permissions: id-token: write contents: read From 9c72c3b0f733c8d8aecd09f5f083c3ffcae9f25a Mon Sep 17 00:00:00 2001 From: Kristina Dubina <65438582+krisstinkou@users.noreply.github.com> Date: Tue, 23 May 2023 18:30:31 +0300 Subject: [PATCH 06/10] fix: athena list schemas argument (#322) Co-authored-by: Krystsina Dubina --- dbt/include/athena/macros/adapters/metadata.sql | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dbt/include/athena/macros/adapters/metadata.sql b/dbt/include/athena/macros/adapters/metadata.sql index 7b118048..0ca7cdb8 100644 --- a/dbt/include/athena/macros/adapters/metadata.sql +++ b/dbt/include/athena/macros/adapters/metadata.sql @@ -4,7 +4,7 @@ {% macro athena__list_schemas(database) -%} - {{ return(adapter.list_schemas()) }} + {{ return(adapter.list_schemas(database)) }} {% endmacro %} From 163b3d5a71ec131d9e0a7c3888b6d23d05f51b02 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 29 May 2023 09:41:11 +0200 Subject: [PATCH 07/10] chore: Update pytest-cov requirement from ~=4.0 to ~=4.1 (#323) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- dev-requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index 07b55264..e1c9f308 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -8,7 +8,7 @@ moto~=4.1.10 pre-commit~=2.21 pyparsing~=3.0.9 pytest~=7.3 -pytest-cov~=4.0 +pytest-cov~=4.1 pytest-dotenv~=0.5 pytest-xdist~=3.3 pyupgrade~=3.3 From 90eecf1fcde2994caea9f6ead316bf54ad6af137 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 5 Jun 2023 08:54:23 +0200 Subject: [PATCH 08/10] chore: Update moto requirement from ~=4.1.10 to ~=4.1.11 (#326) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- dev-requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index e1c9f308..c5074ee1 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -4,7 +4,7 @@ dbt-tests-adapter~=1.5.0 flake8~=5.0 Flake8-pyproject~=1.2 isort~=5.11 -moto~=4.1.10 +moto~=4.1.11 pre-commit~=2.21 pyparsing~=3.0.9 pytest~=7.3 From a28d36545fc605e15495471ac99408f219954061 Mon Sep 17 00:00:00 2001 From: Julian Steger <108534789+juliansteger-sc@users.noreply.github.com> Date: Fri, 9 Jun 2023 13:40:28 +0200 Subject: [PATCH 09/10] fix: BatchDeletePartitions only accepts up to 25 partitions (#328) --- dbt/adapters/athena/impl.py | 37 ++++++++++++++++++++++-------------- dbt/adapters/athena/utils.py | 11 ++++++++++- tests/unit/test_adapter.py | 4 ++-- tests/unit/test_utils.py | 19 +++++++++++++++++- tests/unit/utils.py | 2 +- 5 files changed, 54 insertions(+), 19 deletions(-) diff --git a/dbt/adapters/athena/impl.py b/dbt/adapters/athena/impl.py index 88d01d50..af0f2cd3 100755 --- a/dbt/adapters/athena/impl.py +++ b/dbt/adapters/athena/impl.py @@ -34,7 +34,7 @@ get_table_type, ) from dbt.adapters.athena.s3 import S3DataNaming -from dbt.adapters.athena.utils import clean_sql_comment, get_catalog_id +from dbt.adapters.athena.utils import clean_sql_comment, get_catalog_id, get_chunks from dbt.adapters.base import ConstraintSupport, available from dbt.adapters.base.relation import BaseRelation, InformationSchema from dbt.adapters.sql import SQLAdapter @@ -46,6 +46,9 @@ class AthenaAdapter(SQLAdapter): + BATCH_CREATE_PARTITION_API_LIMIT = 100 + BATCH_DELETE_PARTITION_API_LIMIT = 25 + ConnectionManager = AthenaConnectionManager Relation = AthenaRelation @@ -522,21 +525,27 @@ def swap_table(self, src_relation: AthenaRelation, target_relation: AthenaRelati # if source table has partitions we need to delete and add partitions # it source table hasn't any partitions we need to delete target table partitions if target_table_partitions: - glue_client.batch_delete_partition( - DatabaseName=target_relation.schema, - TableName=target_relation.identifier, - PartitionsToDelete=[{"Values": i["Values"]} for i in target_table_partitions], - ) + for partition_batch in get_chunks(target_table_partitions, AthenaAdapter.BATCH_DELETE_PARTITION_API_LIMIT): + glue_client.batch_delete_partition( + DatabaseName=target_relation.schema, + TableName=target_relation.identifier, + PartitionsToDelete=[{"Values": partition["Values"]} for partition in partition_batch], + ) if src_table_partitions: - glue_client.batch_create_partition( - DatabaseName=target_relation.schema, - TableName=target_relation.identifier, - PartitionInputList=[ - {"Values": p["Values"], "StorageDescriptor": p["StorageDescriptor"], "Parameters": p["Parameters"]} - for p in src_table_partitions - ], - ) + for partition_batch in get_chunks(src_table_partitions, AthenaAdapter.BATCH_CREATE_PARTITION_API_LIMIT): + glue_client.batch_create_partition( + DatabaseName=target_relation.schema, + TableName=target_relation.identifier, + PartitionInputList=[ + { + "Values": partition["Values"], + "StorageDescriptor": partition["StorageDescriptor"], + "Parameters": partition["Parameters"], + } + for partition in partition_batch + ], + ) def _get_glue_table_versions_to_expire(self, relation: AthenaRelation, to_keep: int): """ diff --git a/dbt/adapters/athena/utils.py b/dbt/adapters/athena/utils.py index b0e9bbdf..ee3b20bc 100644 --- a/dbt/adapters/athena/utils.py +++ b/dbt/adapters/athena/utils.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Collection, List, Optional, TypeVar from mypy_boto3_athena.type_defs import DataCatalogTypeDef @@ -11,3 +11,12 @@ def clean_sql_comment(comment: str) -> str: def get_catalog_id(catalog: Optional[DataCatalogTypeDef]) -> Optional[str]: if catalog: return catalog["Parameters"]["catalog-id"] + + +T = TypeVar("T") + + +def get_chunks(lst: Collection[T], n: int) -> Collection[List[T]]: + """Yield successive n-sized chunks from lst.""" + for i in range(0, len(lst), n): + yield lst[i : i + n] diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index 55397f10..ad66d536 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -728,7 +728,7 @@ def test_swap_table_with_partitions(self, mock_aws_service): mock_aws_service.create_table(source_table) mock_aws_service.add_partitions_to_table(DATABASE_NAME, source_table) mock_aws_service.create_table(target_table) - mock_aws_service.add_partitions_to_table(DATABASE_NAME, source_table) + mock_aws_service.add_partitions_to_table(DATABASE_NAME, target_table) source_relation = self.adapter.Relation.create( database=DATA_CATALOG_NAME, schema=DATABASE_NAME, @@ -836,7 +836,7 @@ def test_swap_table_with_no_partitions_to_one_with(self, mock_aws_service): ).get("Partitions") assert self.adapter.get_glue_table_location(target_relation) == f"s3://{BUCKET}/tables/{source_table}" - assert len(target_table_partitions_after) == 3 + assert len(target_table_partitions_after) == 26 @mock_athena @mock_glue diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index b72d9311..23851360 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1,4 +1,4 @@ -from dbt.adapters.athena.utils import clean_sql_comment +from dbt.adapters.athena.utils import clean_sql_comment, get_chunks def test_clean_comment(): @@ -12,3 +12,20 @@ def test_clean_comment(): ) == "my long comment on several lines with weird spaces and indents." ) + + +def test_get_chunks_empty(): + assert len(list(get_chunks([], 5))) == 0 + + +def test_get_chunks_uneven(): + chunks = list(get_chunks([1, 2, 3], 2)) + assert chunks[0] == [1, 2] + assert chunks[1] == [3] + assert len(chunks) == 2 + + +def test_get_chunks_more_elements_than_chunk(): + chunks = list(get_chunks([1, 2, 3], 4)) + assert chunks[0] == [1, 2, 3] + assert len(chunks) == 1 diff --git a/tests/unit/utils.py b/tests/unit/utils.py index eeacd230..60ec02ed 100644 --- a/tests/unit/utils.py +++ b/tests/unit/utils.py @@ -450,7 +450,7 @@ def add_partitions_to_table(self, database, table_name): }, "Parameters": {"compressionType": "snappy", "classification": "parquet"}, } - for dt in ["2022-01-01", "2022-01-02", "2022-01-03"] + for dt in [f"2022-01-{day:02d}" for day in range(1, 27)] ] glue = boto3.client("glue", region_name=AWS_REGION) glue.batch_create_partition( From 12b4a0cd675cd3facbe3d4482e3a782e9a07f041 Mon Sep 17 00:00:00 2001 From: Serhii Dimchenko <39801237+svdimchenko@users.noreply.github.com> Date: Fri, 9 Jun 2023 14:00:25 +0200 Subject: [PATCH 10/10] feat: enable mypy pre-commit check (#329) --- .pre-commit-config.yaml | 13 ++++++ README.md | 1 + dbt/adapters/athena/column.py | 10 ++--- dbt/adapters/athena/connections.py | 28 ++++++------ dbt/adapters/athena/impl.py | 65 ++++++++++++++-------------- dbt/adapters/athena/lakeformation.py | 10 +++-- dbt/adapters/athena/relation.py | 15 +++---- dbt/adapters/athena/utils.py | 7 ++- setup.py | 5 ++- 9 files changed, 85 insertions(+), 69 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0381617f..9e88365e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -59,3 +59,16 @@ repos: - 'Flake8-pyproject~=1.1' args: - '.' + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.3.0 + hooks: + - id: mypy + args: + - --strict + - --ignore-missing-imports + - --install-types + - --allow-subclassing-any + - --allow-untyped-decorators + additional_dependencies: + - types-setuptools==67.8.0.0 + exclude: ^tests/ diff --git a/README.md b/README.md index a22b128e..c7aa35f9 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,7 @@ +

diff --git a/dbt/adapters/athena/column.py b/dbt/adapters/athena/column.py index e5dfdadd..75fc1bdc 100644 --- a/dbt/adapters/athena/column.py +++ b/dbt/adapters/athena/column.py @@ -37,19 +37,17 @@ def timestamp_type(self) -> str: def string_size(self) -> int: if not self.is_string(): raise DbtRuntimeError("Called string_size() on non-string field!") - if not self.char_size: - # Handle error: '>' not supported between instances of 'NoneType' and 'NoneType' for union relations macro - return 0 - return self.char_size + # Handle error: '>' not supported between instances of 'NoneType' and 'NoneType' for union relations macro + return self.char_size or 0 @property def data_type(self) -> str: if self.is_string(): return self.string_type(self.string_size()) elif self.is_numeric(): - return self.numeric_type(self.dtype, self.numeric_precision, self.numeric_scale) + return self.numeric_type(self.dtype, self.numeric_precision, self.numeric_scale) # type: ignore elif self.is_binary(): return self.binary_type() elif self.is_timestamp(): return self.timestamp_type() - return self.dtype + return self.dtype # type: ignore diff --git a/dbt/adapters/athena/connections.py b/dbt/adapters/athena/connections.py index ba622eb1..4bca6184 100644 --- a/dbt/adapters/athena/connections.py +++ b/dbt/adapters/athena/connections.py @@ -4,7 +4,7 @@ from copy import deepcopy from dataclasses import dataclass from decimal import Decimal -from typing import Any, ContextManager, Dict, List, Optional, Tuple, Union +from typing import Any, ContextManager, Dict, List, Optional, Tuple import tenacity from pyathena.connection import Connection as AthenaConnection @@ -56,7 +56,7 @@ def type(self) -> str: return "athena" @property - def unique_field(self): + def unique_field(self) -> str: return f"athena-{hashlib.md5(self.s3_staging_dir.encode()).hexdigest()}" def _connection_keys(self) -> Tuple[str, ...]: @@ -78,7 +78,7 @@ def _connection_keys(self) -> Tuple[str, ...]: class AthenaCursor(Cursor): - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: # type: ignore super().__init__(**kwargs) self._executor = ThreadPoolExecutor() @@ -92,7 +92,7 @@ def _collect_result_set(self, query_id: str) -> AthenaResultSet: retry_config=self._retry_config, ) - def execute( + def execute( # type: ignore self, operation: str, parameters: Optional[Dict[str, Any]] = None, @@ -103,7 +103,7 @@ def execute( cache_expiration_time: int = 0, **kwargs, ): - def inner(): + def inner() -> AthenaCursor: query_id = self._execute( operation, parameters=parameters, @@ -143,7 +143,7 @@ class AthenaConnectionManager(SQLConnectionManager): TYPE = "athena" @classmethod - def data_type_code_to_name(cls, type_code: Union[int, str]) -> str: + def data_type_code_to_name(cls, type_code: str) -> str: """ Get the string representation of the data type from the Athena metadata. Dbt performs a query to retrieve the types of the columns in the SQL query. Then these types are compared @@ -152,8 +152,8 @@ def data_type_code_to_name(cls, type_code: Union[int, str]) -> str: """ return type_code.split("(")[0].split("<")[0].upper() - @contextmanager - def exception_handler(self, sql: str) -> ContextManager: + @contextmanager # type: ignore + def exception_handler(self, sql: str) -> ContextManager: # type: ignore try: yield except Exception as e: @@ -201,23 +201,23 @@ def open(cls, connection: Connection) -> Connection: return connection @classmethod - def get_response(cls, cursor) -> AdapterResponse: + def get_response(cls, cursor: AthenaCursor) -> AdapterResponse: code = "OK" if cursor.state == AthenaQueryExecution.STATE_SUCCEEDED else "ERROR" return AdapterResponse(_message=f"{code} {cursor.rowcount}", rows_affected=cursor.rowcount, code=code) - def cancel(self, connection: Connection): + def cancel(self, connection: Connection) -> None: connection.handle.cancel() - def add_begin_query(self): + def add_begin_query(self) -> None: pass - def add_commit_query(self): + def add_commit_query(self) -> None: pass - def begin(self): + def begin(self) -> None: pass - def commit(self): + def commit(self) -> None: pass diff --git a/dbt/adapters/athena/impl.py b/dbt/adapters/athena/impl.py index af0f2cd3..f231498e 100755 --- a/dbt/adapters/athena/impl.py +++ b/dbt/adapters/athena/impl.py @@ -12,6 +12,7 @@ import agate from botocore.exceptions import ClientError from mypy_boto3_athena.type_defs import DataCatalogTypeDef +from mypy_boto3_glue.type_defs import ColumnTypeDef, TableTypeDef, TableVersionTypeDef from dbt.adapters.athena import AthenaConnectionManager from dbt.adapters.athena.column import AthenaColumn @@ -101,7 +102,7 @@ def apply_lf_grants(self, relation: AthenaRelation, lf_grants_config: Dict[str, lf = client.session.client("lakeformation", region_name=client.region_name, config=get_boto3_config()) catalog = self._get_data_catalog(relation.database) catalog_id = get_catalog_id(catalog) - lf_permissions = LfPermissions(catalog_id, relation, lf) + lf_permissions = LfPermissions(catalog_id, relation, lf) # type: ignore lf_permissions.process_filters(lf_config) lf_permissions.process_permissions(lf_config) @@ -146,7 +147,7 @@ def _s3_table_prefix(self, s3_data_dir: Optional[str]) -> str: return path.join(creds.s3_staging_dir, "tables") - def _s3_data_naming(self, s3_data_naming: Optional[str]) -> Optional[S3DataNaming]: + def _s3_data_naming(self, s3_data_naming: Optional[str]) -> S3DataNaming: """ Returns the s3 data naming strategy if provided, otherwise the value from the connection. """ @@ -175,7 +176,6 @@ def generate_s3_location( s3_path_table_part = relation.s3_path_table_part or relation.identifier schema_name = relation.schema - s3_data_naming = self._s3_data_naming(s3_data_naming) table_prefix = self._s3_table_prefix(s3_data_dir) mapping = { @@ -186,7 +186,7 @@ def generate_s3_location( S3DataNaming.SCHEMA_TABLE_UNIQUE: path.join(table_prefix, schema_name, s3_path_table_part, str(uuid4())), } - return mapping[s3_data_naming] + return mapping[self._s3_data_naming(s3_data_naming)] @available def get_glue_table_location(self, relation: AthenaRelation) -> Optional[str]: @@ -216,12 +216,11 @@ def get_glue_table_location(self, relation: AthenaRelation) -> Optional[str]: f"but no location returned by Glue." ) LOGGER.debug(f"{relation.render()} is stored in {table_location}") - return table_location - + return str(table_location) return None @available - def clean_up_partitions(self, relation: AthenaRelation, where_condition: str): + def clean_up_partitions(self, relation: AthenaRelation, where_condition: str) -> None: conn = self.connections.get_thread_connection() client = conn.handle @@ -240,17 +239,17 @@ def clean_up_partitions(self, relation: AthenaRelation, where_condition: str): self.delete_from_s3(partition["StorageDescriptor"]["Location"]) @available - def clean_up_table(self, relation: AthenaRelation): + def clean_up_table(self, relation: AthenaRelation) -> None: table_location = self.get_glue_table_location(relation) - # this check avoid issues for when the table location is an empty string - # or when the table do not exist and table location is None + # this check avoids issues for when the table location is an empty string + # or when the table does not exist and table location is None if table_location: self.delete_from_s3(table_location) @available def quote_seed_column(self, column: str, quote_config: Optional[bool]) -> str: - return super().quote_seed_column(column, False) + return str(super().quote_seed_column(column, False)) @available def upload_seed_to_s3( @@ -281,10 +280,10 @@ def upload_seed_to_s3( s3_client.upload_file(tmpfile, bucket, object_name) os.remove(tmpfile) - return s3_location + return str(s3_location) @available - def delete_from_s3(self, s3_path: str): + def delete_from_s3(self, s3_path: str) -> None: """ Deletes files from s3 given a s3 path in the format: s3://my_bucket/prefix Additionally, parses the response from the s3 delete request and raises @@ -359,7 +358,7 @@ def _join_catalog_table_owners(self, table: agate.Table, manifest: Manifest) -> right_key=join_keys, ) - def _get_one_table_for_catalog(self, table: dict, database: str) -> list: + def _get_one_table_for_catalog(self, table: TableTypeDef, database: str) -> List[Dict[str, Any]]: table_catalog = { "table_database": database, "table_schema": table["DatabaseName"], @@ -406,7 +405,7 @@ def _get_one_catalog( for page in paginator.paginate(**kwargs): for table in page["TableList"]: - if table["Name"] in relations: + if relations and table["Name"] in relations: catalog.extend(self._get_one_table_for_catalog(table, information_schema.path.database)) table = agate.Table.from_object(catalog) @@ -433,17 +432,17 @@ def _get_data_catalog(self, database: str) -> Optional[DataCatalogTypeDef]: sts = client.session.client("sts", region_name=client.region_name, config=get_boto3_config()) catalog_id = sts.get_caller_identity()["Account"] return {"Name": database, "Type": "GLUE", "Parameters": {"catalog-id": catalog_id}} - else: - with boto3_client_lock: - athena = client.session.client("athena", region_name=client.region_name, config=get_boto3_config()) - return athena.get_data_catalog(Name=database)["DataCatalog"] + with boto3_client_lock: + athena = client.session.client("athena", region_name=client.region_name, config=get_boto3_config()) + return athena.get_data_catalog(Name=database)["DataCatalog"] + return None def list_relations_without_caching(self, schema_relation: AthenaRelation) -> List[BaseRelation]: data_catalog = self._get_data_catalog(schema_relation.database) catalog_id = get_catalog_id(data_catalog) if data_catalog and data_catalog["Type"] != "GLUE": # For non-Glue Data Catalogs, use the original Athena query against INFORMATION_SCHEMA approach - return super().list_relations_without_caching(schema_relation) + return super().list_relations_without_caching(schema_relation) # type: ignore conn = self.connections.get_thread_connection() client = conn.handle @@ -492,7 +491,7 @@ def list_relations_without_caching(self, schema_relation: AthenaRelation) -> Lis return relations @available - def swap_table(self, src_relation: AthenaRelation, target_relation: AthenaRelation): + def swap_table(self, src_relation: AthenaRelation, target_relation: AthenaRelation) -> None: conn = self.connections.get_thread_connection() client = conn.handle @@ -547,7 +546,7 @@ def swap_table(self, src_relation: AthenaRelation, target_relation: AthenaRelati ], ) - def _get_glue_table_versions_to_expire(self, relation: AthenaRelation, to_keep: int): + def _get_glue_table_versions_to_expire(self, relation: AthenaRelation, to_keep: int) -> List[TableVersionTypeDef]: """ Given a table and the amount of its version to keep, it returns the versions to delete """ @@ -570,7 +569,9 @@ def _get_glue_table_versions_to_expire(self, relation: AthenaRelation, to_keep: return table_versions_ordered[int(to_keep) :] @available - def expire_glue_table_versions(self, relation: AthenaRelation, to_keep: int, delete_s3: bool): + def expire_glue_table_versions( + self, relation: AthenaRelation, to_keep: int, delete_s3: bool + ) -> List[TableVersionTypeDef]: conn = self.connections.get_thread_connection() client = conn.handle @@ -606,7 +607,7 @@ def persist_docs_to_glue( model: Dict[str, Any], persist_relation_docs: bool = False, persist_column_docs: bool = False, - ): + ) -> None: conn = self.connections.get_thread_connection() client = conn.handle @@ -651,9 +652,9 @@ def list_schemas(self, database: str) -> List[str]: return result @staticmethod - def _is_current_column(col: dict) -> bool: + def _is_current_column(col: ColumnTypeDef) -> bool: """ - Check if a column is explicit set as not current. If not, it is considered as current. + Check if a column is explicitly set as not current. If not, it is considered as current. """ if col.get("Parameters", {}).get("iceberg.field.current") == "false": return False @@ -689,7 +690,7 @@ def get_columns_in_relation(self, relation: AthenaRelation) -> List[AthenaColumn ] @available - def delete_from_glue_catalog(self, relation: AthenaRelation): + def delete_from_glue_catalog(self, relation: AthenaRelation) -> None: schema_name = relation.schema table_name = relation.identifier @@ -721,16 +722,16 @@ def valid_snapshot_target(self, relation: BaseRelation) -> None: if "dbt_unique_key" in names: sql = self._generate_snapshot_migration_sql(relation=relation, table_columns=table_columns) msg = ( - f"{'!'*90}\n" + f"{'!' * 90}\n" "The snapshot logic of dbt-athena has changed in an incompatible way to be more consistent " "with the dbt-core implementation.\nYou will need to migrate your existing snapshot tables to be " "able to keep using them with the latest dbt-athena version.\nYou can find more information " "in the release notes:\nhttps://github.com/dbt-athena/dbt-athena/releases\n" - f"{'!'*90}\n\n" + f"{'!' * 90}\n\n" "You can use the example query below as a baseline to perform the migration:\n\n" - f"{'-'*90}\n" + f"{'-' * 90}\n" f"{sql}\n" - f"{'-'*90}\n\n" + f"{'-' * 90}\n\n" ) LOGGER.error(msg) raise SnapshotMigrationRequired("Look into 1.5 dbt-athena docs for the complete migration procedure") @@ -745,7 +746,7 @@ def _generate_snapshot_migration_sql(self, relation: AthenaRelation, table_colum - Copy the content of the staging table to the final table - Delete the staging table """ - col_csv = f",\n{' '*16}".join(table_columns) + col_csv = f",\n{' ' * 16}".join(table_columns) staging_relation = relation.incorporate( path={"identifier": relation.identifier + "__dbt_tmp_migration_staging"} ) diff --git a/dbt/adapters/athena/lakeformation.py b/dbt/adapters/athena/lakeformation.py index 043b26d3..cd29e7e6 100644 --- a/dbt/adapters/athena/lakeformation.py +++ b/dbt/adapters/athena/lakeformation.py @@ -74,7 +74,11 @@ def _apply_lf_tags_table( logger.debug(f"EXISTING TABLE TAGS: {lf_tags_table}") logger.debug(f"CONFIG TAGS: {self.lf_tags}") - to_remove = {tag["TagKey"]: tag["TagValues"] for tag in lf_tags_table if tag["TagKey"] not in self.lf_tags} + to_remove = { + tag["TagKey"]: tag["TagValues"] + for tag in lf_tags_table + if tag["TagKey"] not in self.lf_tags # type: ignore + } logger.debug(f"TAGS TO REMOVE: {to_remove}") if to_remove: response = self.lf_client.remove_lf_tags_from_resource( @@ -105,7 +109,7 @@ def _parse_lf_response( self, response: Union[AddLFTagsToResourceResponseTypeDef, RemoveLFTagsFromResourceResponseTypeDef], columns: Optional[List[str]] = None, - lf_tags: Dict[str, str] = None, + lf_tags: Optional[Dict[str, str]] = None, verb: str = "add", ) -> str: failures = response.get("Failures", []) @@ -195,7 +199,7 @@ def process_filters(self, config: LfGrantsConfig) -> None: for f in to_update: self.lf_client.update_data_cells_filter(TableData=f) - def process_permissions(self, config: LfGrantsConfig): + def process_permissions(self, config: LfGrantsConfig) -> None: for name, f in config.data_cell_filters.filters.items(): logger.debug(f"Start processing permissions for filter: {name}") current_permissions = self.lf_client.list_permissions( diff --git a/dbt/adapters/athena/relation.py b/dbt/adapters/athena/relation.py index 2d396ff9..bb07393f 100644 --- a/dbt/adapters/athena/relation.py +++ b/dbt/adapters/athena/relation.py @@ -32,9 +32,9 @@ class AthenaRelation(BaseRelation): include_policy: Policy = field(default_factory=lambda: AthenaIncludePolicy()) s3_path_table_part: Optional[str] = None - def render_hive(self): + def render_hive(self) -> str: """ - Render relation with Hive format. Athena uses Hive format for some DDL statements. + Render relation with Hive format. Athena uses a Hive format for some DDL statements. See: - https://aws.amazon.com/athena/faqs/ "Q: How do I create tables and schemas for my data on Amazon S3?" @@ -45,9 +45,9 @@ def render_hive(self): object.__setattr__(self, "quote_character", "`") # Hive quote char rendered = self.render() object.__setattr__(self, "quote_character", old_value) - return rendered + return str(rendered) - def render_pure(self): + def render_pure(self) -> str: """ Render relation without quotes characters. This is needed for not standard executions like optimize and vacuum @@ -56,20 +56,19 @@ def render_pure(self): object.__setattr__(self, "quote_character", "") rendered = self.render() object.__setattr__(self, "quote_character", old_value) - return rendered + return str(rendered) class AthenaSchemaSearchMap(Dict[InformationSchema, Dict[str, Set[Optional[str]]]]): """A utility class to keep track of what information_schema tables to search for what schemas and relations. The schema and relation values are all - lowercased to avoid duplication. + lowercase to avoid duplication. """ - def add(self, relation: AthenaRelation): + def add(self, relation: AthenaRelation) -> None: key = relation.information_schema_only() if key not in self: self[key] = {} - schema: Optional[str] = None if relation.schema is not None: schema = relation.schema.lower() relation_name = relation.name.lower() diff --git a/dbt/adapters/athena/utils.py b/dbt/adapters/athena/utils.py index ee3b20bc..778fb4c2 100644 --- a/dbt/adapters/athena/utils.py +++ b/dbt/adapters/athena/utils.py @@ -1,4 +1,4 @@ -from typing import Collection, List, Optional, TypeVar +from typing import Generator, List, Optional, TypeVar from mypy_boto3_athena.type_defs import DataCatalogTypeDef @@ -9,14 +9,13 @@ def clean_sql_comment(comment: str) -> str: def get_catalog_id(catalog: Optional[DataCatalogTypeDef]) -> Optional[str]: - if catalog: - return catalog["Parameters"]["catalog-id"] + return catalog["Parameters"]["catalog-id"] if catalog else None T = TypeVar("T") -def get_chunks(lst: Collection[T], n: int) -> Collection[List[T]]: +def get_chunks(lst: List[T], n: int) -> Generator[List[T], None, None]: """Yield successive n-sized chunks from lst.""" for i in range(0, len(lst), n): yield lst[i : i + n] diff --git a/setup.py b/setup.py index b4ff1191..dae9a695 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,7 @@ #!/usr/bin/env python import os import re +from typing import Any, Dict from setuptools import find_namespace_packages, setup @@ -13,7 +14,7 @@ # get version from a separate file -def _get_plugin_version_dict(): +def _get_plugin_version_dict() -> Dict[str, Any]: _version_path = os.path.join(this_directory, "dbt", "adapters", "athena", "__version__.py") _semver = r"""(?P\d+)\.(?P\d+)\.(?P\d+)""" _pre = r"""((?Pa|b|rc)(?P
\d+))?"""
@@ -25,7 +26,7 @@ def _get_plugin_version_dict():
         return match.groupdict()
 
 
-def _get_package_version():
+def _get_package_version() -> str:
     parts = _get_plugin_version_dict()
     return f'{parts["major"]}.{parts["minor"]}.{parts["patch"]}'