Skip to content

Commit

Permalink
Merge branch 'main' into feat/create-task-groups-by-dbt-models
Browse files Browse the repository at this point in the history
  • Loading branch information
maximilianoarcieri authored Mar 4, 2025
2 parents 1f6a8a4 + 46912b6 commit b3c6b0e
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 10 deletions.
3 changes: 3 additions & 0 deletions .github/ISSUE_TEMPLATE/01-bug.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,12 @@ body:
label: ExecutionMode
description: Which ExecutionMode are you using?
options:
- "AIRFLOW_ASYNC"
- "AWS_ECS"
- "AWS_EKS"
- "AZURE_CONTAINER_INSTANCE"
- "DOCKER"
- "GCP_CLOUD_RUN_JOB"
- "KUBERNETES"
- "LOCAL"
- "VIRTUALENV"
Expand Down
15 changes: 9 additions & 6 deletions cosmos/operators/_asynchronous/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,6 @@ def __init__(
self.project_dir = project_dir
self.profile_config = profile_config
self.gcp_conn_id = self.profile_config.profile_mapping.conn_id # type: ignore
profile = self.profile_config.profile_mapping.profile # type: ignore
self.gcp_project = profile["project"]
self.dataset = profile["dataset"]
self.extra_context = extra_context or {}
self.configuration: dict[str, Any] = {}
self.dbt_kwargs = dbt_kwargs or {}
Expand Down Expand Up @@ -103,6 +100,8 @@ def __init__(
self.async_context["profile_type"] = self.profile_config.get_profile_type()
self.async_context["async_operator"] = BigQueryInsertJobOperator
self.compiled_sql = ""
self.gcp_project = ""
self.dataset = ""

@property
def base_cmd(self) -> list[str]:
Expand Down Expand Up @@ -145,10 +144,10 @@ def execute(self, context: Context, **kwargs: Any) -> None:
super().execute(context=context)
else:
self.build_and_run_cmd(context=context, run_as_async=True, async_context=self.async_context)
self._store_compiled_sql(context=context)
self._store_template_fields(context=context)

@provide_session
def _store_compiled_sql(self, context: Context, session: Session = NEW_SESSION) -> None:
def _store_template_fields(self, context: Context, session: Session = NEW_SESSION) -> None:
from airflow.models.renderedtifields import RenderedTaskInstanceFields
from airflow.models.taskinstance import TaskInstance

Expand All @@ -159,6 +158,10 @@ def _store_compiled_sql(self, context: Context, session: Session = NEW_SESSION)
self.log.debug("Executed SQL is: %s", sql)
self.compiled_sql = sql

profile = self.profile_config.profile_mapping.profile
self.gcp_project = profile["project"]
self.dataset = profile["dataset"]

# need to refresh the rendered task field record in the db because Airflow only does this
# before executing the task, not after
ti = context["ti"]
Expand Down Expand Up @@ -188,5 +191,5 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> Any:
"""
job_id = super().execute_complete(context=context, event=event)
self.log.info("Configuration is %s", str(self.configuration))
self._store_compiled_sql(context=context)
self._store_template_fields(context=context)
return job_id
10 changes: 6 additions & 4 deletions tests/operators/_asynchronous/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ def test_dbt_run_airflow_async_bigquery_operator_init(profile_config_mock):
assert operator.project_dir == "/path/to/project"
assert operator.profile_config == profile_config_mock
assert operator.gcp_conn_id == "google_cloud_default"
assert operator.gcp_project == "test_project"
assert operator.dataset == "test_dataset"


def test_dbt_run_airflow_async_bigquery_operator_base_cmd(profile_config_mock):
Expand Down Expand Up @@ -134,15 +132,19 @@ def test_store_compiled_sql(mock_rendered_ti, mock_get_remote_sql, profile_confi
mock_task_instance.task = operator
mock_context = {"ti": mock_task_instance}

operator._store_compiled_sql(mock_context, session=mock_session)
operator._store_template_fields(mock_context, session=mock_session)
# check if gcp_project and dataset are set after the tasks gets executed

assert operator.compiled_sql == "SELECT * FROM test_table;"
assert operator.dataset == "test_dataset"
assert operator.gcp_project == "test_project"

mock_rendered_ti.assert_called_once()
mock_session.add.assert_called_once()
mock_session.query().filter().delete.assert_called_once()


@patch("cosmos.operators._asynchronous.bigquery.DbtRunAirflowAsyncBigqueryOperator._store_compiled_sql")
@patch("cosmos.operators._asynchronous.bigquery.DbtRunAirflowAsyncBigqueryOperator._store_template_fields")
def test_execute_complete(mock_store_sql, profile_config_mock):
mock_context = Mock()
mock_event = {"job_id": "test_job"}
Expand Down

0 comments on commit b3c6b0e

Please sign in to comment.