From a25e713ddf3fae4e609588f080ead9ab64d524a9 Mon Sep 17 00:00:00 2001 From: David Date: Thu, 17 Dec 2020 13:42:20 -0800 Subject: [PATCH 1/2] Use apply_squeeze in opset 13 rather than add_node --- applications/mask_rcnn/mask_rcnn.py | 91 +++++++++-------------------- applications/yolov3/yolov3.py | 16 ++--- keras2onnx/_builtin.py | 59 +++++++++++-------- keras2onnx/ke2onnx/dot.py | 57 +++++++++--------- keras2onnx/ke2onnx/main.py | 10 ++-- keras2onnx/ke2onnx/merge.py | 8 +-- keras2onnx/ke2onnx/simplernn.py | 10 +++- keras2onnx/topology.py | 3 +- 8 files changed, 112 insertions(+), 142 deletions(-) diff --git a/applications/mask_rcnn/mask_rcnn.py b/applications/mask_rcnn/mask_rcnn.py index 004b609e..6d35cb76 100644 --- a/applications/mask_rcnn/mask_rcnn.py +++ b/applications/mask_rcnn/mask_rcnn.py @@ -69,22 +69,14 @@ def convert_BatchNorm(scope, operator, container): def convert_apply_box_deltas_graph(scope, operator, container, oopb, box_transpose, score_identity, deltas_transpose, windows_transpose): - box_squeeze = scope.get_unique_variable_name('box_squeeze') - attrs = {'axes': [0]} - container.add_node('Squeeze', box_transpose, box_squeeze, op_version=operator.target_opset, - **attrs) + oopb = OnnxOperatorBuilder(container, scope) + box_squeeze = oopb.apply_squeeze(box_transpose, name=operator.full_name + '_box_squeeze', axes=[0])[0] # output shape: [spatial_dimension, 4] - deltas_squeeze = scope.get_unique_variable_name('deltas_squeeze') - attrs = {'axes': [0]} - container.add_node('Squeeze', deltas_transpose, deltas_squeeze, op_version=operator.target_opset, - **attrs) + deltas_squeeze = oopb.apply_squeeze(deltas_transpose, name=operator.full_name + '_deltas_squeeze', axes=[0])[0] # output shape: [spatial_dimension, num_classes, 4] - score_squeeze = scope.get_unique_variable_name('score_squeeze') - attrs = {'axes': [0]} - container.add_node('Squeeze', score_identity, score_squeeze, op_version=operator.target_opset, - **attrs) + score_squeeze = oopb.apply_squeeze(score_identity, name=operator.full_name + '_score_squeeze', axes=[0])[0] # output shape: [spatial_dimension, num_classes] class_ids = scope.get_unique_variable_name('class_ids') @@ -113,11 +105,9 @@ def convert_apply_box_deltas_graph(scope, operator, container, oopb, box_transpo op_domain='com.microsoft', op_version=1) - attrs = {'axes': [1]} - prob_range_unsqueeze = oopb.add_node('Unsqueeze', - [prob_range], - operator.inputs[1].full_name + '_prob_range_unsqueeze', - **attrs) + prob_range_unsqueeze = oopb.apply_unsqueeze([prob_range], + operator.inputs[1].full_name + '_prob_range_unsqueeze', + axes=[1])[0] # output shape: [spatial_dimension, 1] attrs = {'axis': 1} @@ -272,10 +262,8 @@ def convert_apply_box_deltas_graph(scope, operator, container, oopb, box_transpo [x1, width_exp], operator.inputs[0].full_name + '_x2') - windows_squeeze = scope.get_unique_variable_name('windows_squeeze') - attrs = {'axes': [0]} - container.add_node('Squeeze', windows_transpose, windows_squeeze, op_version=operator.target_opset, - **attrs) + windows_squeeze = oopb.apply_squeeze(windows_transpose, name=operator.full_name + '_windows_squeeze', + axes=[0])[0] wy1 = oopb.add_node('Slice', [windows_squeeze, ('_start', oopb.int64, np.array([0], dtype='int64')), @@ -336,10 +324,8 @@ def convert_apply_box_deltas_graph(scope, operator, container, oopb, box_transpo op_version=operator.target_opset, name=operator.outputs[0].full_name + '_concat_result', **attrs) - concat_unsqueeze = scope.get_unique_variable_name('concat_unsqueeze') - attrs = {'axes': [0]} - container.add_node('Unsqueeze', concat_result, concat_unsqueeze, op_version=operator.target_opset, - **attrs) + concat_unsqueeze = oopb.apply_unsqueeze(concat_result, name=operator.full_name + '_concat_unsqueeze', + axes=[0])[0] return concat_unsqueeze @@ -358,10 +344,8 @@ def norm_boxes_graph(scope, operator, container, oopb, image_meta): ('_axes', oopb.int64, np.array([0], dtype='int64')) ], operator.inputs[0].full_name + '_image_shape') - image_shape_squeeze = scope.get_unique_variable_name('image_shape_squeeze') - attrs = {'axes': [0]} - container.add_node('Squeeze', image_shape, image_shape_squeeze, op_version=operator.target_opset, - **attrs) + image_shape_squeeze = oopb.apply_squeeze(image_shape, name=operator.full_name + '_image_shape_squeeze', axes=[0])[0] + window = oopb.add_node('Slice', [image_meta, ('_start', oopb.int64, np.array([7], dtype='int64')), @@ -516,13 +500,8 @@ def convert_DetectionLayer(scope, operator, container): name=nms_node.name + '_box_idx') # output shape: [num_selected_indices, 1] - box_idx_squeeze = scope.get_unique_variable_name(operator.output_full_names[0] + '_box_idx_squeeze') - attrs = {'axes': [1]} - container.add_node("Squeeze", - box_idx_output, - box_idx_squeeze, - op_version=operator.target_opset, - name=nms_node.name + '_box_idx_squeeze', **attrs) + box_idx_squeeze = oopb.apply_squeeze(box_idx_output, + name=nms_node.name + '_box_idx_squeeze', axes=[1])[0] # output shape: [num_selected_indices] starts_init_3 = scope.get_unique_variable_name('starts') @@ -548,23 +527,12 @@ def convert_DetectionLayer(scope, operator, container): name=nms_node.name + '_class_box_idx') # output shape: [num_selected_indices, 2] - box_squeeze = scope.get_unique_variable_name(operator.output_full_names[0] + '_box_squeeze') - attrs = {'axes': [0]} - container.add_node("Squeeze", - delta_mul_output, - box_squeeze, - op_version=operator.target_opset, - name=nms_node.name + '_box_squeeze', **attrs) + box_squeeze = oopb.apply_squeeze(delta_mul_output, + name=nms_node.name + '_box_squeeze', axes=[0])[0] # output shape: [spatial_dimension, 4] - score_squeeze = scope.get_local_variable_or_declare_one(operator.output_full_names[0] + '_score_squeeze', - type=FloatTensorType(shape=[None])) - attrs = {'axes': [0]} - container.add_node("Squeeze", - score_identity, - score_squeeze.full_name, - op_version=operator.target_opset, - name=nms_node.name + '_score_squeeze', **attrs) + score_squeeze = oopb.apply_squeeze(score_identity, + name=nms_node.name + '_score_squeeze', axes=[0])[0] # output shape: [spatial_dimension, num_classes] box_gather = scope.get_unique_variable_name(operator.output_full_names[0] + '_box_gather') @@ -578,19 +546,14 @@ def convert_DetectionLayer(scope, operator, container): score_gather = scope.get_unique_variable_name(operator.output_full_names[0] + '_score_gather') container.add_node("GatherND", - [score_squeeze.full_name, class_box_idx_output.full_name], + [score_squeeze, class_box_idx_output.full_name], score_gather, op_version=operator.target_opset, name=nms_node.name + '_score_gather') # output shape: [num_selected_indices] - score_gather_unsqueeze = scope.get_unique_variable_name(operator.output_full_names[0] + '_score_gather_unsqueeze') - attrs = {'axes': [1]} - container.add_node("Unsqueeze", - score_gather, - score_gather_unsqueeze, - op_version=operator.target_opset, - name=nms_node.name + '_score_gather_unsqueeze', **attrs) + score_gather_unsqueeze = oopb.apply_unsqueeze(score_gather, + name=nms_node.name + '_score_gather_unsqueeze', axes=[1])[0] # output shape: [num_selected_indices, 1] @@ -661,12 +624,10 @@ def convert_DetectionLayer(scope, operator, container): nms_node.name + '_detection_final' ) - attrs = {'axes': [0]} - container.add_node("Unsqueeze", - detection_final, - operator.output_full_names[0], - op_version=operator.target_opset, - name=nms_node.name + '_concat_unsqueeze', **attrs) + oopb.apply_op_with_output('apply_unsqueeze', + detection_final, + operator.output_full_names[0], + name=nms_node.name + '_concat_unsqueeze', axes=[0]) # output shape: [1, num_top_K, 6] diff --git a/applications/yolov3/yolov3.py b/applications/yolov3/yolov3.py index 0b8774d7..648804e6 100644 --- a/applications/yolov3/yolov3.py +++ b/applications/yolov3/yolov3.py @@ -12,7 +12,7 @@ from keras.models import load_model from keras2onnx import convert_keras from keras2onnx import set_converter -from keras2onnx.common.onnx_ops import apply_transpose, apply_identity, apply_cast +from keras2onnx.common.onnx_ops import apply_transpose, apply_identity, apply_cast, OnnxOperatorBuilder from keras2onnx.proto import onnx_proto from onnxconverter_common.onnx_ex import get_maximum_opset_supported from onnxconverter_common.onnx_fx import Graph @@ -324,19 +324,15 @@ def detect_img(yolo, img_url, model_file_name): def convert_NMSLayer(scope, operator, container): # type: (keras2onnx.common.InterimContext, keras2onnx.common.Operator, keras2onnx.common.OnnxObjectContainer) -> None + oopb = OnnxOperatorBuilder(container, scope) box_transpose = scope.get_unique_variable_name(operator.inputs[0].full_name + '_tx') score_transpose = scope.get_unique_variable_name(operator.inputs[1].full_name + '_tx') apply_identity(scope, operator.inputs[0].full_name, box_transpose, container) apply_transpose(scope, operator.inputs[1].full_name, score_transpose, container, perm=[1, 0]) - box_batch = scope.get_unique_variable_name(operator.inputs[0].full_name + '_btc') - score_batch = scope.get_unique_variable_name(operator.inputs[1].full_name + '_btc') - - container.add_node("Unsqueeze", box_transpose, - box_batch, op_version=operator.target_opset, axes=[0]) - container.add_node("Unsqueeze", score_transpose, - score_batch, op_version=operator.target_opset, axes=[0]) + box_batch = oopb.apply_unsqueeze(box_transpose, name=operator.inputs[0].full_name + '_btc', axes=[0])[0] + score_batch = oopb.apply_unsqueeze(score_transpose, name=operator.inputs[1].full_name + '_btc', axes=[0])[0] layer = operator.raw_operator # type: YOLONMSLayer @@ -359,9 +355,7 @@ def convert_NMSLayer(scope, operator, container): op_version=operator.target_opset, name=nms_node.name) - cast_batch = scope.get_unique_variable_name(operator.output_full_names[2] + '_btc') - container.add_node("Unsqueeze", cast_name, - cast_batch, op_version=operator.target_opset, axes=[0]) + cast_batch = oopb.apply_unsqueeze(cast_name, name=operator.output_full_names[2] + '_btc', axes=[0])[0] apply_cast(scope, cast_batch, operator.output_full_names[2], container, to=onnx_proto.TensorProto.INT32) apply_identity(scope, box_batch, operator.output_full_names[0], container) diff --git a/keras2onnx/_builtin.py b/keras2onnx/_builtin.py index 6bfea789..150b0017 100644 --- a/keras2onnx/_builtin.py +++ b/keras2onnx/_builtin.py @@ -1295,13 +1295,20 @@ def convert_tf_any_all(scope, operator, container): to=oopb.float, name=operator.full_name + '_cast') keepdims = node.get_attr("keep_dims") - op_type = "ReduceMin" if node.type == "All" else "ReduceSum" - reduce_op = oopb.add_node(op_type, cast_op, - axes=axis, - keepdims=keepdims, - name=operator.full_name + '_reduce') + if node.type == 'All': + reduce_op = oopb.add_node('ReduceMin', cast_op, + axes=axis, + keepdims=keepdims, + name=operator.full_name + '_reduce') + else: + reduce_op = oopb.apply_reducesum(cast_op, + axes=axis, + keepdims=keepdims, + name=operator.full_name + '_reduce') + if not isinstance(reduce_op, list): + reduce_op = [reduce_op] oopb.apply_op_with_output('apply_greater', - [reduce_op, np.array(0, dtype=np.float32)], + reduce_op + [np.array(0, dtype=np.float32)], operator.output_full_names, name=operator.full_name) @@ -1316,10 +1323,9 @@ def convert_tf_pack(scope, operator, container): inputs = [] for i in range(len(node.inputs)): - unsqueeze = oopb.add_node('Unsqueeze', - operator.inputs[i].full_name, - operator.full_name + '_unsqueeze' + str(i), axes=[axis]) - inputs.append(unsqueeze) + unsqueeze = oopb.apply_unsqueeze(operator.inputs[i].full_name, + operator.full_name + '_unsqueeze' + str(i), axes=[axis]) + inputs.extend(unsqueeze) oopb.apply_op_with_output("apply_concat", inputs, @@ -1429,11 +1435,18 @@ def _convert_tf_reduce_op(scope, operator, container, onnx_op): axes = [val + input_rank if val < 0 else val for val in axes] keepdims = node.get_attr("keep_dims") - oopb.add_node_with_output(onnx_op, - operator.inputs[0].full_name, - operator.outputs[0].full_name, - name=operator.full_name + '_reduce_min', - axes=axes, keepdims=keepdims) + if onnx_op == 'ReduceSum': + oopb.apply_op_with_output("apply_"+onnx_op.lower(), + [operator.inputs[0].full_name], + operator.outputs[0].full_name, + name=operator.full_name + '_' + onnx_op.lower(), + axes=axes, keepdims=keepdims) + else: + oopb.add_node_with_output(onnx_op, + operator.inputs[0].full_name, + operator.outputs[0].full_name, + name=operator.full_name + '_' + onnx_op.lower(), + axes=axes, keepdims=keepdims) @converter_func(TYPES.Max) @@ -1768,7 +1781,7 @@ def convert_tf_squeeze(scope, operator, container): if shape is None: raise ValueError("Squeeze input shape cannot be None for node {}".format(node.name)) - oopb.add_node_with_output('Squeeze', + oopb.apply_op_with_output('apply_squeeze', operator.input_full_names[0], operator.output_full_names, operator.inputs[0].full_name + '_squeeze', @@ -1801,9 +1814,8 @@ def convert_tf_topkv2(scope, operator, container): cast_1 = oopb.add_node('Cast', operator.inputs[1].full_name, operator.inputs[1].full_name + '_1_cast', to=oopb.int64) - unsqueeze = oopb.add_node('Unsqueeze', - cast_1, - operator.inputs[1].full_name + '_unsqueeze', axes=[0]) + unsqueeze = oopb.apply_unsqueeze(cast_1, + operator.inputs[1].full_name + '_unsqueeze', axes=[0])[0] k_value = unsqueeze else: k_value = k.item(0) @@ -2168,10 +2180,9 @@ def convert_tf_strided_slice(scope, operator, container): oopb = OnnxOperatorBuilder(container, scope) if len(new_axis_axes) > 0: - new_axis_unsqueeze = oopb.add_node('Unsqueeze', - operator.inputs[0].full_name, - operator.inputs[0].full_name + '_unsqueeze', - axes=new_axis_axes) + new_axis_unsqueeze = oopb.apply_unsqueeze(operator.inputs[0].full_name, + operator.inputs[0].full_name + '_unsqueeze', + axes=new_axis_axes)[0] else: new_axis_unsqueeze = operator.inputs[0].full_name @@ -2236,7 +2247,7 @@ def convert_tf_strided_slice(scope, operator, container): operator.inputs[0].full_name + '_cropping') if needs_squeeze: - oopb.add_node_with_output('Squeeze', + oopb.apply_op_with_output('apply_squeeze', cropped_tensor_name, operator.output_full_names, operator.inputs[0].full_name + '_squeeze', diff --git a/keras2onnx/ke2onnx/dot.py b/keras2onnx/ke2onnx/dot.py index f3e260fc..fa89976f 100644 --- a/keras2onnx/ke2onnx/dot.py +++ b/keras2onnx/ke2onnx/dot.py @@ -116,10 +116,9 @@ def convert_keras_dot_224(scope, operator, container): result_mul = oopb.add_node('Mul', [x_reshape, y_reshape], operator.inputs[0].full_name + '_result_mul') - out = oopb.add_node('ReduceSum', - [result_mul], - operator.inputs[0].full_name + '_out', - axes=[axes[0]]) + out = oopb.apply_reducesum([result_mul], + operator.inputs[0].full_name + '_out', + axes=[axes[0]]) else: x_transpose = oopb.add_node('Transpose', [x_reshape], @@ -128,10 +127,10 @@ def convert_keras_dot_224(scope, operator, container): result_mul = oopb.add_node('Mul', [x_transpose, y_reshape], operator.inputs[0].full_name + '_result_mul') - out = oopb.add_node('ReduceSum', - [result_mul], - operator.inputs[0].full_name + '_out', - axes=[axes[1]]) + out = oopb.apply_reducesum([result_mul], + operator.inputs[0].full_name + '_out', + axes=[axes[1]]) + out = out[0] else: if axes is not None: adj_x = None if axes[0] == max_ndim - 1 else True @@ -168,19 +167,17 @@ def convert_keras_dot_224(scope, operator, container): idx = x_ndim + y_ndim - 3 else: idx = x_ndim - 1 - out_squeeze = oopb.add_node('Squeeze', - [out], - operator.inputs[0].full_name + '_out_squeeze', - axes=list(range(idx, idx + diff))) + out_squeeze = oopb.apply_squeeze([out], + operator.inputs[0].full_name + '_out_squeeze', + axes=list(range(idx, idx + diff))) matrix_len = matrix_len - diff else: out_squeeze = out if matrix_len == 1: - out_expand = oopb.add_node('Unsqueeze', - [out_squeeze], - operator.inputs[0].full_name + '_out_expand', - axes=[1]) + out_squeeze = oopb.apply_unsqueeze([out_squeeze], + operator.inputs[0].full_name + '_out_expand', + axes=[1]) else: out_expand = out_squeeze container.add_node('Identity', out_expand, operator.output_full_names, @@ -216,20 +213,18 @@ def convert_keras_dot_post_224(scope, operator, container): raise RuntimeError('Dimension incompatibility: %s != %s' % (x_shape[axes[0]], y_shape[axes[1]])) if x_ndim == 2: - x_expand = oopb.add_node('Unsqueeze', - [normalized_input_names[0]], - operator.inputs[0].full_name + '_expand', - axes=[1]) + x_expand = oopb.apply_unsqueeze([normalized_input_names[0]], + operator.inputs[0].full_name + '_expand', + axes=[1])[0] a0 += 1 x_ndim += 1 else: x_expand = normalized_input_names[0] if y_ndim == 2: - y_expand = oopb.add_node('Unsqueeze', - [normalized_input_names[1]], - operator.inputs[1].full_name + '_expand', - axes=[2]) + y_expand = oopb.apply_unsqueeze([normalized_input_names[1]], + operator.inputs[1].full_name + '_expand', + axes=[2])[0] y_ndim += 1 else: y_expand = normalized_input_names[1] @@ -421,11 +416,17 @@ def convert_keras_dot_post_224(scope, operator, container): # if the inputs were originally rank 2, we remove the added 1 dim. if orig_x_ndim == 2: - container.add_node('Squeeze', output_reshape, operator.output_full_names, - name=scope.get_unique_operator_name('Squeeze'), axes=[1]) + oopb.apply_op_with_output("apply_squeeze", + output_reshape, + operator.output_full_names, + name=operator.full_name + '_squeeze', + axes=[1]) elif orig_y_ndim == 2: - container.add_node('Squeeze', output_reshape, operator.output_full_names, - name=scope.get_unique_operator_name('Squeeze'), axes=[y_ndim - 1]) + oopb.apply_op_with_output("apply_squeeze", + output_reshape, + operator.output_full_names, + name=operator.full_name + '_squeeze', + axes=[y_ndim - 1]) else: container.add_node('Identity', output_reshape, operator.output_full_names, name=scope.get_unique_operator_name('Identity')) diff --git a/keras2onnx/ke2onnx/main.py b/keras2onnx/ke2onnx/main.py index ab761810..6f6ac2f1 100644 --- a/keras2onnx/ke2onnx/main.py +++ b/keras2onnx/ke2onnx/main.py @@ -104,15 +104,13 @@ def convert_keras_masking(scope, operator, container): not_o = _apply_not_equal(oopb, container.target_opset, operator) cast_o = oopb.apply_cast(not_o, to=oopb.float, name=operator.full_name + '_cast') if operator.output_masks: - reduce_node = oopb.add_node("ReduceSum", - cast_o[0], keepdims=False, axes=[-1], name=operator.full_name + '_reduced') - oopb.add_node_with_output("Greater", [reduce_node, np.array(0, dtype=np.float32)], + reduce_node = oopb.apply_reducesum(cast_o[0], keepdims=False, axes=[-1], name=operator.full_name + '_reduced') + oopb.add_node_with_output("Greater", reduce_node + [np.array(0, dtype=np.float32)], [operator.output_masks[0].full_name], name=operator.full_name + '_greater') - reduce_node2 = oopb.add_node("ReduceSum", - cast_o, keepdims=True, axes=[-1], name=operator.full_name + 'reduced2') + reduce_node2 = oopb.apply_reducesum(cast_o, keepdims=True, axes=[-1], name=operator.full_name + 'reduced2') greater_o = oopb.add_node("Greater", - [reduce_node2, np.array(0, dtype=np.float32)], name=operator.full_name + '_greater2') + reduce_node2 + [np.array(0, dtype=np.float32)], name=operator.full_name + '_greater2') cast2_o = oopb.apply_cast(greater_o, to=oopb.float, name=operator.full_name + '_cast2') oopb.add_node_with_output('Mul', [cast2_o[0], operator.inputs[0].full_name], [operator.outputs[0].full_name], diff --git a/keras2onnx/ke2onnx/merge.py b/keras2onnx/ke2onnx/merge.py index 38ddfc66..3da454b8 100644 --- a/keras2onnx/ke2onnx/merge.py +++ b/keras2onnx/ke2onnx/merge.py @@ -50,16 +50,16 @@ def convert_keras_merge_layer(scope, operator, container): oopb = OnnxOperatorBuilder(container, scope) expanded = [] for idx_, i_ in enumerate(operator.input_masks): - expanded.append(oopb.add_node('Unsqueeze', i_.full_name, i_.full_name + '_i' + str(idx_), axes=[0])) + expanded.extend(oopb.apply_unsqueeze(i_.full_name, i_.full_name + '_i' + str(idx_), axes=[0])) if len(expanded) > 1: concat = oopb.apply_concat(expanded, name=operator.full_name + '_concat') else: concat = expanded[0] cast = oopb.add_node('Cast', concat, name=operator.full_name + '_cast', to=1) - reduced = oopb.add_node('ReduceSum', cast, name=operator.full_name + '_reduced', op_version=1, axes=[0], - keepdims=0) + reduced = oopb.apply_reducesum(cast, name=operator.full_name + '_reduced', axes=[0], + keepdims=0) oopb.apply_op_with_output('apply_greater', - [reduced, np.array([0], dtype=np.float32)], + reduced + [np.array([0], dtype=np.float32)], [operator.output_masks[0].full_name], name=operator.raw_operator.name) diff --git a/keras2onnx/ke2onnx/simplernn.py b/keras2onnx/ke2onnx/simplernn.py index 20f0ab84..01e67413 100644 --- a/keras2onnx/ke2onnx/simplernn.py +++ b/keras2onnx/ke2onnx/simplernn.py @@ -163,9 +163,13 @@ def build_sequence_lengths(scope, operator, container): mask_cast = scope.get_unique_operator_name(operator.full_name + '_mask_cast') sequence_lengths = scope.get_unique_operator_name(operator.full_name + '_seq_lens') - apply_cast(scope, input_mask_name, mask_cast, container, to=TensorProto.INT32) - container.add_node('ReduceSum', mask_cast, sequence_lengths, keepdims=False, axes=[-1]) - return sequence_lengths + oopb = OnnxOperatorBuilder(container, scope) + mask_cast = oopb.apply_cast(input_mask_name, + to=oopb.int32, + name=operator.full_name + 'cast') + sequence_lengths = oopb.apply_reducesum(mask_cast, name=operator.full_name + '_reduced', axes=[-1], + keepdims=False) + return sequence_lengths[0] def build_initial_states(scope, operator, container, bidirectional=False): diff --git a/keras2onnx/topology.py b/keras2onnx/topology.py index 234072f8..a7af8b7e 100644 --- a/keras2onnx/topology.py +++ b/keras2onnx/topology.py @@ -323,7 +323,8 @@ def convert_topology(topology, model_name, doc_string, target_opset, channel_fir if target_opset < 9: nodes = onnxconverter_common.optimizer.optimize_onnx(nodes, nchw_inputs=nchw_inputs, inputs=container.inputs + extra_inputs, - outputs=container.outputs) + outputs=container.outputs, + target_opset=container.target_opset) node_number = len(nodes) else: graph = onnxconverter_common.optimizer.optimize_onnx_graph(nodes, nchw_inputs=nchw_inputs, From 34967e2bb69fafe8f4ce8af07e154429528349c8 Mon Sep 17 00:00:00 2001 From: David Date: Thu, 17 Dec 2020 13:42:20 -0800 Subject: [PATCH 2/2] Use apply_squeeze in opset 13 rather than add_node --- applications/mask_rcnn/mask_rcnn.py | 91 +++++++++-------------------- applications/yolov3/yolov3.py | 16 ++--- keras2onnx/_builtin.py | 59 +++++++++++-------- keras2onnx/ke2onnx/dot.py | 57 +++++++++--------- keras2onnx/ke2onnx/main.py | 10 ++-- keras2onnx/ke2onnx/merge.py | 8 +-- keras2onnx/ke2onnx/simplernn.py | 11 ++-- keras2onnx/topology.py | 3 +- 8 files changed, 112 insertions(+), 143 deletions(-) diff --git a/applications/mask_rcnn/mask_rcnn.py b/applications/mask_rcnn/mask_rcnn.py index 004b609e..6d35cb76 100644 --- a/applications/mask_rcnn/mask_rcnn.py +++ b/applications/mask_rcnn/mask_rcnn.py @@ -69,22 +69,14 @@ def convert_BatchNorm(scope, operator, container): def convert_apply_box_deltas_graph(scope, operator, container, oopb, box_transpose, score_identity, deltas_transpose, windows_transpose): - box_squeeze = scope.get_unique_variable_name('box_squeeze') - attrs = {'axes': [0]} - container.add_node('Squeeze', box_transpose, box_squeeze, op_version=operator.target_opset, - **attrs) + oopb = OnnxOperatorBuilder(container, scope) + box_squeeze = oopb.apply_squeeze(box_transpose, name=operator.full_name + '_box_squeeze', axes=[0])[0] # output shape: [spatial_dimension, 4] - deltas_squeeze = scope.get_unique_variable_name('deltas_squeeze') - attrs = {'axes': [0]} - container.add_node('Squeeze', deltas_transpose, deltas_squeeze, op_version=operator.target_opset, - **attrs) + deltas_squeeze = oopb.apply_squeeze(deltas_transpose, name=operator.full_name + '_deltas_squeeze', axes=[0])[0] # output shape: [spatial_dimension, num_classes, 4] - score_squeeze = scope.get_unique_variable_name('score_squeeze') - attrs = {'axes': [0]} - container.add_node('Squeeze', score_identity, score_squeeze, op_version=operator.target_opset, - **attrs) + score_squeeze = oopb.apply_squeeze(score_identity, name=operator.full_name + '_score_squeeze', axes=[0])[0] # output shape: [spatial_dimension, num_classes] class_ids = scope.get_unique_variable_name('class_ids') @@ -113,11 +105,9 @@ def convert_apply_box_deltas_graph(scope, operator, container, oopb, box_transpo op_domain='com.microsoft', op_version=1) - attrs = {'axes': [1]} - prob_range_unsqueeze = oopb.add_node('Unsqueeze', - [prob_range], - operator.inputs[1].full_name + '_prob_range_unsqueeze', - **attrs) + prob_range_unsqueeze = oopb.apply_unsqueeze([prob_range], + operator.inputs[1].full_name + '_prob_range_unsqueeze', + axes=[1])[0] # output shape: [spatial_dimension, 1] attrs = {'axis': 1} @@ -272,10 +262,8 @@ def convert_apply_box_deltas_graph(scope, operator, container, oopb, box_transpo [x1, width_exp], operator.inputs[0].full_name + '_x2') - windows_squeeze = scope.get_unique_variable_name('windows_squeeze') - attrs = {'axes': [0]} - container.add_node('Squeeze', windows_transpose, windows_squeeze, op_version=operator.target_opset, - **attrs) + windows_squeeze = oopb.apply_squeeze(windows_transpose, name=operator.full_name + '_windows_squeeze', + axes=[0])[0] wy1 = oopb.add_node('Slice', [windows_squeeze, ('_start', oopb.int64, np.array([0], dtype='int64')), @@ -336,10 +324,8 @@ def convert_apply_box_deltas_graph(scope, operator, container, oopb, box_transpo op_version=operator.target_opset, name=operator.outputs[0].full_name + '_concat_result', **attrs) - concat_unsqueeze = scope.get_unique_variable_name('concat_unsqueeze') - attrs = {'axes': [0]} - container.add_node('Unsqueeze', concat_result, concat_unsqueeze, op_version=operator.target_opset, - **attrs) + concat_unsqueeze = oopb.apply_unsqueeze(concat_result, name=operator.full_name + '_concat_unsqueeze', + axes=[0])[0] return concat_unsqueeze @@ -358,10 +344,8 @@ def norm_boxes_graph(scope, operator, container, oopb, image_meta): ('_axes', oopb.int64, np.array([0], dtype='int64')) ], operator.inputs[0].full_name + '_image_shape') - image_shape_squeeze = scope.get_unique_variable_name('image_shape_squeeze') - attrs = {'axes': [0]} - container.add_node('Squeeze', image_shape, image_shape_squeeze, op_version=operator.target_opset, - **attrs) + image_shape_squeeze = oopb.apply_squeeze(image_shape, name=operator.full_name + '_image_shape_squeeze', axes=[0])[0] + window = oopb.add_node('Slice', [image_meta, ('_start', oopb.int64, np.array([7], dtype='int64')), @@ -516,13 +500,8 @@ def convert_DetectionLayer(scope, operator, container): name=nms_node.name + '_box_idx') # output shape: [num_selected_indices, 1] - box_idx_squeeze = scope.get_unique_variable_name(operator.output_full_names[0] + '_box_idx_squeeze') - attrs = {'axes': [1]} - container.add_node("Squeeze", - box_idx_output, - box_idx_squeeze, - op_version=operator.target_opset, - name=nms_node.name + '_box_idx_squeeze', **attrs) + box_idx_squeeze = oopb.apply_squeeze(box_idx_output, + name=nms_node.name + '_box_idx_squeeze', axes=[1])[0] # output shape: [num_selected_indices] starts_init_3 = scope.get_unique_variable_name('starts') @@ -548,23 +527,12 @@ def convert_DetectionLayer(scope, operator, container): name=nms_node.name + '_class_box_idx') # output shape: [num_selected_indices, 2] - box_squeeze = scope.get_unique_variable_name(operator.output_full_names[0] + '_box_squeeze') - attrs = {'axes': [0]} - container.add_node("Squeeze", - delta_mul_output, - box_squeeze, - op_version=operator.target_opset, - name=nms_node.name + '_box_squeeze', **attrs) + box_squeeze = oopb.apply_squeeze(delta_mul_output, + name=nms_node.name + '_box_squeeze', axes=[0])[0] # output shape: [spatial_dimension, 4] - score_squeeze = scope.get_local_variable_or_declare_one(operator.output_full_names[0] + '_score_squeeze', - type=FloatTensorType(shape=[None])) - attrs = {'axes': [0]} - container.add_node("Squeeze", - score_identity, - score_squeeze.full_name, - op_version=operator.target_opset, - name=nms_node.name + '_score_squeeze', **attrs) + score_squeeze = oopb.apply_squeeze(score_identity, + name=nms_node.name + '_score_squeeze', axes=[0])[0] # output shape: [spatial_dimension, num_classes] box_gather = scope.get_unique_variable_name(operator.output_full_names[0] + '_box_gather') @@ -578,19 +546,14 @@ def convert_DetectionLayer(scope, operator, container): score_gather = scope.get_unique_variable_name(operator.output_full_names[0] + '_score_gather') container.add_node("GatherND", - [score_squeeze.full_name, class_box_idx_output.full_name], + [score_squeeze, class_box_idx_output.full_name], score_gather, op_version=operator.target_opset, name=nms_node.name + '_score_gather') # output shape: [num_selected_indices] - score_gather_unsqueeze = scope.get_unique_variable_name(operator.output_full_names[0] + '_score_gather_unsqueeze') - attrs = {'axes': [1]} - container.add_node("Unsqueeze", - score_gather, - score_gather_unsqueeze, - op_version=operator.target_opset, - name=nms_node.name + '_score_gather_unsqueeze', **attrs) + score_gather_unsqueeze = oopb.apply_unsqueeze(score_gather, + name=nms_node.name + '_score_gather_unsqueeze', axes=[1])[0] # output shape: [num_selected_indices, 1] @@ -661,12 +624,10 @@ def convert_DetectionLayer(scope, operator, container): nms_node.name + '_detection_final' ) - attrs = {'axes': [0]} - container.add_node("Unsqueeze", - detection_final, - operator.output_full_names[0], - op_version=operator.target_opset, - name=nms_node.name + '_concat_unsqueeze', **attrs) + oopb.apply_op_with_output('apply_unsqueeze', + detection_final, + operator.output_full_names[0], + name=nms_node.name + '_concat_unsqueeze', axes=[0]) # output shape: [1, num_top_K, 6] diff --git a/applications/yolov3/yolov3.py b/applications/yolov3/yolov3.py index 0b8774d7..648804e6 100644 --- a/applications/yolov3/yolov3.py +++ b/applications/yolov3/yolov3.py @@ -12,7 +12,7 @@ from keras.models import load_model from keras2onnx import convert_keras from keras2onnx import set_converter -from keras2onnx.common.onnx_ops import apply_transpose, apply_identity, apply_cast +from keras2onnx.common.onnx_ops import apply_transpose, apply_identity, apply_cast, OnnxOperatorBuilder from keras2onnx.proto import onnx_proto from onnxconverter_common.onnx_ex import get_maximum_opset_supported from onnxconverter_common.onnx_fx import Graph @@ -324,19 +324,15 @@ def detect_img(yolo, img_url, model_file_name): def convert_NMSLayer(scope, operator, container): # type: (keras2onnx.common.InterimContext, keras2onnx.common.Operator, keras2onnx.common.OnnxObjectContainer) -> None + oopb = OnnxOperatorBuilder(container, scope) box_transpose = scope.get_unique_variable_name(operator.inputs[0].full_name + '_tx') score_transpose = scope.get_unique_variable_name(operator.inputs[1].full_name + '_tx') apply_identity(scope, operator.inputs[0].full_name, box_transpose, container) apply_transpose(scope, operator.inputs[1].full_name, score_transpose, container, perm=[1, 0]) - box_batch = scope.get_unique_variable_name(operator.inputs[0].full_name + '_btc') - score_batch = scope.get_unique_variable_name(operator.inputs[1].full_name + '_btc') - - container.add_node("Unsqueeze", box_transpose, - box_batch, op_version=operator.target_opset, axes=[0]) - container.add_node("Unsqueeze", score_transpose, - score_batch, op_version=operator.target_opset, axes=[0]) + box_batch = oopb.apply_unsqueeze(box_transpose, name=operator.inputs[0].full_name + '_btc', axes=[0])[0] + score_batch = oopb.apply_unsqueeze(score_transpose, name=operator.inputs[1].full_name + '_btc', axes=[0])[0] layer = operator.raw_operator # type: YOLONMSLayer @@ -359,9 +355,7 @@ def convert_NMSLayer(scope, operator, container): op_version=operator.target_opset, name=nms_node.name) - cast_batch = scope.get_unique_variable_name(operator.output_full_names[2] + '_btc') - container.add_node("Unsqueeze", cast_name, - cast_batch, op_version=operator.target_opset, axes=[0]) + cast_batch = oopb.apply_unsqueeze(cast_name, name=operator.output_full_names[2] + '_btc', axes=[0])[0] apply_cast(scope, cast_batch, operator.output_full_names[2], container, to=onnx_proto.TensorProto.INT32) apply_identity(scope, box_batch, operator.output_full_names[0], container) diff --git a/keras2onnx/_builtin.py b/keras2onnx/_builtin.py index 6bfea789..150b0017 100644 --- a/keras2onnx/_builtin.py +++ b/keras2onnx/_builtin.py @@ -1295,13 +1295,20 @@ def convert_tf_any_all(scope, operator, container): to=oopb.float, name=operator.full_name + '_cast') keepdims = node.get_attr("keep_dims") - op_type = "ReduceMin" if node.type == "All" else "ReduceSum" - reduce_op = oopb.add_node(op_type, cast_op, - axes=axis, - keepdims=keepdims, - name=operator.full_name + '_reduce') + if node.type == 'All': + reduce_op = oopb.add_node('ReduceMin', cast_op, + axes=axis, + keepdims=keepdims, + name=operator.full_name + '_reduce') + else: + reduce_op = oopb.apply_reducesum(cast_op, + axes=axis, + keepdims=keepdims, + name=operator.full_name + '_reduce') + if not isinstance(reduce_op, list): + reduce_op = [reduce_op] oopb.apply_op_with_output('apply_greater', - [reduce_op, np.array(0, dtype=np.float32)], + reduce_op + [np.array(0, dtype=np.float32)], operator.output_full_names, name=operator.full_name) @@ -1316,10 +1323,9 @@ def convert_tf_pack(scope, operator, container): inputs = [] for i in range(len(node.inputs)): - unsqueeze = oopb.add_node('Unsqueeze', - operator.inputs[i].full_name, - operator.full_name + '_unsqueeze' + str(i), axes=[axis]) - inputs.append(unsqueeze) + unsqueeze = oopb.apply_unsqueeze(operator.inputs[i].full_name, + operator.full_name + '_unsqueeze' + str(i), axes=[axis]) + inputs.extend(unsqueeze) oopb.apply_op_with_output("apply_concat", inputs, @@ -1429,11 +1435,18 @@ def _convert_tf_reduce_op(scope, operator, container, onnx_op): axes = [val + input_rank if val < 0 else val for val in axes] keepdims = node.get_attr("keep_dims") - oopb.add_node_with_output(onnx_op, - operator.inputs[0].full_name, - operator.outputs[0].full_name, - name=operator.full_name + '_reduce_min', - axes=axes, keepdims=keepdims) + if onnx_op == 'ReduceSum': + oopb.apply_op_with_output("apply_"+onnx_op.lower(), + [operator.inputs[0].full_name], + operator.outputs[0].full_name, + name=operator.full_name + '_' + onnx_op.lower(), + axes=axes, keepdims=keepdims) + else: + oopb.add_node_with_output(onnx_op, + operator.inputs[0].full_name, + operator.outputs[0].full_name, + name=operator.full_name + '_' + onnx_op.lower(), + axes=axes, keepdims=keepdims) @converter_func(TYPES.Max) @@ -1768,7 +1781,7 @@ def convert_tf_squeeze(scope, operator, container): if shape is None: raise ValueError("Squeeze input shape cannot be None for node {}".format(node.name)) - oopb.add_node_with_output('Squeeze', + oopb.apply_op_with_output('apply_squeeze', operator.input_full_names[0], operator.output_full_names, operator.inputs[0].full_name + '_squeeze', @@ -1801,9 +1814,8 @@ def convert_tf_topkv2(scope, operator, container): cast_1 = oopb.add_node('Cast', operator.inputs[1].full_name, operator.inputs[1].full_name + '_1_cast', to=oopb.int64) - unsqueeze = oopb.add_node('Unsqueeze', - cast_1, - operator.inputs[1].full_name + '_unsqueeze', axes=[0]) + unsqueeze = oopb.apply_unsqueeze(cast_1, + operator.inputs[1].full_name + '_unsqueeze', axes=[0])[0] k_value = unsqueeze else: k_value = k.item(0) @@ -2168,10 +2180,9 @@ def convert_tf_strided_slice(scope, operator, container): oopb = OnnxOperatorBuilder(container, scope) if len(new_axis_axes) > 0: - new_axis_unsqueeze = oopb.add_node('Unsqueeze', - operator.inputs[0].full_name, - operator.inputs[0].full_name + '_unsqueeze', - axes=new_axis_axes) + new_axis_unsqueeze = oopb.apply_unsqueeze(operator.inputs[0].full_name, + operator.inputs[0].full_name + '_unsqueeze', + axes=new_axis_axes)[0] else: new_axis_unsqueeze = operator.inputs[0].full_name @@ -2236,7 +2247,7 @@ def convert_tf_strided_slice(scope, operator, container): operator.inputs[0].full_name + '_cropping') if needs_squeeze: - oopb.add_node_with_output('Squeeze', + oopb.apply_op_with_output('apply_squeeze', cropped_tensor_name, operator.output_full_names, operator.inputs[0].full_name + '_squeeze', diff --git a/keras2onnx/ke2onnx/dot.py b/keras2onnx/ke2onnx/dot.py index f3e260fc..fa89976f 100644 --- a/keras2onnx/ke2onnx/dot.py +++ b/keras2onnx/ke2onnx/dot.py @@ -116,10 +116,9 @@ def convert_keras_dot_224(scope, operator, container): result_mul = oopb.add_node('Mul', [x_reshape, y_reshape], operator.inputs[0].full_name + '_result_mul') - out = oopb.add_node('ReduceSum', - [result_mul], - operator.inputs[0].full_name + '_out', - axes=[axes[0]]) + out = oopb.apply_reducesum([result_mul], + operator.inputs[0].full_name + '_out', + axes=[axes[0]]) else: x_transpose = oopb.add_node('Transpose', [x_reshape], @@ -128,10 +127,10 @@ def convert_keras_dot_224(scope, operator, container): result_mul = oopb.add_node('Mul', [x_transpose, y_reshape], operator.inputs[0].full_name + '_result_mul') - out = oopb.add_node('ReduceSum', - [result_mul], - operator.inputs[0].full_name + '_out', - axes=[axes[1]]) + out = oopb.apply_reducesum([result_mul], + operator.inputs[0].full_name + '_out', + axes=[axes[1]]) + out = out[0] else: if axes is not None: adj_x = None if axes[0] == max_ndim - 1 else True @@ -168,19 +167,17 @@ def convert_keras_dot_224(scope, operator, container): idx = x_ndim + y_ndim - 3 else: idx = x_ndim - 1 - out_squeeze = oopb.add_node('Squeeze', - [out], - operator.inputs[0].full_name + '_out_squeeze', - axes=list(range(idx, idx + diff))) + out_squeeze = oopb.apply_squeeze([out], + operator.inputs[0].full_name + '_out_squeeze', + axes=list(range(idx, idx + diff))) matrix_len = matrix_len - diff else: out_squeeze = out if matrix_len == 1: - out_expand = oopb.add_node('Unsqueeze', - [out_squeeze], - operator.inputs[0].full_name + '_out_expand', - axes=[1]) + out_squeeze = oopb.apply_unsqueeze([out_squeeze], + operator.inputs[0].full_name + '_out_expand', + axes=[1]) else: out_expand = out_squeeze container.add_node('Identity', out_expand, operator.output_full_names, @@ -216,20 +213,18 @@ def convert_keras_dot_post_224(scope, operator, container): raise RuntimeError('Dimension incompatibility: %s != %s' % (x_shape[axes[0]], y_shape[axes[1]])) if x_ndim == 2: - x_expand = oopb.add_node('Unsqueeze', - [normalized_input_names[0]], - operator.inputs[0].full_name + '_expand', - axes=[1]) + x_expand = oopb.apply_unsqueeze([normalized_input_names[0]], + operator.inputs[0].full_name + '_expand', + axes=[1])[0] a0 += 1 x_ndim += 1 else: x_expand = normalized_input_names[0] if y_ndim == 2: - y_expand = oopb.add_node('Unsqueeze', - [normalized_input_names[1]], - operator.inputs[1].full_name + '_expand', - axes=[2]) + y_expand = oopb.apply_unsqueeze([normalized_input_names[1]], + operator.inputs[1].full_name + '_expand', + axes=[2])[0] y_ndim += 1 else: y_expand = normalized_input_names[1] @@ -421,11 +416,17 @@ def convert_keras_dot_post_224(scope, operator, container): # if the inputs were originally rank 2, we remove the added 1 dim. if orig_x_ndim == 2: - container.add_node('Squeeze', output_reshape, operator.output_full_names, - name=scope.get_unique_operator_name('Squeeze'), axes=[1]) + oopb.apply_op_with_output("apply_squeeze", + output_reshape, + operator.output_full_names, + name=operator.full_name + '_squeeze', + axes=[1]) elif orig_y_ndim == 2: - container.add_node('Squeeze', output_reshape, operator.output_full_names, - name=scope.get_unique_operator_name('Squeeze'), axes=[y_ndim - 1]) + oopb.apply_op_with_output("apply_squeeze", + output_reshape, + operator.output_full_names, + name=operator.full_name + '_squeeze', + axes=[y_ndim - 1]) else: container.add_node('Identity', output_reshape, operator.output_full_names, name=scope.get_unique_operator_name('Identity')) diff --git a/keras2onnx/ke2onnx/main.py b/keras2onnx/ke2onnx/main.py index ab761810..6f6ac2f1 100644 --- a/keras2onnx/ke2onnx/main.py +++ b/keras2onnx/ke2onnx/main.py @@ -104,15 +104,13 @@ def convert_keras_masking(scope, operator, container): not_o = _apply_not_equal(oopb, container.target_opset, operator) cast_o = oopb.apply_cast(not_o, to=oopb.float, name=operator.full_name + '_cast') if operator.output_masks: - reduce_node = oopb.add_node("ReduceSum", - cast_o[0], keepdims=False, axes=[-1], name=operator.full_name + '_reduced') - oopb.add_node_with_output("Greater", [reduce_node, np.array(0, dtype=np.float32)], + reduce_node = oopb.apply_reducesum(cast_o[0], keepdims=False, axes=[-1], name=operator.full_name + '_reduced') + oopb.add_node_with_output("Greater", reduce_node + [np.array(0, dtype=np.float32)], [operator.output_masks[0].full_name], name=operator.full_name + '_greater') - reduce_node2 = oopb.add_node("ReduceSum", - cast_o, keepdims=True, axes=[-1], name=operator.full_name + 'reduced2') + reduce_node2 = oopb.apply_reducesum(cast_o, keepdims=True, axes=[-1], name=operator.full_name + 'reduced2') greater_o = oopb.add_node("Greater", - [reduce_node2, np.array(0, dtype=np.float32)], name=operator.full_name + '_greater2') + reduce_node2 + [np.array(0, dtype=np.float32)], name=operator.full_name + '_greater2') cast2_o = oopb.apply_cast(greater_o, to=oopb.float, name=operator.full_name + '_cast2') oopb.add_node_with_output('Mul', [cast2_o[0], operator.inputs[0].full_name], [operator.outputs[0].full_name], diff --git a/keras2onnx/ke2onnx/merge.py b/keras2onnx/ke2onnx/merge.py index 38ddfc66..3da454b8 100644 --- a/keras2onnx/ke2onnx/merge.py +++ b/keras2onnx/ke2onnx/merge.py @@ -50,16 +50,16 @@ def convert_keras_merge_layer(scope, operator, container): oopb = OnnxOperatorBuilder(container, scope) expanded = [] for idx_, i_ in enumerate(operator.input_masks): - expanded.append(oopb.add_node('Unsqueeze', i_.full_name, i_.full_name + '_i' + str(idx_), axes=[0])) + expanded.extend(oopb.apply_unsqueeze(i_.full_name, i_.full_name + '_i' + str(idx_), axes=[0])) if len(expanded) > 1: concat = oopb.apply_concat(expanded, name=operator.full_name + '_concat') else: concat = expanded[0] cast = oopb.add_node('Cast', concat, name=operator.full_name + '_cast', to=1) - reduced = oopb.add_node('ReduceSum', cast, name=operator.full_name + '_reduced', op_version=1, axes=[0], - keepdims=0) + reduced = oopb.apply_reducesum(cast, name=operator.full_name + '_reduced', axes=[0], + keepdims=0) oopb.apply_op_with_output('apply_greater', - [reduced, np.array([0], dtype=np.float32)], + reduced + [np.array([0], dtype=np.float32)], [operator.output_masks[0].full_name], name=operator.raw_operator.name) diff --git a/keras2onnx/ke2onnx/simplernn.py b/keras2onnx/ke2onnx/simplernn.py index 20f0ab84..417e46b8 100644 --- a/keras2onnx/ke2onnx/simplernn.py +++ b/keras2onnx/ke2onnx/simplernn.py @@ -7,7 +7,6 @@ from ..proto import onnx_proto, keras from ..common import name_func from ..common.onnx_ops import ( - apply_cast, apply_concat, apply_reshape, apply_slice, @@ -163,9 +162,13 @@ def build_sequence_lengths(scope, operator, container): mask_cast = scope.get_unique_operator_name(operator.full_name + '_mask_cast') sequence_lengths = scope.get_unique_operator_name(operator.full_name + '_seq_lens') - apply_cast(scope, input_mask_name, mask_cast, container, to=TensorProto.INT32) - container.add_node('ReduceSum', mask_cast, sequence_lengths, keepdims=False, axes=[-1]) - return sequence_lengths + oopb = OnnxOperatorBuilder(container, scope) + mask_cast = oopb.apply_cast(input_mask_name, + to=oopb.int32, + name=operator.full_name + 'cast') + sequence_lengths = oopb.apply_reducesum(mask_cast, name=operator.full_name + '_reduced', axes=[-1], + keepdims=False) + return sequence_lengths[0] def build_initial_states(scope, operator, container, bidirectional=False): diff --git a/keras2onnx/topology.py b/keras2onnx/topology.py index 234072f8..a7af8b7e 100644 --- a/keras2onnx/topology.py +++ b/keras2onnx/topology.py @@ -323,7 +323,8 @@ def convert_topology(topology, model_name, doc_string, target_opset, channel_fir if target_opset < 9: nodes = onnxconverter_common.optimizer.optimize_onnx(nodes, nchw_inputs=nchw_inputs, inputs=container.inputs + extra_inputs, - outputs=container.outputs) + outputs=container.outputs, + target_opset=container.target_opset) node_number = len(nodes) else: graph = onnxconverter_common.optimizer.optimize_onnx_graph(nodes, nchw_inputs=nchw_inputs,