Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More cli test #288

Open
wants to merge 33 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
67543d9
add model converter API
fumihwh Oct 5, 2018
70b05ed
refactor and add console_scripts
fumihwh Oct 6, 2018
e08fff8
edit readme
fumihwh Oct 6, 2018
84cf577
add cli
fumihwh Oct 8, 2018
6b62296
edit README.md
fumihwh Oct 8, 2018
88babe3
refactor README.md
fumihwh Oct 9, 2018
50be671
edit README.md
fumihwh Oct 9, 2018
dc31c4b
add converting from ckpt
fumihwh Oct 9, 2018
2ce670a
use underscore
fumihwh Oct 10, 2018
e822d03
add cli test
fumihwh Oct 10, 2018
1db5fd6
rename var
fumihwh Oct 10, 2018
489b3f4
bug fix
fumihwh Oct 10, 2018
fda1f85
make prepare_model as staticmethod
fumihwh Oct 10, 2018
aa4a5c4
Merge branch 'add-model-converter' of https://github.com/fumihwh/onnx…
tjingrant Oct 11, 2018
755220f
add model tests
tjingrant Oct 12, 2018
2c13097
model test checkin
tjingrant Oct 13, 2018
5dde1ad
support strided slice
tjingrant Oct 13, 2018
0eb5b7b
assert unsupported case
tjingrant Oct 13, 2018
e096a37
Update shape.py
tjingrant Oct 14, 2018
3833ec8
strided slice support
tjingrant Oct 14, 2018
61cf2b8
comment
tjingrant Oct 14, 2018
a3e8f80
Merge branch 'fix-shape-default-type' of https://github.com/onnx/onnx…
tjingrant Oct 14, 2018
0949b51
Merge branch 'support-strided-slice-rebase' of https://github.com/onn…
tjingrant Oct 14, 2018
45afb4a
more tests
tjingrant Oct 14, 2018
f2b7901
fix shape type incompatibility with onnx
tjingrant Oct 15, 2018
19b8f9f
Merge branch 'master' into more-cli-test
tjingrant Oct 15, 2018
5dce36d
Merge branch 'master' into more-cli-test
tjingrant Oct 16, 2018
ab4d678
Revert unintended changes
tjingrant Oct 16, 2018
f40adde
add mobilenet
tjingrant Oct 17, 2018
ad7913a
Merge branch 'more-cli-test' of https://github.com/onnx/onnx-tensorfl…
tjingrant Oct 17, 2018
18ad849
Merge branch 'master' into more-cli-test
tjingrant Oct 17, 2018
24861d2
Merge branch 'master' into more-cli-test
tjingrant Nov 21, 2018
84ab963
Merge branch 'master' into more-cli-test
tjingrant Dec 19, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion onnx_tf/handlers/frontend/shape.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from onnx_tf.handlers.frontend_handler import FrontendHandler
from onnx_tf.handlers.handler import onnx_op
from onnx_tf.handlers.handler import tf_op
from onnx_tf.common import get_unique_suffix

import tensorflow as tf
from onnx import TensorProto


@onnx_op("Shape")
Expand All @@ -9,4 +13,28 @@ class Shape(FrontendHandler):

@classmethod
def version_1(cls, node, **kwargs):
return cls.make_node_from_tf_node(node)
out_type = node.attr.get("out_type", tf.int32)

# A flag to indicate whether output is int32.
# If so, we need to insert a Cast node because
# ONNX shape op only supports int64 as output type.
need_cast_to_int32 = False
if tf.as_dtype(out_type) == tf.int32:
need_cast_to_int32 = True

shape_suffix = "_" + get_unique_suffix() if need_cast_to_int32 else ""
shape_name = cls.get_outputs_names(node)[0] + shape_suffix
shape_node = cls.make_node_from_tf_node(
node, outputs=[shape_name], name=shape_name)

if need_cast_to_int32:
dst_t = TensorProto.INT32
cast_node = cls.make_node_from_tf_node(
node,
inputs=[shape_name],
outputs=cls.get_outputs_names(node),
op_type="Cast",
to=TensorProto.INT32)
return [shape_node, cast_node]

return [shape_node]
79 changes: 61 additions & 18 deletions test/frontend/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import sys, os, tempfile
import zipfile
import logging
import subprocess
if sys.version_info >= (3,):
import urllib.request as urllib2
import urllib.parse as urlparse
Expand All @@ -17,6 +18,7 @@

import numpy as np
import tensorflow as tf
import onnx
from tensorflow.python.tools import freeze_graph

from onnx_tf.frontend import tensorflow_graph_to_onnx_model
Expand Down Expand Up @@ -85,22 +87,45 @@ class TestModel(unittest.TestCase):
pass


def create_test(test_model):
def create_test(test_model, interface):

def do_test_expected(self):
tf.reset_default_graph()
work_dir = "".join([test_model["name"], "-", "workspace"])
work_dir_prefix = work_dir + "/"
download_and_extract(test_model["asset_url"], work_dir)
freeze_graph.freeze_graph(
work_dir_prefix + test_model["graph_proto_path"], "", True,
work_dir_prefix + test_model["checkpoint_path"], ",".join(
test_model["outputs"]), "", "", work_dir_prefix + "frozen_graph.pb",
"", "")

with tf.gfile.GFile(work_dir_prefix + "frozen_graph.pb", "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
if "frozen_path" in test_model:
frozen_path = work_dir_prefix + test_model["frozen_path"]
else:
# Parse metagraph def to obtain graph_def
meta_graph_def_file = open(work_dir_prefix + test_model["metagraph_path"],
"rb")

meta_graph_def = tf.MetaGraphDef()
meta_graph_def.ParseFromString(meta_graph_def_file.read())

with open(work_dir_prefix + "graph_def.pb", 'wb') as f:
f.write(meta_graph_def.graph_def.SerializeToString())

# Proceed to freeze graph:
freeze_graph.freeze_graph(input_graph=work_dir_prefix + "graph_def.pb",
input_saver="",
input_binary=True,
input_checkpoint=work_dir_prefix + test_model["checkpoint_path"],
output_node_names=",".join(test_model["outputs"]),
restore_op_name="",
filename_tensor_name="",
output_graph=work_dir_prefix + "frozen_graph.pb",
clear_devices=True,
initializer_nodes="")
# Set the frozen graph path.
frozen_path = work_dir_prefix + "frozen_graph.pb"

# Now read the frozen graph and import it:
with tf.gfile.GFile(frozen_path, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())

with tf.Graph().as_default() as graph:
tf.import_graph_def(
Expand All @@ -125,14 +150,31 @@ def do_test_expected(self):
tf_output_tensors.append(graph.get_tensor_by_name(name + ":0"))
backend_output_names.append(name)

# Obtain reference tensorflow outputs:
with tf.Session(graph=graph) as sess:
logging.debug("ops in the graph:")
logging.debug(graph.get_operations())
output_tf = sess.run(tf_output_tensors, feed_dict=tf_feed_dict)

onnx_model = tensorflow_graph_to_onnx_model(graph_def, backend_output_names)

model = onnx_model
# Now we convert tensorflow models to onnx,
# if we use python API for conversion:
if interface == "python":
model = tensorflow_graph_to_onnx_model(graph_def, backend_output_names, ignore_unimplemented=True)
else: # else we use CLI utility
assert interface == "cli"
subprocess.check_call([
"onnx-tf",
"convert",
"-t",
"onnx",
"-i",
work_dir_prefix + "frozen_graph.pb",
"-o",
work_dir_prefix + "model.onnx",
])
model = onnx.load_model(work_dir_prefix + "model.onnx")

# Run the output onnx models in our backend:
tf_rep = prepare(model)
output_onnx_tf = tf_rep.run(backend_feed_dict)

Expand All @@ -149,12 +191,13 @@ def do_test_expected(self):
try:
for test_model in yaml.safe_load_all(config):
for device in test_model["devices"]:
if supports_device(device):
test_method = create_test(test_model)
test_name_parts = ["test", test_model["name"], device]
test_name = str("_".join(map(str, test_name_parts)))
test_method.__name__ = test_name
setattr(TestModel, test_method.__name__, test_method)
for interface in ["python", "cli"]:
if supports_device(device):
test_method = create_test(test_model, interface)
test_name_parts = ["test", test_model["name"], device, interface]
test_name = str("_".join(map(str, test_name_parts)))
test_method.__name__ = test_name
setattr(TestModel, test_method.__name__, test_method)
except yaml.YAMLError as exception:
print(exception)

Expand Down
27 changes: 25 additions & 2 deletions test/frontend/test_model.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,33 @@
name: mnist_tutorial
asset_url: https://s3-api.us-geo.objectstorage.softlayer.net/onnx-tensorflow/public_tutorial_assets.zip
input_name: Placeholder
graph_proto_path: tutorial_assets/graph.proto
checkpoint_dir: tutorial_assets/ckpt
metagraph_path: tutorial_assets/ckpt/model.ckpt.meta
checkpoint_path: tutorial_assets/ckpt/model.ckpt
devices: [CUDA]
inputs:
Placeholder: [1, 784]
outputs:
- fc2/add
- fc2/add
---
name: mnist-lstm
asset_url: https://s3-api.us-geo.objectstorage.softlayer.net/onnx-tensorflow/test-assets/mnist-lstm.zip
input_name: Placeholder
checkpoint_dir: checkpoint
metagraph_path: checkpoint/lstm.ckpt.meta
checkpoint_path: checkpoint/lstm.ckpt
devices: [CUDA]
inputs:
Placeholder: [4, 28, 28]
outputs:
- Softmax
---
name: mobilenet_v1_1
asset_url: https://s3-api.us-geo.objectstorage.softlayer.net/onnx-tensorflow/mobilenet_v1_1.0_224.zip
input_name: input
frozen_path: mobilenet_v1_1.0_224/mobilenet_v1_1.0_224_frozen.pb
devices: [CUDA]
inputs:
input: [1, 224, 224, 3]
outputs:
- MobilenetV1/Predictions/Softmax
3 changes: 3 additions & 0 deletions test/frontend/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ def do_test_expected(self):
("test_concat", tf.concat, "concat", [[get_rnd([1, 10]),get_rnd([10, 10]),get_rnd([20, 10])], 0], {}),
("test_bias_add_nchw", tf.nn.bias_add, "BiasAdd", [get_rnd([10, 32, 10, 10]),get_rnd([32])], {"data_format":"NCHW"}),
("test_bias_add_nhwc", tf.nn.bias_add, "BiasAdd", [get_rnd([10, 10, 10, 32]),get_rnd([32])], {"data_format":"NHWC"}),
("test_strided_slice", tf.strided_slice, "StridedSlice", [get_rnd([5, 5]), [0, 0], [1, 5]], {}),
("test_strided_slice_shrink", tf.strided_slice, "StridedSlice", [get_rnd([5, 5]), [0, 0], [1, 3]], {"shrink_axis_mask":1}),
]

if not legacy_opset_pre_ver(6):
Expand All @@ -185,3 +187,4 @@ def do_test_expected(self):

if __name__ == '__main__':
unittest.main()