diff --git a/keras_contrib/layers/crf.py b/keras_contrib/layers/crf.py index 88a64ac69..33a9a38ef 100644 --- a/keras_contrib/layers/crf.py +++ b/keras_contrib/layers/crf.py @@ -14,7 +14,7 @@ from keras_contrib.losses import crf_loss from keras_contrib.metrics import crf_marginal_accuracy from keras_contrib.metrics import crf_viterbi_accuracy -from keras_contrib.utils.test_utils import to_tuple +from keras_contrib.utils.test_utils import to_tuple, is_tf_keras class CRF(Layer): @@ -460,7 +460,10 @@ def step(self, input_energy_t, states, return_logZ=True): if K.backend() == 'theano': m = states[3][:, t:(t + 2)] else: - m = K.slice(states[3], [0, t], [-1, 2]) + if is_tf_keras: + m = tf.slice(states[3], [0, t], [-1, 2]) + else: + m = K.slice(states[3], [0, t], [-1, 2]) input_energy_t = input_energy_t * K.expand_dims(m[:, 0]) # (1, F, F)*(B, 1, 1) -> (B, F, F) chain_energy = chain_energy * K.expand_dims( @@ -468,7 +471,11 @@ def step(self, input_energy_t, states, return_logZ=True): if return_logZ: # shapes: (1, B, F) + (B, F, 1) -> (B, F, F) energy = chain_energy + K.expand_dims(input_energy_t - prev_target_val, 2) - new_target_val = K.logsumexp(-energy, 1) # shapes: (B, F) + if is_tf_keras: + import tensorflow as tf + new_target_val = tf.reduce_logsumexp(-energy, 1) # shapes: (B, F) + else: + new_target_val = K.logsumexp(-energy, 1) # shapes: (B, F) return new_target_val, [new_target_val, i + 1] else: energy = chain_energy + K.expand_dims(input_energy_t + prev_target_val, 2) diff --git a/keras_contrib/losses/crf_losses.py b/keras_contrib/losses/crf_losses.py index c9daed1d2..cc087a672 100644 --- a/keras_contrib/losses/crf_losses.py +++ b/keras_contrib/losses/crf_losses.py @@ -29,8 +29,8 @@ def crf_nll(y_true, y_pred): raise TypeError('When learn_model="join", CRF must be the last layer.') if crf.sparse_target: y_true = K.one_hot(K.cast(y_true[:, :, 0], 'int32'), crf.units) - X = crf._inbound_nodes[idx].input_tensors[0] - mask = crf._inbound_nodes[idx].input_masks[0] + X = crf.get_input_at(idx) + mask = crf.get_input_mask_at(idx) nloglik = crf.get_negative_log_likelihood(y_true, X, mask) return nloglik diff --git a/keras_contrib/metrics/crf_accuracies.py b/keras_contrib/metrics/crf_accuracies.py index 13fdac0d7..abfb0f0c9 100644 --- a/keras_contrib/metrics/crf_accuracies.py +++ b/keras_contrib/metrics/crf_accuracies.py @@ -19,8 +19,8 @@ def crf_viterbi_accuracy(y_true, y_pred): '''Use Viterbi algorithm to get best path, and compute its accuracy. `y_pred` must be an output from CRF.''' crf, idx = y_pred._keras_history[:2] - X = crf._inbound_nodes[idx].input_tensors[0] - mask = crf._inbound_nodes[idx].input_masks[0] + X = crf.get_input_at(idx) + mask = crf.get_input_mask_at(idx) y_pred = crf.viterbi_decoding(X, mask) return _get_accuracy(y_true, y_pred, mask, crf.sparse_target) @@ -29,8 +29,8 @@ def crf_marginal_accuracy(y_true, y_pred): '''Use time-wise marginal argmax as prediction. `y_pred` must be an output from CRF with `learn_mode="marginal"`.''' crf, idx = y_pred._keras_history[:2] - X = crf._inbound_nodes[idx].input_tensors[0] - mask = crf._inbound_nodes[idx].input_masks[0] + X = crf.get_input_at(idx) + mask = crf.get_input_mask_at(idx) y_pred = crf.get_marginal_prob(X, mask) return _get_accuracy(y_true, y_pred, mask, crf.sparse_target)