From 4a3c7dce364b31ff738e6b87dd4cf9dd844e35f5 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Wed, 29 Nov 2023 21:48:18 +0100 Subject: [PATCH] Use minimum of Q-functions for KL loss * Cf. Equation 13 paper --- sac/algos/sac.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sac/algos/sac.py b/sac/algos/sac.py index ea3d1e9..8d3a326 100644 --- a/sac/algos/sac.py +++ b/sac/algos/sac.py @@ -310,10 +310,10 @@ def _init_actor_update(self): min_log_target = tf.minimum(log_target1, log_target2) if self._reparameterize: - policy_kl_loss = tf.reduce_mean(log_pi - log_target1) + policy_kl_loss = tf.reduce_mean(log_pi - min_log_target) else: policy_kl_loss = tf.reduce_mean(log_pi * tf.stop_gradient( - log_pi - log_target1 + self._vf_t - policy_prior_log_probs)) + log_pi - min_log_target + self._vf_t - policy_prior_log_probs)) policy_regularization_losses = tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES,