-
Notifications
You must be signed in to change notification settings - Fork 1
/
run.py
37 lines (24 loc) · 873 Bytes
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import logging
import sys
import os
import boto3
import botocore
import sagemaker
from sagemaker.tensorflow import TensorFlow
import sagemaker
import tensorflow as tf
import numpy as np
tf_estimator = TensorFlow(entry_point='itemembd.py', role='SageMakerRole',
training_steps=10, evaluation_steps=None,
train_instance_count=1, train_instance_type='ml.p2.xlarge')
tf_estimator.fit('s3://bucket/path/to/training/data')
predictor = tf_estimator.deploy()
user = tf.make_tensor_proto(values=np.asarray([10]), shape=[1], dtype=tf.float64)
item = tf.make_tensor_proto(values=np.asarray([10]), shape=[1], dtype=tf.float64)
# not working
predictor.predict({'user': 10, 'item': 10})
predictor.predict({'user': [10], 'item': [10]})
d = {'user': user, 'item': item}
predictor.predict(d)
# working
predictor.predict(item)