Skip to content
This repository has been archived by the owner on Oct 13, 2021. It is now read-only.

Commit

Permalink
Set axes=[0] for apply_squeeze rather than leaving unset (#668)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiafatom authored Dec 17, 2020
1 parent 6d799ad commit 16a42f5
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
6 changes: 3 additions & 3 deletions keras2onnx/ke2onnx/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def build_output_states(scope, operator, container, output_names, bidirectional=
squeeze_names.extend(list(zip(split_names, outputs)))

for split_name, output_name in squeeze_names:
apply_squeeze(scope, split_name, output_name, container)
apply_squeeze(scope, split_name, output_name, container, axes=[0])

else:
output_state = op.return_state
Expand All @@ -234,8 +234,8 @@ def build_output_states(scope, operator, container, output_names, bidirectional=

output_h = operator.outputs[1].full_name
output_c = operator.outputs[2].full_name
apply_squeeze(scope, lstm_h, output_h, container)
apply_squeeze(scope, lstm_c, output_c, container)
apply_squeeze(scope, lstm_h, output_h, container, axes=[0])
apply_squeeze(scope, lstm_c, output_c, container, axes=[0])


def _calculate_keras_lstm_output_shapes(operator):
Expand Down
4 changes: 2 additions & 2 deletions keras2onnx/ke2onnx/simplernn.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,14 +391,14 @@ def build_output_states(scope, operator, container, output_names, bidirectional=
apply_split(scope, rnn_h, split_names, container)

for split_name, output_name in zip(split_names, output_names):
apply_squeeze(scope, split_name, output_name, container)
apply_squeeze(scope, split_name, output_name, container, axes=[0])

else:
output_state = op.return_state

if output_state:
output_h = operator.outputs[1].full_name
apply_squeeze(scope, rnn_h, output_h, container)
apply_squeeze(scope, rnn_h, output_h, container, axes=[0])


def is_time_major(op, bidirectional):
Expand Down

0 comments on commit 16a42f5

Please sign in to comment.