From d4e015a91b5e3e890bb350afd8f42c5142503ce5 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Wed, 23 Oct 2024 13:31:27 +0100 Subject: [PATCH] Use attrs converters for access_control [skip ci] skip ci --- task_sdk/src/airflow/sdk/definitions/dag.py | 24 +++++++++------------ 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/task_sdk/src/airflow/sdk/definitions/dag.py b/task_sdk/src/airflow/sdk/definitions/dag.py index 8224f3a31673..71dbdc6f9eb4 100644 --- a/task_sdk/src/airflow/sdk/definitions/dag.py +++ b/task_sdk/src/airflow/sdk/definitions/dag.py @@ -170,6 +170,13 @@ def _convert_tags(tags: Collection[str] | None) -> MutableSet[str]: return set(tags or []) +def _convert_access_control(value, self_: DAG): + if hasattr(self_, "_upgrade_outdated_dag_access_control"): + return self_._upgrade_outdated_dag_access_control(value) + else: + return value + + def _all_after_dag_id_to_kw_only(cls, fields: list[attrs.Attribute]): i = iter(fields) f = next(i) @@ -353,7 +360,9 @@ class DAG: default=None, converter=attrs.Converter(_convert_params, takes_self=True), # type: ignore[misc, call-overload] ) - _access_control: dict | None = None + access_control: dict | None = attrs.field( + default=None, converter=attrs.Converter(_convert_access_control, takes_self=True) + ) is_paused_upon_creation: bool | None = None jinja_environment_kwargs: dict | None = None render_template_as_native_obj: bool = attrs.field(default=False, converter=bool) @@ -381,8 +390,6 @@ def __attrs_post_init__(self): self.start_date = timezone.convert_to_utc(self.start_date) self.end_date = timezone.convert_to_utc(self.end_date) - # This should trigger the setters for access_control - self.access_control = self.access_control @fileloc.default def _default_fileloc(self) -> str: @@ -686,17 +693,6 @@ def __deepcopy__(self, memo: dict[int, Any]): result._log = self._log # type: ignore[attr-defined] return result - @property - def access_control(self): - return self._access_control - - @access_control.setter - def access_control(self, value): - if hasattr(self, "_upgrade_outdated_dag_access_control"): - self._access_control = self._upgrade_outdated_dag_access_control(value) - else: - self._access_control = value - def partial_subset( self, task_ids_or_regex: str | Pattern | Iterable[str],