Skip to content

Commit

Permalink
SageMaker on Flyte: TrainingJob for training with built-in algorithms…
Browse files Browse the repository at this point in the history
… and basic HPOJob support [Alpha] (#120)

* adding trainingjob model and sagemaker task

* adding models for sagemaker proto messages

* add new line at eof

* adding common trainingjob task

* redo flytekit changes to comply with new interface and proto definition

* Fix a logic bug in training job model. Adding SdkSimpleTrainingJobTask type

* Add a comment

* Add SdkSimpleHPOJobTask

* Remove the embedding of the underlying trainingjob's output from the hpojob's interface

* fix a typo

* add new line at eof

* adding custom training job sdk type

* add code for tranlating an enum in hpo_job model; fix hpo_job_task sdk task

* missing a colon

* add the missing input stopping_condition for training job tasks

* bump flyteidl version

* bump to a beta version

* fixing unit tests

* fixing unit tests

* replacing interface types

* change

* fixed training job unit test

* fix hpo job task interface and hide task type from users

* fix hpo job task interface

* fix hpo models

* fix serialization of the underlying trainingjob of a hpo job

* Expose training job as a parameter

* Working!

* replacing hyphens with underscores

* updated

* bug fix

* Sagemaker nb

* Sagemaker HPO

* remove .demo directory

* register and launch standalone trainingjob task

* Merge

* adding unit test for SdkSimpleHPOJobTask

* fixing unit tests

* preventing installing numpy==1.19.0 which introduces a breaking change for unit tests

* fix semver

* make changes corresponding to flyteidl changes (renaming hpo to hyperparameter tuning)

* bump beta version

* Delete config.yaml

* make changes to reflect changes in flyteidl

* make task name consistent

* add missing properties for hyperparameter models

* add missing type hints and remove unused imports

* remove unused sdk sagemaker dir

* remove unused test file

* revert numpy semver

* remove type hints for self because CI is using python 3.6.3 while __future__.annotations requires python 3.7

* complete docstrings for hpo job task

* fix unit test

* adding input_file_type (wip)

* add input file type support

* add docs

* reflecting the renamed type and field

* reflecting remove of libsvm content type

* reflecting remove of libsvm content type

* Give metric_definitions a None as the default value because built-in algorithm does not allow custom metrics

* nix a print statement

* nix custom training job for the current release

* rename SdkSimpleTrainingJobTask to SdkBuiltinAlgorithmTrainingJobTask

* revert setup.py dependency

Co-authored-by: Yee Hing Tong <[email protected]>
Co-authored-by: Ketan Umare <[email protected]>
Co-authored-by: Haytham AbuelFutuh <[email protected]>
  • Loading branch information
4 people authored Jul 31, 2020
1 parent a32dd2d commit c2e8424
Show file tree
Hide file tree
Showing 13 changed files with 1,176 additions and 6 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ build/
dist
*.iml
.eggs
.demo
2 changes: 1 addition & 1 deletion flytekit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

import flytekit.plugins

__version__ = "0.10.12"
__version__ = '0.11.0'
2 changes: 2 additions & 0 deletions flytekit/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class SdkTaskType(object):
PYTORCH_TASK = "pytorch"
# Raw container task is just a name, it defaults to using the regular container task (like python etc), but sets the data_config in the container
RAW_CONTAINER_TASK = "raw-container"
SAGEMAKER_TRAINING_JOB_TASK = "sagemaker_training_job_task"
SAGEMAKER_HYPERPARAMETER_TUNING_JOB_TASK = "sagemaker_hyperparameter_tuning_job_task"

GLOBAL_INPUT_NODE_ID = ''

Expand Down
93 changes: 93 additions & 0 deletions flytekit/common/tasks/sagemaker/hpo_job_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from __future__ import absolute_import
import datetime as _datetime

from google.protobuf.json_format import MessageToDict

from flyteidl.plugins.sagemaker import hyperparameter_tuning_job_pb2 as _pb2_hpo_job
from flytekit import __version__
from flytekit.common.constants import SdkTaskType
from flytekit.common.tasks import task as _sdk_task
from flytekit.common import interface as _interface
from flytekit.common.tasks.sagemaker.training_job_task import SdkBuiltinAlgorithmTrainingJobTask
from flytekit.models import task as _task_models
from flytekit.models import interface as _interface_model
from flytekit.models.sagemaker import hpo_job as _hpo_job_model
from flytekit.models import literals as _literal_models
from flytekit.models import types as _types_models
from flytekit.models.core import types as _core_types
from flytekit.sdk import types as _sdk_types


class SdkSimpleHyperparameterTuningJobTask(_sdk_task.SdkTask):

def __init__(
self,
max_number_of_training_jobs: int,
max_parallel_training_jobs: int,
training_job: SdkBuiltinAlgorithmTrainingJobTask,
retries: int = 0,
cacheable: bool = False,
cache_version: str = "",
):
"""
:param max_number_of_training_jobs: The maximum number of training jobs that can be launched by this
hyperparameter tuning job
:param max_parallel_training_jobs: The maximum number of training jobs that can launched by this hyperparameter
tuning job in parallel
:param training_job: The reference to the training job definition
:param retries: Number of retries to attempt
:param cacheable: The flag to set if the user wants the output of the task execution to be cached
:param cache_version: String describing the caching version for task discovery purposes
"""
# Use the training job model as a measure of type checking
hpo_job = _hpo_job_model.HyperparameterTuningJob(
max_number_of_training_jobs=max_number_of_training_jobs,
max_parallel_training_jobs=max_parallel_training_jobs,
training_job=training_job.training_job_model,
).to_flyte_idl()

# Setting flyte-level timeout to 0, and let SageMaker respect the StoppingCondition of
# the underlying training job
# TODO: Discuss whether this is a viable interface or contract
timeout = _datetime.timedelta(seconds=0)

inputs = {
"hyperparameter_tuning_job_config": _interface_model.Variable(
_sdk_types.Types.Proto(
_pb2_hpo_job.HyperparameterTuningJobConfig).to_flyte_literal_type(), ""
),
}
inputs.update(training_job.interface.inputs)

super(SdkSimpleHyperparameterTuningJobTask, self).__init__(
type=SdkTaskType.SAGEMAKER_HYPERPARAMETER_TUNING_JOB_TASK,
metadata=_task_models.TaskMetadata(
runtime=_task_models.RuntimeMetadata(
type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
version=__version__,
flavor='sagemaker'
),
discoverable=cacheable,
timeout=timeout,
retries=_literal_models.RetryStrategy(retries=retries),
interruptible=False,
discovery_version=cache_version,
deprecated_error_message="",
),
interface=_interface.TypedInterface(
inputs=inputs,
outputs={
"model": _interface_model.Variable(
type=_types_models.LiteralType(
blob=_core_types.BlobType(
format="",
dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
)
),
description=""
)
}
),
custom=MessageToDict(hpo_job),
)
112 changes: 112 additions & 0 deletions flytekit/common/tasks/sagemaker/training_job_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from __future__ import absolute_import

from typing import Dict, Callable
import datetime as _datetime

from flytekit import __version__
from flytekit.common.tasks import task as _sdk_task, sdk_runnable as _sdk_runnable
from flytekit.models import task as _task_models
from flytekit.models import interface as _interface_model
from flytekit.common import interface as _interface
from flytekit.models.sagemaker import training_job as _training_job_models
from flyteidl.plugins.sagemaker import training_job_pb2 as _training_job_pb2
from google.protobuf.json_format import MessageToDict
from flytekit.models import types as _idl_types
from flytekit.models.core import types as _core_types
from flytekit.models import literals as _literal_models
from flytekit.common.constants import SdkTaskType
from flytekit.common.exceptions import user as _user_exceptions


def _content_type_to_blob_format(content_type: _training_job_models) -> str:
if content_type == _training_job_models.InputContentType.TEXT_CSV:
return "csv"
else:
raise _user_exceptions.FlyteValueException("Unsupported InputContentType: {}".format(content_type))


class SdkBuiltinAlgorithmTrainingJobTask(_sdk_task.SdkTask):
def __init__(
self,
training_job_resource_config: _training_job_models.TrainingJobResourceConfig,
algorithm_specification: _training_job_models.AlgorithmSpecification,
retries: int = 0,
cacheable: bool = False,
cache_version: str = "",
):
"""
:param training_job_resource_config: The options to configure the training job
:param algorithm_specification: The options to configure the target algorithm of the training
:param retries: Number of retries to attempt
:param cacheable: The flag to set if the user wants the output of the task execution to be cached
:param cache_version: String describing the caching version for task discovery purposes
"""
# Use the training job model as a measure of type checking
self._training_job_model = _training_job_models.TrainingJob(
algorithm_specification=algorithm_specification,
training_job_resource_config=training_job_resource_config,
)

# Setting flyte-level timeout to 0, and let SageMaker takes the StoppingCondition and terminate the training
# job gracefully
timeout = _datetime.timedelta(seconds=0)

super(SdkBuiltinAlgorithmTrainingJobTask, self).__init__(
type=SdkTaskType.SAGEMAKER_TRAINING_JOB_TASK,
metadata=_task_models.TaskMetadata(
runtime=_task_models.RuntimeMetadata(
type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
version=__version__,
flavor='sagemaker'
),
discoverable=cacheable,
timeout=timeout,
retries=_literal_models.RetryStrategy(retries=retries),
interruptible=False,
discovery_version=cache_version,
deprecated_error_message="",
),
interface=_interface.TypedInterface(
inputs={
"static_hyperparameters": _interface_model.Variable(
type=_idl_types.LiteralType(simple=_idl_types.SimpleType.STRUCT),
description="",
),
"train": _interface_model.Variable(
type=_idl_types.LiteralType(
blob=_core_types.BlobType(
format=_content_type_to_blob_format(algorithm_specification.input_content_type),
dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART
),
),
description="",
),
"validation": _interface_model.Variable(
type=_idl_types.LiteralType(
blob=_core_types.BlobType(
format=_content_type_to_blob_format(algorithm_specification.input_content_type),
dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART
),
),
description="",
),
},
outputs={
"model": _interface_model.Variable(
type=_idl_types.LiteralType(
blob=_core_types.BlobType(
format="",
dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
)
),
description=""
)
}
),
custom=MessageToDict(self._training_job_model.to_flyte_idl()),
)

@property
def training_job_model(self) -> _training_job_models.TrainingJob:
return self._training_job_model
Loading

0 comments on commit c2e8424

Please sign in to comment.