diff --git a/alf/algorithms/oaec_algorithm.py b/alf/algorithms/oaec_algorithm.py index cf946f0e9..85df0f676 100644 --- a/alf/algorithms/oaec_algorithm.py +++ b/alf/algorithms/oaec_algorithm.py @@ -43,16 +43,20 @@ "OaecCriticInfo", ["q_values", "target_q_values"], default_value=()) OaecActionState = namedtuple( "OaecActionState", ['actor_network', 'critics'], default_value=()) +OaecConsActorState = namedtuple( + "OaecConsActorState", ['cons_actor_network', 'critics'], default_value=()) OaecState = namedtuple( - "OaecState", ['action', 'actor', 'target_actor', 'critics'], default_value=()) + "OaecState", ['action', 'actor', 'target_actor', 'conservative_actor', 'critics'], + default_value=()) OaecInfo = namedtuple( "OaecInfo", [ "reward", "reward_noise", "mask", "step_type", "discount", - "action", "action_distribution", "actor_loss", "critic", "discounted_return" + "action", "action_distribution", "actor_loss", "cons_actor_loss", + "critic", "discounted_return" ], default_value=()) OaecLossInfo = namedtuple( - 'OaecLossInfo', ('actor', 'critic'), default_value=()) + 'OaecLossInfo', ('actor', 'conservative_actor', 'critic'), default_value=()) @alf.configurable @@ -121,6 +125,7 @@ def __init__(self, beta_lb=0.5, conservative_actor_training=False, conservative_critic_training=True, + separate_conservative_actor=False, output_target_critic=True, use_target_actor=True, std_for_overestimate='tot', @@ -130,6 +135,7 @@ def __init__(self, dqda_clipping=None, action_l2=0, actor_optimizer=None, + conservative_actor_optimizer=None, critic_optimizer=None, debug_summaries=False, name="OaecAlgorithm"): @@ -207,6 +213,8 @@ def __init__(self, conservative critic values. conservative_critic_training (bool): whether to train critic using conservative target critic values. + separate_conservative_actor (bool): whether to train a separate + conservative actor for generating target critic actions. output_target_critic (bool): whether to use the target critic output whenever critic values are needed, such as explorative rollout and actor training. @@ -220,6 +228,8 @@ def __init__(self, Does not perform clipping if ``dqda_clipping == 0``. action_l2 (float): weight of squared action l2-norm on actor loss. actor_optimizer (torch.optim.optimizer): The optimizer for actor. + conservative_actor_optimizer (torch.optim.optimizer): The optimizer for + conservative_actor, only used if separate_conservative_actor is True. critic_optimizer (torch.optim.optimizer): The optimizer for critic. debug_summaries (bool): True if debug summaries should be created. name (str): The name of this algorithm. @@ -258,6 +268,7 @@ def __init__(self, self._beta_lb = beta_lb self._conservative_actor_training = conservative_actor_training self._conservative_critic_training = conservative_critic_training + self._separate_conservative_actor = separate_conservative_actor if output_target_critic: self._output_critic_name = 'q' else: @@ -280,10 +291,15 @@ def __init__(self, action_state_spec = OaecActionState( actor_network=actor_network.state_spec, critics=critic_networks.state_spec) + cons_actor_state_spec = OaecConsActorState( + cons_actor_network=actor_network.state_spec, + critics=critic_networks.state_spec) + train_state_spec = OaecState( action=action_state_spec, actor=critic_networks.state_spec, target_actor=actor_network.state_spec, + conservative_actor=cons_actor_state_spec, critics=OaecCriticState( critics=critic_networks.state_spec, target_critics=critic_networks.state_spec)) @@ -308,6 +324,13 @@ def __init__(self, self._actor_network = actor_network self._critic_networks = critic_networks + if self._separate_conservative_actor: + self._conservative_actor_network = actor_network.copy( + name='conservative_actor_network') + if conservative_actor_optimizer is not None: + self.add_optimizer(conservative_actor_optimizer, + [self._conservative_actor_network]) + self._target_critic_networks = critic_networks.copy( name='target_critic_networks') original_models = [self._critic_networks] @@ -473,6 +496,8 @@ def rollout_step(self, inputs: TimeStep, state=None): new_state = OaecState( action=action_state, actor=state.actor, + target_actor=state.target_actor, + conservative_actor=state.conservative_actor, critics=state.critics) info = OaecRolloutInfo(action=action, mask=mask) @@ -544,7 +569,8 @@ def _critic_train_step(self, inputs: TimeStep, state: OaecCriticState, return state, info - def _actor_train_step(self, inputs: TimeStep, state, action): + def _actor_train_step(self, inputs: TimeStep, state, action, + conservative_training=False): if self._output_target_critic: q_values, critic_states = self._target_critic_networks( (inputs.observation, action), state=state) @@ -556,7 +582,7 @@ def _actor_train_step(self, inputs: TimeStep, state, action): q_values = q_values * self.reward_weights # use the mean of default and bootstrapped target critics q_value = q_values[:, :1 + self._num_bootstrap_critics].mean(-1) - if self._conservative_actor_training: + if conservative_training: if self._std_for_overestimate == 'tot': q_bootstrap = q_values[:, 1:1 + self._num_bootstrap_critics] q_bootstrap_diff = q_bootstrap - q_values[:, :1] @@ -594,12 +620,32 @@ def train_step(self, inputs: TimeStep, state: OaecState, action_dist, action, action_state = self._predict_action( inputs.observation, state=state.action) actor_state, actor_loss_info = self._actor_train_step( - inputs=inputs, state=state.actor, action=action) + inputs=inputs, state=state.actor, action=action, + conservative_training=not self._separate_conservative_actor and + self._conservative_actor_training) # collect infor for critic_networks training - target_actor_state = () target_critic_action = action target_action_dist = action_dist + + # train conservative_actor_network + if self._separate_conservative_actor: + cons_action_dist, cons_actor_state = self._conservative_actor_network( + inputs.observation, state=state.conservative_actor.cons_actor_network) + cons_action = dist_utils.rsample_action_distribution(cons_action_dist) + cons_critic_states, cons_actor_loss_info = self._actor_train_step( + inputs=inputs, state=state.conservative_actor.critics, + action=cons_action, conservative_training=True) + cons_actor_state = OaecConsActorState( + cons_actor_network=cons_actor_state, + critics=cons_critic_states) + target_critic_action = cons_action + target_action_dist = cons_action_dist + else: + cons_actor_state = state.conservative_actor + cons_actor_loss_info = LossInfo() + + target_actor_state = () if self._use_target_actor: target_action_dist, target_actor_state = self._target_actor_network( inputs.observation, state=state.target_actor) @@ -611,7 +657,8 @@ def train_step(self, inputs: TimeStep, state: OaecState, state = OaecState( action=action_state, actor=actor_state, - target_actor=target_actor_state, critics=critic_states), + target_actor=target_actor_state, + conservative_actor=cons_actor_state, critics=critic_states), info = OaecInfo( reward=inputs.reward, reward_noise=rollout_info.reward_noise, @@ -620,7 +667,8 @@ def train_step(self, inputs: TimeStep, state: OaecState, discount=inputs.discount, action_distribution=action_dist, critic=critic_info, - actor_loss=actor_loss_info) + actor_loss=actor_loss_info, + cons_actor_loss=cons_actor_loss_info) return AlgStep(output=action, state=state, info=info) def calc_loss(self, info: OaecInfo): @@ -710,6 +758,7 @@ def calc_loss(self, info: OaecInfo): priority = () actor_loss = info.actor_loss + cons_actor_loss = info.cons_actor_loss if self._debug_summaries and alf.summary.should_record_summaries(): with alf.summary.scope(self._name): @@ -720,9 +769,12 @@ def calc_loss(self, info: OaecInfo): safe_mean_hist_summary("opt_ptb_weights", self._opt_ptb_weights) return LossInfo( - loss=critic_loss + actor_loss.loss, + loss=critic_loss + math_ops.add_ignore_empty( + actor_loss.loss, cons_actor_loss.loss), priority=priority, - extra=OaecLossInfo(critic=critic_loss, actor=actor_loss.extra)) + extra=OaecLossInfo(critic=critic_loss, + actor=actor_loss.extra, + conservative_actor=cons_actor_loss.extra)) def after_update(self, root_inputs, info: OaecInfo): self._update_target() diff --git a/alf/examples/oaec_dmc.json b/alf/examples/oaec_dmc.json index 9b8836f8b..05b7f13ab 100644 --- a/alf/examples/oaec_dmc.json +++ b/alf/examples/oaec_dmc.json @@ -2,7 +2,7 @@ "desc": [ "DM Control tasks with 4 seeds on each environment" ], - "version": "hopper_oaec_43224_nb_2-ces", + "version": "hopper_oaec_bfce1_nb_2-ces", "use_gpu": true, "gpus": [ 0, 1 @@ -17,6 +17,7 @@ "OaecAlgorithm.beta_lb": "[.1]", "OaecAlgorithm.conservative_actor_training": "[True]", "OaecAlgorithm.conservative_critic_training": "[False]", + "OaecAlgorithm.separate_conservative_actor": "[True]", "OaecAlgorithm.output_target_critic": "[True]", "OaecAlgorithm.num_sampled_target_q_actions": "[0]", "OaecAlgorithm.target_q_from_sampled_actions": "['mean']", diff --git a/alf/examples/oaec_dmc_conf.py b/alf/examples/oaec_dmc_conf.py index 5e30447ad..105302ee7 100644 --- a/alf/examples/oaec_dmc_conf.py +++ b/alf/examples/oaec_dmc_conf.py @@ -36,6 +36,7 @@ beta_lb=.5, conservative_actor_training=False, conservative_critic_training=True, + separate_conservative_actor=False, output_target_critic=True, std_for_overestimate='tot', opt_ptb_single_data=True, @@ -49,6 +50,7 @@ bootstrap_mask_prob=0.8, use_target_actor=True, actor_optimizer=AdamTF(lr=3e-4), + conservative_actor_optimizer=AdamTF(lr=3e-4), critic_optimizer=AdamTF(lr=3e-4), target_update_tau=0.005)