From 16a42f59f398526062e66a31f73da9383d7adbd0 Mon Sep 17 00:00:00 2001 From: David Fan <30608893+jiafatom@users.noreply.github.com> Date: Wed, 16 Dec 2020 21:42:50 -0800 Subject: [PATCH] Set axes=[0] for apply_squeeze rather than leaving unset (#668) --- keras2onnx/ke2onnx/lstm.py | 6 +++--- keras2onnx/ke2onnx/simplernn.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/keras2onnx/ke2onnx/lstm.py b/keras2onnx/ke2onnx/lstm.py index 26cd0591..c02144ca 100644 --- a/keras2onnx/ke2onnx/lstm.py +++ b/keras2onnx/ke2onnx/lstm.py @@ -224,7 +224,7 @@ def build_output_states(scope, operator, container, output_names, bidirectional= squeeze_names.extend(list(zip(split_names, outputs))) for split_name, output_name in squeeze_names: - apply_squeeze(scope, split_name, output_name, container) + apply_squeeze(scope, split_name, output_name, container, axes=[0]) else: output_state = op.return_state @@ -234,8 +234,8 @@ def build_output_states(scope, operator, container, output_names, bidirectional= output_h = operator.outputs[1].full_name output_c = operator.outputs[2].full_name - apply_squeeze(scope, lstm_h, output_h, container) - apply_squeeze(scope, lstm_c, output_c, container) + apply_squeeze(scope, lstm_h, output_h, container, axes=[0]) + apply_squeeze(scope, lstm_c, output_c, container, axes=[0]) def _calculate_keras_lstm_output_shapes(operator): diff --git a/keras2onnx/ke2onnx/simplernn.py b/keras2onnx/ke2onnx/simplernn.py index 4b648c50..20f0ab84 100644 --- a/keras2onnx/ke2onnx/simplernn.py +++ b/keras2onnx/ke2onnx/simplernn.py @@ -391,14 +391,14 @@ def build_output_states(scope, operator, container, output_names, bidirectional= apply_split(scope, rnn_h, split_names, container) for split_name, output_name in zip(split_names, output_names): - apply_squeeze(scope, split_name, output_name, container) + apply_squeeze(scope, split_name, output_name, container, axes=[0]) else: output_state = op.return_state if output_state: output_h = operator.outputs[1].full_name - apply_squeeze(scope, rnn_h, output_h, container) + apply_squeeze(scope, rnn_h, output_h, container, axes=[0]) def is_time_major(op, bidirectional):