Skip to content

Commit

Permalink
add an option to use target_critic or critic when critic values are n…
Browse files Browse the repository at this point in the history
…eeded during rollout or training
  • Loading branch information
runjerry committed Oct 29, 2024
1 parent 33c9b27 commit d76576e
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 7 deletions.
21 changes: 17 additions & 4 deletions alf/algorithms/oaec_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def __init__(self,
# align_optimization_noise=False,
beta_ub=1.0,
beta_lb=0.5,
output_target_critic=True,
use_target_actor=True,
target_update_tau=0.05,
target_update_period=1,
Expand Down Expand Up @@ -185,6 +186,9 @@ def __init__(self,
beta_ub (float): parameter for computing the upperbound of Q value:
:math:`Q_ub(s,a) = \mu_Q(s,a) + \beta_ub * \sigma_Q(s,a)`
beta_lb
output_target_critic (bool): whether to use the target critic output
whenever critic values are needed, such as explorative rollout
and actor training.
use_target_actor (bool): whether to use target actor for actor.
rollout_random_action (float): the probability of taking a uniform
random action during a ``rollout_step()``. 0 means always directly
Expand Down Expand Up @@ -226,6 +230,7 @@ def __init__(self,
self._reward_noise_scale = reward_noise_scale
self._beta_ub = beta_ub
self._beta_lb = beta_lb
self._output_target_critic = output_target_critic
self._use_target_actor = use_target_actor
self._num_rollout_sampled_actions = num_rollout_sampled_actions
self._bootstrap_mask_prob = bootstrap_mask_prob
Expand Down Expand Up @@ -334,8 +339,12 @@ def _predict_action(self,

## Step 2: forward critic_network to get the Q_values
# [n_sampled * n_env, n_opt_ptb + n_bootstrap + 1]
q_values, critic_states = self._target_critic_networks(
(critic_observations, critic_actions), state=state.critics)
if self._output_target_critic:
q_values, critic_states = self._target_critic_networks(
(critic_observations, critic_actions), state=state.critics)
else:
q_values, critic_states = self._critic_networks(
(critic_observations, critic_actions), state=state.critics)
# [n_sampled * n_env, n_bootstrap]
q_bootstrap = q_values[:, 1:1 + self._num_bootstrap_critics]
# [n_sampled * n_env, n_opt_ptb]
Expand Down Expand Up @@ -447,8 +456,12 @@ def _critic_train_step(self, inputs: TimeStep, state: OaecCriticState,
return state, info

def _actor_train_step(self, inputs: TimeStep, state, action):
q_values, critic_states = self._target_critic_networks(
(inputs.observation, action), state=state)
if self._output_target_critic:
q_values, critic_states = self._target_critic_networks(
(inputs.observation, action), state=state)
else:
q_values, critic_states = self._critic_networks(
(inputs.observation, action), state=state)
if self.has_multidim_reward():
# Multidimensional reward: [B, replicas, reward_dim]
q_values = q_values * self.reward_weights
Expand Down
7 changes: 4 additions & 3 deletions alf/examples/oaec_dmc.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"desc": [
"DM Control tasks with 4 seeds on each environment"
],
"version": "hopper_oaec_a4371_nb_2-ces",
"version": "hopper_oaec_33c9b_nb_2-ces",
"use_gpu": true,
"gpus": [
0,
Expand All @@ -15,13 +15,14 @@
"hopper-hop"
],
"OaecAlgorithm.beta_ub": "[1.]",
"OaecAlgorithm.beta_lb": "[.5]",
"OaecAlgorithm.beta_lb": "[.2, .1]",
"OaecAlgorithm.output_target_critic": "[False]",
"OaecAlgorithm.num_rollout_sampled_actions": "[10]",
"OaecAlgorithm.num_bootstrap_critics": "[2]",
"OaecAlgorithm.bootstrap_mask_prob": "[0.8]",
"OaecAlgorithm.use_target_actor": "[False]",
"OaecAlgorithm.target_update_tau": "[0.005]",
"TrainerConfig.random_seed": "list(range(4))"
"TrainerConfig.random_seed": "list(range(2))"
}
}

1 change: 1 addition & 0 deletions alf/examples/oaec_dmc_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
critic_network_cls=critic_network_cls,
beta_ub=1.,
beta_lb=.5,
output_target_critic=True,
reward_noise_scale=None,
num_rollout_sampled_actions=10,
num_bootstrap_critics=2,
Expand Down

0 comments on commit d76576e

Please sign in to comment.