diff --git a/onnx_tf/handlers/frontend/shape.py b/onnx_tf/handlers/frontend/shape.py index af130eaf1..0ef83277b 100644 --- a/onnx_tf/handlers/frontend/shape.py +++ b/onnx_tf/handlers/frontend/shape.py @@ -1,6 +1,10 @@ from onnx_tf.handlers.frontend_handler import FrontendHandler from onnx_tf.handlers.handler import onnx_op from onnx_tf.handlers.handler import tf_op +from onnx_tf.common import get_unique_suffix + +import tensorflow as tf +from onnx import TensorProto @onnx_op("Shape") @@ -9,4 +13,28 @@ class Shape(FrontendHandler): @classmethod def version_1(cls, node, **kwargs): - return cls.make_node_from_tf_node(node) + out_type = node.attr.get("out_type", tf.int32) + + # A flag to indicate whether output is int32. + # If so, we need to insert a Cast node because + # ONNX shape op only supports int64 as output type. + need_cast_to_int32 = False + if tf.as_dtype(out_type) == tf.int32: + need_cast_to_int32 = True + + shape_suffix = "_" + get_unique_suffix() if need_cast_to_int32 else "" + shape_name = cls.get_outputs_names(node)[0] + shape_suffix + shape_node = cls.make_node_from_tf_node( + node, outputs=[shape_name], name=shape_name) + + if need_cast_to_int32: + dst_t = TensorProto.INT32 + cast_node = cls.make_node_from_tf_node( + node, + inputs=[shape_name], + outputs=cls.get_outputs_names(node), + op_type="Cast", + to=TensorProto.INT32) + return [shape_node, cast_node] + + return [shape_node] diff --git a/test/frontend/test_model.py b/test/frontend/test_model.py index ec754a486..3e83787bf 100644 --- a/test/frontend/test_model.py +++ b/test/frontend/test_model.py @@ -8,6 +8,7 @@ import sys, os, tempfile import zipfile import logging +import subprocess if sys.version_info >= (3,): import urllib.request as urllib2 import urllib.parse as urlparse @@ -17,6 +18,7 @@ import numpy as np import tensorflow as tf +import onnx from tensorflow.python.tools import freeze_graph from onnx_tf.frontend import tensorflow_graph_to_onnx_model @@ -85,22 +87,45 @@ class TestModel(unittest.TestCase): pass -def create_test(test_model): +def create_test(test_model, interface): def do_test_expected(self): tf.reset_default_graph() work_dir = "".join([test_model["name"], "-", "workspace"]) work_dir_prefix = work_dir + "/" download_and_extract(test_model["asset_url"], work_dir) - freeze_graph.freeze_graph( - work_dir_prefix + test_model["graph_proto_path"], "", True, - work_dir_prefix + test_model["checkpoint_path"], ",".join( - test_model["outputs"]), "", "", work_dir_prefix + "frozen_graph.pb", - "", "") - with tf.gfile.GFile(work_dir_prefix + "frozen_graph.pb", "rb") as f: - graph_def = tf.GraphDef() - graph_def.ParseFromString(f.read()) + if "frozen_path" in test_model: + frozen_path = work_dir_prefix + test_model["frozen_path"] + else: + # Parse metagraph def to obtain graph_def + meta_graph_def_file = open(work_dir_prefix + test_model["metagraph_path"], + "rb") + + meta_graph_def = tf.MetaGraphDef() + meta_graph_def.ParseFromString(meta_graph_def_file.read()) + + with open(work_dir_prefix + "graph_def.pb", 'wb') as f: + f.write(meta_graph_def.graph_def.SerializeToString()) + + # Proceed to freeze graph: + freeze_graph.freeze_graph(input_graph=work_dir_prefix + "graph_def.pb", + input_saver="", + input_binary=True, + input_checkpoint=work_dir_prefix + test_model["checkpoint_path"], + output_node_names=",".join(test_model["outputs"]), + restore_op_name="", + filename_tensor_name="", + output_graph=work_dir_prefix + "frozen_graph.pb", + clear_devices=True, + initializer_nodes="") + # Set the frozen graph path. + frozen_path = work_dir_prefix + "frozen_graph.pb" + + # Now read the frozen graph and import it: + with tf.gfile.GFile(frozen_path, "rb") as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) with tf.Graph().as_default() as graph: tf.import_graph_def( @@ -125,14 +150,31 @@ def do_test_expected(self): tf_output_tensors.append(graph.get_tensor_by_name(name + ":0")) backend_output_names.append(name) + # Obtain reference tensorflow outputs: with tf.Session(graph=graph) as sess: logging.debug("ops in the graph:") logging.debug(graph.get_operations()) output_tf = sess.run(tf_output_tensors, feed_dict=tf_feed_dict) - onnx_model = tensorflow_graph_to_onnx_model(graph_def, backend_output_names) - - model = onnx_model + # Now we convert tensorflow models to onnx, + # if we use python API for conversion: + if interface == "python": + model = tensorflow_graph_to_onnx_model(graph_def, backend_output_names, ignore_unimplemented=True) + else: # else we use CLI utility + assert interface == "cli" + subprocess.check_call([ + "onnx-tf", + "convert", + "-t", + "onnx", + "-i", + work_dir_prefix + "frozen_graph.pb", + "-o", + work_dir_prefix + "model.onnx", + ]) + model = onnx.load_model(work_dir_prefix + "model.onnx") + + # Run the output onnx models in our backend: tf_rep = prepare(model) output_onnx_tf = tf_rep.run(backend_feed_dict) @@ -149,12 +191,13 @@ def do_test_expected(self): try: for test_model in yaml.safe_load_all(config): for device in test_model["devices"]: - if supports_device(device): - test_method = create_test(test_model) - test_name_parts = ["test", test_model["name"], device] - test_name = str("_".join(map(str, test_name_parts))) - test_method.__name__ = test_name - setattr(TestModel, test_method.__name__, test_method) + for interface in ["python", "cli"]: + if supports_device(device): + test_method = create_test(test_model, interface) + test_name_parts = ["test", test_model["name"], device, interface] + test_name = str("_".join(map(str, test_name_parts))) + test_method.__name__ = test_name + setattr(TestModel, test_method.__name__, test_method) except yaml.YAMLError as exception: print(exception) diff --git a/test/frontend/test_model.yaml b/test/frontend/test_model.yaml index e41b89ba5..8a280d400 100644 --- a/test/frontend/test_model.yaml +++ b/test/frontend/test_model.yaml @@ -1,10 +1,33 @@ name: mnist_tutorial asset_url: https://s3-api.us-geo.objectstorage.softlayer.net/onnx-tensorflow/public_tutorial_assets.zip input_name: Placeholder -graph_proto_path: tutorial_assets/graph.proto +checkpoint_dir: tutorial_assets/ckpt +metagraph_path: tutorial_assets/ckpt/model.ckpt.meta checkpoint_path: tutorial_assets/ckpt/model.ckpt devices: [CUDA] inputs: Placeholder: [1, 784] outputs: - - fc2/add \ No newline at end of file + - fc2/add +--- +name: mnist-lstm +asset_url: https://s3-api.us-geo.objectstorage.softlayer.net/onnx-tensorflow/test-assets/mnist-lstm.zip +input_name: Placeholder +checkpoint_dir: checkpoint +metagraph_path: checkpoint/lstm.ckpt.meta +checkpoint_path: checkpoint/lstm.ckpt +devices: [CUDA] +inputs: + Placeholder: [4, 28, 28] +outputs: + - Softmax +--- +name: mobilenet_v1_1 +asset_url: https://s3-api.us-geo.objectstorage.softlayer.net/onnx-tensorflow/mobilenet_v1_1.0_224.zip +input_name: input +frozen_path: mobilenet_v1_1.0_224/mobilenet_v1_1.0_224_frozen.pb +devices: [CUDA] +inputs: + input: [1, 224, 224, 3] +outputs: + - MobilenetV1/Predictions/Softmax \ No newline at end of file diff --git a/test/frontend/test_node.py b/test/frontend/test_node.py index 3428cae8b..d2acc7c28 100644 --- a/test/frontend/test_node.py +++ b/test/frontend/test_node.py @@ -166,6 +166,8 @@ def do_test_expected(self): ("test_concat", tf.concat, "concat", [[get_rnd([1, 10]),get_rnd([10, 10]),get_rnd([20, 10])], 0], {}), ("test_bias_add_nchw", tf.nn.bias_add, "BiasAdd", [get_rnd([10, 32, 10, 10]),get_rnd([32])], {"data_format":"NCHW"}), ("test_bias_add_nhwc", tf.nn.bias_add, "BiasAdd", [get_rnd([10, 10, 10, 32]),get_rnd([32])], {"data_format":"NHWC"}), +("test_strided_slice", tf.strided_slice, "StridedSlice", [get_rnd([5, 5]), [0, 0], [1, 5]], {}), +("test_strided_slice_shrink", tf.strided_slice, "StridedSlice", [get_rnd([5, 5]), [0, 0], [1, 3]], {"shrink_axis_mask":1}), ] if not legacy_opset_pre_ver(6): @@ -185,3 +187,4 @@ def do_test_expected(self): if __name__ == '__main__': unittest.main() + \ No newline at end of file