diff --git a/latch/idl/admin/common.py b/latch/idl/admin/common.py new file mode 100644 index 00000000..bd1c6f3e --- /dev/null +++ b/latch/idl/admin/common.py @@ -0,0 +1,168 @@ +import typing +from collections.abc import Iterable, Mapping +from dataclasses import dataclass + +import flyteidl.admin.common_pb2 as pb + +from ..core.execution import WorkflowExecution +from ..utils import merged_pb, to_idl_many + + +@dataclass +class EmailNotification: + """Defines an email notification specification.""" + + recipients_email: Iterable[str] + """ + The list of email addresses recipients for this notification. + +required + """ + + def to_idl(self) -> pb.EmailNotification: + return pb.EmailNotification(recipients_email=self.recipients_email) + + +@dataclass +class PagerDutyNotification: + """Defines a pager duty notification specification.""" + + recipients_email: Iterable[str] + """ + Currently, PagerDuty notifications leverage email to trigger a notification. + +required + """ + + def to_idl(self) -> pb.PagerDutyNotification: + return pb.PagerDutyNotification(recipients_email=self.recipients_email) + + +@dataclass +class SlackNotification: + """Defines a slack notification specification.""" + + recipients_email: Iterable[str] + """ + Currently, Slack notifications leverage email to trigger a notification. + +required + """ + + def to_idl(self) -> pb.SlackNotification: + return pb.SlackNotification(recipients_email=self.recipients_email) + + +@dataclass +class Notification: + """ + Represents a structure for notifications based on execution status. + The notification content is configured within flyte admin but can be templatized. + Future iterations could expose configuring notifications with custom content. + """ + + phases: Iterable[WorkflowExecution.Phase] + """ + A list of phases to which users can associate the notifications to. + +required + """ + + type: "typing.Union[NotificationTypeEmail, NotificationTypePagerDuty, NotificationTypeSlack]" + """ + The type of notification to trigger. + +required + """ + + def to_idl(self) -> pb.Notification: + return merged_pb(pb.Notification(phases=to_idl_many(self.phases)), self.type) + + +@dataclass +class NotificationTypeEmail: + email: EmailNotification + + def to_idl(self) -> pb.Notification: + return pb.Notification(email=self.email.to_idl()) + + +@dataclass +class NotificationTypePagerDuty: + pager_duty: PagerDutyNotification + + def to_idl(self) -> pb.Notification: + return pb.Notification(pager_duty=self.pager_duty.to_idl()) + + +@dataclass +class NotificationTypeSlack: + slack: SlackNotification + + def to_idl(self) -> pb.Notification: + return pb.Notification(slack=self.slack.to_idl()) + + +@dataclass +class Labels: + """ + Label values to be applied to an execution resource. + In the future a mode (e.g. OVERRIDE, APPEND, etc) can be defined + to specify how to merge labels defined at registration and execution time. + """ + + values: Mapping[str, str] + """Map of custom labels to be applied to the execution resource.""" + + def to_idl(self) -> pb.Labels: + return pb.Labels(values=self.values) + + +@dataclass +class Annotations: + """ + Annotation values to be applied to an execution resource. + In the future a mode (e.g. OVERRIDE, APPEND, etc) can be defined + to specify how to merge annotations defined at registration and execution time. + """ + + values: Mapping[str, str] + """Map of custom annotations to be applied to the execution resource.""" + + def to_idl(self) -> pb.Annotations: + return pb.Annotations(values=self.values) + + +@dataclass +class AuthRole: + """ + Defines permissions associated with executions created by this launch plan spec. + Use either of these roles when they have permissions required by your workflow execution. + Deprecated. + """ + + assumable_iam_role: str + """Defines an optional iam role which will be used for tasks run in executions created with this launch plan.""" + + kubernetes_service_account: str + """Defines an optional kubernetes service account which will be used for tasks run in executions created with this launch plan.""" + + def to_idl(self) -> pb.AuthRole: + return pb.AuthRole( + assumable_iam_role=self.assumable_iam_role, + kubernetes_service_account=self.kubernetes_service_account, + ) + + +@dataclass +class RawOutputDataConfig: + """ + Encapsulates user settings pertaining to offloaded data (i.e. Blobs, Schema, query data, etc.). + See https://github.com/flyteorg/flyte/issues/211 for more background information. + """ + + output_location_prefix: str + """ + Prefix for where offloaded data from user workflows will be written + e.g. s3://bucket/key or s3://bucket/ + """ + + def to_idl(self) -> pb.RawOutputDataConfig: + return pb.RawOutputDataConfig( + output_location_prefix=self.output_location_prefix + ) diff --git a/latch/idl/admin/launch_plan.py b/latch/idl/admin/launch_plan.py new file mode 100644 index 00000000..06287a7f --- /dev/null +++ b/latch/idl/admin/launch_plan.py @@ -0,0 +1,143 @@ +from collections.abc import Iterable +from dataclasses import dataclass +from typing import Optional + +import flyteidl.admin.launch_plan_pb2 as pb +import google.protobuf.wrappers_pb2 as pb_wrap + +from ..core.execution import QualityOfService +from ..core.identifier import Identifier +from ..core.interface import ParameterMap +from ..core.literals import LiteralMap +from ..core.security import SecurityContext +from ..utils import to_idl_many, try_to_idl +from .common import Annotations, AuthRole, Labels, Notification, RawOutputDataConfig +from .schedule import Schedule + + +@dataclass +class Auth: + """ + Defines permissions associated with executions created by this launch plan spec. + Use either of these roles when they have permissions required by your workflow execution. + Deprecated. + """ + + assumable_iam_role: str + """Defines an optional iam role which will be used for tasks run in executions created with this launch plan.""" + + kubernetes_service_account: str + """Defines an optional kubernetes service account which will be used for tasks run in executions created with this launch plan.""" + + def to_idl(self) -> pb.Auth: + return pb.Auth( + assumable_iam_role=self.assumable_iam_role, + kubernetes_service_account=self.kubernetes_service_account, + ) + + +@dataclass +class LaunchPlanSpec: + """User-provided launch plan definition and configuration values.""" + + workflow_id: Identifier + """Reference to the Workflow template that the launch plan references""" + + entity_metadata: "LaunchPlanMetadata" + """Metadata for the Launch Plan""" + + default_inputs: ParameterMap + """ + Input values to be passed for the execution. + These can be overriden when an execution is created with this launch plan. + """ + + fixed_inputs: LiteralMap + """ + Fixed, non-overridable inputs for the Launch Plan. + These can not be overriden when an execution is created with this launch plan. + """ + + """ + String to indicate the role to use to execute the workflow underneath + + Deprecated + """ + role: str + + labels: Labels + """Custom labels to be applied to the execution resource.""" + + annotations: Annotations + """Custom annotations to be applied to the execution resource.""" + + security_context: SecurityContext + """Indicates security context for permissions triggered with this launch plan""" + + quality_of_service: QualityOfService + """Indicates the runtime priority of the execution.""" + + raw_output_data_config: RawOutputDataConfig + """Encapsulates user settings pertaining to offloaded data (i.e. Blobs, Schema, query data, etc.).""" + + max_parallelism: int + """ + Controls the maximum number of tasknodes that can be run in parallel for the entire workflow. + This is useful to achieve fairness. Note: MapTasks are regarded as one unit, + and parallelism/concurrency of MapTasks is independent from this. + """ + + interruptible: Optional[bool] = None + """ + Allows for the interruptible flag of a workflow to be overwritten for a single execution. + Omitting this field uses the workflow's value as a default. + As we need to distinguish between the field not being provided and its default value false, we have to use a wrapper + around the bool field. + """ + + auth: Optional[Auth] = None + """ + Indicates the permission associated with workflow executions triggered with this launch plan. + + Deprecated + """ + + auth_role: Optional[AuthRole] = None + + def to_idl(self) -> pb.LaunchPlanSpec: + return pb.LaunchPlanSpec( + workflow_id=self.workflow_id.to_idl(), + entity_metadata=self.entity_metadata.to_idl(), + default_inputs=self.default_inputs.to_idl(), + fixed_inputs=self.fixed_inputs.to_idl(), + role=self.role, + labels=self.labels.to_idl(), + annotations=self.annotations.to_idl(), + auth=try_to_idl(self.auth), + auth_role=try_to_idl(self.auth_role), + security_context=self.security_context.to_idl(), + quality_of_service=self.quality_of_service.to_idl(), + raw_output_data_config=self.raw_output_data_config.to_idl(), + max_parallelism=self.max_parallelism, + interruptible=pb_wrap.BoolValue(value=self.interruptible), + ) + + +@dataclass +class LaunchPlanMetadata: + """ + Additional launch plan attributes included in the LaunchPlanSpec not strictly required to launch + the reference workflow. + """ + + schedule: Schedule + """Schedule to execute the Launch Plan""" + + notifications: Iterable[Notification] + """List of notifications based on Execution status transitions""" + + def to_idl(self) -> pb.LaunchPlanMetadata: + return pb.LaunchPlanMetadata( + schedule=self.schedule.to_idl(), + notifications=to_idl_many(self.notifications), + ) diff --git a/latch/idl/admin/schedule.py b/latch/idl/admin/schedule.py new file mode 100644 index 00000000..5611f0ac --- /dev/null +++ b/latch/idl/admin/schedule.py @@ -0,0 +1,95 @@ +import typing +from dataclasses import dataclass +from enum import Enum + +import flyteidl.admin.schedule_pb2 as pb + +from ..utils import merged_pb + + +class FixedRateUnit(int, Enum): + """Represents a frequency at which to run a schedule.""" + + minute = pb.MINUTE + hour = pb.HOUR + day = pb.DAY + + def to_idl(self) -> pb.FixedRateUnit: + return self.value + + +@dataclass +class FixedRate: + """Option for schedules run at a certain frequency e.g. every 2 minutes.""" + + value: int + unit: FixedRateUnit + + def to_idl(self) -> pb.FixedRate: + return pb.FixedRate(value=self.value, unit=self.unit.to_idl()) + + +@dataclass +class CronSchedule: + """Options for schedules to run according to a cron expression.""" + + schedule: str + """ + Standard/default cron implementation as described by https://en.wikipedia.org/wiki/Cron#CRON_expression; + Also supports nonstandard predefined scheduling definitions + as described by https://docs.aws.amazon.com/AmazonCloudWatch/latest/events/ScheduledEvents.html#CronExpressions + except @reboot + """ + + offset: str + """ISO 8601 duration as described by https://en.wikipedia.org/wiki/ISO_8601#Durations""" + + def to_idl(self) -> pb.CronSchedule: + return pb.CronSchedule(schedule=self.schedule, offset=self.offset) + + +@dataclass +class Schedule: + """Defines complete set of information required to trigger an execution on a schedule.""" + + ScheduleExpression: "typing.Union[ScheduleExpressionCronExpression, ScheduleExpressionFixedRate, ScheduleExpressionCronSchedule]" + + kickoff_time_input_arg: str + """Name of the input variable that the kickoff time will be supplied to when the workflow is kicked off.""" + + def to_idl(self) -> pb.Schedule: + return merged_pb( + pb.Schedule(kickoff_time_input_arg=self.kickoff_time_input_arg), + self.ScheduleExpression, + ) + + +@dataclass +class ScheduleExpressionCronExpression: + """ + Uses AWS syntax: Minutes Hours Day-of-month Month Day-of-week Year + e.g. for a schedule that runs every 15 minutes: 0/15 * * * ? * + + Deprecated + """ + + cron_expression: str + + def to_idl(self) -> pb.Schedule: + return pb.Schedule(cron_expression=self.cron_expression) + + +@dataclass +class ScheduleExpressionFixedRate: + rate: FixedRate + + def to_idl(self) -> pb.Schedule: + return pb.Schedule(rate=self.rate.to_idl()) + + +@dataclass +class ScheduleExpressionCronSchedule: + cron_schedule: CronSchedule + + def to_idl(self) -> pb.Schedule: + return pb.Schedule(cron_schedule=self.cron_schedule.to_idl()) diff --git a/latch/idl/admin/workflow.py b/latch/idl/admin/workflow.py new file mode 100644 index 00000000..724c7da9 --- /dev/null +++ b/latch/idl/admin/workflow.py @@ -0,0 +1,28 @@ +from collections.abc import Iterable +from dataclasses import dataclass, field + +import flyteidl.admin.workflow_pb2 as pb + +from ..core.workflow import WorkflowTemplate +from ..utils import to_idl_many + + +@dataclass +class WorkflowSpec: + """Represents a structure that encapsulates the specification of the workflow.""" + + template: WorkflowTemplate + """Template of the task that encapsulates all the metadata of the workflow.""" + + sub_workflows: Iterable[WorkflowTemplate] = field(default_factory=list) + """ + Workflows that are embedded into other workflows need to be passed alongside the parent workflow to the + propeller compiler (since the compiler doesn't have any knowledge of other workflows - ie, it doesn't reach out + to Admin to see other registered workflows). In fact, subworkflows do not even need to be registered. + """ + + def to_idl(self) -> pb.WorkflowSpec: + return pb.WorkflowSpec( + template=self.template.to_idl(), + sub_workflows=to_idl_many(self.sub_workflows), + ) diff --git a/latch/idl/core/condition.py b/latch/idl/core/condition.py new file mode 100644 index 00000000..4b339e34 --- /dev/null +++ b/latch/idl/core/condition.py @@ -0,0 +1,124 @@ +import typing +from dataclasses import dataclass +from enum import Enum + +import flyteidl.core.condition_pb2 as pb + +from .literals import Primitive + + +@dataclass +class ComparsionExpression: + """ + Defines a 2-level tree where the root is a comparison operator and Operands are primitives or known variables. + Each expression results in a boolean result. + """ + + class Operator(int, Enum): + """Binary Operator for each expression""" + + eq = pb.ComparisonExpression.EQ + neq = pb.ComparisonExpression.NEQ + # Greater Than + gt = pb.ComparisonExpression.GT + gte = pb.ComparisonExpression.GTE + # Less Than + lt = pb.ComparisonExpression.LT + lte = pb.ComparisonExpression.LTE + + def to_idl(self) -> pb.ComparisonExpression.Operator: + return self.value + + operator: Operator + left_value: "Operand" + right_value: "Operand" + + def to_idl(self) -> pb.ComparisonExpression: + return pb.ComparisonExpression( + operator=self.operator.to_idl(), + left_value=self.left_value.to_idl(), + right_value=self.right_value.to_idl(), + ) + + +@dataclass +class Operand: + """Defines an operand to a comparison expression.""" + + val: "typing.Union[OperandPrimitive, OperandVar]" + + def to_idl(self) -> pb.Operand: + return self.val.to_idl() + + +@dataclass +class OperandPrimitive: + primitive: Primitive + """Can be a constant""" + + def to_idl(self) -> pb.Operand: + return pb.Operand(primitive=self.primitive.to_idl()) + + +@dataclass +class OperandVar: + var: str + """Or one of this node's input variables""" + + def to_idl(self) -> pb.Operand: + return pb.Operand(var=self.var) + + +@dataclass +class BooleanExpression: + """ + Defines a boolean expression tree. It can be a simple or a conjunction expression. + Multiple expressions can be combined using a conjunction or a disjunction to result in a final boolean result. + """ + + expr: "typing.Union[BooleanExpressionConjuctionExpression, BooleanExpressionComparisonExpression]" + + def to_idl(self) -> pb.BooleanExpression: + return self.expr.to_idl() + + +@dataclass +class BooleanExpressionConjuctionExpression: + conjunction_expression: "ConjuctionExpression" + + def to_idl(self) -> pb.BooleanExpression: + return pb.BooleanExpression(conjunction=self.conjunction_expression.to_idl()) + + +@dataclass +class BooleanExpressionComparisonExpression: + comparison_expression: ComparsionExpression + + def to_idl(self) -> pb.BooleanExpression: + return pb.BooleanExpression(comparison=self.comparison_expression.to_idl()) + + +@dataclass +class ConjuctionExpression: + """Defines a conjunction expression of two boolean expressions.""" + + class LogicalOperator(int, Enum): + """Nested conditions. They can be conjoined using AND / OR""" + + # Conjunction + and_ = pb.ConjunctionExpression.AND + or_ = pb.ConjunctionExpression.OR + + def to_idl(self) -> pb.ConjunctionExpression.LogicalOperator: + return self.value + + operator: LogicalOperator + left_expression: BooleanExpression + right_expression: BooleanExpression + + def to_idl(self) -> pb.ConjunctionExpression: + return pb.ConjunctionExpression( + operator=self.operator.to_idl(), + left_expression=self.left_expression.to_idl(), + right_expression=self.right_expression.to_idl(), + ) diff --git a/latch/idl/core/execution.py b/latch/idl/core/execution.py new file mode 100644 index 00000000..67a34bd1 --- /dev/null +++ b/latch/idl/core/execution.py @@ -0,0 +1,80 @@ +import typing +from dataclasses import dataclass, field +from datetime import timedelta +from enum import Enum + +import flyteidl.core.execution_pb2 as pb + +from ..utils import dur_from_td + + +@dataclass +class WorkflowExecution: + """Indicates various phases of Workflow Execution""" + + class Phase(int, Enum): + undefined = pb.WorkflowExecution.UNDEFINED + queued = pb.WorkflowExecution.QUEUED + running = pb.WorkflowExecution.RUNNING + succeeding = pb.WorkflowExecution.SUCCEEDING + succeeded = pb.WorkflowExecution.SUCCEEDED + failing = pb.WorkflowExecution.FAILING + failed = pb.WorkflowExecution.FAILED + aborted = pb.WorkflowExecution.ABORTED + timed_out = pb.WorkflowExecution.TIMED_OUT + aborting = pb.WorkflowExecution.ABORTING + + def to_idl(self) -> pb.WorkflowExecution.Phase: + return self.value + + +@dataclass +class QualityOfServiceSpec: + """Represents customized execution run-time attributes.""" + + queueing_budget: timedelta + """Indicates how much queueing delay an execution can tolerate.""" + + # Add future, user-configurable options here + + def to_idl(self) -> pb.QualityOfServiceSpec: + return pb.QualityOfServiceSpec(dur_from_td(self.queueing_budget)) + + +@dataclass +class QualityOfService: + """Indicates the priority of an execution.""" + + class Tier(int, Enum): + undefined = pb.QualityOfService.UNDEFINED + """Default: no quality of service specified.""" + + high = pb.QualityOfService.HIGH + medium = pb.QualityOfService.MEDIUM + low = pb.QualityOfService.LOW + + def to_idl(self) -> pb.QualityOfService.Tier: + return self.value + + designation: "typing.Union[QualityOfServiceDesignationTier, QualityOfServiceDesignationSpec]" = field( + default_factory=lambda: QualityOfServiceDesignationTier() + ) + + def to_idl(self) -> pb.QualityOfService: + return self.designation.to_idl() + + +@dataclass +class QualityOfServiceDesignationTier: + tier: QualityOfService.Tier = QualityOfService.Tier.undefined + + def to_idl(self) -> pb.QualityOfService: + return pb.QualityOfService(tier=self.tier.to_idl()) + + +@dataclass +class QualityOfServiceDesignationSpec: + spec: QualityOfServiceSpec + + def to_idl(self) -> pb.QualityOfService: + return pb.QualityOfService(spec=self.spec.to_idl()) diff --git a/latch/idl/core/identifier.py b/latch/idl/core/identifier.py new file mode 100644 index 00000000..8c4b2561 --- /dev/null +++ b/latch/idl/core/identifier.py @@ -0,0 +1,100 @@ +from dataclasses import dataclass +from enum import Enum + +import flyteidl.core.identifier_pb2 as pb + + +class ResourceType(int, Enum): + unspecified = pb.UNSPECIFIED + task = pb.TASK + workflow = pb.WORKFLOW + launch_plan = pb.LAUNCH_PLAN + dataset = pb.DATASET + """ + A dataset represents an entity modeled in Flyte DataCatalog. A Dataset is also a versioned entity and can be a compilation of multiple individual objects. + Eventually all Catalog objects should be modeled similar to Flyte Objects. The Dataset entities makes it possible for the UI and CLI to act on the objects + in a similar manner to other Flyte objects + """ + + def to_idl(self) -> pb.ResourceType: + return self.value + + +@dataclass +class Identifier: + """Encapsulation of fields that uniquely identifies a Flyte resource.""" + + resource_type: ResourceType + """Identifies the specific type of resource that this identifier corresponds to.""" + + project: str + """Name of the project the resource belongs to.""" + domain: str + """ + Name of the domain the resource belongs to. + A domain can be considered as a subset within a specific project. + """ + name: str + """User provided value for the resource.""" + version: str + """Specific version of the resource.""" + + def to_idl(self) -> pb.Identifier: + return pb.Identifier( + resource_type=self.resource_type.to_idl(), + project=self.project, + domain=self.domain, + name=self.name, + version=self.version, + ) + + +@dataclass +class WorkflowExecutionIdentifier: + """Encapsulation of fields that uniquely identifies a Flyte workflow execution""" + + project: str + """Name of the project the resource belongs to.""" + domain: str + """ + Name of the domain the resource belongs to. + A domain can be considered as a subset within a specific project. + """ + name: str + """User provided value for the resource.""" + + def to_idl(self) -> pb.WorkflowExecutionIdentifier: + return pb.WorkflowExecutionIdentifier( + project=self.project, + domain=self.domain, + name=self.name, + ) + + +@dataclass +class NodeExecutionIdentifier: + """Encapsulation of fields that identify a Flyte node execution entity.""" + + node_id: str + execution_id: WorkflowExecutionIdentifier + + def to_idl(self) -> pb.NodeExecutionIdentifier: + return pb.NodeExecutionIdentifier( + node_id=self.node_id, execution_id=self.execution_id.to_idl() + ) + + +@dataclass +class TaskExecutionIdentifier: + """Encapsulation of fields that identify a Flyte task execution entity.""" + + task_id: Identifier + node_execution_id: NodeExecutionIdentifier + retry_attempt: int = 0 + + def to_idl(self) -> pb.TaskExecutionIdentifier: + return pb.TaskExecutionIdentifier( + task_id=self.task_id.to_idl(), + node_execution_id=self.node_execution_id.to_idl(), + retry_attempt=self.retry_attempt, + ) diff --git a/latch/idl/core/interface.py b/latch/idl/core/interface.py new file mode 100644 index 00000000..5d7e6056 --- /dev/null +++ b/latch/idl/core/interface.py @@ -0,0 +1,97 @@ +import typing +from collections.abc import Mapping +from dataclasses import dataclass +from typing import Optional + +import flyteidl.core.interface_pb2 as pb + +from ..utils import merged_pb, to_idl_mapping +from .literals import Literal +from .types import LiteralType + + +@dataclass +class Variable: + """Defines a strongly typed variable.""" + + type: LiteralType + """Variable literal type.""" + + description: str + """+optional string describing input variable""" + + def to_idl(self) -> pb.Variable: + return pb.Variable(type=self.type.to_idl(), description=self.description) + + +@dataclass +class VariableMap: + """A map of Variables""" + + variables: Mapping[str, Variable] + """Defines a map of variable names to variables.""" + + def to_idl(self) -> pb.VariableMap: + return pb.VariableMap(variables=to_idl_mapping(self.variables)) + + +@dataclass +class TypedInterface: + """Defines strongly typed inputs and outputs.""" + + inputs: VariableMap + outputs: VariableMap + + def to_idl(self) -> pb.TypedInterface: + return pb.TypedInterface( + inputs=self.inputs.to_idl(), outputs=self.outputs.to_idl() + ) + + +@dataclass +class Parameter: + """ + A parameter is used as input to a launch plan and has + the special ability to have a default value or mark itself as required. + """ + + var: Variable + """+required Variable. Defines the type of the variable backing this parameter.""" + + behavior: "Optional[typing.Union[ParameterBehaviorDefault, ParameterBehaviorRequired]]" = ( + None + ) + + def to_idl(self) -> pb.Parameter: + return merged_pb(pb.Parameter(var=self.var.to_idl()), self.behavior) + + +@dataclass +class ParameterBehaviorDefault: + """Defines a default value that has to match the variable type defined.""" + + default: Literal + + def to_idl(self) -> pb.Parameter: + return pb.Parameter(default=self.default.to_idl()) + + +@dataclass +class ParameterBehaviorRequired: + """+optional, is this value required to be filled.""" + + required: bool + + def to_idl(self) -> pb.Parameter: + return pb.Parameter(required=self.required) + + +@dataclass +class ParameterMap: + """A map of Parameters.""" + + parameters: Mapping[str, Parameter] + """Defines a map of parameter names to parameters.""" + + def to_idl(self) -> pb.ParameterMap: + return pb.ParameterMap(parameters=to_idl_mapping(self.parameters)) diff --git a/latch/idl/core/literals.py b/latch/idl/core/literals.py new file mode 100644 index 00000000..6cc7d002 --- /dev/null +++ b/latch/idl/core/literals.py @@ -0,0 +1,521 @@ +import typing +from collections.abc import Iterable, Mapping +from dataclasses import dataclass +from datetime import datetime as datetime_ +from datetime import timedelta +from typing import Optional + +import flyteidl.core.literals_pb2 as pb +import google.protobuf.struct_pb2 as pb_struct +import google.protobuf.timestamp_pb2 as pb_ts + +from ..utils import ( + dur_from_td, + merged_pb, + timestamp_from_datetime, + to_idl_many, + to_idl_mapping, + try_to_idl, +) +from .types import ( + BlobType, + Error, + LiteralType, + OutputReference, + SchemaType, + StructuredDatasetType, +) + + +@dataclass +class Primitive: + """Primitive Types""" + + value: "typing.Union[PrimitiveInt, PrimitiveFloat, PrimitiveString, PrimitiveBoolean, PrimitiveDatetime, PrimitiveDuration]" + """ + Defines one of simple primitive types. These types will get translated into different programming languages as + described in https://developers.google.com/protocol-buffers/docs/proto#scalar. + """ + + def to_idl(self) -> pb.Primitive: + return self.value.to_idl() + + +@dataclass +class PrimitiveInt: + integer: int + + def to_idl(self) -> pb.Primitive: + return pb.Primitive(integer=self.integer) + + +@dataclass +class PrimitiveFloat: + float_value: float + + def to_idl(self) -> pb.Primitive: + return pb.Primitive(float_value=self.float_value) + + +@dataclass +class PrimitiveString: + string_value: str + + def to_idl(self) -> pb.Primitive: + return pb.Primitive(string_value=self.string_value) + + +@dataclass +class PrimitiveBoolean: + boolean: bool + + def to_idl(self) -> pb.Primitive: + return pb.Primitive(boolean=self.boolean) + + +@dataclass +class PrimitiveDatetime: + datetime: datetime_ + + def to_idl(self) -> pb.Primitive: + return pb.Primitive(datetime=timestamp_from_datetime(self.datetime)) + + +@dataclass +class PrimitiveDuration: + duration: timedelta + + def to_idl(self) -> pb.Primitive: + return pb.Primitive(duration=dur_from_td(self.duration)) + + +@dataclass +class Void: + """ + Used to denote a nil/null/None assignment to a scalar value. The underlying LiteralType for Void is intentionally + undefined since it can be assigned to a scalar of any LiteralType. + + maximsmol: note: Void can no longer be assigned to a scalar of any type since union types were introduced + """ + + def to_idl(self) -> pb.Void: + return pb.Void() + + +@dataclass +class Blob: + """ + Refers to an offloaded set of files. It encapsulates the type of the store and a unique uri for where the data is. + There are no restrictions on how the uri is formatted since it will depend on how to interact with the store. + """ + + metadata: "BlobMetadata" + uri: str + + def to_idl(self) -> pb.Blob: + return pb.Blob(metadata=self.metadata.to_idl(), uri=self.uri) + + +@dataclass +class BlobMetadata: + type: BlobType + + def to_idl(self) -> pb.BlobMetadata: + return pb.BlobMetadata(type=self.type.to_idl()) + + +@dataclass +class Binary: + """ + A simple byte array with a tag to help different parts of the system communicate about what is in the byte array. + It's strongly advisable that consumers of this type define a unique tag and validate the tag before parsing the data. + """ + + value: bytes + tag: str + + def to_idl(self) -> pb.Binary: + return pb.Binary(value=self.value, tag=self.tag) + + +@dataclass +class Schema: + """ + A strongly typed schema that defines the interface of data retrieved from the underlying storage medium. + + maximsmol: note: pretty much unsupported + """ + + uri: str + type: SchemaType + + def to_idl(self) -> pb.Schema: + return pb.Schema(uri=self.uri, type=self.type.to_idl()) + + +@dataclass +class Union: + """The runtime representation of a tagged union value. See `UnionType` for more details.""" + + value: "Literal" + type: LiteralType + + def to_idl(self) -> pb.Union: + return pb.Union(value=self.value.to_idl(), type=self.type.to_idl()) + + +@dataclass +class RecordField: + key: str + value: "Literal" + + def to_idl(self) -> pb.RecordField: + return pb.RecordField(key=self.key, value=self.value.to_idl()) + + +@dataclass +class Record: + fields: Iterable[RecordField] + + def to_idl(self) -> pb.Record: + return pb.Record(fields=to_idl_many(self.fields)) + + +@dataclass +class StructuredDatasetMetadata: + """ + Bundle the type information along with the literal. + This is here because StructuredDatasets can often be more defined at run time than at compile time. + That is, at compile time you might only declare a task to return a pandas dataframe or a StructuredDataset, + without any column information, but at run time, you might have that column information. + flytekit python will copy this type information into the literal, from the type information, if not provided by + the various plugins (encoders). + Since this field is run time generated, it's not used for any type checking. + """ + + structured_dataset_type: StructuredDatasetType + + def to_idl(self) -> pb.StructuredDatasetMetadata: + return pb.StructuredDatasetMetadata( + structured_dataset_type=self.structured_dataset_type.to_idl() + ) + + +@dataclass +class StructuredDataset: + uri: str + """ + String location uniquely identifying where the data is. + Should start with the storage location (e.g. s3://, gs://, bq://, etc.) + """ + + metadata: StructuredDatasetMetadata + + def to_idl(self) -> pb.StructuredDataset: + return pb.StructuredDataset(uri=self.uri, metadata=self.metadata.to_idl()) + + +@dataclass +class Scalar: + value: "typing.Union[ScalarPrimitive, ScalarBlob, ScalarBinary, ScalarSchema, ScalarVoid, ScalarError, ScalarGeneric, ScalarStructuredDataset, ScalarUnion]" + + def to_idl(self) -> pb.Scalar: + raise NotImplementedError() + + +@dataclass +class ScalarPrimitive: + primitive: Primitive + + def to_idl(self) -> pb.Scalar: + return pb.Scalar(primitive=self.primitive.to_idl()) + + +@dataclass +class ScalarBlob: + blob: Blob + + def to_idl(self) -> pb.Scalar: + return pb.Scalar(blob=self.blob.to_idl()) + + +@dataclass +class ScalarBinary: + binary: Binary + + def to_idl(self) -> pb.Scalar: + return pb.Scalar(binary=self.binary.to_idl()) + + +@dataclass +class ScalarSchema: + schema: Schema + + def to_idl(self) -> pb.Scalar: + return pb.Scalar(schema=self.schema.to_idl()) + + +@dataclass +class ScalarVoid: + none_type: Void + + def to_idl(self) -> pb.Scalar: + return pb.Scalar(none_type=self.none_type.to_idl()) + + +@dataclass +class ScalarError: + error: Error + + def to_idl(self) -> pb.Scalar: + return pb.Scalar(error=self.error.to_idl()) + + +@dataclass +class ScalarGeneric: + """maximsmol: note: use Records i.e. dataclasses instead""" + + generic: pb_struct.Struct + + def to_idl(self) -> pb.Scalar: + return pb.Scalar(generic=self.generic) + + +@dataclass +class ScalarStructuredDataset: + structured_dataset: StructuredDataset + + def to_idl(self) -> pb.Scalar: + return pb.Scalar(structured_dataset=self.structured_dataset.to_idl()) + + +@dataclass +class ScalarUnion: + union: Union + + def to_idl(self) -> pb.Scalar: + return pb.Scalar(union=self.union.to_idl()) + + +@dataclass +class Literal: + value: "typing.Union[LiteralScalar, LiteralLiteralCollection, LiteralLiteralMap, LiteralRecord]" + + hash: Optional[str] = None + """ + A hash representing this literal. + This is used for caching purposes. For more details refer to RFC 1893 + (https://github.com/flyteorg/flyte/blob/516dd3926957af83c1c3ba6c12817477486be5c5/rfc/system/1893-caching-of-offloaded-objects.md) + """ + + def to_idl(self) -> pb.Literal: + return merged_pb(pb.Literal(hash=self.hash), self.value) + + +@dataclass +class LiteralScalar: + """A simple value.""" + + scalar: Scalar + + def to_idl(self) -> pb.Literal: + return pb.Literal(scalar=self.scalar.to_idl()) + + +@dataclass +class LiteralLiteralCollection: + """A collection of literals to allow nesting.""" + + collection: "LiteralCollection" + + def to_idl(self) -> pb.Literal: + return pb.Literal(collection=self.collection.to_idl()) + + +@dataclass +class LiteralLiteralMap: + """A map of strings to literals.""" + + map: "LiteralMap" + + def to_idl(self) -> pb.Literal: + return pb.Literal(map=self.map.to_idl()) + + +@dataclass +class LiteralRecord: + record: Record + + def to_idl(self) -> pb.Literal: + return pb.Literal(record=self.record.to_idl()) + + +@dataclass +class LiteralCollection: + """A collection of literals. This is a workaround since oneofs in proto messages cannot contain a repeated field.""" + + literals: Iterable[Literal] + + def to_idl(self) -> pb.LiteralCollection: + return pb.LiteralCollection(literals=to_idl_many(self.literals)) + + +@dataclass +class LiteralMap: + """A map of literals. This is a workaround since oneofs in proto messages cannot contain a repeated field.""" + + literals: Mapping[str, Literal] + + def to_idl(self) -> pb.LiteralMap: + return pb.LiteralMap(literals=to_idl_mapping(self.literals)) + + +@dataclass +class BindingDataCollection: + """A collection of BindingData items.""" + + bindings: "Iterable[BindingData]" + + def to_idl(self) -> pb.BindingDataCollection: + return pb.BindingDataCollection(bindings=to_idl_many(self.bindings)) + + +@dataclass +class BindingDataMap: + """A map of BindingData items.""" + + bindings: "Mapping[str, BindingData]" + + def to_idl(self) -> pb.BindingDataMap: + return pb.BindingDataMap(bindings=to_idl_mapping(self.bindings)) + + +@dataclass +class BindingDataRecordField: + key: str + binding: "BindingData" + + def to_idl(self) -> pb.BindingDataRecordField: + return pb.BindingDataRecordField(key=self.key, binding=self.binding.to_idl()) + + +@dataclass +class BindingDataRecord: + fields: Iterable[BindingDataRecordField] + + def to_idl(self) -> pb.BindingDataRecord: + return pb.BindingDataRecord(fields=to_idl_many(self.fields)) + + +@dataclass +class UnionInfo: + targetType: LiteralType + + def to_idl(self) -> pb.UnionInfo: + return pb.UnionInfo(targetType=self.targetType.to_idl()) + + +@dataclass +class BindingData: + """Specifies either a simple value or a reference to another output.""" + + value: "typing.Union[BindingDataScalar, BindingDataBindingCollection, BindingDataPromise, BindingDataBindingMap, BindingDataBindingRecord]" + + union: Optional[UnionInfo] = None + + def to_idl(self) -> pb.BindingData: + return merged_pb(pb.BindingData(union=try_to_idl(self.union)), self.value) + + +@dataclass +class BindingDataScalar: + """A simple scalar value.""" + + scalar: Scalar + + def to_idl(self) -> pb.BindingData: + return pb.BindingData(scalar=self.scalar.to_idl()) + + +@dataclass +class BindingDataBindingCollection: + """ + A collection of binding data. This allows nesting of binding data to any number + of levels. + """ + + collection: BindingDataCollection + + def to_idl(self) -> pb.BindingData: + return pb.BindingData(collection=self.collection.to_idl()) + + +@dataclass +class BindingDataPromise: + """References an output promised by another node.""" + + promise: OutputReference + + def to_idl(self) -> pb.BindingData: + return pb.BindingData(promise=self.promise.to_idl()) + + +@dataclass +class BindingDataBindingMap: + """A map of bindings. The key is always a string.""" + + map: BindingDataMap + + def to_idl(self) -> pb.BindingData: + return pb.BindingData(map=self.map.to_idl()) + + +@dataclass +class BindingDataBindingRecord: + """A map of bindings. The key is always a string.""" + + record: BindingDataRecord + + def to_idl(self) -> pb.BindingData: + return pb.BindingData(record=self.record.to_idl()) + + +@dataclass +class Binding: + """An input/output binding of a variable to either static value or a node output.""" + + var: str + """Variable name must match an input/output variable of the node.""" + binding: BindingData + """Data to use to bind this variable.""" + + def to_idl(self) -> pb.Binding: + return pb.Binding(var=self.var, binding=self.binding.to_idl()) + + +@dataclass +class KeyValuePair: + """A generic key value pair.""" + + key: str + """required.""" + + value: str + """+optional.""" + + def to_idl(self) -> pb.KeyValuePair: + return pb.KeyValuePair(key=self.key, value=self.value) + + +@dataclass +class RetryStrategy: + """Retry strategy associated with an executable unit.""" + + retries: int + """ + Number of retries. Retries will be consumed when the job fails with a recoverable error. + The number of retries must be less than or equals to 10. + """ + + def to_idl(self) -> pb.RetryStrategy: + return pb.RetryStrategy(retries=self.retries) diff --git a/latch/idl/core/security.py b/latch/idl/core/security.py new file mode 100644 index 00000000..170cf16a --- /dev/null +++ b/latch/idl/core/security.py @@ -0,0 +1,217 @@ +from collections.abc import Iterable +from dataclasses import dataclass +from enum import Enum +from typing import Optional + +import flyteidl.core.security_pb2 as pb + +from ..utils import to_idl_many, try_to_idl + + +@dataclass +class Secret: + """ + Secret encapsulates information about the secret a task needs to proceed. An environment variable + FLYTE_SECRETS_ENV_PREFIX will be passed to indicate the prefix of the environment variables that will be present if + secrets are passed through environment variables. + FLYTE_SECRETS_DEFAULT_DIR will be passed to indicate the prefix of the path where secrets will be mounted if secrets + are passed through file mounts. + """ + + class MountType(int, Enum): + any = pb.Secret.ANY + """Default case, indicates the client can tolerate either mounting options.""" + + env_var = pb.Secret.ENV_VAR + """ENV_VAR indicates the secret needs to be mounted as an environment variable.""" + + file = pb.Secret.FILE + """FILE indicates the secret needs to be mounted as a file.""" + + def to_idl(self) -> pb.Secret.MountType: + return self.value + + group: str + """ + The name of the secret group where to find the key referenced below. For K8s secrets, this should be the name of + the v1/secret object. For Confidant, this should be the Credential name. For Vault, this should be the secret name. + For AWS Secret Manager, this should be the name of the secret. + +required + """ + + group_version: Optional[str] = None + """ + The group version to fetch. This is not supported in all secret management systems. It'll be ignored for the ones + that do not support it. + +optional + """ + + key: Optional[str] = None + """ + The name of the secret to mount. This has to match an existing secret in the system. It's up to the implementation + of the secret management system to require case sensitivity. For K8s secrets, Confidant and Vault, this should + match one of the keys inside the secret. For AWS Secret Manager, it's ignored. + +optional + """ + + mount_requirement: Optional[MountType] = None + """ + mount_requirement is optional. Indicates where the secret has to be mounted. If provided, the execution will fail + if the underlying key management system cannot satisfy that requirement. If not provided, the default location + will depend on the key management system. + +optional + """ + + def to_idl(self) -> pb.Secret: + return pb.Secret( + group=self.group, + group_version=self.group_version, + key=self.key, + mount_requirement=try_to_idl(self.mount_requirement), + ) + + +@dataclass +class OAuth2Client: + """OAuth2Client encapsulates OAuth2 Client Credentials to be used when making calls on behalf of that task.""" + + client_id: str + """ + client_id is the public id for the client to use. The system will not perform any pre-auth validation that the + secret requested matches the client_id indicated here. + +required + """ + + client_secret: Secret + """ + client_secret is a reference to the secret used to authenticate the OAuth2 client. + +required + """ + + def to_idl(self) -> pb.OAuth2Client: + return pb.OAuth2Client( + client_id=self.client_id, client_secret=self.client_secret.to_idl() + ) + + +@dataclass +class Identity: + """ + Identity encapsulates the various security identities a task can run as. It's up to the underlying plugin to pick the + right identity for the execution environment. + """ + + iam_role: str + """iam_role references the fully qualified name of Identity & Access Management role to impersonate.""" + + k8s_service_account: str + """k8s_service_account references a kubernetes service account to impersonate.""" + + oauth2_client: OAuth2Client + """ + oauth2_client references an oauth2 client. Backend plugins can use this information to impersonate the client when + making external calls. + """ + + def to_idl(self) -> pb.Identity: + return pb.Identity( + iam_role=self.iam_role, + k8s_service_account=self.k8s_service_account, + oauth2_client=self.oauth2_client.to_idl(), + ) + + +@dataclass +class OAuth2TokenRequest: + """ + OAuth2TokenRequest encapsulates information needed to request an OAuth2 token. + FLYTE_TOKENS_ENV_PREFIX will be passed to indicate the prefix of the environment variables that will be present if + tokens are passed through environment variables. + FLYTE_TOKENS_PATH_PREFIX will be passed to indicate the prefix of the path where secrets will be mounted if tokens + are passed through file mounts. + """ + + class Type(int, Enum): + """Type of the token requested.""" + + client_credentials = pb.OAuth2TokenRequest.CLIENT_CREDENTIALS + """CLIENT_CREDENTIALS indicates a 2-legged OAuth token requested using client credentials.""" + + def to_idl(self) -> pb.OAuth2TokenRequest.Type: + return self.value + + name: str + """ + name indicates a unique id for the token request within this task token requests. It'll be used as a suffix for + environment variables and as a filename for mounting tokens as files. + +required + """ + + type: Type + """ + type indicates the type of the request to make. Defaults to CLIENT_CREDENTIALS. + +required + """ + + client: OAuth2Client + """ + client references the client_id/secret to use to request the OAuth2 token. + +required + """ + + idp_discovery_endpoint: Optional[str] = None + """ + idp_discovery_endpoint references the discovery endpoint used to retrieve token endpoint and other related + information. + +optional + """ + + token_endpoint: Optional[str] = None + """ + token_endpoint references the token issuance endpoint. If idp_discovery_endpoint is not provided, this parameter is + mandatory. + +optional + """ + + def to_idl(self) -> pb.OAuth2TokenRequest: + return pb.OAuth2TokenRequest( + name=self.name, + type=self.type.to_idl(), + client=self.client.to_idl(), + idp_discovery_endpoint=self.idp_discovery_endpoint, + token_endpoint=self.token_endpoint, + ) + + +@dataclass +class SecurityContext: + """SecurityContext holds security attributes that apply to tasks.""" + + run_as: Identity + """ + run_as encapsulates the identity a pod should run as. If the task fills in multiple fields here, it'll be up to the + backend plugin to choose the appropriate identity for the execution engine the task will run on. + """ + + secrets: Iterable[Secret] + """ + secrets indicate the list of secrets the task needs in order to proceed. Secrets will be mounted/passed to the + pod as it starts. If the plugin responsible for kicking of the task will not run it on a flyte cluster (e.g. AWS + Batch), it's the responsibility of the plugin to fetch the secret (which means propeller identity will need access + to the secret) and to pass it to the remote execution engine. + """ + + tokens: Iterable[OAuth2TokenRequest] + """ + tokens indicate the list of token requests the task needs in order to proceed. Tokens will be mounted/passed to the + pod as it starts. If the plugin responsible for kicking of the task will not run it on a flyte cluster (e.g. AWS + Batch), it's the responsibility of the plugin to fetch the secret (which means propeller identity will need access + to the secret) and to pass it to the remote execution engine. + """ + + def to_idl(self) -> pb.SecurityContext: + return pb.SecurityContext( + run_as=self.run_as.to_idl(), + secrets=to_idl_many(self.secrets), + tokens=to_idl_many(self.tokens), + ) diff --git a/latch/idl/core/tasks.py b/latch/idl/core/tasks.py new file mode 100644 index 00000000..4df3d9c8 --- /dev/null +++ b/latch/idl/core/tasks.py @@ -0,0 +1,502 @@ +import typing +from collections.abc import Iterable, Mapping +from dataclasses import dataclass +from datetime import timedelta +from enum import Enum +from typing import Optional + +import flyteidl.core.tasks_pb2 as pb +import google.protobuf.struct_pb2 as pb_struct + +from ..utils import dur_from_td, merged_pb, to_idl_many +from .identifier import Identifier +from .interface import TypedInterface +from .literals import KeyValuePair, RetryStrategy +from .security import SecurityContext + + +@dataclass +class Resources: + """ + A customizable interface to convey resources requested for a container. This can be interpreted differently for different + container engines. + """ + + class ResourceName(int, Enum): + """Known resource names.""" + + unknown = pb.Resources.UNKNOWN + cpu = pb.Resources.CPU + gpu = pb.Resources.GPU + memory = pb.Resources.MEMORY + storage = pb.Resources.STORAGE + ephemeral_storage = pb.Resources.EPHEMERAL_STORAGE + """For Kubernetes-based deployments, pods use ephemeral local storage for scratch space, caching, and for logs.""" + + def to_idl(self) -> pb.Resources.ResourceName: + return self.value + + @dataclass + class ResourceEntry: + """Encapsulates a resource name and value.""" + + name: "Resources.ResourceName" + """Resource name.""" + + value: str + """ + Value must be a valid k8s quantity. See + https://github.com/kubernetes/apimachinery/blob/master/pkg/api/resource/quantity.go#L30-L80 + """ + + def to_idl(self) -> pb.Resources.ResourceEntry: + return pb.Resources.ResourceEntry(name=self.name.to_idl(), value=self.value) + + requests: Iterable[ResourceEntry] + """The desired set of resources requested. ResourceNames must be unique within the list.""" + + limits: Iterable[ResourceEntry] + """ + Defines a set of bounds (e.g. min/max) within which the task can reliably run. ResourceNames must be unique + within the list. + """ + + def to_idl(self) -> pb.Resources: + return pb.Resources( + requests=to_idl_many(self.requests), + limits=to_idl_many(self.limits), + ) + + +@dataclass +class RuntimeMetadatta: + """Runtime information. This is loosely defined to allow for extensibility.""" + + class RuntimeType(int, Enum): + other = pb.RuntimeMetadata.OTHER + flyte_sdk = pb.RuntimeMetadata.FLYTE_SDK + + def to_idl(self) -> pb.RuntimeMetadata.RuntimeType: + return self.value + + type: RuntimeType + """Type of runtime.""" + + version: str + """ + Version of the runtime. All versions should be backward compatible. However, certain cases call for version + checks to ensure tighter validation or setting expectations. + """ + + flavor: str + """+optional It can be used to provide extra information about the runtime (e.g. python, golang... etc.).""" + + def to_idl(self) -> pb.RuntimeMetadata: + return pb.RuntimeMetadata( + type=self.type.to_idl(), version=self.version, flavor=self.flavor + ) + + +@dataclass +class TaskMetadata: + """Task Metadata""" + + discoverable: bool + """Indicates whether the system should attempt to lookup this task's output to avoid duplication of work.""" + + runtime: RuntimeMetadatta + """Runtime information about the task.""" + + timeout: timedelta + """The overall timeout of a task including user-triggered retries.""" + + retries: RetryStrategy + """Number of retries per task.""" + + discovery_version: str + """Indicates a logical version to apply to this task for the purpose of discovery.""" + + deprecated_error_message: str + """ + If set, this indicates that this task is deprecated. This will enable owners of tasks to notify consumers + of the ending of support for a given task. + """ + + cache_serializable: bool + """Indicates whether the system should attempt to execute discoverable instances in serial to avoid duplicate work""" + + interruptible: Optional[bool] = None + """ + Identify whether task is interruptible + + For interruptible we will populate it at the node level but require it be part of TaskMetadata + for a user to set the value. + We are using oneof instead of bool because otherwise we would be unable to distinguish between value being + set by the user or defaulting to false. + The logic of handling precedence will be done as part of flytepropeller. + """ + + def to_idl(self) -> pb.TaskMetadata: + res = pb.TaskMetadata( + discoverable=self.discoverable, + runtime=self.runtime.to_idl(), + timeout=dur_from_td(self.timeout), + retries=self.retries.to_idl(), + discovery_version=self.discovery_version, + deprecated_error_message=self.deprecated_error_message, + cache_serializable=self.cache_serializable, + ) + + if self.interruptible is not None: + res.interruptible = self.interruptible + + return res + + +@dataclass +class TaskTemplate: + """ + A Task structure that uniquely identifies a task in the system + Tasks are registered as a first step in the system. + """ + + id: Identifier + """Auto generated taskId by the system. Task Id uniquely identifies this task globally.""" + + type: str + """ + A predefined yet extensible Task type identifier. This can be used to customize any of the components. If no + extensions are provided in the system, Flyte will resolve the this task to its TaskCategory and default the + implementation registered for the TaskCategory. + """ + + metadata: TaskMetadata + """Extra metadata about the task.""" + + interface: TypedInterface + """ + A strongly typed interface for the task. This enables others to use this task within a workflow and guarantees + compile-time validation of the workflow to avoid costly runtime failures. + """ + + custom: pb_struct.Struct + """Custom data about the task. This is extensible to allow various plugins in the system.""" + + target: "typing.Union[TaskTemplateTargetContainer, TaskTemplateTargetK8sPod, TaskTemplateTargetSql]" + """ + Known target types that the system will guarantee plugins for. Custom SDK plugins are allowed to set these if needed. + If no corresponding execution-layer plugins are found, the system will default to handling these using built-in + handlers. + """ + + task_type_version: int + """This can be used to customize task handling at execution time for the same task type.""" + + security_context: SecurityContext + """security_context encapsulates security attributes requested to run this task.""" + + config: Mapping[str, str] + """ + Metadata about the custom defined for this task. This is extensible to allow various plugins in the system + to use as required. + reserve the field numbers 1 through 15 for very frequently occurring message elements + """ + + def to_idl(self) -> pb.TaskTemplate: + return merged_pb( + pb.TaskTemplate( + id=self.id.to_idl(), + type=self.type, + metadata=self.metadata.to_idl(), + interface=self.interface.to_idl(), + custom=self.custom, + task_type_version=self.task_type_version, + security_context=self.security_context.to_idl(), + config=self.config, + ), + self.target, + ) + + +@dataclass +class TaskTemplateTargetContainer: + container: "Container" + + def to_idl(self) -> pb.TaskTemplate: + return pb.TaskTemplate(container=self.container.to_idl()) + + +@dataclass +class TaskTemplateTargetK8sPod: + k8s_pod: "K8sPod" + + def to_idl(self) -> pb.TaskTemplate: + return pb.TaskTemplate(k8s_pod=self.k8s_pod.to_idl()) + + +@dataclass +class TaskTemplateTargetSql: + sql: "Sql" + + def to_idl(self) -> pb.TaskTemplate: + return pb.TaskTemplate(sql=self.sql.to_idl()) + + +@dataclass +class ContainerPort: + """Defines port properties for a container.""" + + container_port: int + """ + Number of port to expose on the pod's IP address. + This must be a valid port number, 0 < x < 65536. + """ + + def to_idl(self) -> pb.ContainerPort: + return pb.ContainerPort(container_port=self.container_port) + + +@dataclass +class Container: + image: str + """Container image url. Eg: docker/redis:latest""" + + command: Iterable[str] + """Command to be executed, if not provided, the default entrypoint in the container image will be used.""" + + args: Iterable[str] + """ + These will default to Flyte given paths. If provided, the system will not append known paths. If the task still + needs flyte's inputs and outputs path, add $(FLYTE_INPUT_FILE), $(FLYTE_OUTPUT_FILE) wherever makes sense and the + system will populate these before executing the container. + """ + + resources: Resources + """Container resources requirement as specified by the container engine.""" + + env: Iterable[KeyValuePair] + """Environment variables will be set as the container is starting up.""" + + config: Iterable[KeyValuePair] + """ + Allows extra configs to be available for the container. + TODO: elaborate on how configs will become available. + Deprecated, please use TaskTemplate.config instead. + """ + + ports: Iterable[ContainerPort] + """ + Ports to open in the container. This feature is not supported by all execution engines. (e.g. supported on K8s but + not supported on AWS Batch) + Only K8s + """ + + data_config: "DataLoadingConfig" + """ + BETA: Optional configuration for DataLoading. If not specified, then default values are used. + This makes it possible to to run a completely portable container, that uses inputs and outputs + only from the local file-system and without having any reference to flyteidl. This is supported only on K8s at the moment. + If data loading is enabled, then data will be mounted in accompanying directories specified in the DataLoadingConfig. If the directories + are not specified, inputs will be mounted onto and outputs will be uploaded from a pre-determined file-system path. Refer to the documentation + to understand the default paths. + Only K8s + """ + + class Architecture(int, Enum): + unknown = pb.Container.UNKNOWN + amd64 = pb.Container.AMD64 + arm64 = pb.Container.ARM64 + arm_v6 = pb.Container.ARM_V6 + arm_v7 = pb.Container.ARM_V7 + + def to_idl(self) -> pb.Container.Architecture: + return self.value + + architecture: Architecture + """Architecture-type the container image supports.""" + + def to_idl(self) -> pb.Container: + return pb.Container( + image=self.image, + command=self.command, + args=self.args, + resources=self.resources.to_idl(), + env=to_idl_many(self.env), + config=to_idl_many(self.config), + ports=to_idl_many(self.ports), + data_config=self.data_config.to_idl(), + architecture=self.architecture.to_idl(), + ) + + +@dataclass +class IOStrategy: + """Strategy to use when dealing with Blob, Schema, or multipart blob data (large datasets)""" + + class DownloadMode(int, Enum): + """Mode to use for downloading""" + + download_eager = pb.IOStrategy.DOWNLOAD_EAGER + """All data will be downloaded before the main container is executed""" + download_stream = pb.IOStrategy.DOWNLOAD_STREAM + """Data will be downloaded as a stream and an End-Of-Stream marker will be written to indicate all data has been downloaded. Refer to protocol for details""" + do_not_download = pb.IOStrategy.DO_NOT_DOWNLOAD + """Large objects (offloaded) will not be downloaded""" + + def to_idl(self) -> pb.IOStrategy.DownloadMode: + return self.value + + class UploadMode(int, Enum): + """Mode to use for uploading""" + + uplaod_on_exit = pb.IOStrategy.UPLOAD_ON_EXIT + """All data will be uploaded after the main container exits""" + upload_eager = pb.IOStrategy.UPLOAD_EAGER + """Data will be uploaded as it appears. Refer to protocol specification for details""" + do_not_upload = pb.IOStrategy.DO_NOT_UPLOAD + """Data will not be uploaded, only references will be written""" + + def to_idl(self) -> pb.IOStrategy.UploadMode: + return self.value + + download_mode: DownloadMode + """Mode to use to manage downloads""" + upload_mode: UploadMode + """Mode to use to manage uploads""" + + def to_idl(self) -> pb.IOStrategy: + return pb.IOStrategy( + download_mode=self.download_mode.to_idl(), + upload_mode=self.upload_mode.to_idl(), + ) + + +@dataclass +class DataLoadingConfig: + """ + This configuration allows executing raw containers in Flyte using the Flyte CoPilot system. + Flyte CoPilot, eliminates the needs of flytekit or sdk inside the container. Any inputs required by the users container are side-loaded in the input_path + Any outputs generated by the user container - within output_path are automatically uploaded. + """ + + class LiteralMapFormat(int, Enum): + """ + LiteralMapFormat decides the encoding format in which the input metadata should be made available to the containers. + If the user has access to the protocol buffer definitions, it is recommended to use the PROTO format. + JSON and YAML do not need any protobuf definitions to read it + All remote references in core.LiteralMap are replaced with local filesystem references (the data is downloaded to local filesystem) + """ + + json = pb.DataLoadingConfig.JSON + """JSON for the metadata (which contains inlined primitive values). The representation is inline with the standard json specification as specified - https://www.json.org/json-en.html""" + yaml = pb.DataLoadingConfig.YAML + """YAML for the metadata (which contains inlined primitive values)""" + proto = pb.DataLoadingConfig.PROTO + """Proto is a serialized binary of `core.LiteralMap` defined in flyteidl/core""" + + def to_idl(self) -> pb.DataLoadingConfig.LiteralMapFormat: + return self.value + + enabled: bool + """Flag enables DataLoading Config. If this is not set, data loading will not be used!""" + + input_path: str + """ + File system path (start at root). This folder will contain all the inputs exploded to a separate file. + Example, if the input interface needs (x: int, y: blob, z: multipart_blob) and the input path is "/var/flyte/inputs", then the file system will look like + /var/flyte/inputs/inputs. .pb .json .yaml> -> Format as defined previously. The Blob and Multipart blob will reference local filesystem instead of remote locations + /var/flyte/inputs/x -> X is a file that contains the value of x (integer) in string format + /var/flyte/inputs/y -> Y is a file in Binary format + /var/flyte/inputs/z/... -> Note Z itself is a directory + More information about the protocol - refer to docs #TODO reference docs here + """ + + output_path: str + """File system path (start at root). This folder should contain all the outputs for the task as individual files and/or an error text file""" + + format: LiteralMapFormat + """ + In the inputs folder, there will be an additional summary/metadata file that contains references to all files or inlined primitive values. + This format decides the actual encoding for the data. Refer to the encoding to understand the specifics of the contents and the encoding + """ + + io_strategy: IOStrategy + + def to_idl(self) -> pb.DataLoadingConfig: + return pb.DataLoadingConfig( + enabled=self.enabled, + input_path=self.input_path, + output_path=self.output_path, + format=self.format.to_idl(), + io_strategy=self.io_strategy.to_idl(), + ) + + +@dataclass +class K8sPod: + """Defines a pod spec and additional pod metadata that is created when a task is executed.""" + + metadata: "K8sObjectMetadata" + """Contains additional metadata for building a kubernetes pod.""" + + pod_spec: pb_struct.Struct + """ + Defines the primary pod spec created when a task is executed. + This should be a JSON-marshalled pod spec, which can be defined in + - go, using: https://github.com/kubernetes/api/blob/release-1.21/core/v1/types.go#L2936 + - python: using https://github.com/kubernetes-client/python/blob/release-19.0/kubernetes/client/models/v1_pod_spec.py + """ + + def to_idl(self) -> pb.K8sPod: + return pb.K8sPod(metadata=self.metadata.to_idl(), pod_spec=self.pod_spec) + + +@dataclass +class K8sObjectMetadata: + """Metadata for building a kubernetes object when a task is executed.""" + + labels: Mapping[str, str] + """Optional labels to add to the pod definition.""" + + annotations: Mapping[str, str] + """Optional annotations to add to the pod definition.""" + + def to_idl(self) -> pb.K8sObjectMetadata: + return pb.K8sObjectMetadata(labels=self.labels, annotations=self.annotations) + + +@dataclass +class Sql: + """Sql represents a generic sql workload with a statement and dialect.""" + + statement: str + """ + The actual query to run, the query can have templated parameters. + We use Flyte's Golang templating format for Query templating. + Refer to the templating documentation. + https://docs.flyte.org/projects/cookbook/en/latest/auto/integrations/external_services/hive/hive.html#sphx-glr-auto-integrations-external-services-hive-hive-py + For example, + insert overwrite directory '{{ .rawOutputDataPrefix }}' stored as parquet + select * + from my_table + where ds = '{{ .Inputs.ds }}' + """ + + class Dialect(int, Enum): + undefined = pb.Sql.UNDEFINED + ansi = pb.Sql.ANSI + hive = pb.Sql.HIVE + other = pb.Sql.OTHER + + def to_idl(self) -> pb.Sql.Dialect: + return self.value + + dialect: Dialect + """ + The dialect of the SQL statement. This is used to validate and parse SQL statements at compilation time to avoid + expensive runtime operations. If set to an unsupported dialect, no validation will be done on the statement. + We support the following dialect: ansi, hive. + """ + + def to_idl(self) -> pb.Sql: + return pb.Sql(statement=self.statement, dialect=self.dialect.to_idl()) diff --git a/latch/idl/core/types.py b/latch/idl/core/types.py new file mode 100644 index 00000000..436d4dac --- /dev/null +++ b/latch/idl/core/types.py @@ -0,0 +1,374 @@ +from collections.abc import Iterable +from dataclasses import dataclass +from enum import Enum +from typing import Optional, Union + +import flyteidl.core.types_pb2 as pb +import google.protobuf.struct_pb2 as pb_struct + +from ..utils import merged_pb, to_idl_many, try_to_idl + + +class SimpleType(int, Enum): + """Define a set of simple types.""" + + none = pb.NONE + integer = pb.INTEGER + float = pb.FLOAT + string = pb.STRING + boolean = pb.BOOLEAN + datetime = pb.DATETIME + duration = pb.DURATION + binary = pb.BINARY + error = pb.ERROR + struct = pb.STRUCT + + def to_idl(self) -> pb.SimpleType: + return self.value + + +@dataclass +class SchemaType: + """Defines schema columns and types to strongly type-validate schemas interoperability.""" + + @dataclass + class SchemaColumn: + name: str + """A unique name -within the schema type- for the column""" + + class SchemaColumnType(int, Enum): + """Define a set of simple types.""" + + integer = pb.SchemaType.SchemaColumn.INTEGER + float = pb.SchemaType.SchemaColumn.FLOAT + string = pb.SchemaType.SchemaColumn.STRING + boolean = pb.SchemaType.SchemaColumn.BOOLEAN + datetime = pb.SchemaType.SchemaColumn.DATETIME + duration = pb.SchemaType.SchemaColumn.DURATION + + def to_idl(self) -> pb.SchemaType.SchemaColumn.SchemaColumnType: + return self.value + + type: SchemaColumnType + """The column type. This allows a limited set of types currently.""" + + def to_idl(self) -> pb.SchemaType.SchemaColumn: + return pb.SchemaType.SchemaColumn(name=self.name, type=self.type.to_idl()) + + columns: Iterable[SchemaColumn] + """A list of ordered columns this schema comprises of.""" + + def to_idl(self) -> pb.SchemaType: + return pb.SchemaType(columns=to_idl_many(self.columns)) + + +@dataclass +class StructuredDatasetType: + @dataclass + class DatasetColumn: + name: str + """A unique name within the schema type for the column.""" + + literal_type: "LiteralType" + """The column type.""" + + def to_idl(self) -> pb.StructuredDatasetType.DatasetColumn: + return pb.StructuredDatasetType.DatasetColumn( + name=self.name, literal_type=self.literal_type.to_idl() + ) + + columns: Iterable[DatasetColumn] + """A list of ordered columns this schema comprises of.""" + + format: str + """ + This is the storage format, the format of the bits at rest + parquet, feather, csv, etc. + For two types to be compatible, the format will need to be an exact match. + """ + + external_schema_type: Optional[str] = None + """ + This is a string representing the type that the bytes in external_schema_bytes are formatted in. + This is an optional field that will not be used for type checking. + """ + + external_schema_bytes: Optional[bytes] = None + """ + The serialized bytes of a third-party schema library like Arrow. + This is an optional field that will not be used for type checking. + """ + + def to_idl(self) -> pb.StructuredDatasetType: + return pb.StructuredDatasetType( + columns=to_idl_many(self.columns), + format=self.format, + external_schema_type=self.external_schema_type, + external_schema_bytes=self.external_schema_bytes, + ) + + +@dataclass +class BlobType: + """Defines type behavior for blob objects""" + + class BlobDimensionality(int, Enum): + """Define a set of simple types.""" + + single = pb.BlobType.SINGLE + multipart = pb.BlobType.MULTIPART + + def to_idl(self) -> pb.BlobType.BlobDimensionality: + return self.value + + dimensionality: BlobDimensionality + + format: str = "" + """ + Format can be a free form string understood by SDK/UI etc like + csv, parquet etc + """ + + def to_idl(self) -> pb.BlobType: + return pb.BlobType( + format=self.format, dimensionality=self.dimensionality.to_idl() + ) + + +@dataclass +class EnumType: + """ + Enables declaring enum types, with predefined string values + For len(values) > 0, the first value in the ordered list is regarded as the default value. If you wish + To provide no defaults, make the first value as undefined. + """ + + values: Iterable[str] + """Predefined set of enum values.""" + + def to_idl(self) -> pb.EnumType: + return pb.EnumType(values=self.values) + + +@dataclass +class UnionType: + """ + Defines a tagged union type, also known as a variant (and formally as the sum type). + + A sum type S is defined by a sequence of types (A, B, C, ...), each tagged by a string tag + A value of type S is constructed from a value of any of the variant types. The specific choice of type is recorded by + storing the varaint's tag with the literal value and can be examined in runtime. + + Type S is typically written as + S := Apple A | Banana B | Cantaloupe C | ... + + Notably, a nullable (optional) type is a sum type between some type X and the singleton type representing a null-value: + Optional X := X | Null + + See also: https://en.wikipedia.org/wiki/Tagged_union + """ + + variants: "Iterable[LiteralType]" + """Predefined set of variants in union.""" + + def to_idl(self) -> pb.UnionType: + return pb.UnionType(variants=to_idl_many(self.variants)) + + +@dataclass +class RecordFieldType: + key: str + type: "LiteralType" + + def to_idl(self) -> pb.RecordFieldType: + return pb.RecordFieldType(key=self.key, type=self.type.to_idl()) + + +@dataclass +class RecordType: + fields: Iterable[RecordFieldType] + + def to_idl(self) -> pb.RecordType: + return pb.RecordType(fields=to_idl_many(self.fields)) + + +@dataclass +class TypeStructure: + """ + Hints to improve type matching + e.g. allows distinguishing output from custom type transformers + even if the underlying IDL serialization matches. + """ + + tag: str + """Must exactly match for types to be castable""" + + def to_idl(self) -> pb.TypeStructure: + return pb.TypeStructure(tag=self.tag) + + +@dataclass +class TypeAnnotation: + """TypeAnnotation encapsulates registration time information about a type. This can be used for various control-plane operations. TypeAnnotation will not be available at runtime when a task runs.""" + + annotations: pb_struct.Struct + """A arbitrary JSON payload to describe a type.""" + + def to_idl(self) -> pb.TypeAnnotation: + return pb.TypeAnnotation(annotations=self.annotations) + + +@dataclass +class LiteralType: + """Defines a strong type to allow type checking between interfaces.""" + + type: "Union[LiteralTypeSimple, LiteralTypeSchema, LiteralTypeCollection, LiteralTypeMap, LiteralTypeBlob, LiteralTypeEnum, LiteralTypeStructuredDataset, LiteralTypeUnion, LiteralTypeRecord]" + + metadata: Optional[pb_struct.Struct] = None + """ + This field contains type metadata that is descriptive of the type, but is NOT considered in type-checking. This might be used by + consumers to identify special behavior or display extended information for the type. + + maximsmol: note: old-style dataclass serialization used metadata when type-checking + though the original comment really refers to how propeller treats the type + and iirc propeller always ignores .metadata + """ + + annotation: Optional[TypeAnnotation] = None + """ + This field contains arbitrary data that might have special semantic + meaning for the client but does not effect internal flyte behavior. + """ + + structure: Optional[TypeStructure] = None + """Hints to improve type matching.""" + + def to_idl(self) -> pb.LiteralType: + return merged_pb( + pb.LiteralType( + metadata=self.metadata, + annotation=try_to_idl(self.annotation), + structure=try_to_idl(self.structure), + ), + self.type, + ) + + +@dataclass +class LiteralTypeSimple: + """A simple type that can be compared one-to-one with another.""" + + simple: SimpleType + + def to_idl(self) -> pb.LiteralType: + return pb.LiteralType(simple=self.simple.to_idl()) + + +@dataclass +class LiteralTypeSchema: + """A complex type that requires matching of inner fields.""" + + schema: SchemaType + + def to_idl(self) -> pb.LiteralType: + return pb.LiteralType(schema=self.schema.to_idl()) + + +@dataclass +class LiteralTypeCollection: + """Defines the type of the value of a collection. Only homogeneous collections are allowed.""" + + collection_type: LiteralType + + def to_idl(self) -> pb.LiteralType: + return pb.LiteralType(collection_type=self.collection_type.to_idl()) + + +@dataclass +class LiteralTypeMap: + """Defines the type of the value of a map type. The type of the key is always a string.""" + + map_value_type: LiteralType + + def to_idl(self) -> pb.LiteralType: + return pb.LiteralType(map_value_type=self.map_value_type.to_idl()) + + +@dataclass +class LiteralTypeBlob: + """A blob might have specialized implementation details depending on associated metadata.""" + + blob: BlobType + + def to_idl(self) -> pb.LiteralType: + return pb.LiteralType(blob=self.blob.to_idl()) + + +@dataclass +class LiteralTypeEnum: + """Defines an enum with pre-defined string values.""" + + enum_type: EnumType + + def to_idl(self) -> pb.LiteralType: + return pb.LiteralType(enum_type=self.enum_type.to_idl()) + + +@dataclass +class LiteralTypeStructuredDataset: + """Generalized schema support""" + + structured_dataset_type: StructuredDatasetType + + def to_idl(self) -> pb.LiteralType: + return pb.LiteralType( + structured_dataset_type=self.structured_dataset_type.to_idl() + ) + + +@dataclass +class LiteralTypeUnion: + """Defines an union type with pre-defined LiteralTypes.""" + + union_type: UnionType + + def to_idl(self) -> pb.LiteralType: + return pb.LiteralType(union_type=self.union_type.to_idl()) + + +@dataclass +class LiteralTypeRecord: + record_type: RecordType + + def to_idl(self) -> pb.LiteralType: + return pb.LiteralType(record_type=self.record_type.to_idl()) + + +@dataclass +class OutputReference: + """ + A reference to an output produced by a node. The type can be retrieved -and validated- from + the underlying interface of the node. + """ + + node_id: str + """Node id must exist at the graph layer.""" + + var: str + """Variable name must refer to an output variable for the node.""" + + def to_idl(self) -> pb.OutputReference: + return pb.OutputReference(node_id=self.node_id, var=self.var) + + +@dataclass +class Error: + """Represents an error thrown from a node.""" + + failed_node_id: str + """The node id that threw the error.""" + message: str + """Error message thrown.""" + + def to_idl(self) -> pb.Error: + return pb.Error(failed_node_id=self.failed_node_id, message=self.message) diff --git a/latch/idl/core/workflow.py b/latch/idl/core/workflow.py new file mode 100644 index 00000000..a556a414 --- /dev/null +++ b/latch/idl/core/workflow.py @@ -0,0 +1,401 @@ +import typing +from collections.abc import Iterable +from dataclasses import dataclass, field +from datetime import timedelta +from enum import Enum +from typing import Optional + +import flyteidl.core.workflow_pb2 as pb + +from ..utils import dur_from_td, to_idl_many, try_to_idl +from .condition import BooleanExpression +from .execution import QualityOfService, QualityOfServiceDesignationTier +from .identifier import Identifier +from .interface import TypedInterface +from .literals import Binding, RetryStrategy +from .tasks import Resources +from .types import Error + + +@dataclass +class IfBlock: + """Defines a condition and the execution unit that should be executed if the condition is satisfied.""" + + condition: BooleanExpression + then_node: "Node" + + def to_idl(self) -> pb.IfBlock: + return pb.IfBlock( + condition=self.condition.to_idl(), then_node=self.then_node.to_idl() + ) + + +@dataclass +class IfElseBlock: + """ + Defines a series of if/else blocks. The first branch whose condition evaluates to true is the one to execute. + If no conditions were satisfied, the else_node or the error will execute. + """ + + case: IfBlock + """+required. First condition to evaluate.""" + + other: Iterable[IfBlock] + """+optional. Additional branches to evaluate.""" + + default: "typing.Union[IfElseBlockElseNode, IfElseBlockError]" + """+required.""" + + def to_idl(self) -> pb.IfElseBlock: + res = pb.IfElseBlock( + case=self.case.to_idl(), other=(x.to_idl() for x in self.other) + ) + res.MergeFrom(self.default.to_idl()) + return res + + +@dataclass +class IfElseBlockElseNode: + """Execute a node in case none of the branches were taken.""" + + else_node: "Node" + + def to_idl(self) -> pb.IfElseBlock: + return pb.IfElseBlock(else_node=self.else_node.to_idl()) + + +@dataclass +class IfElseBlockError: + """Throw an error in case none of the branches were taken.""" + + error: Error + + def to_idl(self) -> pb.IfElseBlock: + return pb.IfElseBlock(error=self.error.to_idl()) + + +@dataclass +class BranchNode: + """ + BranchNode is a special node that alter the flow of the workflow graph. It allows the control flow to branch at + runtime based on a series of conditions that get evaluated on various parameters (e.g. inputs, primitives). + """ + + if_else: IfElseBlock + """+required""" + + def to_idl(self) -> pb.BranchNode: + return pb.BranchNode(if_else=self.if_else.to_idl()) + + +@dataclass +class TaskNode: + """Refers to the task that the Node is to execute.""" + + reference: "TaskNodeReferenceId" # oneof with one element + + overrides: "TaskNodeOverrides" + """Optional overrides applied at task execution time.""" + + def to_idl(self) -> pb.TaskNode: + res = pb.TaskNode(overrides=self.overrides.to_idl()) + res.MergeFrom(self.reference.to_idl()) + return res + + +@dataclass +class TaskNodeReferenceId: + """Use a globally unique identifier for the task.""" + + reference_id: Identifier + + def to_idl(self) -> pb.TaskNode: + return pb.TaskNode(reference_id=self.reference_id.to_idl()) + + +@dataclass +class WorkflowNode: + """Refers to a the workflow the node is to execute.""" + + reference: "typing.Union[WorkflowNodeLaunchplanRef, WorkflowNodeSubWorkflowRef]" + + def to_idl(self) -> pb.WorkflowNode: + return self.reference.to_idl() + + +@dataclass +class WorkflowNodeLaunchplanRef: + """Use a launch plan with a globally unique identifier.""" + + launchplan_ref: Identifier + + def to_idl(self) -> pb.WorkflowNode: + return pb.WorkflowNode(launchplan_ref=self.launchplan_ref.to_idl()) + + +@dataclass +class WorkflowNodeSubWorkflowRef: + """Reference a subworkflow, that should be defined with the compiler context""" + + sub_workflow_ref: Identifier + + def to_idl(self) -> pb.WorkflowNode: + return pb.WorkflowNode(sub_workflow_ref=self.sub_workflow_ref.to_idl()) + + +@dataclass +class NodeMetadata: + """Defines extra information about the Node.""" + + name: str + """A friendly name for the Node""" + + timeout: timedelta + """The overall timeout of a task.""" + + retries: RetryStrategy + """Number of retries per task.""" + + interruptible_value: "NodeMetadataInterruptible" # oneof with one element + + def to_idl(self) -> pb.NodeMetadata: + res = pb.NodeMetadata( + name=self.name, + timeout=dur_from_td(self.timeout), + retries=self.retries.to_idl(), + ) + res.MergeFrom(self.interruptible_value.to_idl()) + return res + + +@dataclass +class NodeMetadataInterruptible: + """Identify whether node is interruptible""" + + interruptible: bool + + def to_idl(self) -> pb.NodeMetadata: + return pb.NodeMetadata(interruptible=self.interruptible) + + +@dataclass +class Alias: + """Links a variable to an alias.""" + + var: str + """Must match one of the output variable names on a node.""" + + alias: str + """A workflow-level unique alias that downstream nodes can refer to in their input.""" + + def to_idl(self) -> pb.Alias: + return pb.Alias(var=self.var, alias=self.alias) + + +@dataclass +class Node: + """ + A Workflow graph Node. One unit of execution in the graph. Each node can be linked to a Task, a Workflow or a branch + node. + """ + + id: str + """ + A workflow-level unique identifier that identifies this node in the workflow. "inputs" and "outputs" are reserved + node ids that cannot be used by other nodes. + """ + + metadata: NodeMetadata + """Extra metadata about the node.""" + + inputs: Iterable[Binding] + """ + Specifies how to bind the underlying interface's inputs. All required inputs specified in the underlying interface + must be fulfilled. + """ + + target: "typing.Union[NodeTargetTask, NodeTargetWorkflow, NodeTargetBranch]" + """Information about the target to execute in this node.""" + + upstream_node_ids: Iterable[str] = field(default_factory=list) + """ + +optional Specifies execution dependency for this node ensuring it will only get scheduled to run after all its + upstream nodes have completed. This node will have an implicit dependency on any node that appears in inputs + field. + """ + + output_aliases: Iterable[Alias] = field(default_factory=list) + """ + +optional. A node can define aliases for a subset of its outputs. This is particularly useful if different nodes + need to conform to the same interface (e.g. all branches in a branch node). Downstream nodes must refer to this + nodes outputs using the alias if one's specified. + """ + + def to_idl(self) -> pb.Node: + res = pb.Node( + id=self.id, + metadata=self.metadata.to_idl(), + inputs=(x.to_idl() for x in self.inputs), + upstream_node_ids=self.upstream_node_ids, + output_aliases=(x.to_idl() for x in self.output_aliases), + ) + + res.MergeFrom(self.target.to_idl()) + + return res + + +@dataclass +class NodeTargetTask: + """Information about the Task to execute in this node.""" + + task_node: TaskNode + + def to_idl(self) -> pb.Node: + return pb.Node(task_node=self.task_node.to_idl()) + + +@dataclass +class NodeTargetWorkflow: + """Information about the Workflow to execute in this mode.""" + + workflow_node: WorkflowNode + + def to_idl(self) -> pb.Node: + return pb.Node(workflow_node=self.workflow_node.to_idl()) + + +@dataclass +class NodeTargetBranch: + """Information about the Workflow to execute in this mode.""" + + branch_node: BranchNode + + def to_idl(self) -> pb.Node: + return pb.Node(branch_node=self.branch_node.to_idl()) + + +@dataclass +class WorkflowMetadata: + """ + This is workflow layer metadata. These settings are only applicable to the workflow as a whole, and do not + percolate down to child entities (like tasks) launched by the workflow. + """ + + quality_of_service: QualityOfService = field(default_factory=QualityOfService) + """Indicates the runtime priority of workflow executions.""" + + class OnFailurePolicy(int, Enum): + """Failure Handling Strategy""" + + fail_immediately = pb.WorkflowMetadata.FAIL_IMMEDIATELY + """ + FAIL_IMMEDIATELY instructs the system to fail as soon as a node fails in the workflow. It'll automatically + abort all currently running nodes and clean up resources before finally marking the workflow executions as + failed. + """ + + fail_after_executable_nodes_complete = ( + pb.WorkflowMetadata.FAIL_AFTER_EXECUTABLE_NODES_COMPLETE + ) + """ + FAIL_AFTER_EXECUTABLE_NODES_COMPLETE instructs the system to make as much progress as it can. The system will + not alter the dependencies of the execution graph so any node that depend on the failed node will not be run. + Other nodes that will be executed to completion before cleaning up resources and marking the workflow + execution as failed. + """ + + def to_idl(self) -> pb.WorkflowMetadata.OnFailurePolicy: + return self.value + + on_failure: OnFailurePolicy = OnFailurePolicy.fail_immediately + """Defines how the system should behave when a failure is detected in the workflow execution.""" + + def to_idl(self) -> pb.WorkflowMetadata: + return pb.WorkflowMetadata( + quality_of_service=self.quality_of_service.to_idl(), + on_failure=self.on_failure.to_idl(), + ) + + +@dataclass +class WorkflowMetadataDefaults: + """ + The difference between these settings and the WorkflowMetadata ones is that these are meant to be passed down to + a workflow's underlying entities (like tasks). For instance, 'interruptible' has no meaning at the workflow layer, it + is only relevant when a task executes. The settings here are the defaults that are passed to all nodes + unless explicitly overridden at the node layer. + If you are adding a setting that applies to both the Workflow itself, and everything underneath it, it should be + added to both this object and the WorkflowMetadata object above. + """ + + interruptible: bool = False + """Whether child nodes of the workflow are interruptible.""" + + def to_idl(self) -> pb.WorkflowMetadataDefaults: + return pb.WorkflowMetadataDefaults(interruptible=self.interruptible) + + +@dataclass +class WorkflowTemplate: + """ + Flyte Workflow Structure that encapsulates task, branch and subworkflow nodes to form a statically analyzable, + directed acyclic graph. + """ + + id: Identifier + """A globally unique identifier for the workflow.""" + + interface: TypedInterface + """Defines a strongly typed interface for the Workflow. This can include some optional parameters.""" + + nodes: Iterable[Node] + """A list of nodes. In addition, "globals" is a special reserved node id that can be used to consume workflow inputs.""" + + outputs: Iterable[Binding] + """ + A list of output bindings that specify how to construct workflow outputs. Bindings can pull node outputs or + specify literals. All workflow outputs specified in the interface field must be bound in order for the workflow + to be validated. A workflow has an implicit dependency on all of its nodes to execute successfully in order to + bind final outputs. + Most of these outputs will be Binding's with a BindingData of type OutputReference. That is, your workflow can + just have an output of some constant (`Output(5)`), but usually, the workflow will be pulling + outputs from the output of a task. + """ + + metadata_defaults: WorkflowMetadataDefaults = field( + default_factory=WorkflowMetadataDefaults + ) + """workflow defaults""" + + metadata: WorkflowMetadata = field(default_factory=WorkflowMetadata) + """Extra metadata about the workflow.""" + + failure_node: Optional[Node] = None + """ + +optional A catch-all node. This node is executed whenever the execution engine determines the workflow has failed. + The interface of this node must match the Workflow interface with an additional input named "error" of type + pb.lyft.flyte.core.Error. + """ + + def to_idl(self) -> pb.WorkflowTemplate: + return pb.WorkflowTemplate( + id=self.id.to_idl(), + metadata=self.metadata.to_idl(), + interface=self.interface.to_idl(), + nodes=to_idl_many(self.nodes), + outputs=to_idl_many(self.outputs), + failure_node=try_to_idl(self.failure_node), + metadata_defaults=self.metadata_defaults.to_idl(), + ) + + +@dataclass +class TaskNodeOverrides: + """Optional task node overrides that will be applied at task execution time.""" + + resources: Resources + """A customizable interface to convey resources requested for a task container. """ + + def to_idl(self) -> pb.TaskNodeOverrides: + return pb.TaskNodeOverrides(resources=self.resources.to_idl()) diff --git a/latch/idl/utils.py b/latch/idl/utils.py new file mode 100644 index 00000000..92989af7 --- /dev/null +++ b/latch/idl/utils.py @@ -0,0 +1,53 @@ +from collections.abc import Iterable, Mapping +from datetime import datetime, timedelta +from typing import Optional, Protocol, TypeVar + +from google.protobuf.duration_pb2 import Duration +from google.protobuf.message import Message +from google.protobuf.timestamp_pb2 import Timestamp + +K = TypeVar("K") +T = TypeVar("T", bound=Message) +R = TypeVar("R", covariant=True) + + +class HasToIdl(Protocol[R]): + def to_idl(self) -> R: + ... + + +def try_to_idl(x: Optional[HasToIdl[R]]) -> Optional[R]: + if x is None: + return + + return x.to_idl() + + +def dur_from_td(x: timedelta) -> Duration: + res = Duration() + res.FromTimedelta(x) + return res + + +def timestamp_from_datetime(x: datetime) -> Timestamp: + res = Timestamp() + res.FromDatetime(x) + return res + + +def to_idl_many(xs: Iterable[HasToIdl[R]]) -> Iterable[R]: + return (x.to_idl() for x in xs) + + +def to_idl_mapping(xs: Mapping[K, HasToIdl[R]]) -> Mapping[K, R]: + return {k: v.to_idl() for k, v in xs.items()} + + +def merged_pb(x: T, *mixins: Optional[HasToIdl[T]]) -> T: + for m in mixins: + if m is None: + continue + + x.MergeFrom(m.to_idl()) + + return x