Skip to content

Commit

Permalink
Use attrs converters for access_control [skip ci]
Browse files Browse the repository at this point in the history
skip ci
  • Loading branch information
kaxil committed Oct 23, 2024
1 parent a92d531 commit d4e015a
Showing 1 changed file with 10 additions and 14 deletions.
24 changes: 10 additions & 14 deletions task_sdk/src/airflow/sdk/definitions/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,13 @@ def _convert_tags(tags: Collection[str] | None) -> MutableSet[str]:
return set(tags or [])


def _convert_access_control(value, self_: DAG):
if hasattr(self_, "_upgrade_outdated_dag_access_control"):
return self_._upgrade_outdated_dag_access_control(value)
else:
return value


def _all_after_dag_id_to_kw_only(cls, fields: list[attrs.Attribute]):
i = iter(fields)
f = next(i)
Expand Down Expand Up @@ -353,7 +360,9 @@ class DAG:
default=None,
converter=attrs.Converter(_convert_params, takes_self=True), # type: ignore[misc, call-overload]
)
_access_control: dict | None = None
access_control: dict | None = attrs.field(
default=None, converter=attrs.Converter(_convert_access_control, takes_self=True)
)
is_paused_upon_creation: bool | None = None
jinja_environment_kwargs: dict | None = None
render_template_as_native_obj: bool = attrs.field(default=False, converter=bool)
Expand Down Expand Up @@ -381,8 +390,6 @@ def __attrs_post_init__(self):

self.start_date = timezone.convert_to_utc(self.start_date)
self.end_date = timezone.convert_to_utc(self.end_date)
# This should trigger the setters for access_control
self.access_control = self.access_control

@fileloc.default
def _default_fileloc(self) -> str:
Expand Down Expand Up @@ -686,17 +693,6 @@ def __deepcopy__(self, memo: dict[int, Any]):
result._log = self._log # type: ignore[attr-defined]
return result

@property
def access_control(self):
return self._access_control

@access_control.setter
def access_control(self, value):
if hasattr(self, "_upgrade_outdated_dag_access_control"):
self._access_control = self._upgrade_outdated_dag_access_control(value)
else:
self._access_control = value

def partial_subset(
self,
task_ids_or_regex: str | Pattern | Iterable[str],
Expand Down

0 comments on commit d4e015a

Please sign in to comment.