Skip to content
This repository has been archived by the owner on Oct 13, 2021. It is now read-only.

Commit

Permalink
Fix tf.keras CI build on RNN recurrent_v2 (#661)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiafatom authored Dec 4, 2020
1 parent 577d13d commit ba10e6e
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 4 deletions.
5 changes: 2 additions & 3 deletions keras2onnx/ke2onnx/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
# license information.
###############################################################################
import numpy as np
from ..proto import keras, is_tf_keras, is_keras_older_than
from ..proto.tfcompat import is_tf2
from ..proto import keras, is_tf_keras, is_keras_older_than, is_tensorflow_later_than
from ..common import with_variable, k2o_logger
from ..common.onnx_ops import OnnxOperatorBuilder

Expand Down Expand Up @@ -261,7 +260,7 @@ def convert_keras_training_only_layer(scope, operator, container):
_adv_activations.ReLU: convert_keras_advanced_activation,
})

if is_tf_keras and is_tf2:
if is_tf_keras and is_tensorflow_later_than("1.14.0"):
keras_layer_to_operator.update({
_layer.recurrent.GRU: convert_keras_gru,
_layer.recurrent.LSTM: convert_keras_lstm,
Expand Down
1 change: 0 additions & 1 deletion tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2122,7 +2122,6 @@ def test_bidirectional_seqlen_none(runner, rnn_class):
assert runner(onnx_model.graph.name, onnx_model, x, expected)


@pytest.mark.skipif(is_tf2, reason='TODO')
@pytest.mark.parametrize("rnn_class", RNN_CLASSES)
def test_rnn_state_passing(runner, rnn_class):
input1 = Input(shape=(None, 5))
Expand Down

0 comments on commit ba10e6e

Please sign in to comment.