Skip to content

Commit

Permalink
Add identity to task execution metadata (#2315)
Browse files Browse the repository at this point in the history
Signed-off-by: noahjax <[email protected]>
Signed-off-by: ddl-rliu <[email protected]>
Co-authored-by: ddl-rliu <[email protected]>
  • Loading branch information
noahjax and ddl-rliu authored Jun 14, 2024
1 parent 76fe8cb commit 8562fd9
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 1 deletion.
3 changes: 3 additions & 0 deletions flytekit/models/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,14 @@ class Identity(_common.FlyteIdlEntity):
iam_role: Optional[str] = None
k8s_service_account: Optional[str] = None
oauth2_client: Optional[OAuth2Client] = None
execution_identity: Optional[str] = None

def to_flyte_idl(self) -> _sec.Identity:
return _sec.Identity(
iam_role=self.iam_role if self.iam_role else None,
k8s_service_account=self.k8s_service_account if self.k8s_service_account else None,
oauth2_client=self.oauth2_client.to_flyte_idl() if self.oauth2_client else None,
execution_identity=self.execution_identity if self.execution_identity else None,
)

@classmethod
Expand All @@ -108,6 +110,7 @@ def from_flyte_idl(cls, pb2_object: _sec.Identity) -> "Identity":
oauth2_client=OAuth2Client.from_flyte_idl(pb2_object.oauth2_client)
if pb2_object.oauth2_client and pb2_object.oauth2_client.ByteSize()
else None,
execution_identity=pb2_object.execution_identity if pb2_object.execution_identity else None,
)


Expand Down
9 changes: 9 additions & 0 deletions flytekit/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,7 @@ def __init__(
annotations,
k8s_service_account,
environment_variables,
identity,
):
"""
Runtime task execution metadata.
Expand All @@ -539,13 +540,15 @@ def __init__(
:param dict[str, str] annotations: Annotations to use for the execution of this task.
:param Text k8s_service_account: Service account to use for execution of this task.
:param dict[str, str] environment_variables: Environment variables for this task.
:param flytekit.models.security.Identity identity: Identity of user executing this task
"""
self._task_execution_id = task_execution_id
self._namespace = namespace
self._labels = labels
self._annotations = annotations
self._k8s_service_account = k8s_service_account
self._environment_variables = environment_variables
self._identity = identity

@property
def task_execution_id(self):
Expand All @@ -571,6 +574,10 @@ def k8s_service_account(self):
def environment_variables(self):
return self._environment_variables

@property
def identity(self):
return self._identity

def to_flyte_idl(self):
"""
:rtype: flyteidl.admin.agent_pb2.TaskExecutionMetadata
Expand All @@ -584,6 +591,7 @@ def to_flyte_idl(self):
environment_variables={k: v for k, v in self.environment_variables.items()}
if self.labels is not None
else None,
identity=self.identity.to_flyte_idl() if self.identity else None,
)
return task_execution_metadata

Expand All @@ -604,6 +612,7 @@ def from_flyte_idl(cls, pb2_object):
environment_variables={k: v for k, v in pb2_object.environment_variables.items()}
if pb2_object.environment_variables is not None
else None,
identity=_sec.Identity.from_flyte_idl(pb2_object.identity) if pb2_object.identity else None,
)


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ dependencies = [
"diskcache>=5.2.1",
"docker>=4.0.0,<7.0.0",
"docstring-parser>=0.9.0",
"flyteidl>=1.11.0b1",
"flyteidl>=1.12.0",
"fsspec>=2023.3.0",
"gcsfs>=2023.3.0",
"googleapis-common-protos>=1.57",
Expand Down
2 changes: 2 additions & 0 deletions tests/flytekit/unit/extend/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
WorkflowExecutionIdentifier,
)
from flytekit.models.literals import LiteralMap
from flytekit.models.security import Identity
from flytekit.models.task import TaskExecutionMetadata, TaskTemplate
from flytekit.tools.translator import get_serializable

Expand Down Expand Up @@ -159,6 +160,7 @@ def simple_task(i: int):
annotations={"annotation_key": "annotation_val"},
k8s_service_account="k8s service account",
environment_variables={"env_var_key": "env_var_val"},
identity=Identity(execution_identity="task executor"),
)


Expand Down

0 comments on commit 8562fd9

Please sign in to comment.