diff --git a/cosmos/operators/eks.py b/cosmos/operators/eks.py index 5165aa067..4e3e8cc97 100644 --- a/cosmos/operators/eks.py +++ b/cosmos/operators/eks.py @@ -1,12 +1,21 @@ -from typing import Sequence +from __future__ import annotations + +from typing import Any, Sequence from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.eks import EksHook from airflow.utils.context import Context -from cosmos.operators.kubernetes import DbtKubernetesBaseOperator, DbtTestKubernetesOperator, \ - DbtBuildKubernetesOperator, DbtRunOperationKubernetesOperator, DbtRunKubernetesOperator, \ - DbtSnapshotKubernetesOperator, DbtSeedKubernetesOperator, DbtLSKubernetesOperator +from cosmos.operators.kubernetes import ( + DbtBuildKubernetesOperator, + DbtKubernetesBaseOperator, + DbtLSKubernetesOperator, + DbtRunKubernetesOperator, + DbtRunOperationKubernetesOperator, + DbtSeedKubernetesOperator, + DbtSnapshotKubernetesOperator, + DbtTestKubernetesOperator, +) DEFAULT_CONN_ID = "aws_default" DEFAULT_NAMESPACE = "default" @@ -21,17 +30,18 @@ class DbtEksBaseOperator(DbtKubernetesBaseOperator): "pod_name", "aws_conn_id", "region", - } | set(DbtKubernetesBaseOperator.template_fields) + } + | set(DbtKubernetesBaseOperator.template_fields) ) def __init__( - self, - cluster_name: str, - pod_name: str | None = None, - namespace: str | None = DEFAULT_NAMESPACE, - aws_conn_id: str = DEFAULT_CONN_ID, - region: str | None = None, - **kwargs, + self, + cluster_name: str, + pod_name: str | None = None, + namespace: str | None = DEFAULT_NAMESPACE, + aws_conn_id: str = DEFAULT_CONN_ID, + region: str | None = None, + **kwargs: Any, ) -> None: self.cluster_name = cluster_name self.pod_name = pod_name @@ -48,13 +58,13 @@ def __init__( if self.config_file: raise AirflowException("The config_file is not an allowed parameter for the EksPodOperator.") - def execute(self, context: Context): + def execute(self, context: Context) -> Any | None: # type: ignore eks_hook = EksHook( aws_conn_id=self.aws_conn_id, region_name=self.region, ) with eks_hook.generate_config_file( - eks_cluster_name=self.cluster_name, pod_namespace=self.namespace + eks_cluster_name=self.cluster_name, pod_namespace=self.namespace ) as self.config_file: return super().execute(context) @@ -64,8 +74,9 @@ class DbtBuildEksOperator(DbtEksBaseOperator, DbtBuildKubernetesOperator): Executes a dbt core build command. """ - template_fields: Sequence[ - str] = DbtEksBaseOperator.template_fields + DbtBuildKubernetesOperator.template_fields # type: ignore[operator] + template_fields: Sequence[str] = ( + DbtEksBaseOperator.template_fields + DbtBuildKubernetesOperator.template_fields # type: ignore[operator] + ) class DbtLSEksOperator(DbtEksBaseOperator, DbtLSKubernetesOperator): @@ -79,8 +90,9 @@ class DbtSeedEksOperator(DbtEksBaseOperator, DbtSeedKubernetesOperator): Executes a dbt core seed command. """ - template_fields: Sequence[ - str] = DbtEksBaseOperator.template_fields + DbtSeedKubernetesOperator.template_fields # type: ignore[operator] + template_fields: Sequence[str] = ( + DbtEksBaseOperator.template_fields + DbtSeedKubernetesOperator.template_fields # type: ignore[operator] + ) class DbtSnapshotEksOperator(DbtEksBaseOperator, DbtSnapshotKubernetesOperator): @@ -94,16 +106,19 @@ class DbtRunEksOperator(DbtEksBaseOperator, DbtRunKubernetesOperator): Executes a dbt core run command. """ - template_fields: Sequence[ - str] = DbtEksBaseOperator.template_fields + DbtRunKubernetesOperator.template_fields # type: ignore[operator] + template_fields: Sequence[str] = ( + DbtEksBaseOperator.template_fields + DbtRunKubernetesOperator.template_fields # type: ignore[operator] + ) class DbtTestEksOperator(DbtEksBaseOperator, DbtTestKubernetesOperator): """ - Executes a dbt core test command. - """ - template_fields: Sequence[ - str] = DbtEksBaseOperator.template_fields + DbtTestKubernetesOperator.template_fields # type: ignore[operator] + Executes a dbt core test command. + """ + + template_fields: Sequence[str] = ( + DbtEksBaseOperator.template_fields + DbtTestKubernetesOperator.template_fields # type: ignore[operator] + ) class DbtRunOperationEksOperator(DbtEksBaseOperator, DbtRunOperationKubernetesOperator): @@ -111,5 +126,6 @@ class DbtRunOperationEksOperator(DbtEksBaseOperator, DbtRunOperationKubernetesOp Executes a dbt core run-operation command. """ - template_fields: Sequence[ - str] = DbtEksBaseOperator.template_fields + DbtRunOperationKubernetesOperator.template_fields # type: ignore[operator] + template_fields: Sequence[str] = ( + DbtEksBaseOperator.template_fields + DbtRunOperationKubernetesOperator.template_fields # type: ignore[operator] + ) diff --git a/docs/getting_started/execution-modes.rst b/docs/getting_started/execution-modes.rst index 7549c82db..9588bead2 100644 --- a/docs/getting_started/execution-modes.rst +++ b/docs/getting_started/execution-modes.rst @@ -168,7 +168,7 @@ EKS ---------- The ``eks`` approach is very similar to the ``kubernetes`` approach, but it is specifically designed to run on AWS EKS clusters. -It uses the [EKSPodOperator](https://airflow.apache.org/docs/apache-airflow-providers-amazon/2.2.0/operators/eks.html#perform-a-task-on-an-amazon-eks-cluster) +It uses the `EKSPodOperator `_ to run the dbt commands. You need to provide the ``cluster_name`` in your operator_args to connect to the EKS cluster. diff --git a/tests/operators/test_eks.py b/tests/operators/test_eks.py index 0809724d6..842563348 100644 --- a/tests/operators/test_eks.py +++ b/tests/operators/test_eks.py @@ -2,8 +2,13 @@ import pytest -from cosmos.operators.eks import DbtLSEksOperator, DbtSeedEksOperator, DbtBuildEksOperator, DbtTestEksOperator, \ - DbtRunEksOperator +from cosmos.operators.eks import ( + DbtBuildEksOperator, + DbtLSEksOperator, + DbtRunEksOperator, + DbtSeedEksOperator, + DbtTestEksOperator, +) @pytest.fixture() @@ -26,14 +31,17 @@ def mock_kubernetes_execute(): } -@pytest.mark.parametrize("command_name,command_operator", [ - ("ls", DbtLSEksOperator(**base_kwargs)), - ("run", DbtRunEksOperator(**base_kwargs)), - ("test", DbtTestEksOperator(**base_kwargs)), - ("build", DbtBuildEksOperator(**base_kwargs)), - ("seed", DbtSeedEksOperator(**base_kwargs)), -]) -def test_dbt_kubernetes_build_command(command_name, command_operator ): +@pytest.mark.parametrize( + "command_name,command_operator", + [ + ("ls", DbtLSEksOperator(**base_kwargs)), + ("run", DbtRunEksOperator(**base_kwargs)), + ("test", DbtTestEksOperator(**base_kwargs)), + ("build", DbtBuildEksOperator(**base_kwargs)), + ("seed", DbtSeedEksOperator(**base_kwargs)), + ], +) +def test_dbt_kubernetes_build_command(command_name, command_operator): """ Since we know that the KubernetesOperator is tested, we can just test that the command is built correctly and added to the "arguments" parameter. @@ -67,7 +75,7 @@ def test_dbt_kubernetes_operator_execute(mock_generate_config_file, mock_build_k mock_build_kube_args.assert_called_once() # Assert that the generate_config_file method was called in the execution to create the kubeconfig for eks - mock_generate_config_file.assert_called_once_with(eks_cluster_name='my-cluster', pod_namespace='default') + mock_generate_config_file.assert_called_once_with(eks_cluster_name="my-cluster", pod_namespace="default") # Assert that the kubernetes execute method was called in the execution mock_kubernetes_execute.assert_called_once()