diff --git a/keras2onnx/ke2onnx/main.py b/keras2onnx/ke2onnx/main.py index e07b2595..ab761810 100644 --- a/keras2onnx/ke2onnx/main.py +++ b/keras2onnx/ke2onnx/main.py @@ -4,8 +4,7 @@ # license information. ############################################################################### import numpy as np -from ..proto import keras, is_tf_keras, is_keras_older_than -from ..proto.tfcompat import is_tf2 +from ..proto import keras, is_tf_keras, is_keras_older_than, is_tensorflow_later_than from ..common import with_variable, k2o_logger from ..common.onnx_ops import OnnxOperatorBuilder @@ -261,7 +260,7 @@ def convert_keras_training_only_layer(scope, operator, container): _adv_activations.ReLU: convert_keras_advanced_activation, }) -if is_tf_keras and is_tf2: +if is_tf_keras and is_tensorflow_later_than("1.14.0"): keras_layer_to_operator.update({ _layer.recurrent.GRU: convert_keras_gru, _layer.recurrent.LSTM: convert_keras_lstm, diff --git a/tests/test_layers.py b/tests/test_layers.py index 591dfae4..3df6954b 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -2122,7 +2122,6 @@ def test_bidirectional_seqlen_none(runner, rnn_class): assert runner(onnx_model.graph.name, onnx_model, x, expected) -@pytest.mark.skipif(is_tf2, reason='TODO') @pytest.mark.parametrize("rnn_class", RNN_CLASSES) def test_rnn_state_passing(runner, rnn_class): input1 = Input(shape=(None, 5))