Skip to content
This repository has been archived by the owner on Oct 13, 2021. It is now read-only.

Use apply_squeeze in opset 13 rather than add_node #669

Merged
merged 9 commits into from
Dec 18, 2020
91 changes: 26 additions & 65 deletions applications/mask_rcnn/mask_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,22 +69,14 @@ def convert_BatchNorm(scope, operator, container):


def convert_apply_box_deltas_graph(scope, operator, container, oopb, box_transpose, score_identity, deltas_transpose, windows_transpose):
box_squeeze = scope.get_unique_variable_name('box_squeeze')
attrs = {'axes': [0]}
container.add_node('Squeeze', box_transpose, box_squeeze, op_version=operator.target_opset,
**attrs)
oopb = OnnxOperatorBuilder(container, scope)
box_squeeze = oopb.apply_squeeze(box_transpose, name=operator.full_name + '_box_squeeze', axes=[0])[0]
# output shape: [spatial_dimension, 4]

deltas_squeeze = scope.get_unique_variable_name('deltas_squeeze')
attrs = {'axes': [0]}
container.add_node('Squeeze', deltas_transpose, deltas_squeeze, op_version=operator.target_opset,
**attrs)
deltas_squeeze = oopb.apply_squeeze(deltas_transpose, name=operator.full_name + '_deltas_squeeze', axes=[0])[0]
# output shape: [spatial_dimension, num_classes, 4]

score_squeeze = scope.get_unique_variable_name('score_squeeze')
attrs = {'axes': [0]}
container.add_node('Squeeze', score_identity, score_squeeze, op_version=operator.target_opset,
**attrs)
score_squeeze = oopb.apply_squeeze(score_identity, name=operator.full_name + '_score_squeeze', axes=[0])[0]
# output shape: [spatial_dimension, num_classes]

class_ids = scope.get_unique_variable_name('class_ids')
Expand Down Expand Up @@ -113,11 +105,9 @@ def convert_apply_box_deltas_graph(scope, operator, container, oopb, box_transpo
op_domain='com.microsoft',
op_version=1)

attrs = {'axes': [1]}
prob_range_unsqueeze = oopb.add_node('Unsqueeze',
[prob_range],
operator.inputs[1].full_name + '_prob_range_unsqueeze',
**attrs)
prob_range_unsqueeze = oopb.apply_unsqueeze([prob_range],
operator.inputs[1].full_name + '_prob_range_unsqueeze',
axes=[1])[0]
# output shape: [spatial_dimension, 1]

attrs = {'axis': 1}
Expand Down Expand Up @@ -272,10 +262,8 @@ def convert_apply_box_deltas_graph(scope, operator, container, oopb, box_transpo
[x1, width_exp],
operator.inputs[0].full_name + '_x2')

windows_squeeze = scope.get_unique_variable_name('windows_squeeze')
attrs = {'axes': [0]}
container.add_node('Squeeze', windows_transpose, windows_squeeze, op_version=operator.target_opset,
**attrs)
windows_squeeze = oopb.apply_squeeze(windows_transpose, name=operator.full_name + '_windows_squeeze',
axes=[0])[0]
wy1 = oopb.add_node('Slice',
[windows_squeeze,
('_start', oopb.int64, np.array([0], dtype='int64')),
Expand Down Expand Up @@ -336,10 +324,8 @@ def convert_apply_box_deltas_graph(scope, operator, container, oopb, box_transpo
op_version=operator.target_opset,
name=operator.outputs[0].full_name + '_concat_result', **attrs)

concat_unsqueeze = scope.get_unique_variable_name('concat_unsqueeze')
attrs = {'axes': [0]}
container.add_node('Unsqueeze', concat_result, concat_unsqueeze, op_version=operator.target_opset,
**attrs)
concat_unsqueeze = oopb.apply_unsqueeze(concat_result, name=operator.full_name + '_concat_unsqueeze',
axes=[0])[0]
return concat_unsqueeze


Expand All @@ -358,10 +344,8 @@ def norm_boxes_graph(scope, operator, container, oopb, image_meta):
('_axes', oopb.int64, np.array([0], dtype='int64'))
],
operator.inputs[0].full_name + '_image_shape')
image_shape_squeeze = scope.get_unique_variable_name('image_shape_squeeze')
attrs = {'axes': [0]}
container.add_node('Squeeze', image_shape, image_shape_squeeze, op_version=operator.target_opset,
**attrs)
image_shape_squeeze = oopb.apply_squeeze(image_shape, name=operator.full_name + '_image_shape_squeeze', axes=[0])[0]

window = oopb.add_node('Slice',
[image_meta,
('_start', oopb.int64, np.array([7], dtype='int64')),
Expand Down Expand Up @@ -516,13 +500,8 @@ def convert_DetectionLayer(scope, operator, container):
name=nms_node.name + '_box_idx')
# output shape: [num_selected_indices, 1]

box_idx_squeeze = scope.get_unique_variable_name(operator.output_full_names[0] + '_box_idx_squeeze')
attrs = {'axes': [1]}
container.add_node("Squeeze",
box_idx_output,
box_idx_squeeze,
op_version=operator.target_opset,
name=nms_node.name + '_box_idx_squeeze', **attrs)
box_idx_squeeze = oopb.apply_squeeze(box_idx_output,
name=nms_node.name + '_box_idx_squeeze', axes=[1])[0]
# output shape: [num_selected_indices]

starts_init_3 = scope.get_unique_variable_name('starts')
Expand All @@ -548,23 +527,12 @@ def convert_DetectionLayer(scope, operator, container):
name=nms_node.name + '_class_box_idx')
# output shape: [num_selected_indices, 2]

box_squeeze = scope.get_unique_variable_name(operator.output_full_names[0] + '_box_squeeze')
attrs = {'axes': [0]}
container.add_node("Squeeze",
delta_mul_output,
box_squeeze,
op_version=operator.target_opset,
name=nms_node.name + '_box_squeeze', **attrs)
box_squeeze = oopb.apply_squeeze(delta_mul_output,
name=nms_node.name + '_box_squeeze', axes=[0])[0]
# output shape: [spatial_dimension, 4]

score_squeeze = scope.get_local_variable_or_declare_one(operator.output_full_names[0] + '_score_squeeze',
type=FloatTensorType(shape=[None]))
attrs = {'axes': [0]}
container.add_node("Squeeze",
score_identity,
score_squeeze.full_name,
op_version=operator.target_opset,
name=nms_node.name + '_score_squeeze', **attrs)
score_squeeze = oopb.apply_squeeze(score_identity,
name=nms_node.name + '_score_squeeze', axes=[0])[0]
# output shape: [spatial_dimension, num_classes]

box_gather = scope.get_unique_variable_name(operator.output_full_names[0] + '_box_gather')
Expand All @@ -578,19 +546,14 @@ def convert_DetectionLayer(scope, operator, container):

score_gather = scope.get_unique_variable_name(operator.output_full_names[0] + '_score_gather')
container.add_node("GatherND",
[score_squeeze.full_name, class_box_idx_output.full_name],
[score_squeeze, class_box_idx_output.full_name],
score_gather,
op_version=operator.target_opset,
name=nms_node.name + '_score_gather')
# output shape: [num_selected_indices]

score_gather_unsqueeze = scope.get_unique_variable_name(operator.output_full_names[0] + '_score_gather_unsqueeze')
attrs = {'axes': [1]}
container.add_node("Unsqueeze",
score_gather,
score_gather_unsqueeze,
op_version=operator.target_opset,
name=nms_node.name + '_score_gather_unsqueeze', **attrs)
score_gather_unsqueeze = oopb.apply_unsqueeze(score_gather,
name=nms_node.name + '_score_gather_unsqueeze', axes=[1])[0]
# output shape: [num_selected_indices, 1]


Expand Down Expand Up @@ -661,12 +624,10 @@ def convert_DetectionLayer(scope, operator, container):
nms_node.name + '_detection_final'
)

attrs = {'axes': [0]}
container.add_node("Unsqueeze",
detection_final,
operator.output_full_names[0],
op_version=operator.target_opset,
name=nms_node.name + '_concat_unsqueeze', **attrs)
oopb.apply_op_with_output('apply_unsqueeze',
detection_final,
operator.output_full_names[0],
name=nms_node.name + '_concat_unsqueeze', axes=[0])
# output shape: [1, num_top_K, 6]


Expand Down
16 changes: 5 additions & 11 deletions applications/yolov3/yolov3.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from keras.models import load_model
from keras2onnx import convert_keras
from keras2onnx import set_converter
from keras2onnx.common.onnx_ops import apply_transpose, apply_identity, apply_cast
from keras2onnx.common.onnx_ops import apply_transpose, apply_identity, apply_cast, OnnxOperatorBuilder
from keras2onnx.proto import onnx_proto
from onnxconverter_common.onnx_ex import get_maximum_opset_supported
from onnxconverter_common.onnx_fx import Graph
Expand Down Expand Up @@ -324,19 +324,15 @@ def detect_img(yolo, img_url, model_file_name):

def convert_NMSLayer(scope, operator, container):
# type: (keras2onnx.common.InterimContext, keras2onnx.common.Operator, keras2onnx.common.OnnxObjectContainer) -> None
oopb = OnnxOperatorBuilder(container, scope)
box_transpose = scope.get_unique_variable_name(operator.inputs[0].full_name + '_tx')
score_transpose = scope.get_unique_variable_name(operator.inputs[1].full_name + '_tx')

apply_identity(scope, operator.inputs[0].full_name, box_transpose, container)
apply_transpose(scope, operator.inputs[1].full_name, score_transpose, container, perm=[1, 0])

box_batch = scope.get_unique_variable_name(operator.inputs[0].full_name + '_btc')
score_batch = scope.get_unique_variable_name(operator.inputs[1].full_name + '_btc')

container.add_node("Unsqueeze", box_transpose,
box_batch, op_version=operator.target_opset, axes=[0])
container.add_node("Unsqueeze", score_transpose,
score_batch, op_version=operator.target_opset, axes=[0])
box_batch = oopb.apply_unsqueeze(box_transpose, name=operator.inputs[0].full_name + '_btc', axes=[0])[0]
score_batch = oopb.apply_unsqueeze(score_transpose, name=operator.inputs[1].full_name + '_btc', axes=[0])[0]

layer = operator.raw_operator # type: YOLONMSLayer

Expand All @@ -359,9 +355,7 @@ def convert_NMSLayer(scope, operator, container):
op_version=operator.target_opset,
name=nms_node.name)

cast_batch = scope.get_unique_variable_name(operator.output_full_names[2] + '_btc')
container.add_node("Unsqueeze", cast_name,
cast_batch, op_version=operator.target_opset, axes=[0])
cast_batch = oopb.apply_unsqueeze(cast_name, name=operator.output_full_names[2] + '_btc', axes=[0])[0]
apply_cast(scope, cast_batch, operator.output_full_names[2], container, to=onnx_proto.TensorProto.INT32)

apply_identity(scope, box_batch, operator.output_full_names[0], container)
Expand Down
61 changes: 36 additions & 25 deletions keras2onnx/_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,7 @@ def convert_tf_expand_dims(scope, operator, container):
oopb = OnnxOperatorBuilder(container, scope)
node = operator.raw_operator
axis = _cal_tensor_value(node.inputs[1]).tolist()
rank = len(_cal_tensor_shape(node.inputs[0]))
rank = len(_cal_tensor_shape(node.inputs[0])) + 1
oopb.apply_op_with_output("apply_unsqueeze",
[operator.inputs[0].full_name],
operator.output_full_names,
Expand Down Expand Up @@ -1295,13 +1295,20 @@ def convert_tf_any_all(scope, operator, container):
to=oopb.float,
name=operator.full_name + '_cast')
keepdims = node.get_attr("keep_dims")
op_type = "ReduceMin" if node.type == "All" else "ReduceSum"
reduce_op = oopb.add_node(op_type, cast_op,
axes=axis,
keepdims=keepdims,
name=operator.full_name + '_reduce')
if node.type == 'All':
reduce_op = oopb.add_node('ReduceMin', cast_op,
axes=axis,
keepdims=keepdims,
name=operator.full_name + '_reduce')
else:
reduce_op = oopb.apply_reducesum(cast_op,
axes=axis,
keepdims=keepdims,
name=operator.full_name + '_reduce')
if not isinstance(reduce_op, list):
reduce_op = [reduce_op]
oopb.apply_op_with_output('apply_greater',
[reduce_op, np.array(0, dtype=np.float32)],
reduce_op + [np.array(0, dtype=np.float32)],
operator.output_full_names,
name=operator.full_name)

Expand All @@ -1316,10 +1323,9 @@ def convert_tf_pack(scope, operator, container):

inputs = []
for i in range(len(node.inputs)):
unsqueeze = oopb.add_node('Unsqueeze',
operator.inputs[i].full_name,
operator.full_name + '_unsqueeze' + str(i), axes=[axis])
inputs.append(unsqueeze)
unsqueeze = oopb.apply_unsqueeze(operator.inputs[i].full_name,
operator.full_name + '_unsqueeze' + str(i), axes=[axis])
inputs.extend(unsqueeze)

oopb.apply_op_with_output("apply_concat",
inputs,
Expand Down Expand Up @@ -1429,11 +1435,18 @@ def _convert_tf_reduce_op(scope, operator, container, onnx_op):
axes = [val + input_rank if val < 0 else val for val in axes]

keepdims = node.get_attr("keep_dims")
oopb.add_node_with_output(onnx_op,
operator.inputs[0].full_name,
operator.outputs[0].full_name,
name=operator.full_name + '_reduce_min',
axes=axes, keepdims=keepdims)
if onnx_op == 'ReduceSum':
oopb.apply_op_with_output("apply_"+onnx_op.lower(),
[operator.inputs[0].full_name],
operator.outputs[0].full_name,
name=operator.full_name + '_' + onnx_op.lower(),
axes=axes, keepdims=keepdims)
else:
oopb.add_node_with_output(onnx_op,
operator.inputs[0].full_name,
operator.outputs[0].full_name,
name=operator.full_name + '_' + onnx_op.lower(),
axes=axes, keepdims=keepdims)


@converter_func(TYPES.Max)
Expand Down Expand Up @@ -1768,7 +1781,7 @@ def convert_tf_squeeze(scope, operator, container):
if shape is None:
raise ValueError("Squeeze input shape cannot be None for node {}".format(node.name))

oopb.add_node_with_output('Squeeze',
oopb.apply_op_with_output('apply_squeeze',
operator.input_full_names[0],
operator.output_full_names,
operator.inputs[0].full_name + '_squeeze',
Expand Down Expand Up @@ -1801,9 +1814,8 @@ def convert_tf_topkv2(scope, operator, container):
cast_1 = oopb.add_node('Cast',
operator.inputs[1].full_name,
operator.inputs[1].full_name + '_1_cast', to=oopb.int64)
unsqueeze = oopb.add_node('Unsqueeze',
cast_1,
operator.inputs[1].full_name + '_unsqueeze', axes=[0])
unsqueeze = oopb.apply_unsqueeze(cast_1,
operator.inputs[1].full_name + '_unsqueeze', axes=[0])[0]
k_value = unsqueeze
else:
k_value = k.item(0)
Expand Down Expand Up @@ -2168,10 +2180,9 @@ def convert_tf_strided_slice(scope, operator, container):
oopb = OnnxOperatorBuilder(container, scope)

if len(new_axis_axes) > 0:
new_axis_unsqueeze = oopb.add_node('Unsqueeze',
operator.inputs[0].full_name,
operator.inputs[0].full_name + '_unsqueeze',
axes=new_axis_axes)
new_axis_unsqueeze = oopb.apply_unsqueeze(operator.inputs[0].full_name,
operator.inputs[0].full_name + '_unsqueeze',
axes=new_axis_axes)[0]
else:
new_axis_unsqueeze = operator.inputs[0].full_name

Expand Down Expand Up @@ -2236,7 +2247,7 @@ def convert_tf_strided_slice(scope, operator, container):
operator.inputs[0].full_name + '_cropping')

if needs_squeeze:
oopb.add_node_with_output('Squeeze',
oopb.apply_op_with_output('apply_squeeze',
cropped_tensor_name,
operator.output_full_names,
operator.inputs[0].full_name + '_squeeze',
Expand Down
Loading