From c2e84245fa4248f2ad29cea93e6eb89d88122bd5 Mon Sep 17 00:00:00 2001 From: Chang-Hong Hsu Date: Fri, 31 Jul 2020 11:06:27 -0700 Subject: [PATCH] SageMaker on Flyte: TrainingJob for training with built-in algorithms and basic HPOJob support [Alpha] (#120) * adding trainingjob model and sagemaker task * adding models for sagemaker proto messages * add new line at eof * adding common trainingjob task * redo flytekit changes to comply with new interface and proto definition * Fix a logic bug in training job model. Adding SdkSimpleTrainingJobTask type * Add a comment * Add SdkSimpleHPOJobTask * Remove the embedding of the underlying trainingjob's output from the hpojob's interface * fix a typo * add new line at eof * adding custom training job sdk type * add code for tranlating an enum in hpo_job model; fix hpo_job_task sdk task * missing a colon * add the missing input stopping_condition for training job tasks * bump flyteidl version * bump to a beta version * fixing unit tests * fixing unit tests * replacing interface types * change * fixed training job unit test * fix hpo job task interface and hide task type from users * fix hpo job task interface * fix hpo models * fix serialization of the underlying trainingjob of a hpo job * Expose training job as a parameter * Working! * replacing hyphens with underscores * updated * bug fix * Sagemaker nb * Sagemaker HPO * remove .demo directory * register and launch standalone trainingjob task * Merge * adding unit test for SdkSimpleHPOJobTask * fixing unit tests * preventing installing numpy==1.19.0 which introduces a breaking change for unit tests * fix semver * make changes corresponding to flyteidl changes (renaming hpo to hyperparameter tuning) * bump beta version * Delete config.yaml * make changes to reflect changes in flyteidl * make task name consistent * add missing properties for hyperparameter models * add missing type hints and remove unused imports * remove unused sdk sagemaker dir * remove unused test file * revert numpy semver * remove type hints for self because CI is using python 3.6.3 while __future__.annotations requires python 3.7 * complete docstrings for hpo job task * fix unit test * adding input_file_type (wip) * add input file type support * add docs * reflecting the renamed type and field * reflecting remove of libsvm content type * reflecting remove of libsvm content type * Give metric_definitions a None as the default value because built-in algorithm does not allow custom metrics * nix a print statement * nix custom training job for the current release * rename SdkSimpleTrainingJobTask to SdkBuiltinAlgorithmTrainingJobTask * revert setup.py dependency Co-authored-by: Yee Hing Tong Co-authored-by: Ketan Umare Co-authored-by: Haytham AbuelFutuh --- .gitignore | 1 + flytekit/__init__.py | 2 +- flytekit/common/constants.py | 2 + .../common/tasks/sagemaker/hpo_job_task.py | 93 ++++++ .../tasks/sagemaker/training_job_task.py | 112 +++++++ flytekit/models/sagemaker/hpo_job.py | 198 ++++++++++++ flytekit/models/sagemaker/parameter_ranges.py | 211 +++++++++++++ flytekit/models/sagemaker/training_job.py | 293 ++++++++++++++++++ flytekit/plugins/__init__.py | 2 +- flytekit/sdk/tasks.py | 75 +++++ sample-notebooks/raw-container-shell.ipynb | 6 +- sample-notebooks/raw-container.ipynb | 2 +- .../unit/sdk/tasks/test_sagemaker_tasks.py | 185 +++++++++++ 13 files changed, 1176 insertions(+), 6 deletions(-) create mode 100644 flytekit/common/tasks/sagemaker/hpo_job_task.py create mode 100644 flytekit/common/tasks/sagemaker/training_job_task.py create mode 100644 flytekit/models/sagemaker/hpo_job.py create mode 100644 flytekit/models/sagemaker/parameter_ranges.py create mode 100644 flytekit/models/sagemaker/training_job.py create mode 100644 tests/flytekit/unit/sdk/tasks/test_sagemaker_tasks.py diff --git a/.gitignore b/.gitignore index 7df4cb78cc..64ecdb7e1d 100644 --- a/.gitignore +++ b/.gitignore @@ -18,3 +18,4 @@ build/ dist *.iml .eggs +.demo diff --git a/flytekit/__init__.py b/flytekit/__init__.py index e175f3b8e1..6757707c15 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -2,4 +2,4 @@ import flytekit.plugins -__version__ = "0.10.12" +__version__ = '0.11.0' diff --git a/flytekit/common/constants.py b/flytekit/common/constants.py index 9b6ab77a35..8f3af75de3 100644 --- a/flytekit/common/constants.py +++ b/flytekit/common/constants.py @@ -24,6 +24,8 @@ class SdkTaskType(object): PYTORCH_TASK = "pytorch" # Raw container task is just a name, it defaults to using the regular container task (like python etc), but sets the data_config in the container RAW_CONTAINER_TASK = "raw-container" + SAGEMAKER_TRAINING_JOB_TASK = "sagemaker_training_job_task" + SAGEMAKER_HYPERPARAMETER_TUNING_JOB_TASK = "sagemaker_hyperparameter_tuning_job_task" GLOBAL_INPUT_NODE_ID = '' diff --git a/flytekit/common/tasks/sagemaker/hpo_job_task.py b/flytekit/common/tasks/sagemaker/hpo_job_task.py new file mode 100644 index 0000000000..780655ca77 --- /dev/null +++ b/flytekit/common/tasks/sagemaker/hpo_job_task.py @@ -0,0 +1,93 @@ +from __future__ import absolute_import +import datetime as _datetime + +from google.protobuf.json_format import MessageToDict + +from flyteidl.plugins.sagemaker import hyperparameter_tuning_job_pb2 as _pb2_hpo_job +from flytekit import __version__ +from flytekit.common.constants import SdkTaskType +from flytekit.common.tasks import task as _sdk_task +from flytekit.common import interface as _interface +from flytekit.common.tasks.sagemaker.training_job_task import SdkBuiltinAlgorithmTrainingJobTask +from flytekit.models import task as _task_models +from flytekit.models import interface as _interface_model +from flytekit.models.sagemaker import hpo_job as _hpo_job_model +from flytekit.models import literals as _literal_models +from flytekit.models import types as _types_models +from flytekit.models.core import types as _core_types +from flytekit.sdk import types as _sdk_types + + +class SdkSimpleHyperparameterTuningJobTask(_sdk_task.SdkTask): + + def __init__( + self, + max_number_of_training_jobs: int, + max_parallel_training_jobs: int, + training_job: SdkBuiltinAlgorithmTrainingJobTask, + retries: int = 0, + cacheable: bool = False, + cache_version: str = "", + ): + """ + + :param max_number_of_training_jobs: The maximum number of training jobs that can be launched by this + hyperparameter tuning job + :param max_parallel_training_jobs: The maximum number of training jobs that can launched by this hyperparameter + tuning job in parallel + :param training_job: The reference to the training job definition + :param retries: Number of retries to attempt + :param cacheable: The flag to set if the user wants the output of the task execution to be cached + :param cache_version: String describing the caching version for task discovery purposes + """ + # Use the training job model as a measure of type checking + hpo_job = _hpo_job_model.HyperparameterTuningJob( + max_number_of_training_jobs=max_number_of_training_jobs, + max_parallel_training_jobs=max_parallel_training_jobs, + training_job=training_job.training_job_model, + ).to_flyte_idl() + + # Setting flyte-level timeout to 0, and let SageMaker respect the StoppingCondition of + # the underlying training job + # TODO: Discuss whether this is a viable interface or contract + timeout = _datetime.timedelta(seconds=0) + + inputs = { + "hyperparameter_tuning_job_config": _interface_model.Variable( + _sdk_types.Types.Proto( + _pb2_hpo_job.HyperparameterTuningJobConfig).to_flyte_literal_type(), "" + ), + } + inputs.update(training_job.interface.inputs) + + super(SdkSimpleHyperparameterTuningJobTask, self).__init__( + type=SdkTaskType.SAGEMAKER_HYPERPARAMETER_TUNING_JOB_TASK, + metadata=_task_models.TaskMetadata( + runtime=_task_models.RuntimeMetadata( + type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, + version=__version__, + flavor='sagemaker' + ), + discoverable=cacheable, + timeout=timeout, + retries=_literal_models.RetryStrategy(retries=retries), + interruptible=False, + discovery_version=cache_version, + deprecated_error_message="", + ), + interface=_interface.TypedInterface( + inputs=inputs, + outputs={ + "model": _interface_model.Variable( + type=_types_models.LiteralType( + blob=_core_types.BlobType( + format="", + dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE + ) + ), + description="" + ) + } + ), + custom=MessageToDict(hpo_job), + ) diff --git a/flytekit/common/tasks/sagemaker/training_job_task.py b/flytekit/common/tasks/sagemaker/training_job_task.py new file mode 100644 index 0000000000..f7134673be --- /dev/null +++ b/flytekit/common/tasks/sagemaker/training_job_task.py @@ -0,0 +1,112 @@ +from __future__ import absolute_import + +from typing import Dict, Callable +import datetime as _datetime + +from flytekit import __version__ +from flytekit.common.tasks import task as _sdk_task, sdk_runnable as _sdk_runnable +from flytekit.models import task as _task_models +from flytekit.models import interface as _interface_model +from flytekit.common import interface as _interface +from flytekit.models.sagemaker import training_job as _training_job_models +from flyteidl.plugins.sagemaker import training_job_pb2 as _training_job_pb2 +from google.protobuf.json_format import MessageToDict +from flytekit.models import types as _idl_types +from flytekit.models.core import types as _core_types +from flytekit.models import literals as _literal_models +from flytekit.common.constants import SdkTaskType +from flytekit.common.exceptions import user as _user_exceptions + + +def _content_type_to_blob_format(content_type: _training_job_models) -> str: + if content_type == _training_job_models.InputContentType.TEXT_CSV: + return "csv" + else: + raise _user_exceptions.FlyteValueException("Unsupported InputContentType: {}".format(content_type)) + + +class SdkBuiltinAlgorithmTrainingJobTask(_sdk_task.SdkTask): + def __init__( + self, + training_job_resource_config: _training_job_models.TrainingJobResourceConfig, + algorithm_specification: _training_job_models.AlgorithmSpecification, + retries: int = 0, + cacheable: bool = False, + cache_version: str = "", + ): + """ + + :param training_job_resource_config: The options to configure the training job + :param algorithm_specification: The options to configure the target algorithm of the training + :param retries: Number of retries to attempt + :param cacheable: The flag to set if the user wants the output of the task execution to be cached + :param cache_version: String describing the caching version for task discovery purposes + """ + # Use the training job model as a measure of type checking + self._training_job_model = _training_job_models.TrainingJob( + algorithm_specification=algorithm_specification, + training_job_resource_config=training_job_resource_config, + ) + + # Setting flyte-level timeout to 0, and let SageMaker takes the StoppingCondition and terminate the training + # job gracefully + timeout = _datetime.timedelta(seconds=0) + + super(SdkBuiltinAlgorithmTrainingJobTask, self).__init__( + type=SdkTaskType.SAGEMAKER_TRAINING_JOB_TASK, + metadata=_task_models.TaskMetadata( + runtime=_task_models.RuntimeMetadata( + type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, + version=__version__, + flavor='sagemaker' + ), + discoverable=cacheable, + timeout=timeout, + retries=_literal_models.RetryStrategy(retries=retries), + interruptible=False, + discovery_version=cache_version, + deprecated_error_message="", + ), + interface=_interface.TypedInterface( + inputs={ + "static_hyperparameters": _interface_model.Variable( + type=_idl_types.LiteralType(simple=_idl_types.SimpleType.STRUCT), + description="", + ), + "train": _interface_model.Variable( + type=_idl_types.LiteralType( + blob=_core_types.BlobType( + format=_content_type_to_blob_format(algorithm_specification.input_content_type), + dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART + ), + ), + description="", + ), + "validation": _interface_model.Variable( + type=_idl_types.LiteralType( + blob=_core_types.BlobType( + format=_content_type_to_blob_format(algorithm_specification.input_content_type), + dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART + ), + ), + description="", + ), + }, + outputs={ + "model": _interface_model.Variable( + type=_idl_types.LiteralType( + blob=_core_types.BlobType( + format="", + dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE + ) + ), + description="" + ) + } + ), + custom=MessageToDict(self._training_job_model.to_flyte_idl()), + ) + + @property + def training_job_model(self) -> _training_job_models.TrainingJob: + return self._training_job_model diff --git a/flytekit/models/sagemaker/hpo_job.py b/flytekit/models/sagemaker/hpo_job.py new file mode 100644 index 0000000000..900fb58e2a --- /dev/null +++ b/flytekit/models/sagemaker/hpo_job.py @@ -0,0 +1,198 @@ +from __future__ import absolute_import + +from flyteidl.plugins.sagemaker import hyperparameter_tuning_job_pb2 as _pb2_hpo_job +from flytekit.models import common as _common +from flytekit.models.sagemaker import parameter_ranges as _parameter_ranges_models, training_job as _training_job + + +class HyperparameterTuningObjectiveType(object): + MINIMIZE = _pb2_hpo_job.HyperparameterTuningObjectiveType.MINIMIZE + MAXIMIZE = _pb2_hpo_job.HyperparameterTuningObjectiveType.MAXIMIZE + + +class HyperparameterTuningObjective(_common.FlyteIdlEntity): + """ + HyperparameterTuningObjective is a data structure that contains the target metric and the + objective of the hyperparameter tuning. + + https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-define-metrics.html + """ + def __init__( + self, + objective_type: int, + metric_name: str, + ): + self._objective_type = objective_type + self._metric_name = metric_name + + @property + def objective_type(self) -> int: + """ + Enum value of HyperparameterTuningObjectiveType. objective_type determines the direction of the tuning of + the Hyperparameter Tuning Job with respect to the specified metric. + :rtype: int + """ + return self._objective_type + + @property + def metric_name(self) -> str: + """ + The target metric name, which is the user-defined name of the metric specified in the + training job's algorithm specification + :rtype: str + """ + return self._metric_name + + def to_flyte_idl(self) -> _pb2_hpo_job.HyperparameterTuningObjective: + + return _pb2_hpo_job.HyperparameterTuningObjective( + objective_type=self.objective_type, + metric_name=self._metric_name, + ) + + @classmethod + def from_flyte_idl(cls, pb2_object: _pb2_hpo_job.HyperparameterTuningObjective): + + return cls( + objective_type=pb2_object.objective_type, + metric_name=pb2_object.metric_name, + ) + + +class HyperparameterTuningStrategy: + BAYESIAN = _pb2_hpo_job.HyperparameterTuningStrategy.BAYESIAN + RANDOM = _pb2_hpo_job.HyperparameterTuningStrategy.RANDOM + + +class TrainingJobEarlyStoppingType: + OFF = _pb2_hpo_job.TrainingJobEarlyStoppingType.OFF + AUTO = _pb2_hpo_job.TrainingJobEarlyStoppingType.AUTO + + +class HyperparameterTuningJobConfig(_common.FlyteIdlEntity): + """ + The specification of the hyperparameter tuning process + https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-ex-tuning-job.html#automatic-model-tuning-ex-low-tuning-config + """ + def __init__( + self, + hyperparameter_ranges: _parameter_ranges_models.ParameterRanges, + tuning_strategy: int, + tuning_objective: HyperparameterTuningObjective, + training_job_early_stopping_type: int, + ): + self._hyperparameter_ranges = hyperparameter_ranges + self._tuning_strategy = tuning_strategy + self._tuning_objective = tuning_objective + self._training_job_early_stopping_type = training_job_early_stopping_type + + @property + def hyperparameter_ranges(self) -> _parameter_ranges_models.ParameterRanges: + """ + hyperparameter_ranges is a structure containing a map that maps hyperparameter name to the corresponding + hyperparameter range object + :rtype: _parameter_ranges_models.ParameterRanges + """ + return self._hyperparameter_ranges + + @property + def tuning_strategy(self) -> int: + """ + Enum value of HyperparameterTuningStrategy. Setting the strategy used when searching in the hyperparameter space + :rtype: int + """ + return self._tuning_strategy + + @property + def tuning_objective(self) -> HyperparameterTuningObjective: + """ + The target metric and the objective of the hyperparameter tuning. + :rtype: HyperparameterTuningObjective + """ + return self._tuning_objective + + @property + def training_job_early_stopping_type(self) -> int: + """ + Enum value of TrainingJobEarlyStoppingType. When the training jobs launched by the hyperparameter tuning job + are not improving significantly, a hyperparameter tuning job can be stopping early. This attribute determines + how the early stopping is to be done. + Note that there's only a subset of built-in algorithms that supports early stopping. + see: https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-early-stopping.html + :rtype: int + """ + return self._training_job_early_stopping_type + + def to_flyte_idl(self) -> _pb2_hpo_job.HyperparameterTuningJobConfig: + + return _pb2_hpo_job.HyperparameterTuningJobConfig( + hyperparameter_ranges=self._hyperparameter_ranges.to_flyte_idl(), + tuning_strategy=self._tuning_strategy, + tuning_objective=self._tuning_objective.to_flyte_idl(), + training_job_early_stopping_type=self._training_job_early_stopping_type, + ) + + @classmethod + def from_flyte_idl(cls, pb2_object: _pb2_hpo_job.HyperparameterTuningJobConfig): + + return cls( + hyperparameter_ranges=( + _parameter_ranges_models.ParameterRanges.from_flyte_idl(pb2_object.hyperparameter_ranges)), + tuning_strategy=pb2_object.tuning_strategy, + tuning_objective=HyperparameterTuningObjective.from_flyte_idl(pb2_object.tuning_objective), + training_job_early_stopping_type=pb2_object.training_job_early_stopping_type, + ) + + +class HyperparameterTuningJob(_common.FlyteIdlEntity): + + def __init__( + self, + max_number_of_training_jobs: int, + max_parallel_training_jobs: int, + training_job: _training_job.TrainingJob, + ): + self._max_number_of_training_jobs = max_number_of_training_jobs + self._max_parallel_training_jobs = max_parallel_training_jobs + self._training_job = training_job + + @property + def max_number_of_training_jobs(self) -> int: + """ + The maximum number of training jobs that a hyperparameter tuning job can launch. + https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_ResourceLimits.html + :rtype: int + """ + return self._max_number_of_training_jobs + + @property + def max_parallel_training_jobs(self) -> int: + """ + The maximum number of concurrent training job that an hpo job can launch + https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_ResourceLimits.html + :rtype: int + """ + return self._max_parallel_training_jobs + + @property + def training_job(self) -> _training_job.TrainingJob: + """ + The reference to the underlying training job that the hyperparameter tuning job will launch during the process + :rtype: _training_job.TrainingJob + """ + return self._training_job + + def to_flyte_idl(self) -> _pb2_hpo_job.HyperparameterTuningJob: + return _pb2_hpo_job.HyperparameterTuningJob( + max_number_of_training_jobs=self._max_number_of_training_jobs, + max_parallel_training_jobs=self._max_parallel_training_jobs, + training_job=self._training_job.to_flyte_idl(), # SDK task has already serialized it + ) + + @classmethod + def from_flyte_idl(cls, pb2_object: _pb2_hpo_job.HyperparameterTuningJob): + return cls( + max_number_of_training_jobs=pb2_object.max_number_of_training_jobs, + max_parallel_training_jobs=pb2_object.max_parallel_training_jobs, + training_job=_training_job.TrainingJob.from_flyte_idl(pb2_object.training_job), + ) diff --git a/flytekit/models/sagemaker/parameter_ranges.py b/flytekit/models/sagemaker/parameter_ranges.py new file mode 100644 index 0000000000..d02b81c4c9 --- /dev/null +++ b/flytekit/models/sagemaker/parameter_ranges.py @@ -0,0 +1,211 @@ +from __future__ import absolute_import + +from typing import Dict, List +from flyteidl.plugins.sagemaker import parameter_ranges_pb2 as _idl_parameter_ranges +from flytekit.models import common as _common + + +class HyperparameterScalingType(object): + AUTO = _idl_parameter_ranges.HyperparameterScalingType.AUTO + LINEAR = _idl_parameter_ranges.HyperparameterScalingType.LINEAR + LOGARITHMIC = _idl_parameter_ranges.HyperparameterScalingType.LOGARITHMIC + REVERSELOGARITHMIC = _idl_parameter_ranges.HyperparameterScalingType.REVERSELOGARITHMIC + + +class ContinuousParameterRange(_common.FlyteIdlEntity): + def __init__( + self, + max_value: float, + min_value: float, + scaling_type: int, + ): + """ + + :param float max_value: + :param float min_value: + :param int scaling_type: + """ + self._max_value = max_value + self._min_value = min_value + self._scaling_type = scaling_type + + @property + def max_value(self) -> float: + """ + + :rtype: float + """ + return self._max_value + + @property + def min_value(self) -> float: + """ + + :rtype: float + """ + return self._min_value + + @property + def scaling_type(self) -> int: + """ + enum value from HyperparameterScalingType + :rtype: int + """ + return self._scaling_type + + def to_flyte_idl(self) -> _idl_parameter_ranges.ContinuousParameterRange: + """ + :rtype: _idl_parameter_ranges.ContinuousParameterRange + """ + + return _idl_parameter_ranges.ContinuousParameterRange( + max_value=self._max_value, + min_value=self._min_value, + scaling_type=self.scaling_type, + ) + + @classmethod + def from_flyte_idl(cls, pb2_object: _idl_parameter_ranges.ContinuousParameterRange): + """ + + :param pb2_object: + :rtype: ContinuousParameterRange + """ + return cls( + max_value=pb2_object.max_value, + min_value=pb2_object.min_value, + scaling_type=pb2_object.scaling_type, + ) + + +class IntegerParameterRange(_common.FlyteIdlEntity): + def __init__( + self, + max_value: int, + min_value: int, + scaling_type: int, + ): + """ + :param int max_value: + :param int min_value: + :param int scaling_type: + """ + self._max_value = max_value + self._min_value = min_value + self._scaling_type = scaling_type + + @property + def max_value(self) -> int: + """ + :rtype: int + """ + return self._max_value + + @property + def min_value(self) -> int: + """ + + :rtype: int + """ + return self._min_value + + @property + def scaling_type(self) -> int: + """ + enum value from HyperparameterScalingType + :rtype: int + """ + return self._scaling_type + + def to_flyte_idl(self) -> _idl_parameter_ranges.IntegerParameterRange: + """ + :rtype: _idl_parameter_ranges.IntegerParameterRange + """ + return _idl_parameter_ranges.IntegerParameterRange( + max_value=self._max_value, + min_value=self._min_value, + scaling_type=self.scaling_type, + ) + + @classmethod + def from_flyte_idl(cls, pb2_object: _idl_parameter_ranges.IntegerParameterRange): + """ + + :param pb2_object: + :rtype: IntegerParameterRange + """ + return cls( + max_value=pb2_object.max_value, + min_value=pb2_object.min_value, + scaling_type=pb2_object.scaling_type, + ) + + +class CategoricalParameterRange(_common.FlyteIdlEntity): + def __init__( + self, + values: List[str], + ): + """ + + :param List[str] values: list of strings representing categorical values + """ + self._values = values + + @property + def values(self) -> List[str]: + """ + :rtype: List[str] + """ + return self._values + + def to_flyte_idl(self) -> _idl_parameter_ranges.CategoricalParameterRange: + """ + :rtype: _idl_parameter_ranges.CategoricalParameterRange + """ + return _idl_parameter_ranges.CategoricalParameterRange( + values=self._values + ) + + @classmethod + def from_flyte_idl(cls, pb2_object: _idl_parameter_ranges.CategoricalParameterRange): + return cls( + values=pb2_object.values + ) + + +class ParameterRanges(_common.FlyteIdlEntity): + def __init__( + self, + parameter_range_map: Dict[str, _common.FlyteIdlEntity], + ): + self._parameter_range_map = parameter_range_map + + def to_flyte_idl(self) -> _idl_parameter_ranges.ParameterRanges: + converted = {} + for k, v in self._parameter_range_map.items(): + if isinstance(v, IntegerParameterRange): + converted[k] = _idl_parameter_ranges.ParameterRangeOneOf(integer_parameter_range=v.to_flyte_idl()) + elif isinstance(v, ContinuousParameterRange): + converted[k] = _idl_parameter_ranges.ParameterRangeOneOf(continuous_parameter_range=v.to_flyte_idl()) + else: + converted[k] = _idl_parameter_ranges.ParameterRangeOneOf(categorical_parameter_range=v.to_flyte_idl()) + + return _idl_parameter_ranges.ParameterRanges( + parameter_range_map=converted, + ) + + @classmethod + def from_flyte_idl(cls, pb2_object: _idl_parameter_ranges.ParameterRanges): + converted = {} + for k, v in pb2_object.parameter_range_map.items(): + if isinstance(v, _idl_parameter_ranges.ContinuousParameterRange): + converted[k] = ContinuousParameterRange.from_flyte_idl(v) + elif isinstance(v, _idl_parameter_ranges.IntegerParameterRange): + converted[k] = IntegerParameterRange.from_flyte_idl(v) + else: + converted[k] = CategoricalParameterRange.from_flyte_idl(v) + + return cls( + parameter_range_map=converted, + ) diff --git a/flytekit/models/sagemaker/training_job.py b/flytekit/models/sagemaker/training_job.py new file mode 100644 index 0000000000..f4862a9b66 --- /dev/null +++ b/flytekit/models/sagemaker/training_job.py @@ -0,0 +1,293 @@ +from __future__ import absolute_import + +from typing import List +from flyteidl.plugins.sagemaker import training_job_pb2 as _training_job_pb2 +from flytekit.models import common as _common + + +class TrainingJobResourceConfig(_common.FlyteIdlEntity): + """ + TrainingJobResourceConfig is a pass-through, specifying the instance type to use for the training job, the + number of instances to launch, and the size of the ML storage volume the user wants to provision + Refer to SageMaker official doc for more details: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTrainingJob.html + """ + def __init__( + self, + instance_count: int, + instance_type: str, + volume_size_in_gb: int, + ): + self._instance_count = instance_count + self._instance_type = instance_type + self._volume_size_in_gb = volume_size_in_gb + + @property + def instance_count(self) -> int: + """ + The number of ML compute instances to use. For distributed training, provide a value greater than 1. + :rtype: int + """ + return self._instance_count + + @property + def instance_type(self) -> str: + """ + The ML compute instance type. + :rtype: str + """ + return self._instance_type + + @property + def volume_size_in_gb(self) -> int: + """ + The size of the ML storage volume that you want to provision to store the data and intermediate artifacts, etc. + :rtype: int + """ + return self._volume_size_in_gb + + def to_flyte_idl(self) -> _training_job_pb2.TrainingJobResourceConfig: + """ + + :rtype: _training_job_pb2.TrainingJobResourceConfig + """ + return _training_job_pb2.TrainingJobResourceConfig( + instance_count=self.instance_count, + instance_type=self.instance_type, + volume_size_in_gb=self.volume_size_in_gb, + ) + + @classmethod + def from_flyte_idl(cls, pb2_object: _training_job_pb2.TrainingJobResourceConfig): + """ + + :param pb2_object: + :rtype: TrainingJobResourceConfig + """ + return cls( + instance_count=pb2_object.instance_count, + instance_type=pb2_object.instance_type, + volume_size_in_gb=pb2_object.volume_size_in_gb, + ) + + +class MetricDefinition(_common.FlyteIdlEntity): + def __init__( + self, + name: str, + regex: str, + ): + self._name = name + self._regex = regex + + @property + def name(self) -> str: + """ + The user-defined name of the metric + :rtype: str + """ + return self._name + + @property + def regex(self) -> str: + """ + SageMaker hyperparameter tuning using this regex to parses your algorithm’s stdout and stderr + streams to find the algorithm metrics on which the users want to track + :rtype: str + """ + return self._regex + + def to_flyte_idl(self) -> _training_job_pb2.MetricDefinition: + """ + + :rtype: _training_job_pb2.MetricDefinition + """ + return _training_job_pb2.MetricDefinition( + name=self.name, + regex=self.regex, + ) + + @classmethod + def from_flyte_idl(cls, pb2_object: _training_job_pb2.MetricDefinition): + """ + + :param pb2_object: _training_job_pb2.MetricDefinition + :rtype: MetricDefinition + """ + return cls( + name=pb2_object.name, + regex=pb2_object.regex, + ) + + +class InputMode(object): + """ + When using FILE input mode, different SageMaker built-in algorithms require different file types of input data + See https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-training.html + https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-algo-docker-registry-paths.html + """ + PIPE = _training_job_pb2.InputMode.PIPE + FILE = _training_job_pb2.InputMode.FILE + + +class AlgorithmName(object): + """ + The algorithm name is used for deciding which pre-built image to point to. + This is only required for use cases where SageMaker's built-in algorithm mode is used. + While we currently only support a subset of the algorithms, more will be added to the list. + See: https://docs.aws.amazon.com/sagemaker/latest/dg/algos.html + """ + CUSTOM = _training_job_pb2.AlgorithmName.CUSTOM + XGBOOST = _training_job_pb2.AlgorithmName.XGBOOST + + +class InputContentType(object): + """ + Specifies the type of content for input data. Different SageMaker built-in algorithms require different content types of input data + See https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-training.html + https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-algo-docker-registry-paths.html + """ + TEXT_CSV = _training_job_pb2.InputContentType.TEXT_CSV + + +class AlgorithmSpecification(_common.FlyteIdlEntity): + """ + Specifies the training algorithm to be used in the training job + This object is mostly a pass-through, with a couple of exceptions include: (1) in Flyte, users don't need to specify + TrainingImage; either use the built-in algorithm mode by using Flytekit's Simple Training Job and specifying an algorithm + name and an algorithm version or (2) when users want to supply custom algorithms they should set algorithm_name field to + CUSTOM. In this case, the value of the algorithm_version field has no effect + For pass-through use cases: refer to this AWS official document for more details + https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_AlgorithmSpecification.html + """ + def __init__( + self, + algorithm_name: int, + algorithm_version: str, + input_mode: int, + metric_definitions: List[MetricDefinition] = None, + input_content_type: int = InputContentType.TEXT_CSV, + ): + self._input_mode = input_mode + self._input_content_type = input_content_type + self._algorithm_name = algorithm_name + self._algorithm_version = algorithm_version + self._metric_definitions = metric_definitions or [] + + @property + def input_mode(self) -> int: + """ + enum value from InputMode. The input mode can be either PIPE or FILE + :rtype: int + """ + return self._input_mode + + @property + def input_content_type(self) -> int: + """ + enum value from InputContentType. The content type of the input data + See https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-training.html + https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-algo-docker-registry-paths.html + :rtype: int + """ + return self._input_content_type + + @property + def algorithm_name(self) -> int: + """ + The algorithm name is used for deciding which pre-built image to point to. + enum value from AlgorithmName. + :rtype: int + """ + return self._algorithm_name + + @property + def algorithm_version(self) -> str: + """ + version of the algorithm (if using built-in algorithm mode). + :rtype: str + """ + return self._algorithm_version + + @property + def metric_definitions(self) -> List[MetricDefinition]: + """ + A list of metric definitions for SageMaker to evaluate/track on the progress of the training job + See this: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_AlgorithmSpecification.html + + Note that, when you use one of the Amazon SageMaker built-in algorithms, you cannot define custom metrics. + If you are doing hyperparameter tuning, built-in algorithms automatically send metrics to hyperparameter tuning. + When using hyperparameter tuning, you do need to choose one of the metrics that the built-in algorithm emits as + the objective metric for the tuning job. + See this: https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-define-metrics.html + :rtype: List[MetricDefinition] + """ + return self._metric_definitions + + def to_flyte_idl(self) -> _training_job_pb2.AlgorithmSpecification: + + return _training_job_pb2.AlgorithmSpecification( + input_mode=self.input_mode, + algorithm_name=self.algorithm_name, + algorithm_version=self.algorithm_version, + metric_definitions=[m.to_flyte_idl() for m in self.metric_definitions], + input_content_type=self.input_content_type, + ) + + @classmethod + def from_flyte_idl(cls, pb2_object: _training_job_pb2.AlgorithmSpecification): + + return cls( + input_mode=pb2_object.input_mode, + algorithm_name=pb2_object.algorithm_name, + algorithm_version=pb2_object.algorithm_version, + metric_definitions=[MetricDefinition.from_flyte_idl(m) for m in pb2_object.metric_definitions], + input_content_type=pb2_object.input_content_type, + ) + + +class TrainingJob(_common.FlyteIdlEntity): + def __init__( + self, + algorithm_specification: AlgorithmSpecification, + training_job_resource_config: TrainingJobResourceConfig, + ): + self._algorithm_specification = algorithm_specification + self._training_job_resource_config = training_job_resource_config + + @property + def algorithm_specification(self) -> AlgorithmSpecification: + """ + Contains the information related to the algorithm to use in the training job + :rtype: AlgorithmSpecification + """ + return self._algorithm_specification + + @property + def training_job_resource_config(self) -> TrainingJobResourceConfig: + """ + Specifies the information around the instances that will be used to run the training job. + :rtype: TrainingJobResourceConfig + """ + return self._training_job_resource_config + + def to_flyte_idl(self) -> _training_job_pb2.TrainingJob: + """ + :rtype: _training_job_pb2.TrainingJob + """ + + return _training_job_pb2.TrainingJob( + algorithm_specification=self.algorithm_specification.to_flyte_idl(), + training_job_resource_config=self.training_job_resource_config.to_flyte_idl(), + ) + + @classmethod + def from_flyte_idl(cls, pb2_object: _training_job_pb2.TrainingJob): + """ + + :param pb2_object: + :rtype: TrainingJob + """ + return cls( + algorithm_specification=pb2_object.algorithm_specification, + training_job_resource_config=pb2_object.training_job_resource_config, + ) diff --git a/flytekit/plugins/__init__.py b/flytekit/plugins/__init__.py index a155108136..b626856a99 100644 --- a/flytekit/plugins/__init__.py +++ b/flytekit/plugins/__init__.py @@ -59,4 +59,4 @@ "pytorch", ["torch>=1.0.0,<2.0.0"], [torch] -) \ No newline at end of file +) diff --git a/flytekit/sdk/tasks.py b/flytekit/sdk/tasks.py index 86f8c8d32f..0fa08035a0 100644 --- a/flytekit/sdk/tasks.py +++ b/flytekit/sdk/tasks.py @@ -8,6 +8,10 @@ from flytekit.common.tasks import sdk_runnable as _sdk_runnable_tasks, sdk_dynamic as _sdk_dynamic, \ spark_task as _sdk_spark_tasks, generic_spark_task as _sdk_generic_spark_task, hive_task as _sdk_hive_tasks, \ sidecar_task as _sdk_sidecar_tasks, pytorch_task as _sdk_pytorch_tasks +from flytekit.common.tasks.sagemaker import ( + training_job_task as _sdk_sagemaker_training_job_tasks, + hpo_job_task as _sdk_hpo_job_tasks +) from flytekit.common.tasks import task as _task from flytekit.common.types import helpers as _type_helpers from flytekit.sdk.spark_types import SparkType as _spark_type @@ -1156,3 +1160,74 @@ def wrapper(fn): return wrapper(_task_function) else: return wrapper + + +def custom_training_job_task( + _task_function=None, + cache_version='', + retries=0, + interruptible=False, + deprecated='', + cache=False, + timeout=None, + workers_count=1, + per_replica_storage_request="", + per_replica_cpu_request="", + per_replica_gpu_request="", + per_replica_memory_request="", + per_replica_storage_limit="", + per_replica_cpu_limit="", + per_replica_gpu_limit="", + per_replica_memory_limit="", + environment=None, + cls=None +): + """ + + :param _task_function: + :param cache_version: + :param retries: + :param interruptible: + :param deprecated: + :param cache: + :param timeout: + :param workers_count: + :param per_replica_storage_request: + :param per_replica_cpu_request: + :param per_replica_gpu_request: + :param per_replica_memory_request: + :param per_replica_storage_limit: + :param per_replica_cpu_limit: + :param per_replica_gpu_limit: + :param per_replica_memory_limit: + :param environment: + :param cls: + :return: + """ + + def wrapper(fn): + return (cls or _sdk_sagemaker_training_job_tasks.SdkBuiltinAlgorithmTrainingJobTask)( + task_function=fn, + task_type=_common_constants.SdkTaskType.PYTORCH_TASK, + discovery_version=cache_version, + retries=retries, + interruptible=interruptible, + deprecated=deprecated, + discoverable=cache, + timeout=timeout or _datetime.timedelta(seconds=0), + workers_count=workers_count, + per_replica_storage_request=per_replica_storage_request, + per_replica_cpu_request=per_replica_cpu_request, + per_replica_gpu_request=per_replica_gpu_request, + per_replica_memory_request=per_replica_memory_request, + per_replica_storage_limit=per_replica_storage_limit, + per_replica_cpu_limit=per_replica_cpu_limit, + per_replica_gpu_limit=per_replica_gpu_limit, + per_replica_memory_limit=per_replica_memory_limit, + environment=environment or {} + ) + + if _task_function: + return wrapper(_task_function) + else: + return wrapper \ No newline at end of file diff --git a/sample-notebooks/raw-container-shell.ipynb b/sample-notebooks/raw-container-shell.ipynb index 8353dfafef..021525a7e2 100644 --- a/sample-notebooks/raw-container-shell.ipynb +++ b/sample-notebooks/raw-container-shell.ipynb @@ -164,9 +164,9 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3.7.4 64-bit ('flytekit': virtualenv)", + "display_name": "Python 3", "language": "python", - "name": "python37464bitflytekitvirtualenv72cbb5e9968e4a299c6026c09cce8d4c" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -178,7 +178,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.4" + "version": "3.7.6" } }, "nbformat": 4, diff --git a/sample-notebooks/raw-container.ipynb b/sample-notebooks/raw-container.ipynb index 9afa501c96..a1faf7179b 100644 --- a/sample-notebooks/raw-container.ipynb +++ b/sample-notebooks/raw-container.ipynb @@ -284,7 +284,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.4" + "version": "3.7.6" } }, "nbformat": 4, diff --git a/tests/flytekit/unit/sdk/tasks/test_sagemaker_tasks.py b/tests/flytekit/unit/sdk/tasks/test_sagemaker_tasks.py new file mode 100644 index 0000000000..ae9bb7dfff --- /dev/null +++ b/tests/flytekit/unit/sdk/tasks/test_sagemaker_tasks.py @@ -0,0 +1,185 @@ +from __future__ import absolute_import +from flytekit.common.tasks.sagemaker.training_job_task import SdkBuiltinAlgorithmTrainingJobTask +from flytekit.common.tasks.sagemaker.hpo_job_task import SdkSimpleHyperparameterTuningJobTask +from flytekit.common import constants as _common_constants +from flytekit.common.tasks import task as _sdk_task +from flytekit.models.core import identifier as _identifier +import datetime as _datetime +from flytekit.models.sagemaker.training_job import TrainingJobResourceConfig, AlgorithmSpecification, \ + MetricDefinition, AlgorithmName, InputMode, InputContentType +# from flytekit.sdk.sagemaker.types import InputMode, AlgorithmName +from google.protobuf.json_format import ParseDict +from flyteidl.plugins.sagemaker.training_job_pb2 import TrainingJobResourceConfig as _pb2_TrainingJobResourceConfig +from flyteidl.plugins.sagemaker.hyperparameter_tuning_job_pb2 import HyperparameterTuningJobConfig as _pb2_HPOJobConfig +from flytekit.sdk import types as _sdk_types +from flytekit.common.tasks.sagemaker import hpo_job_task +from flytekit.models import types as _idl_types +from flytekit.models.core import types as _core_types + +example_hyperparams = { + "base_score": "0.5", + "booster": "gbtree", + "csv_weights": "0", + "dsplit": "row", + "grow_policy": "depthwise", + "lambda_bias": "0.0", + "max_bin": "256", + "max_leaves": "0", + "normalize_type": "tree", + "objective": "reg:linear", + "one_drop": "0", + "prob_buffer_row": "1.0", + "process_type": "default", + "rate_drop": "0.0", + "refresh_leaf": "1", + "sample_type": "uniform", + "scale_pos_weight": "1.0", + "silent": "0", + "sketch_eps": "0.03", + "skip_drop": "0.0", + "tree_method": "auto", + "tweedie_variance_power": "1.5", + "updater": "grow_colmaker,prune", +} + +builtin_algorithm_training_job_task = SdkBuiltinAlgorithmTrainingJobTask( + training_job_resource_config=TrainingJobResourceConfig( + instance_type="ml.m4.xlarge", + instance_count=1, + volume_size_in_gb=25, + ), + algorithm_specification=AlgorithmSpecification( + input_mode=InputMode.FILE, + input_content_type=InputContentType.TEXT_CSV, + algorithm_name=AlgorithmName.XGBOOST, + algorithm_version="0.72", + ), +) + +builtin_algorithm_training_job_task._id = _identifier.Identifier( + _identifier.ResourceType.TASK, "my_project", "my_domain", "my_name", "my_version") + + +def test_builtin_algorithm_training_job_task(): + assert isinstance(builtin_algorithm_training_job_task, SdkBuiltinAlgorithmTrainingJobTask) + assert isinstance(builtin_algorithm_training_job_task, _sdk_task.SdkTask) + assert builtin_algorithm_training_job_task.interface.inputs['train'].description == '' + assert builtin_algorithm_training_job_task.interface.inputs['train'].type == \ + _idl_types.LiteralType( + blob=_core_types.BlobType( + format="csv", + dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART + ) + ) + assert builtin_algorithm_training_job_task.interface.inputs['train'].type == \ + _sdk_types.Types.MultiPartCSV.to_flyte_literal_type() + assert builtin_algorithm_training_job_task.interface.inputs['validation'].description == '' + assert builtin_algorithm_training_job_task.interface.inputs['validation'].type == \ + _sdk_types.Types.MultiPartCSV.to_flyte_literal_type() + assert builtin_algorithm_training_job_task.interface.inputs['train'].type == \ + _idl_types.LiteralType( + blob=_core_types.BlobType( + format="csv", + dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART + ) + ) + assert builtin_algorithm_training_job_task.interface.inputs['static_hyperparameters'].description == '' + assert builtin_algorithm_training_job_task.interface.inputs['static_hyperparameters'].type == \ + _sdk_types.Types.Generic.to_flyte_literal_type() + assert builtin_algorithm_training_job_task.interface.outputs['model'].description == '' + assert builtin_algorithm_training_job_task.interface.outputs['model'].type == \ + _sdk_types.Types.Blob.to_flyte_literal_type() + assert builtin_algorithm_training_job_task.type == _common_constants.SdkTaskType.SAGEMAKER_TRAINING_JOB_TASK + assert builtin_algorithm_training_job_task.metadata.timeout == _datetime.timedelta(seconds=0) + assert builtin_algorithm_training_job_task.metadata.deprecated_error_message == '' + assert builtin_algorithm_training_job_task.metadata.discoverable is False + assert builtin_algorithm_training_job_task.metadata.discovery_version == '' + assert builtin_algorithm_training_job_task.metadata.retries.retries == 0 + assert "metricDefinitions" not in builtin_algorithm_training_job_task.custom["algorithmSpecification"].keys() + + ParseDict(builtin_algorithm_training_job_task.custom['trainingJobResourceConfig'], + _pb2_TrainingJobResourceConfig) # fails the test if it cannot be parsed + + +builtin_algorithm_training_job_task2 = SdkBuiltinAlgorithmTrainingJobTask( + training_job_resource_config=TrainingJobResourceConfig( + instance_type="ml.m4.xlarge", + instance_count=1, + volume_size_in_gb=25, + ), + algorithm_specification=AlgorithmSpecification( + input_mode=InputMode.FILE, + input_content_type=InputContentType.TEXT_CSV, + algorithm_name=AlgorithmName.XGBOOST, + algorithm_version="0.72", + metric_definitions=[MetricDefinition(name="Validation error", regex="validation:error")] + ), +) + +simple_xgboost_hpo_job_task = hpo_job_task.SdkSimpleHyperparameterTuningJobTask( + training_job=builtin_algorithm_training_job_task2, + max_number_of_training_jobs=10, + max_parallel_training_jobs=5, + cache_version='1', + retries=2, + cacheable=True, +) + +simple_xgboost_hpo_job_task._id = _identifier.Identifier( + _identifier.ResourceType.TASK, "my_project", "my_domain", "my_name", "my_version") + + +def test_simple_hpo_job_task(): + assert isinstance(simple_xgboost_hpo_job_task, SdkSimpleHyperparameterTuningJobTask) + assert isinstance(simple_xgboost_hpo_job_task, _sdk_task.SdkTask) + # Checking if the input of the underlying SdkTrainingJobTask has been embedded + assert simple_xgboost_hpo_job_task.interface.inputs['train'].description == '' + assert simple_xgboost_hpo_job_task.interface.inputs['train'].type == \ + _sdk_types.Types.MultiPartCSV.to_flyte_literal_type() + assert simple_xgboost_hpo_job_task.interface.inputs['train'].type == \ + _idl_types.LiteralType( + blob=_core_types.BlobType( + format="csv", + dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART + ) + ) + assert simple_xgboost_hpo_job_task.interface.inputs['validation'].description == '' + assert simple_xgboost_hpo_job_task.interface.inputs['validation'].type == \ + _sdk_types.Types.MultiPartCSV.to_flyte_literal_type() + assert simple_xgboost_hpo_job_task.interface.inputs['validation'].type == \ + _idl_types.LiteralType( + blob=_core_types.BlobType( + format="csv", + dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART + ) + ) + assert simple_xgboost_hpo_job_task.interface.inputs['static_hyperparameters'].description == '' + assert simple_xgboost_hpo_job_task.interface.inputs['static_hyperparameters'].type == \ + _sdk_types.Types.Generic.to_flyte_literal_type() + + # Checking if the hpo-specific input is defined + assert simple_xgboost_hpo_job_task.interface.inputs['hyperparameter_tuning_job_config'].description == '' + assert simple_xgboost_hpo_job_task.interface.inputs['hyperparameter_tuning_job_config'].type == \ + _sdk_types.Types.Proto(_pb2_HPOJobConfig).to_flyte_literal_type() + assert simple_xgboost_hpo_job_task.interface.outputs['model'].description == '' + assert simple_xgboost_hpo_job_task.interface.outputs['model'].type == \ + _sdk_types.Types.Blob.to_flyte_literal_type() + assert simple_xgboost_hpo_job_task.type == _common_constants.SdkTaskType.SAGEMAKER_HYPERPARAMETER_TUNING_JOB_TASK + + # Checking if the spec of the TrainingJob is embedded into the custom field + # of this SdkSimpleHyperparameterTuningJobTask + assert simple_xgboost_hpo_job_task.to_flyte_idl().custom["trainingJob"] == ( + builtin_algorithm_training_job_task2.to_flyte_idl().custom) + + assert simple_xgboost_hpo_job_task.metadata.timeout == _datetime.timedelta(seconds=0) + assert simple_xgboost_hpo_job_task.metadata.discoverable is True + assert simple_xgboost_hpo_job_task.metadata.discovery_version == '1' + assert simple_xgboost_hpo_job_task.metadata.retries.retries == 2 + + assert simple_xgboost_hpo_job_task.metadata.deprecated_error_message == '' + assert "metricDefinitions" in simple_xgboost_hpo_job_task.custom["trainingJob"]["algorithmSpecification"].keys() + assert len(simple_xgboost_hpo_job_task.custom["trainingJob"]["algorithmSpecification"]["metricDefinitions"]) == 1 + """ These are attributes for SdkRunnable. We will need these when supporting CustomTrainingJobTask and CustomHPOJobTask + assert simple_xgboost_hpo_job_task.task_module == __name__ + assert simple_xgboost_hpo_job_task._get_container_definition().args[0] == 'pyflyte-execute' + """