Skip to content

Commit

Permalink
add an option to use a separate conservative actor just for generatin…
Browse files Browse the repository at this point in the history
…g target critic actions
  • Loading branch information
runjerry committed Nov 13, 2024
1 parent bfce1e9 commit b93160a
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 12 deletions.
74 changes: 63 additions & 11 deletions alf/algorithms/oaec_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,20 @@
"OaecCriticInfo", ["q_values", "target_q_values"], default_value=())
OaecActionState = namedtuple(
"OaecActionState", ['actor_network', 'critics'], default_value=())
OaecConsActorState = namedtuple(
"OaecConsActorState", ['cons_actor_network', 'critics'], default_value=())
OaecState = namedtuple(
"OaecState", ['action', 'actor', 'target_actor', 'critics'], default_value=())
"OaecState", ['action', 'actor', 'target_actor', 'conservative_actor', 'critics'],
default_value=())
OaecInfo = namedtuple(
"OaecInfo", [
"reward", "reward_noise", "mask", "step_type", "discount",
"action", "action_distribution", "actor_loss", "critic", "discounted_return"
"action", "action_distribution", "actor_loss", "cons_actor_loss",
"critic", "discounted_return"
],
default_value=())
OaecLossInfo = namedtuple(
'OaecLossInfo', ('actor', 'critic'), default_value=())
'OaecLossInfo', ('actor', 'conservative_actor', 'critic'), default_value=())


@alf.configurable
Expand Down Expand Up @@ -121,6 +125,7 @@ def __init__(self,
beta_lb=0.5,
conservative_actor_training=False,
conservative_critic_training=True,
separate_conservative_actor=False,
output_target_critic=True,
use_target_actor=True,
std_for_overestimate='tot',
Expand All @@ -130,6 +135,7 @@ def __init__(self,
dqda_clipping=None,
action_l2=0,
actor_optimizer=None,
conservative_actor_optimizer=None,
critic_optimizer=None,
debug_summaries=False,
name="OaecAlgorithm"):
Expand Down Expand Up @@ -207,6 +213,8 @@ def __init__(self,
conservative critic values.
conservative_critic_training (bool): whether to train critic using
conservative target critic values.
separate_conservative_actor (bool): whether to train a separate
conservative actor for generating target critic actions.
output_target_critic (bool): whether to use the target critic output
whenever critic values are needed, such as explorative rollout
and actor training.
Expand All @@ -220,6 +228,8 @@ def __init__(self,
Does not perform clipping if ``dqda_clipping == 0``.
action_l2 (float): weight of squared action l2-norm on actor loss.
actor_optimizer (torch.optim.optimizer): The optimizer for actor.
conservative_actor_optimizer (torch.optim.optimizer): The optimizer for
conservative_actor, only used if separate_conservative_actor is True.
critic_optimizer (torch.optim.optimizer): The optimizer for critic.
debug_summaries (bool): True if debug summaries should be created.
name (str): The name of this algorithm.
Expand Down Expand Up @@ -258,6 +268,7 @@ def __init__(self,
self._beta_lb = beta_lb
self._conservative_actor_training = conservative_actor_training
self._conservative_critic_training = conservative_critic_training
self._separate_conservative_actor = separate_conservative_actor
if output_target_critic:
self._output_critic_name = 'q'
else:
Expand All @@ -280,10 +291,15 @@ def __init__(self,
action_state_spec = OaecActionState(
actor_network=actor_network.state_spec,
critics=critic_networks.state_spec)
cons_actor_state_spec = OaecConsActorState(
cons_actor_network=actor_network.state_spec,
critics=critic_networks.state_spec)

train_state_spec = OaecState(
action=action_state_spec,
actor=critic_networks.state_spec,
target_actor=actor_network.state_spec,
conservative_actor=cons_actor_state_spec,
critics=OaecCriticState(
critics=critic_networks.state_spec,
target_critics=critic_networks.state_spec))
Expand All @@ -308,6 +324,13 @@ def __init__(self,
self._actor_network = actor_network
self._critic_networks = critic_networks

if self._separate_conservative_actor:
self._conservative_actor_network = actor_network.copy(
name='conservative_actor_network')
if conservative_actor_optimizer is not None:
self.add_optimizer(conservative_actor_optimizer,
[self._conservative_actor_network])

self._target_critic_networks = critic_networks.copy(
name='target_critic_networks')
original_models = [self._critic_networks]
Expand Down Expand Up @@ -473,6 +496,8 @@ def rollout_step(self, inputs: TimeStep, state=None):
new_state = OaecState(
action=action_state,
actor=state.actor,
target_actor=state.target_actor,
conservative_actor=state.conservative_actor,
critics=state.critics)
info = OaecRolloutInfo(action=action,
mask=mask)
Expand Down Expand Up @@ -544,7 +569,8 @@ def _critic_train_step(self, inputs: TimeStep, state: OaecCriticState,

return state, info

def _actor_train_step(self, inputs: TimeStep, state, action):
def _actor_train_step(self, inputs: TimeStep, state, action,
conservative_training=False):
if self._output_target_critic:
q_values, critic_states = self._target_critic_networks(
(inputs.observation, action), state=state)
Expand All @@ -556,7 +582,7 @@ def _actor_train_step(self, inputs: TimeStep, state, action):
q_values = q_values * self.reward_weights
# use the mean of default and bootstrapped target critics
q_value = q_values[:, :1 + self._num_bootstrap_critics].mean(-1)
if self._conservative_actor_training:
if conservative_training:
if self._std_for_overestimate == 'tot':
q_bootstrap = q_values[:, 1:1 + self._num_bootstrap_critics]
q_bootstrap_diff = q_bootstrap - q_values[:, :1]
Expand Down Expand Up @@ -594,12 +620,32 @@ def train_step(self, inputs: TimeStep, state: OaecState,
action_dist, action, action_state = self._predict_action(
inputs.observation, state=state.action)
actor_state, actor_loss_info = self._actor_train_step(
inputs=inputs, state=state.actor, action=action)
inputs=inputs, state=state.actor, action=action,
conservative_training=not self._separate_conservative_actor and
self._conservative_actor_training)

# collect infor for critic_networks training
target_actor_state = ()
target_critic_action = action
target_action_dist = action_dist

# train conservative_actor_network
if self._separate_conservative_actor:
cons_action_dist, cons_actor_state = self._conservative_actor_network(
inputs.observation, state=state.conservative_actor.cons_actor_network)
cons_action = dist_utils.rsample_action_distribution(cons_action_dist)
cons_critic_states, cons_actor_loss_info = self._actor_train_step(
inputs=inputs, state=state.conservative_actor.critics,
action=cons_action, conservative_training=True)
cons_actor_state = OaecConsActorState(
cons_actor_network=cons_actor_state,
critics=cons_critic_states)
target_critic_action = cons_action
target_action_dist = cons_action_dist
else:
cons_actor_state = state.conservative_actor
cons_actor_loss_info = LossInfo()

target_actor_state = ()
if self._use_target_actor:
target_action_dist, target_actor_state = self._target_actor_network(
inputs.observation, state=state.target_actor)
Expand All @@ -611,7 +657,8 @@ def train_step(self, inputs: TimeStep, state: OaecState,

state = OaecState(
action=action_state, actor=actor_state,
target_actor=target_actor_state, critics=critic_states),
target_actor=target_actor_state,
conservative_actor=cons_actor_state, critics=critic_states),
info = OaecInfo(
reward=inputs.reward,
reward_noise=rollout_info.reward_noise,
Expand All @@ -620,7 +667,8 @@ def train_step(self, inputs: TimeStep, state: OaecState,
discount=inputs.discount,
action_distribution=action_dist,
critic=critic_info,
actor_loss=actor_loss_info)
actor_loss=actor_loss_info,
cons_actor_loss=cons_actor_loss_info)
return AlgStep(output=action, state=state, info=info)

def calc_loss(self, info: OaecInfo):
Expand Down Expand Up @@ -710,6 +758,7 @@ def calc_loss(self, info: OaecInfo):
priority = ()

actor_loss = info.actor_loss
cons_actor_loss = info.cons_actor_loss

if self._debug_summaries and alf.summary.should_record_summaries():
with alf.summary.scope(self._name):
Expand All @@ -720,9 +769,12 @@ def calc_loss(self, info: OaecInfo):
safe_mean_hist_summary("opt_ptb_weights", self._opt_ptb_weights)

return LossInfo(
loss=critic_loss + actor_loss.loss,
loss=critic_loss + math_ops.add_ignore_empty(
actor_loss.loss, cons_actor_loss.loss),
priority=priority,
extra=OaecLossInfo(critic=critic_loss, actor=actor_loss.extra))
extra=OaecLossInfo(critic=critic_loss,
actor=actor_loss.extra,
conservative_actor=cons_actor_loss.extra))

def after_update(self, root_inputs, info: OaecInfo):
self._update_target()
Expand Down
3 changes: 2 additions & 1 deletion alf/examples/oaec_dmc.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"desc": [
"DM Control tasks with 4 seeds on each environment"
],
"version": "hopper_oaec_43224_nb_2-ces",
"version": "hopper_oaec_bfce1_nb_2-ces",
"use_gpu": true,
"gpus": [
0, 1
Expand All @@ -17,6 +17,7 @@
"OaecAlgorithm.beta_lb": "[.1]",
"OaecAlgorithm.conservative_actor_training": "[True]",
"OaecAlgorithm.conservative_critic_training": "[False]",
"OaecAlgorithm.separate_conservative_actor": "[True]",
"OaecAlgorithm.output_target_critic": "[True]",
"OaecAlgorithm.num_sampled_target_q_actions": "[0]",
"OaecAlgorithm.target_q_from_sampled_actions": "['mean']",
Expand Down
2 changes: 2 additions & 0 deletions alf/examples/oaec_dmc_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
beta_lb=.5,
conservative_actor_training=False,
conservative_critic_training=True,
separate_conservative_actor=False,
output_target_critic=True,
std_for_overestimate='tot',
opt_ptb_single_data=True,
Expand All @@ -49,6 +50,7 @@
bootstrap_mask_prob=0.8,
use_target_actor=True,
actor_optimizer=AdamTF(lr=3e-4),
conservative_actor_optimizer=AdamTF(lr=3e-4),
critic_optimizer=AdamTF(lr=3e-4),
target_update_tau=0.005)

Expand Down

0 comments on commit b93160a

Please sign in to comment.