-
Notifications
You must be signed in to change notification settings - Fork 287
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
SageMaker on Flyte: TrainingJob for training with built-in algorithms…
… and basic HPOJob support [Alpha] (#120) * adding trainingjob model and sagemaker task * adding models for sagemaker proto messages * add new line at eof * adding common trainingjob task * redo flytekit changes to comply with new interface and proto definition * Fix a logic bug in training job model. Adding SdkSimpleTrainingJobTask type * Add a comment * Add SdkSimpleHPOJobTask * Remove the embedding of the underlying trainingjob's output from the hpojob's interface * fix a typo * add new line at eof * adding custom training job sdk type * add code for tranlating an enum in hpo_job model; fix hpo_job_task sdk task * missing a colon * add the missing input stopping_condition for training job tasks * bump flyteidl version * bump to a beta version * fixing unit tests * fixing unit tests * replacing interface types * change * fixed training job unit test * fix hpo job task interface and hide task type from users * fix hpo job task interface * fix hpo models * fix serialization of the underlying trainingjob of a hpo job * Expose training job as a parameter * Working! * replacing hyphens with underscores * updated * bug fix * Sagemaker nb * Sagemaker HPO * remove .demo directory * register and launch standalone trainingjob task * Merge * adding unit test for SdkSimpleHPOJobTask * fixing unit tests * preventing installing numpy==1.19.0 which introduces a breaking change for unit tests * fix semver * make changes corresponding to flyteidl changes (renaming hpo to hyperparameter tuning) * bump beta version * Delete config.yaml * make changes to reflect changes in flyteidl * make task name consistent * add missing properties for hyperparameter models * add missing type hints and remove unused imports * remove unused sdk sagemaker dir * remove unused test file * revert numpy semver * remove type hints for self because CI is using python 3.6.3 while __future__.annotations requires python 3.7 * complete docstrings for hpo job task * fix unit test * adding input_file_type (wip) * add input file type support * add docs * reflecting the renamed type and field * reflecting remove of libsvm content type * reflecting remove of libsvm content type * Give metric_definitions a None as the default value because built-in algorithm does not allow custom metrics * nix a print statement * nix custom training job for the current release * rename SdkSimpleTrainingJobTask to SdkBuiltinAlgorithmTrainingJobTask * revert setup.py dependency Co-authored-by: Yee Hing Tong <[email protected]> Co-authored-by: Ketan Umare <[email protected]> Co-authored-by: Haytham AbuelFutuh <[email protected]>
- Loading branch information
1 parent
a32dd2d
commit c2e8424
Showing
13 changed files
with
1,176 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,3 +18,4 @@ build/ | |
dist | ||
*.iml | ||
.eggs | ||
.demo |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,4 +2,4 @@ | |
|
||
import flytekit.plugins | ||
|
||
__version__ = "0.10.12" | ||
__version__ = '0.11.0' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
from __future__ import absolute_import | ||
import datetime as _datetime | ||
|
||
from google.protobuf.json_format import MessageToDict | ||
|
||
from flyteidl.plugins.sagemaker import hyperparameter_tuning_job_pb2 as _pb2_hpo_job | ||
from flytekit import __version__ | ||
from flytekit.common.constants import SdkTaskType | ||
from flytekit.common.tasks import task as _sdk_task | ||
from flytekit.common import interface as _interface | ||
from flytekit.common.tasks.sagemaker.training_job_task import SdkBuiltinAlgorithmTrainingJobTask | ||
from flytekit.models import task as _task_models | ||
from flytekit.models import interface as _interface_model | ||
from flytekit.models.sagemaker import hpo_job as _hpo_job_model | ||
from flytekit.models import literals as _literal_models | ||
from flytekit.models import types as _types_models | ||
from flytekit.models.core import types as _core_types | ||
from flytekit.sdk import types as _sdk_types | ||
|
||
|
||
class SdkSimpleHyperparameterTuningJobTask(_sdk_task.SdkTask): | ||
|
||
def __init__( | ||
self, | ||
max_number_of_training_jobs: int, | ||
max_parallel_training_jobs: int, | ||
training_job: SdkBuiltinAlgorithmTrainingJobTask, | ||
retries: int = 0, | ||
cacheable: bool = False, | ||
cache_version: str = "", | ||
): | ||
""" | ||
:param max_number_of_training_jobs: The maximum number of training jobs that can be launched by this | ||
hyperparameter tuning job | ||
:param max_parallel_training_jobs: The maximum number of training jobs that can launched by this hyperparameter | ||
tuning job in parallel | ||
:param training_job: The reference to the training job definition | ||
:param retries: Number of retries to attempt | ||
:param cacheable: The flag to set if the user wants the output of the task execution to be cached | ||
:param cache_version: String describing the caching version for task discovery purposes | ||
""" | ||
# Use the training job model as a measure of type checking | ||
hpo_job = _hpo_job_model.HyperparameterTuningJob( | ||
max_number_of_training_jobs=max_number_of_training_jobs, | ||
max_parallel_training_jobs=max_parallel_training_jobs, | ||
training_job=training_job.training_job_model, | ||
).to_flyte_idl() | ||
|
||
# Setting flyte-level timeout to 0, and let SageMaker respect the StoppingCondition of | ||
# the underlying training job | ||
# TODO: Discuss whether this is a viable interface or contract | ||
timeout = _datetime.timedelta(seconds=0) | ||
|
||
inputs = { | ||
"hyperparameter_tuning_job_config": _interface_model.Variable( | ||
_sdk_types.Types.Proto( | ||
_pb2_hpo_job.HyperparameterTuningJobConfig).to_flyte_literal_type(), "" | ||
), | ||
} | ||
inputs.update(training_job.interface.inputs) | ||
|
||
super(SdkSimpleHyperparameterTuningJobTask, self).__init__( | ||
type=SdkTaskType.SAGEMAKER_HYPERPARAMETER_TUNING_JOB_TASK, | ||
metadata=_task_models.TaskMetadata( | ||
runtime=_task_models.RuntimeMetadata( | ||
type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, | ||
version=__version__, | ||
flavor='sagemaker' | ||
), | ||
discoverable=cacheable, | ||
timeout=timeout, | ||
retries=_literal_models.RetryStrategy(retries=retries), | ||
interruptible=False, | ||
discovery_version=cache_version, | ||
deprecated_error_message="", | ||
), | ||
interface=_interface.TypedInterface( | ||
inputs=inputs, | ||
outputs={ | ||
"model": _interface_model.Variable( | ||
type=_types_models.LiteralType( | ||
blob=_core_types.BlobType( | ||
format="", | ||
dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE | ||
) | ||
), | ||
description="" | ||
) | ||
} | ||
), | ||
custom=MessageToDict(hpo_job), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
from __future__ import absolute_import | ||
|
||
from typing import Dict, Callable | ||
import datetime as _datetime | ||
|
||
from flytekit import __version__ | ||
from flytekit.common.tasks import task as _sdk_task, sdk_runnable as _sdk_runnable | ||
from flytekit.models import task as _task_models | ||
from flytekit.models import interface as _interface_model | ||
from flytekit.common import interface as _interface | ||
from flytekit.models.sagemaker import training_job as _training_job_models | ||
from flyteidl.plugins.sagemaker import training_job_pb2 as _training_job_pb2 | ||
from google.protobuf.json_format import MessageToDict | ||
from flytekit.models import types as _idl_types | ||
from flytekit.models.core import types as _core_types | ||
from flytekit.models import literals as _literal_models | ||
from flytekit.common.constants import SdkTaskType | ||
from flytekit.common.exceptions import user as _user_exceptions | ||
|
||
|
||
def _content_type_to_blob_format(content_type: _training_job_models) -> str: | ||
if content_type == _training_job_models.InputContentType.TEXT_CSV: | ||
return "csv" | ||
else: | ||
raise _user_exceptions.FlyteValueException("Unsupported InputContentType: {}".format(content_type)) | ||
|
||
|
||
class SdkBuiltinAlgorithmTrainingJobTask(_sdk_task.SdkTask): | ||
def __init__( | ||
self, | ||
training_job_resource_config: _training_job_models.TrainingJobResourceConfig, | ||
algorithm_specification: _training_job_models.AlgorithmSpecification, | ||
retries: int = 0, | ||
cacheable: bool = False, | ||
cache_version: str = "", | ||
): | ||
""" | ||
:param training_job_resource_config: The options to configure the training job | ||
:param algorithm_specification: The options to configure the target algorithm of the training | ||
:param retries: Number of retries to attempt | ||
:param cacheable: The flag to set if the user wants the output of the task execution to be cached | ||
:param cache_version: String describing the caching version for task discovery purposes | ||
""" | ||
# Use the training job model as a measure of type checking | ||
self._training_job_model = _training_job_models.TrainingJob( | ||
algorithm_specification=algorithm_specification, | ||
training_job_resource_config=training_job_resource_config, | ||
) | ||
|
||
# Setting flyte-level timeout to 0, and let SageMaker takes the StoppingCondition and terminate the training | ||
# job gracefully | ||
timeout = _datetime.timedelta(seconds=0) | ||
|
||
super(SdkBuiltinAlgorithmTrainingJobTask, self).__init__( | ||
type=SdkTaskType.SAGEMAKER_TRAINING_JOB_TASK, | ||
metadata=_task_models.TaskMetadata( | ||
runtime=_task_models.RuntimeMetadata( | ||
type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, | ||
version=__version__, | ||
flavor='sagemaker' | ||
), | ||
discoverable=cacheable, | ||
timeout=timeout, | ||
retries=_literal_models.RetryStrategy(retries=retries), | ||
interruptible=False, | ||
discovery_version=cache_version, | ||
deprecated_error_message="", | ||
), | ||
interface=_interface.TypedInterface( | ||
inputs={ | ||
"static_hyperparameters": _interface_model.Variable( | ||
type=_idl_types.LiteralType(simple=_idl_types.SimpleType.STRUCT), | ||
description="", | ||
), | ||
"train": _interface_model.Variable( | ||
type=_idl_types.LiteralType( | ||
blob=_core_types.BlobType( | ||
format=_content_type_to_blob_format(algorithm_specification.input_content_type), | ||
dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART | ||
), | ||
), | ||
description="", | ||
), | ||
"validation": _interface_model.Variable( | ||
type=_idl_types.LiteralType( | ||
blob=_core_types.BlobType( | ||
format=_content_type_to_blob_format(algorithm_specification.input_content_type), | ||
dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART | ||
), | ||
), | ||
description="", | ||
), | ||
}, | ||
outputs={ | ||
"model": _interface_model.Variable( | ||
type=_idl_types.LiteralType( | ||
blob=_core_types.BlobType( | ||
format="", | ||
dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE | ||
) | ||
), | ||
description="" | ||
) | ||
} | ||
), | ||
custom=MessageToDict(self._training_job_model.to_flyte_idl()), | ||
) | ||
|
||
@property | ||
def training_job_model(self) -> _training_job_models.TrainingJob: | ||
return self._training_job_model |
Oops, something went wrong.