Skip to content

Commit

Permalink
Serialization to produce registerable entities (#104)
Browse files Browse the repository at this point in the history
  • Loading branch information
wild-endeavor authored May 4, 2020
1 parent 07f8e80 commit fe5b620
Show file tree
Hide file tree
Showing 12 changed files with 344 additions and 53 deletions.
2 changes: 1 addition & 1 deletion flytekit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from __future__ import absolute_import
import flytekit.plugins

__version__ = '0.7.0'
__version__ = '0.7.1b0'
24 changes: 19 additions & 5 deletions flytekit/clis/sdk_in_container/register.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import absolute_import

import logging as _logging

import click

from flytekit.clis.sdk_in_container.constants import CTX_PROJECT, CTX_DOMAIN, CTX_TEST, CTX_PACKAGES, CTX_VERSION
from flytekit.common import utils as _utils
from flytekit.common.core import identifier as _identifier
from flytekit.common.tasks import task as _task
from flytekit.configuration.internal import look_up_version_from_image_tag as _look_up_version_from_image_tag, \
IMAGE as _IMAGE
Expand All @@ -13,21 +16,31 @@
def register_all(project, domain, pkgs, test, version):
if test:
click.echo('Test switch enabled, not doing anything...')

click.echo('Running task, workflow, and launch plan registration for {}, {}, {} with version {}'.format(
project, domain, pkgs, version))

# m = module (i.e. python file)
# k = value of dir(m), type str
# o = object (e.g. SdkWorkflow)
loaded_entities = []
for m, k, o in iterate_registerable_entities_in_order(pkgs):
name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type)

_logging.debug("Found module {}\n K: {} Instantiated in {}".format(m, k, o._instantiated_in))
o._id = _identifier.Identifier(
o.resource_type,
project,
domain,
name,
version
)
loaded_entities.append(o)

for o in loaded_entities:
if test:
click.echo("Would register {:20} {}".format("{}:".format(o.entity_type_text), name))
click.echo("Would register {:20} {}".format("{}:".format(o.entity_type_text), o.id.name))
else:
click.echo("Registering {:20} {}".format("{}:".format(o.entity_type_text), name))
o.register(project, domain, name, version)
click.echo("Registering {:20} {}".format("{}:".format(o.entity_type_text), o.id.name))
o.register(project, domain, o.id.name, version)


def register_tasks_only(project, domain, pkgs, test, version):
Expand All @@ -47,6 +60,7 @@ def register_tasks_only(project, domain, pkgs, test, version):
click.echo("Registering task {:20} {}".format("{}:".format(t.entity_type_text), name))
t.register(project, domain, name, version)


@click.group('register')
# --pkgs on the register group is DEPRECATED, use same arg on pyflyte.main instead
@click.option('--pkgs', multiple=True, help="DEPRECATED. This arg can only be used before the 'register' keyword")
Expand Down
206 changes: 164 additions & 42 deletions flytekit/clis/sdk_in_container/serialize.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,132 @@
from __future__ import absolute_import
from __future__ import print_function

import logging as _logging
import math as _math
import os as _os

import click

from flytekit.clis.sdk_in_container.constants import CTX_PACKAGES, CTX_PROJECT, CTX_DOMAIN, CTX_VERSION
from flytekit.common import workflow as _workflow, utils as _utils
from flytekit.common import utils as _utils
from flytekit.common.core import identifier as _identifier
from flytekit.common.exceptions.scopes import system_entry_point
from flytekit.common.tasks import task as _sdk_task
from flytekit.common.utils import write_proto_to_file as _write_proto_to_file
from flytekit.configuration import TemporaryConfiguration
from flytekit.configuration.internal import CONFIGURATION_PATH
from flytekit.configuration.internal import IMAGE as _IMAGE
from flytekit.models.workflow_closure import WorkflowClosure as _WorkflowClosure
from flytekit.configuration import internal as _internal_configuration
from flytekit.tools.module_loader import iterate_registerable_entities_in_order


@system_entry_point
def serialize_tasks(pkgs):
# Serialize all tasks
for m, k, t in iterate_registerable_entities_in_order(pkgs, include_entities={_sdk_task.SdkTask}):
fname = '{}.pb'.format(_utils.fqdn(m.__name__, k, entity_type=t.resource_type))
click.echo('Writing task {} to {}'.format(t.id, fname))
pb = t.to_flyte_idl()
_write_proto_to_file(pb, fname)
def serialize_tasks_only(project, domain, pkgs, version, folder=None):
"""
:param Text project:
:param Text domain:
:param list[Text] pkgs:
:param Text version:
:param Text folder:
:return:
"""
# m = module (i.e. python file)
# k = value of dir(m), type str
# o = object (e.g. SdkWorkflow)
loaded_entities = []
for m, k, o in iterate_registerable_entities_in_order(pkgs, include_entities={_sdk_task.SdkTask}):
name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type)
_logging.debug("Found module {}\n K: {} Instantiated in {}".format(m, k, o._instantiated_in))
o._id = _identifier.Identifier(
o.resource_type,
project,
domain,
name,
version
)
loaded_entities.append(o)

zero_padded_length = _determine_text_chars(len(loaded_entities))
for i, entity in enumerate(loaded_entities):
serialized = entity.serialize()
fname_index = str(i).zfill(zero_padded_length)
fname = '{}_{}.pb'.format(fname_index, entity._id.name)
click.echo(' Writing {} to\n {}'.format(entity._id, fname))
_write_proto_to_file(serialized, fname)

identifier_fname = '{}_{}.identifier.pb'.format(fname_index, entity._id.name)
if folder:
identifier_fname = _os.path.join(folder, identifier_fname)
_write_proto_to_file(entity._id.to_flyte_idl(), identifier_fname)


@system_entry_point
def serialize_workflows(pkgs):
# Create map to look up tasks by their unique identifier. This is so we can compile them into the workflow closure.
tmap = {}
for _, _, t in iterate_registerable_entities_in_order(pkgs, include_entities={_sdk_task.SdkTask}):
tmap[t.id] = t
def serialize_all(project, domain, pkgs, version, folder=None):
"""
In order to register, we have to comply with Admin's endpoints. Those endpoints take the following object. These
flyteidl.admin.launch_plan_pb2.LaunchPlanSpec
flyteidl.admin.workflow_pb2.WorkflowSpec
flyteidl.admin.task_pb2.TaskSpec
for m, k, w in iterate_registerable_entities_in_order(pkgs, include_entities={_workflow.SdkWorkflow}):
click.echo('Serializing {}'.format(_utils.fqdn(m.__name__, k, entity_type=w.resource_type)))
task_templates = []
for n in w.nodes:
if n.task_node is not None:
task_templates.append(tmap[n.task_node.reference_id])
However, if we were to merely call .to_flyte_idl() on all the discovered entities, what we would get are:
flyteidl.admin.launch_plan_pb2.LaunchPlanSpec
flyteidl.core.workflow_pb2.WorkflowTemplate
flyteidl.core.tasks_pb2.TaskTemplate
For Workflows and Tasks therefore, there is special logic in the serialize function that translates these objects.
:param Text project:
:param Text domain:
:param list[Text] pkgs:
:param Text version:
:param Text folder:
:return:
"""

wc = _WorkflowClosure(workflow=w, tasks=task_templates)
wc_pb = wc.to_flyte_idl()
# m = module (i.e. python file)
# k = value of dir(m), type str
# o = object (e.g. SdkWorkflow)
loaded_entities = []
for m, k, o in iterate_registerable_entities_in_order(pkgs):
name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type)
_logging.debug("Found module {}\n K: {} Instantiated in {}".format(m, k, o._instantiated_in))
o._id = _identifier.Identifier(
o.resource_type,
project,
domain,
name,
version
)
loaded_entities.append(o)

zero_padded_length = _determine_text_chars(len(loaded_entities))
for i, entity in enumerate(loaded_entities):
serialized = entity.serialize()
fname_index = str(i).zfill(zero_padded_length)
fname = '{}_{}.pb'.format(fname_index, entity._id.name)
click.echo(' Writing {} to\n {}'.format(entity._id, fname))
_write_proto_to_file(serialized, fname)

# Not everything serialized will necessarily have an identifier field in it, even though some do (like the
# TaskTemplate). To be more rigorous, we write an explicit identifier file that reflects the choices (like
# project/domain, etc.) made for this serialize call. We should not allow users to specify a different project
# for instance come registration time, to avoid mismatches between potential internal ids like the TaskTemplate
# and the registered entity.
identifier_fname = '{}_{}.identifier.pb'.format(fname_index, entity._id.name)
if folder:
identifier_fname = _os.path.join(folder, identifier_fname)
_write_proto_to_file(entity._id.to_flyte_idl(), identifier_fname)


def _determine_text_chars(length):
"""
This function is used to help prefix files. If there are only 10 entries, then we just need one digit (0-9) to be
the prefix. If there are 11, then we'll need two (00-10).
fname = '{}.pb'.format(_utils.fqdn(m.__name__, k, entity_type=w.resource_type))
click.echo(' Writing workflow closure {}'.format(fname))
_write_proto_to_file(wc_pb, fname)
:param int length:
:rtype: int
"""
return _math.ceil(_math.log(length, 10))


@click.group('serialize')
Expand All @@ -57,37 +139,77 @@ def serialize(ctx):
object contains the WorkflowTemplate, along with the relevant tasks for that workflow. In lieu of Admin,
this serialization step will set the URN of the tasks to the fully qualified name of the task function.
"""
click.echo('Serializing Flyte elements with image {}'.format(_IMAGE.get()))
click.echo('Serializing Flyte elements with image {}'.format(_internal_configuration.IMAGE.get()))


@click.command('tasks')
@click.option('-v', '--version', type=str, help='Version to serialize tasks with. This is normally parsed from the'
'image, but you can override here.')
@click.option('-f', '--folder', type=click.Path(exists=True))
@click.pass_context
def tasks(ctx):
def tasks(ctx, version=None, folder=None):
project = ctx.obj[CTX_PROJECT]
domain = ctx.obj[CTX_DOMAIN]
pkgs = ctx.obj[CTX_PACKAGES]

if folder:
click.echo(f"Writing output to {folder}")

version = version or ctx.obj[CTX_VERSION] or _internal_configuration.look_up_version_from_image_tag(
_internal_configuration.IMAGE.get())

internal_settings = {
'project': ctx.obj[CTX_PROJECT],
'domain': ctx.obj[CTX_DOMAIN],
'version': ctx.obj[CTX_VERSION]
'project': project,
'domain': domain,
'version': version,
}
# Populate internal settings for project/domain/version from the environment so that the file names are resolved
# with the correct strings. The file itself doesn't need to change though.
with TemporaryConfiguration(CONFIGURATION_PATH.get(), internal_settings):
serialize_tasks(pkgs)
# with the correct strings. The file itself doesn't need to change though.
with TemporaryConfiguration(_internal_configuration.CONFIGURATION_PATH.get(), internal_settings):
_logging.debug("Serializing with settings\n"
"\n Project: {}"
"\n Domain: {}"
"\n Version: {}"
"\n\nover the following packages {}".format(project, domain, version, pkgs)
)
serialize_tasks_only(project, domain, pkgs, version, folder)


@click.command('workflows')
@click.option('-v', '--version', type=str, help='Version to serialize tasks with. This is normally parsed from the'
'image, but you can override here.')
# For now let's just assume that the directory needs to exist. If you're docker run -v'ing, docker will create the
# directory for you so it shouldn't be a problem.
@click.option('-f', '--folder', type=click.Path(exists=True))
@click.pass_context
def workflows(ctx):
def workflows(ctx, version=None, folder=None):
_logging.getLogger().setLevel(_logging.DEBUG)

if folder:
click.echo(f"Writing output to {folder}")

project = ctx.obj[CTX_PROJECT]
domain = ctx.obj[CTX_DOMAIN]
pkgs = ctx.obj[CTX_PACKAGES]

version = version or ctx.obj[CTX_VERSION] or _internal_configuration.look_up_version_from_image_tag(
_internal_configuration.IMAGE.get())

internal_settings = {
'project': ctx.obj[CTX_PROJECT],
'domain': ctx.obj[CTX_DOMAIN],
'version': ctx.obj[CTX_VERSION]
'project': project,
'domain': domain,
'version': version,
}
# Populate internal settings for project/domain/version from the environment so that the file names are resolved
# with the correct strings. The file itself doesn't need to change though.
with TemporaryConfiguration(CONFIGURATION_PATH.get(), internal_settings):
serialize_workflows(pkgs)
# with the correct strings. The file itself doesn't need to change though.
with TemporaryConfiguration(_internal_configuration.CONFIGURATION_PATH.get(), internal_settings):
_logging.debug("Serializing with settings\n"
"\n Project: {}"
"\n Domain: {}"
"\n Version: {}"
"\n\nover the following packages {}".format(project, domain, version, pkgs)
)
serialize_all(project, domain, pkgs, version, folder)


serialize.add_command(tasks)
Expand Down
8 changes: 8 additions & 0 deletions flytekit/common/launch_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,14 @@ def register(self, project, domain, name, version):
self._id = id_to_register
return _six.text_type(self.id)

@_exception_scopes.system_entry_point
def serialize(self):
"""
Unlike the SdkWorkflow serialize call, nothing special needs to be done here.
:rtype: flyteidl.admin.launch_plan_pb2.LaunchPlanSpec
"""
return self.to_flyte_idl()

@classmethod
def from_flyte_idl(cls, _):
raise _user_exceptions.FlyteAssertion(
Expand Down
14 changes: 14 additions & 0 deletions flytekit/common/mixins/registerable.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,20 @@ def register(self, project, domain, name, version):
"""
pass

@_abc.abstractmethod
def serialize(self, project, domain, name, version):
"""
Registerable entities also are required to be serialized. This allows flytekit to separate serialization from
the network call to Admin (mostly at least, if a Launch Plan is fetched for instance as part of another
workflow, it will still hit Admin.
:param Text project: The project in which to serialize this task.
:param Text domain: The domain in which to serialize this task.
:param Text name: The name to give this task.
:param Text version: The version in which to serialize this task.
"""
pass

@_abc.abstractproperty
def resource_type(self):
"""
Expand Down
7 changes: 7 additions & 0 deletions flytekit/common/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,13 @@ def register(self, project, domain, name, version):
self._id = old_id
raise

@_exception_scopes.system_entry_point
def serialize(self):
"""
:rtype: flyteidl.admin.task_pb2.TaskSpec
"""
return _task_model.TaskSpec(self).to_flyte_idl()

@classmethod
@_exception_scopes.system_entry_point
def fetch(cls, project, domain, name, version):
Expand Down
15 changes: 15 additions & 0 deletions flytekit/common/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from flytekit.models.core import workflow as _workflow_models, identifier as _identifier_model
from flytekit.common.exceptions import system as _system_exceptions
from flytekit.common import constants as _constants
from flytekit.models.admin import workflow as _admin_workflow_model


class Output(object):
Expand Down Expand Up @@ -286,6 +287,20 @@ def register(self, project, domain, name, version):
self._id = old_id
raise

@_exception_scopes.system_entry_point
def serialize(self):
"""
Serializing a workflow should produce an object similar to what the registration step produces, in preparation
for actual registration to Admin.
:rtype: flyteidl.admin.workflow_pb2.WorkflowSpec
"""
sub_workflows = self.get_sub_workflows()
return _admin_workflow_model.WorkflowSpec(
self,
sub_workflows,
).to_flyte_idl()

@_exception_scopes.system_entry_point
def validate(self):
pass
Expand Down
Loading

0 comments on commit fe5b620

Please sign in to comment.