Skip to content

Commit

Permalink
Fix types & rst link
Browse files Browse the repository at this point in the history
  • Loading branch information
VolkerSchiewe committed May 7, 2024
1 parent 6fda2db commit c42f5ad
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 38 deletions.
68 changes: 42 additions & 26 deletions cosmos/operators/eks.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
from typing import Sequence
from __future__ import annotations

from typing import Any, Sequence

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.eks import EksHook
from airflow.utils.context import Context

from cosmos.operators.kubernetes import DbtKubernetesBaseOperator, DbtTestKubernetesOperator, \
DbtBuildKubernetesOperator, DbtRunOperationKubernetesOperator, DbtRunKubernetesOperator, \
DbtSnapshotKubernetesOperator, DbtSeedKubernetesOperator, DbtLSKubernetesOperator
from cosmos.operators.kubernetes import (
DbtBuildKubernetesOperator,
DbtKubernetesBaseOperator,
DbtLSKubernetesOperator,
DbtRunKubernetesOperator,
DbtRunOperationKubernetesOperator,
DbtSeedKubernetesOperator,
DbtSnapshotKubernetesOperator,
DbtTestKubernetesOperator,
)

DEFAULT_CONN_ID = "aws_default"
DEFAULT_NAMESPACE = "default"
Expand All @@ -21,17 +30,18 @@ class DbtEksBaseOperator(DbtKubernetesBaseOperator):
"pod_name",
"aws_conn_id",
"region",
} | set(DbtKubernetesBaseOperator.template_fields)
}
| set(DbtKubernetesBaseOperator.template_fields)
)

def __init__(
self,
cluster_name: str,
pod_name: str | None = None,
namespace: str | None = DEFAULT_NAMESPACE,
aws_conn_id: str = DEFAULT_CONN_ID,
region: str | None = None,
**kwargs,
self,
cluster_name: str,
pod_name: str | None = None,
namespace: str | None = DEFAULT_NAMESPACE,
aws_conn_id: str = DEFAULT_CONN_ID,
region: str | None = None,
**kwargs: Any,
) -> None:
self.cluster_name = cluster_name
self.pod_name = pod_name
Expand All @@ -48,13 +58,13 @@ def __init__(
if self.config_file:
raise AirflowException("The config_file is not an allowed parameter for the EksPodOperator.")

def execute(self, context: Context):
def execute(self, context: Context) -> Any | None: # type: ignore
eks_hook = EksHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region,
)
with eks_hook.generate_config_file(
eks_cluster_name=self.cluster_name, pod_namespace=self.namespace
eks_cluster_name=self.cluster_name, pod_namespace=self.namespace
) as self.config_file:
return super().execute(context)

Expand All @@ -64,8 +74,9 @@ class DbtBuildEksOperator(DbtEksBaseOperator, DbtBuildKubernetesOperator):
Executes a dbt core build command.
"""

template_fields: Sequence[
str] = DbtEksBaseOperator.template_fields + DbtBuildKubernetesOperator.template_fields # type: ignore[operator]
template_fields: Sequence[str] = (
DbtEksBaseOperator.template_fields + DbtBuildKubernetesOperator.template_fields # type: ignore[operator]
)


class DbtLSEksOperator(DbtEksBaseOperator, DbtLSKubernetesOperator):
Expand All @@ -79,8 +90,9 @@ class DbtSeedEksOperator(DbtEksBaseOperator, DbtSeedKubernetesOperator):
Executes a dbt core seed command.
"""

template_fields: Sequence[
str] = DbtEksBaseOperator.template_fields + DbtSeedKubernetesOperator.template_fields # type: ignore[operator]
template_fields: Sequence[str] = (
DbtEksBaseOperator.template_fields + DbtSeedKubernetesOperator.template_fields # type: ignore[operator]
)


class DbtSnapshotEksOperator(DbtEksBaseOperator, DbtSnapshotKubernetesOperator):
Expand All @@ -94,22 +106,26 @@ class DbtRunEksOperator(DbtEksBaseOperator, DbtRunKubernetesOperator):
Executes a dbt core run command.
"""

template_fields: Sequence[
str] = DbtEksBaseOperator.template_fields + DbtRunKubernetesOperator.template_fields # type: ignore[operator]
template_fields: Sequence[str] = (
DbtEksBaseOperator.template_fields + DbtRunKubernetesOperator.template_fields # type: ignore[operator]
)


class DbtTestEksOperator(DbtEksBaseOperator, DbtTestKubernetesOperator):
"""
Executes a dbt core test command.
"""
template_fields: Sequence[
str] = DbtEksBaseOperator.template_fields + DbtTestKubernetesOperator.template_fields # type: ignore[operator]
Executes a dbt core test command.
"""

template_fields: Sequence[str] = (
DbtEksBaseOperator.template_fields + DbtTestKubernetesOperator.template_fields # type: ignore[operator]
)


class DbtRunOperationEksOperator(DbtEksBaseOperator, DbtRunOperationKubernetesOperator):
"""
Executes a dbt core run-operation command.
"""

template_fields: Sequence[
str] = DbtEksBaseOperator.template_fields + DbtRunOperationKubernetesOperator.template_fields # type: ignore[operator]
template_fields: Sequence[str] = (
DbtEksBaseOperator.template_fields + DbtRunOperationKubernetesOperator.template_fields # type: ignore[operator]
)
2 changes: 1 addition & 1 deletion docs/getting_started/execution-modes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ EKS
----------

The ``eks`` approach is very similar to the ``kubernetes`` approach, but it is specifically designed to run on AWS EKS clusters.
It uses the [EKSPodOperator](https://airflow.apache.org/docs/apache-airflow-providers-amazon/2.2.0/operators/eks.html#perform-a-task-on-an-amazon-eks-cluster)
It uses the `EKSPodOperator <https://airflow.apache.org/docs/apache-airflow-providers-amazon/2.2.0/operators/eks.html#perform-a-task-on-an-amazon-eks-cluster>`_
to run the dbt commands. You need to provide the ``cluster_name`` in your operator_args to connect to the EKS cluster.


Expand Down
30 changes: 19 additions & 11 deletions tests/operators/test_eks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,13 @@

import pytest

from cosmos.operators.eks import DbtLSEksOperator, DbtSeedEksOperator, DbtBuildEksOperator, DbtTestEksOperator, \
DbtRunEksOperator
from cosmos.operators.eks import (
DbtBuildEksOperator,
DbtLSEksOperator,
DbtRunEksOperator,
DbtSeedEksOperator,
DbtTestEksOperator,
)


@pytest.fixture()
Expand All @@ -26,14 +31,17 @@ def mock_kubernetes_execute():
}


@pytest.mark.parametrize("command_name,command_operator", [
("ls", DbtLSEksOperator(**base_kwargs)),
("run", DbtRunEksOperator(**base_kwargs)),
("test", DbtTestEksOperator(**base_kwargs)),
("build", DbtBuildEksOperator(**base_kwargs)),
("seed", DbtSeedEksOperator(**base_kwargs)),
])
def test_dbt_kubernetes_build_command(command_name, command_operator ):
@pytest.mark.parametrize(
"command_name,command_operator",
[
("ls", DbtLSEksOperator(**base_kwargs)),
("run", DbtRunEksOperator(**base_kwargs)),
("test", DbtTestEksOperator(**base_kwargs)),
("build", DbtBuildEksOperator(**base_kwargs)),
("seed", DbtSeedEksOperator(**base_kwargs)),
],
)
def test_dbt_kubernetes_build_command(command_name, command_operator):
"""
Since we know that the KubernetesOperator is tested, we can just test that the
command is built correctly and added to the "arguments" parameter.
Expand Down Expand Up @@ -67,7 +75,7 @@ def test_dbt_kubernetes_operator_execute(mock_generate_config_file, mock_build_k
mock_build_kube_args.assert_called_once()

# Assert that the generate_config_file method was called in the execution to create the kubeconfig for eks
mock_generate_config_file.assert_called_once_with(eks_cluster_name='my-cluster', pod_namespace='default')
mock_generate_config_file.assert_called_once_with(eks_cluster_name="my-cluster", pod_namespace="default")

# Assert that the kubernetes execute method was called in the execution
mock_kubernetes_execute.assert_called_once()
Expand Down

0 comments on commit c42f5ad

Please sign in to comment.