diff --git a/.github/ISSUE_TEMPLATE/01-bug.yml b/.github/ISSUE_TEMPLATE/01-bug.yml index ad663a79e..2765661b8 100644 --- a/.github/ISSUE_TEMPLATE/01-bug.yml +++ b/.github/ISSUE_TEMPLATE/01-bug.yml @@ -52,9 +52,12 @@ body: label: ExecutionMode description: Which ExecutionMode are you using? options: + - "AIRFLOW_ASYNC" + - "AWS_ECS" - "AWS_EKS" - "AZURE_CONTAINER_INSTANCE" - "DOCKER" + - "GCP_CLOUD_RUN_JOB" - "KUBERNETES" - "LOCAL" - "VIRTUALENV" diff --git a/cosmos/operators/_asynchronous/bigquery.py b/cosmos/operators/_asynchronous/bigquery.py index e8879b0fe..1d28f5da3 100644 --- a/cosmos/operators/_asynchronous/bigquery.py +++ b/cosmos/operators/_asynchronous/bigquery.py @@ -73,9 +73,6 @@ def __init__( self.project_dir = project_dir self.profile_config = profile_config self.gcp_conn_id = self.profile_config.profile_mapping.conn_id # type: ignore - profile = self.profile_config.profile_mapping.profile # type: ignore - self.gcp_project = profile["project"] - self.dataset = profile["dataset"] self.extra_context = extra_context or {} self.configuration: dict[str, Any] = {} self.dbt_kwargs = dbt_kwargs or {} @@ -103,6 +100,8 @@ def __init__( self.async_context["profile_type"] = self.profile_config.get_profile_type() self.async_context["async_operator"] = BigQueryInsertJobOperator self.compiled_sql = "" + self.gcp_project = "" + self.dataset = "" @property def base_cmd(self) -> list[str]: @@ -145,10 +144,10 @@ def execute(self, context: Context, **kwargs: Any) -> None: super().execute(context=context) else: self.build_and_run_cmd(context=context, run_as_async=True, async_context=self.async_context) - self._store_compiled_sql(context=context) + self._store_template_fields(context=context) @provide_session - def _store_compiled_sql(self, context: Context, session: Session = NEW_SESSION) -> None: + def _store_template_fields(self, context: Context, session: Session = NEW_SESSION) -> None: from airflow.models.renderedtifields import RenderedTaskInstanceFields from airflow.models.taskinstance import TaskInstance @@ -159,6 +158,10 @@ def _store_compiled_sql(self, context: Context, session: Session = NEW_SESSION) self.log.debug("Executed SQL is: %s", sql) self.compiled_sql = sql + profile = self.profile_config.profile_mapping.profile + self.gcp_project = profile["project"] + self.dataset = profile["dataset"] + # need to refresh the rendered task field record in the db because Airflow only does this # before executing the task, not after ti = context["ti"] @@ -188,5 +191,5 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> Any: """ job_id = super().execute_complete(context=context, event=event) self.log.info("Configuration is %s", str(self.configuration)) - self._store_compiled_sql(context=context) + self._store_template_fields(context=context) return job_id diff --git a/tests/operators/_asynchronous/test_bigquery.py b/tests/operators/_asynchronous/test_bigquery.py index f339c9880..40717cd62 100644 --- a/tests/operators/_asynchronous/test_bigquery.py +++ b/tests/operators/_asynchronous/test_bigquery.py @@ -38,8 +38,6 @@ def test_dbt_run_airflow_async_bigquery_operator_init(profile_config_mock): assert operator.project_dir == "/path/to/project" assert operator.profile_config == profile_config_mock assert operator.gcp_conn_id == "google_cloud_default" - assert operator.gcp_project == "test_project" - assert operator.dataset == "test_dataset" def test_dbt_run_airflow_async_bigquery_operator_base_cmd(profile_config_mock): @@ -134,15 +132,19 @@ def test_store_compiled_sql(mock_rendered_ti, mock_get_remote_sql, profile_confi mock_task_instance.task = operator mock_context = {"ti": mock_task_instance} - operator._store_compiled_sql(mock_context, session=mock_session) + operator._store_template_fields(mock_context, session=mock_session) + # check if gcp_project and dataset are set after the tasks gets executed assert operator.compiled_sql == "SELECT * FROM test_table;" + assert operator.dataset == "test_dataset" + assert operator.gcp_project == "test_project" + mock_rendered_ti.assert_called_once() mock_session.add.assert_called_once() mock_session.query().filter().delete.assert_called_once() -@patch("cosmos.operators._asynchronous.bigquery.DbtRunAirflowAsyncBigqueryOperator._store_compiled_sql") +@patch("cosmos.operators._asynchronous.bigquery.DbtRunAirflowAsyncBigqueryOperator._store_template_fields") def test_execute_complete(mock_store_sql, profile_config_mock): mock_context = Mock() mock_event = {"job_id": "test_job"}