Skip to content

Commit

Permalink
Merge pull request #147 from joemansfield/ss-jm-vertexai-train-cpuonly
Browse files Browse the repository at this point in the history
ss-jm-add-tf-version-flag
  • Loading branch information
mylenebiddle authored Mar 29, 2022
2 parents ee1db45 + 1778c31 commit d3fd309
Showing 1 changed file with 18 additions and 10 deletions.
28 changes: 18 additions & 10 deletions 10_mlops/train_on_vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@
ENDPOINT_NAME = 'flights'


def train_custom_model(data_set, timestamp, develop_mode, cpu_only_mode, extra_args=None):
def train_custom_model(data_set, timestamp, develop_mode, cpu_only_mode, tf_version, extra_args=None):
# Set up training and deployment infra
tf_version = '2-' + tf.__version__[2:3]

if cpu_only_mode:
train_image='us-docker.pkg.dev/vertex-ai/training/tf-cpu.2-6:latest'
deploy_image='us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-6:latest'
train_image='us-docker.pkg.dev/vertex-ai/training/tf-cpu.{}:latest'.format(tf_version)
deploy_image='us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.{}:latest'.format(tf_version)
else:
train_image = "us-docker.pkg.dev/vertex-ai/training/tf-gpu.{}:latest".format(tf_version)
deploy_image = "us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.{}:latest".format(tf_version)
Expand Down Expand Up @@ -103,11 +103,10 @@ def train_automl_model(data_set, timestamp, develop_mode):
return model


def do_hyperparameter_tuning(data_set, timestamp, develop_mode, cpu_only_mode):
def do_hyperparameter_tuning(data_set, timestamp, develop_mode, cpu_only_mode, tf_version):
# Vertex AI services require regional API endpoints.
tf_version = '2-' + tf.__version__[2:3]
if cpu_only_mode:
train_image='us-docker.pkg.dev/vertex-ai/training/tf-cpu.2-6:latest'
train_image='us-docker.pkg.dev/vertex-ai/training/tf-cpu.{}:latest'.format(tf_version)
else:
train_image = "us-docker.pkg.dev/vertex-ai/training/tf-gpu.{}:latest".format(tf_version)

Expand Down Expand Up @@ -182,7 +181,7 @@ def do_hyperparameter_tuning(data_set, timestamp, develop_mode, cpu_only_mode):

# run the best trial to completion
logging.info('Launching full training job with {}'.format(best_params))
return train_custom_model(data_set, timestamp, develop_mode, cpu_only_mode, extra_args=best_params)
return train_custom_model(data_set, timestamp, develop_mode, cpu_only_mode, tf_version, extra_args=best_params)


@dsl.pipeline(name="flights-ch9-pipeline",
Expand All @@ -198,14 +197,18 @@ def main():
display_name='data-{}'.format(ENDPOINT_NAME),
gcs_source=all_files
)
if TF_VERSION is not None:
tf_version = TF_VERSION.replace(".", "-")
else:
tf_version = '2-' + tf.__version__[2:3]

# train
if AUTOML:
model = train_automl_model(data_set, TIMESTAMP, DEVELOP_MODE)
elif NUM_HPARAM_TRIALS > 1:
model = do_hyperparameter_tuning(data_set, TIMESTAMP, DEVELOP_MODE, CPU_ONLY_MODE)
model = do_hyperparameter_tuning(data_set, TIMESTAMP, DEVELOP_MODE, CPU_ONLY_MODE, tf_version)
else:
model = train_custom_model(data_set, TIMESTAMP, DEVELOP_MODE, CPU_ONLY_MODE)
model = train_custom_model(data_set, TIMESTAMP, DEVELOP_MODE, CPU_ONLY_MODE, tf_version)

# create endpoint if it doesn't already exist
endpoints = aiplatform.Endpoint.list(
Expand Down Expand Up @@ -294,6 +297,10 @@ def run_pipeline():
dest='cpuonly',
action='store_true')
parser.set_defaults(cpuonly=False)
parser.add_argument(
'--tfversion',
help='TensorFlow version to use'
)

# parse args
logging.getLogger().setLevel(logging.INFO)
Expand All @@ -303,6 +310,7 @@ def run_pipeline():
REGION = args['region']
DEVELOP_MODE = args['develop']
CPU_ONLY_MODE = args['cpuonly']
TF_VERSION = args['tfversion']
AUTOML = args['automl']
NUM_HPARAM_TRIALS = args['num_hparam_trials']
TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")
Expand Down

0 comments on commit d3fd309

Please sign in to comment.