From ec91f467c82b3e13cc8b873a6412f623c007d133 Mon Sep 17 00:00:00 2001 From: noahjax Date: Mon, 1 Apr 2024 08:57:15 -0700 Subject: [PATCH] add identity to task execution metadata Signed-off-by: noahjax --- dev-requirements.in | 2 +- flytekit/models/security.py | 3 +++ flytekit/models/task.py | 11 +++++++++++ pyproject.toml | 2 +- tests/flytekit/unit/extend/test_agent.py | 2 ++ 5 files changed, 18 insertions(+), 2 deletions(-) diff --git a/dev-requirements.in b/dev-requirements.in index fb90c597b9..1ff7de0071 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -1,5 +1,5 @@ -e file:.#egg=flytekit -git+https://github.com/flyteorg/flyte.git@master#subdirectory=flyteidl +git+https://github.com/dominodatalab/flyte.git@noahjax.add-owner-reference-to-create-task#subdirectory=flyteidl coverage[toml] hypothesis diff --git a/flytekit/models/security.py b/flytekit/models/security.py index a9ee7e7cb9..cdc1065cc4 100644 --- a/flytekit/models/security.py +++ b/flytekit/models/security.py @@ -88,12 +88,14 @@ class Identity(_common.FlyteIdlEntity): iam_role: Optional[str] = None k8s_service_account: Optional[str] = None oauth2_client: Optional[OAuth2Client] = None + execution_identity: Optional[str] = None def to_flyte_idl(self) -> _sec.Identity: return _sec.Identity( iam_role=self.iam_role if self.iam_role else None, k8s_service_account=self.k8s_service_account if self.k8s_service_account else None, oauth2_client=self.oauth2_client.to_flyte_idl() if self.oauth2_client else None, + execution_identity=self.execution_identity, ) @classmethod @@ -104,6 +106,7 @@ def from_flyte_idl(cls, pb2_object: _sec.Identity) -> "Identity": oauth2_client=OAuth2Client.from_flyte_idl(pb2_object.oauth2_client) if pb2_object.oauth2_client and pb2_object.oauth2_client.ByteSize() else None, + execution_identity=pb2_object.execution_identity, ) diff --git a/flytekit/models/task.py b/flytekit/models/task.py index 198adf2859..d831aa52c3 100644 --- a/flytekit/models/task.py +++ b/flytekit/models/task.py @@ -528,6 +528,7 @@ def __init__( annotations, k8s_service_account, environment_variables, + identity, ): """ Runtime task execution metadata. @@ -539,6 +540,7 @@ def __init__( :param dict[str, str] annotations: Annotations to use for the execution of this task. :param Text k8s_service_account: Service account to use for execution of this task. :param dict[str, str] environment_variables: Environment variables for this task. + :param flytekit.models.security.Identity identity: Identity of user executing this task """ self._task_execution_id = task_execution_id self._namespace = namespace @@ -546,6 +548,7 @@ def __init__( self._annotations = annotations self._k8s_service_account = k8s_service_account self._environment_variables = environment_variables + self._identity = identity @property def task_execution_id(self): @@ -571,6 +574,10 @@ def k8s_service_account(self): def environment_variables(self): return self._environment_variables + @property + def identity(self): + return self._identity + def to_flyte_idl(self): """ :rtype: flyteidl.admin.agent_pb2.TaskExecutionMetadata @@ -584,6 +591,7 @@ def to_flyte_idl(self): environment_variables={k: v for k, v in self.environment_variables.items()} if self.labels is not None else None, + identity=self.identity.to_flyte_idl() if self.identity else None, ) return task_execution_metadata @@ -604,6 +612,9 @@ def from_flyte_idl(cls, pb2_object): environment_variables={k: v for k, v in pb2_object.environment_variables.items()} if pb2_object.environment_variables is not None else None, + identity=_sec.Identity.from_flyte_idl(pb2_object.identity) + if pb2_object.identity + else None, ) diff --git a/pyproject.toml b/pyproject.toml index df457ac3ad..1cc852185e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ dependencies = [ "diskcache>=5.2.1", "docker>=4.0.0,<7.0.0", "docstring-parser>=0.9.0", - "flyteidl>=1.11.0b1", + # "flyteidl>=1.11.0b1", "fsspec>=2023.3.0", "gcsfs>=2023.3.0", "googleapis-common-protos>=1.57", diff --git a/tests/flytekit/unit/extend/test_agent.py b/tests/flytekit/unit/extend/test_agent.py index 2bf23abb25..2ddb0e5074 100644 --- a/tests/flytekit/unit/extend/test_agent.py +++ b/tests/flytekit/unit/extend/test_agent.py @@ -45,6 +45,7 @@ WorkflowExecutionIdentifier, ) from flytekit.models.literals import LiteralMap +from flytekit.models.security import Identity from flytekit.models.task import TaskExecutionMetadata, TaskTemplate from flytekit.tools.translator import get_serializable @@ -157,6 +158,7 @@ def simple_task(i: int): annotations={"annotation_key": "annotation_val"}, k8s_service_account="k8s service account", environment_variables={"env_var_key": "env_var_val"}, + identity=Identity(execution_identity="task executor"), )