diff --git a/sac/algos/sac.py b/sac/algos/sac.py index ea3d1e9..8d3a326 100644 --- a/sac/algos/sac.py +++ b/sac/algos/sac.py @@ -310,10 +310,10 @@ def _init_actor_update(self): min_log_target = tf.minimum(log_target1, log_target2) if self._reparameterize: - policy_kl_loss = tf.reduce_mean(log_pi - log_target1) + policy_kl_loss = tf.reduce_mean(log_pi - min_log_target) else: policy_kl_loss = tf.reduce_mean(log_pi * tf.stop_gradient( - log_pi - log_target1 + self._vf_t - policy_prior_log_probs)) + log_pi - min_log_target + self._vf_t - policy_prior_log_probs)) policy_regularization_losses = tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES,