Skip to content

Commit

Permalink
add std summaries in rollout
Browse files Browse the repository at this point in the history
  • Loading branch information
runjerry committed Nov 1, 2024
1 parent d76576e commit 0e29251
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 8 deletions.
13 changes: 13 additions & 0 deletions alf/algorithms/oaec_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,10 @@ def __init__(self,
self._reward_noise_scale = reward_noise_scale
self._beta_ub = beta_ub
self._beta_lb = beta_lb
if output_target_critic:
self._output_critic_name = 'q'
else:
self._output_critic_name = 'target_q'
self._output_target_critic = output_target_critic
self._use_target_actor = use_target_actor
self._num_rollout_sampled_actions = num_rollout_sampled_actions
Expand Down Expand Up @@ -381,6 +385,15 @@ def _predict_action(self,
# [n_env, ...]
action = actions[action_idx, batch_idx, ...]

if self._debug_summaries and alf.summary.should_record_summaries():
with alf.summary.scope(self._name):
safe_mean_hist_summary(f"explore/{self._output_critic_name}_tot_std",
q_tot_std)
safe_mean_hist_summary(f"explore/{self._output_critic_name}_opt_std",
q_opt_std)
safe_mean_hist_summary(f"explore/{self._output_critic_name}_epi_std",
q_epi_std)

# else:
# # This uniform sampling during initial collect stage is
# # important since current explore_network is deterministic
Expand Down
14 changes: 6 additions & 8 deletions alf/examples/oaec_dmc.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,22 @@
"desc": [
"DM Control tasks with 4 seeds on each environment"
],
"version": "hopper_oaec_33c9b_nb_2-ces",
"version": "hopper_oaec_d7657_nb_2-ces",
"use_gpu": true,
"gpus": [
0,
1
0
],
"max_worker_num": 4,
"max_worker_num": 2,
"repeats": 1,
"parameters": {
"create_environment.env_name": [
"hopper-hop"
],
"OaecAlgorithm.beta_ub": "[1.]",
"OaecAlgorithm.beta_lb": "[.2, .1]",
"OaecAlgorithm.output_target_critic": "[False]",
"OaecAlgorithm.num_rollout_sampled_actions": "[10]",
"OaecAlgorithm.beta_lb": "[.1]",
"OaecAlgorithm.output_target_critic": "[True]",
"OaecAlgorithm.num_bootstrap_critics": "[2]",
"OaecAlgorithm.bootstrap_mask_prob": "[0.8]",
"OaecAlgorithm.bootstrap_mask_prob": "[0.5]",
"OaecAlgorithm.use_target_actor": "[False]",
"OaecAlgorithm.target_update_tau": "[0.005]",
"TrainerConfig.random_seed": "list(range(2))"
Expand Down

0 comments on commit 0e29251

Please sign in to comment.