Skip to content
This repository has been archived by the owner on Mar 26, 2019. It is now read-only.

Commit

Permalink
Merge pull request #16 from Roshrini/code-fixes
Browse files Browse the repository at this point in the history
Code fixes for pooling layer and broadcast operators shape fix
  • Loading branch information
lupesko authored Jan 22, 2018
2 parents 625cde1 + f6c061e commit 013a6a2
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 36 deletions.
3 changes: 0 additions & 3 deletions onnx_mxnet/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# pylint: disable=invalid-name,no-self-use,too-many-branches,too-few-public-methods,too-many-arguments
"""Shared functions and classes for frontends."""
from __future__ import absolute_import as _abs
import warnings
from mxnet.base import string_types

class Renamer(object):
Expand Down Expand Up @@ -93,8 +92,6 @@ def __call__(self, attrs):
for k in attrs.keys():
if k in self._excludes:
raise NotImplementedError("Attribute {} not supported yet.".format(k))
elif k in self._disables:
warnings.warn("Attribute {} is disabled in sym.{}".format(k, op_name))
elif k in self._ignores:
pass
elif k in self._transforms:
Expand Down
23 changes: 9 additions & 14 deletions onnx_mxnet/import_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,9 @@ def _pooling(name):
transforms={
'kernel_shape': 'kernel',
'strides': 'stride',
'pads': ('pad', (0, 0), _revert_caffe2_pad)},
'pads': 'pad'},
# pooling convention full to match caffe2
extras={'pool_type': name, 'pooling_convention':'full'},
ignores=['dilations'],
extras={'pool_type': name, 'pooling_convention':'valid'},
custom_check=_dimension_constraint())

def _conv():
Expand All @@ -88,21 +87,17 @@ def _conv_transpose():
'kernel_shape': 'kernel',
'strides': 'stride',
'dilations': ('dilate', (0, 0)),
'pads': ('pad', (0, 0), _revert_caffe2_pad)},
'pads': ('pad', (0, 0), _revert_caffe2_pad),
'group': ('num_group', 1)},
disables=['output_shape'],
custom_check=_dimension_constraint())

def _change_eps_cudnn(attr):
"""Limiting eps value to 1e-5 for cudnn batchnorm."""
if attr < 1e-5:
attr = 1e-4
return attr

def _batch_norm():
"""converting attributes for BatchNorm operator"""
return AttrCvt(
op_name='BatchNorm',
transforms={'epsilon': ('eps', (1e-5), _change_eps_cudnn)},
transforms={'epsilon': 'eps'},
extras={'cudnn_off': 1},
ignores=['spatial', 'is_test', 'consumed_inputs'])

def _activation(name):
Expand All @@ -118,9 +113,9 @@ def _pad_sequence_fix(attr):
mxnet: (x1_begin, x1_end, ... , xn_begin, xn_end)
onnx: (x1_begin, x2_begin, ... , xn_end, xn_end)"""
new_attr = ()
if len(attr)%2 == 0:
for index in range(len(attr) / 2):
new_attr = new_attr + attr[index::len(attr) / 2]
if len(attr) % 2 == 0:
for index in range(int(len(attr) / 2)):
new_attr = new_attr + attr[index::int(len(attr) / 2)]
return new_attr

def _pad():
Expand Down
36 changes: 27 additions & 9 deletions onnx_mxnet/import_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
""" Support import export formats."""
from __future__ import absolute_import as _abs
import mxnet as mx
from onnx_mxnet.import_helper import _identity_list, _convert_map
from onnx_mxnet.import_helper import _identity_list, _convert_map, _pad_sequence_fix

def _convert_operator(op_name, attrs, identity_list=None, convert_map=None):
"""Convert from onnx operator to mxnet operator.
Expand Down Expand Up @@ -111,13 +111,13 @@ def from_onnx(self, graph):
op_name = node.op_type
node_name = node.name.strip()
node_name = node_name if node_name else None
attr = self._parse_attr(node.attribute)
new_op, new_attr = _convert_operator(op_name, attr)
onnx_attr = self._parse_attr(node.attribute)
new_op, mx_attr = _convert_operator(op_name, onnx_attr)
inputs = [self._nodes[self._renames.get(i, i)] for i in node.input]

# some workarounds for onnx problem
new_attr = self._fix_bias(new_op, new_attr, len(inputs))
new_attr = self._fix_channels(new_op, new_attr, list(node.input))
mx_attr = self._fix_bias(new_op, mx_attr, len(inputs))
mx_attr = self._fix_channels(new_op, mx_attr, list(node.input))
self._fix_bias_shape(node.op_type, graph.node[idx - 1].op_type, node.input)

# calling again to get new symbols after some workarounds
Expand All @@ -127,14 +127,19 @@ def from_onnx(self, graph):
# mxnet's equivalent linalg_gemm doesn't. So using combination of
# transpose and FullyConnected operators.
if op_name == 'Gemm':
new_op, inputs, new_attr = self._fix_gemm('FullyConnected', inputs, attr)
new_op, inputs, mx_attr = self._fix_gemm('FullyConnected', inputs, onnx_attr)

# onnx slice works on multiple axes whereas mxnet's slice_axis is for single axis
if op_name == 'Slice':
op = self._fix_slice(inputs, new_attr)
op = self._fix_slice(inputs, mx_attr)
elif op_name == 'AveragePool' and onnx_attr.get('pads') is not None or \
op_name == 'MaxPool' and onnx_attr.get('pads') is not None:
op = self._fix_pooling(op_name, inputs, onnx_attr)
else:
op = new_op(name=node_name, *inputs, **new_attr)
op = new_op(name=node_name, *inputs, **mx_attr)

node_output = self._fix_outputs(op_name, node.output)

assert len(node_output) == len(op.list_outputs()), (
"Number of output mismatch {} vs {} in {}.".format(
len(node_output), len(op.list_outputs()), op_name))
Expand Down Expand Up @@ -176,6 +181,19 @@ def run_node(self, node, device='CPU'): # pylint: disable=unused-argument
# now return the outputs
return op

def _fix_pooling(self, op_name, inputs, new_attr):
"""onnx pooling operator supports asymmetrical padding
Adding pad operator before pooling in mxnet to work with onnx"""
pool_type = 'avg' if op_name == 'AveragePool' else 'max'
stride = new_attr.get('strides')
kernel = new_attr.get('kernel_shape')
padding = new_attr.get('pads')
pad_width = (0, 0, 0, 0) + _pad_sequence_fix(padding)
new_pad_op = mx.sym.pad(inputs[0], mode='constant', pad_width=pad_width)
new_pooling_op = mx.sym.Pooling(new_pad_op, pool_type=pool_type,
stride=stride, kernel=kernel)
return new_pooling_op

def _fix_slice(self, inputs, new_attr):
"""onnx slice provides slicing on multiple axis. Adding multiple slice_axis operator
for multiple axes from mxnet"""
Expand Down Expand Up @@ -258,7 +276,7 @@ def _fix_bias(self, op, attrs, num_inputs):

def _fix_bias_shape(self, op_name, last_op_name, inputs):
"""A workaround to reshape bias term to (1, num_channel)."""
if op_name == 'Add' and last_op_name == 'Conv':
if op_name == 'Add' and last_op_name == 'Conv' or op_name == 'Mul' or op_name == 'Add':
assert len(list(inputs)) == 2
bias_name = self._renames.get(inputs[1], inputs[1])
bias = self._params[bias_name]
Expand Down
17 changes: 9 additions & 8 deletions onnx_mxnet/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
def extract_file(model_tar):
"""Extract tar file and returns model path and input, output data"""
# extract tar file
tar = tarfile.open(model_tar, "r:gz")
tar = tarfile.open(model_tar, "r:*")
tar.extractall()
tar.close()
path = model_tar.rsplit('_', 1)[0]
Expand All @@ -53,14 +53,15 @@ def verify_onnx_forward_impl(model_path, input_data, output_data):
"""Verifies result after inference"""
print("Converting onnx format to mxnet's symbol and params...")
sym, params = onnx_mxnet.import_model(model_path)

# create module
mod = mx.mod.Module(symbol=sym, data_names=['input_0'], context=mx.cpu(), label_names=None)
mod.bind(for_training=False, data_shapes=[('input_0', input_data.shape)], label_shapes=None)
mod.set_params(arg_params=params, aux_params=None)
mod.set_params(arg_params=params, aux_params=params, allow_missing=True, allow_extra=True)
# run inference
Batch = namedtuple('Batch', ['data'])

mod.forward(Batch([mx.nd.array(input_data)]))
mod.forward(Batch([mx.nd.array(input_data)]), is_train=False)

# Run the model with an onnx backend and verify the results
npt.assert_equal(mod.get_outputs()[0].shape, output_data.shape)
Expand All @@ -81,8 +82,8 @@ def verify_model(name):
verify_model('bvlc_alexnet_onnx') # working
verify_model('vgg16_onnx') # working
verify_model('vgg19_onnx') # working
# verify_model('inception_v1_onnx') # working, accuracy is different
# verify_model('inception_v2_onnx') # [WIP]
# verify_model('shufflenet_onnx') # [WIP]
# verify_model('densenet121_onnx') # [WIP]
# verify_model('resnet50_onnx') # [WIP]
#verify_model('inception_v1_onnx') # working, accuracy is different 1.4
#verify_model('inception_v2_onnx') # working, accuracy is different 7.4
#verify_model('shufflenet_onnx') # working, accuracy is different 10.2
verify_model('densenet121_onnx') # working
#verify_model('resnet50_onnx') # working, accuracy is different 18.1
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@

setup(
name='onnx-mxnet',
version='0.3.1',
version='0.3.2',
description='ONNX-MXNet Model converter',
url='https://github.com/onnx/onnx-mxnet',
keywords='ONNX MXNet model converter deep learning',
packages=pkgs,
install_requires=['mxnet>=0.11.0', 'onnx>=0.2'],
tests_require=['mxnet>=0.11.0', 'onnx>=0.2', 'pylint'],
tests_require=['mxnet>=0.11.0', 'onnx>=0.2', 'numpy', 'pylint'],
include_package_data=True,
license='Apache 2.0'
)

0 comments on commit 013a6a2

Please sign in to comment.