Skip to content

SAC-RND implementation #32

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
927 changes: 927 additions & 0 deletions algorithms/sac_rnd_jax.py

Large diffs are not rendered by default.

26 changes: 26 additions & 0 deletions configs/sac_rnd/antmaze/large_diverse_v1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
actor_beta: 0.5
actor_learning_rate: 0.0003
alpha_learning_rate: 0.0003
batch_size: 256
critic_beta: 0.01
critic_layernorm: true
critic_learning_rate: 0.0003
dataset_name: "antmaze-large-diverse-v1"
eval_episodes: 10
eval_every: 50
eval_seed: 42
gamma: 0.999
group: "sac-rnd-antmaze-large-diverse-v1-multiseed-v0"
hidden_dim: 256
name: "SAC-RND"
normalize_reward: true
num_critics: 2
num_epochs: 3000
num_updates_on_epoch: 1000
project: "CORL"
rnd_embedding_dim: 32
rnd_hidden_dim: 256
rnd_learning_rate: 0.0003
rnd_num_epochs: 1
tau: 0.005
train_seed: 10
26 changes: 26 additions & 0 deletions configs/sac_rnd/antmaze/large_play_v1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
actor_beta: 1.0
actor_learning_rate: 0.0003
alpha_learning_rate: 0.0003
batch_size: 256
critic_beta: 0.01
critic_layernorm: true
critic_learning_rate: 0.0003
dataset_name: "antmaze-large-play-v1"
eval_episodes: 10
eval_every: 50
eval_seed: 42
gamma: 0.999
group: "sac-rnd-antmaze-large-play-v1-multiseed-v0"
hidden_dim: 256
name: "SAC-RND"
normalize_reward: true
num_critics: 2
num_epochs: 3000
num_updates_on_epoch: 1000
project: "CORL"
rnd_embedding_dim: 32
rnd_hidden_dim: 256
rnd_learning_rate: 0.0003
rnd_num_epochs: 1
tau: 0.005
train_seed: 10
26 changes: 26 additions & 0 deletions configs/sac_rnd/antmaze/medium_diverse_v1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
actor_beta: 1.0
actor_learning_rate: 0.0003
alpha_learning_rate: 0.0003
batch_size: 256
critic_beta: 0.01
critic_layernorm: true
critic_learning_rate: 0.0003
dataset_name: "antmaze-medium-diverse-v1"
eval_episodes: 10
eval_every: 50
eval_seed: 42
gamma: 0.999
group: "sac-rnd-antmaze-medium-diverse-v1-multiseed-v0"
hidden_dim: 256
name: "SAC-RND"
normalize_reward: true
num_critics: 2
num_epochs: 3000
num_updates_on_epoch: 1000
project: "CORL"
rnd_embedding_dim: 32
rnd_hidden_dim: 256
rnd_learning_rate: 0.0003
rnd_num_epochs: 1
tau: 0.005
train_seed: 10
26 changes: 26 additions & 0 deletions configs/sac_rnd/antmaze/medium_play_v1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
actor_beta: 0.5
actor_learning_rate: 0.0003
alpha_learning_rate: 0.0003
batch_size: 256
critic_beta: 0.001
critic_layernorm: true
critic_learning_rate: 0.0003
dataset_name: "antmaze-medium-play-v1"
eval_episodes: 10
eval_every: 50
eval_seed: 42
gamma: 0.999
group: "sac-rnd-antmaze-medium-play-v1-multiseed-v0"
hidden_dim: 256
name: "SAC-RND"
normalize_reward: true
num_critics: 2
num_epochs: 3000
num_updates_on_epoch: 1000
project: "CORL"
rnd_embedding_dim: 32
rnd_hidden_dim: 256
rnd_learning_rate: 0.0003
rnd_num_epochs: 1
tau: 0.005
train_seed: 10
26 changes: 26 additions & 0 deletions configs/sac_rnd/antmaze/umaze_diverse_v1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
actor_beta: 1.0
actor_learning_rate: 0.0003
alpha_learning_rate: 0.0003
batch_size: 256
critic_beta: 0.1
critic_layernorm: true
critic_learning_rate: 0.0003
dataset_name: "antmaze-umaze-diverse-v1"
eval_episodes: 10
eval_every: 50
eval_seed: 42
gamma: 0.999
group: "sac-rnd-antmaze-umaze-diverse-v1-multiseed-v0"
hidden_dim: 256
name: "SAC-RND"
normalize_reward: true
num_critics: 2
num_epochs: 3000
num_updates_on_epoch: 1000
project: "CORL"
rnd_embedding_dim: 32
rnd_hidden_dim: 256
rnd_learning_rate: 0.0003
rnd_num_epochs: 1
tau: 0.005
train_seed: 10
26 changes: 26 additions & 0 deletions configs/sac_rnd/antmaze/umaze_v1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
actor_beta: 1.0
actor_learning_rate: 0.0003
alpha_learning_rate: 0.0003
batch_size: 256
critic_beta: 0.1
critic_layernorm: true
critic_learning_rate: 0.0003
dataset_name: "antmaze-umaze-v1"
eval_episodes: 10
eval_every: 50
eval_seed: 42
gamma: 0.999
group: "sac-rnd-antmaze-umaze-v1-multiseed-v0"
hidden_dim: 256
name: "SAC-RND"
normalize_reward: true
num_critics: 2
num_epochs: 3000
num_updates_on_epoch: 1000
project: "CORL"
rnd_embedding_dim: 32
rnd_hidden_dim: 256
rnd_learning_rate: 0.0003
rnd_num_epochs: 1
tau: 0.005
train_seed: 10
26 changes: 26 additions & 0 deletions configs/sac_rnd/halfcheetah/expert_v2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
actor_beta: 6.0
actor_learning_rate: 0.001
alpha_learning_rate: 0.001
batch_size: 1024
critic_beta: 6.0
critic_layernorm: true
critic_learning_rate: 0.001
dataset_name: "halfcheetah-expert-v2"
eval_episodes: 10
eval_every: 50
eval_seed: 42
gamma: 0.99
group: "sac-rnd-halfcheetah-expert-v2-multiseed-v0"
hidden_dim: 256
name: "SAC-RND"
normalize_reward: false
num_critics: 2
num_epochs: 3000
num_updates_on_epoch: 1000
project: "CORL"
rnd_embedding_dim: 32
rnd_hidden_dim: 256
rnd_learning_rate: 0.0003
rnd_num_epochs: 1
tau: 0.005
train_seed: 10
26 changes: 26 additions & 0 deletions configs/sac_rnd/halfcheetah/full_replay_v2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
actor_beta: 3.0
actor_learning_rate: 0.001
alpha_learning_rate: 0.001
batch_size: 1024
critic_beta: 3.0
critic_layernorm: true
critic_learning_rate: 0.001
dataset_name: "halfcheetah-full-replay-v2"
eval_episodes: 10
eval_every: 50
eval_seed: 42
gamma: 0.99
group: "sac-rnd-halfcheetah-full-replay-v2-multiseed-v0"
hidden_dim: 256
name: "SAC-RND"
normalize_reward: false
num_critics: 2
num_epochs: 3000
num_updates_on_epoch: 1000
project: "CORL"
rnd_embedding_dim: 32
rnd_hidden_dim: 256
rnd_learning_rate: 0.0003
rnd_num_epochs: 1
tau: 0.005
train_seed: 10
26 changes: 26 additions & 0 deletions configs/sac_rnd/halfcheetah/medium_expert_v2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
actor_beta: 0.1
actor_learning_rate: 0.001
alpha_learning_rate: 0.001
batch_size: 1024
critic_beta: 0.1
critic_layernorm: true
critic_learning_rate: 0.001
dataset_name: "halfcheetah-medium-expert-v2"
eval_episodes: 10
eval_every: 50
eval_seed: 42
gamma: 0.99
group: "sac-rnd-halfcheetah-medium-expert-v2-multiseed-v0"
hidden_dim: 256
name: "SAC-RND"
normalize_reward: false
num_critics: 2
num_epochs: 3000
num_updates_on_epoch: 1000
project: "CORL"
rnd_embedding_dim: 32
rnd_hidden_dim: 256
rnd_learning_rate: 0.0003
rnd_num_epochs: 1
tau: 0.005
train_seed: 10
26 changes: 26 additions & 0 deletions configs/sac_rnd/halfcheetah/medium_replay_v2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
actor_beta: 0.1
actor_learning_rate: 0.001
alpha_learning_rate: 0.001
batch_size: 1024
critic_beta: 0.1
critic_layernorm: true
critic_learning_rate: 0.001
dataset_name: "halfcheetah-medium-replay-v2"
eval_episodes: 10
eval_every: 50
eval_seed: 42
gamma: 0.99
group: "sac-rnd-halfcheetah-medium-replay-v2-multiseed-v0"
hidden_dim: 256
name: "SAC-RND"
normalize_reward: false
num_critics: 2
num_epochs: 3000
num_updates_on_epoch: 1000
project: "CORL"
rnd_embedding_dim: 32
rnd_hidden_dim: 256
rnd_learning_rate: 0.0003
rnd_num_epochs: 1
tau: 0.005
train_seed: 10
26 changes: 26 additions & 0 deletions configs/sac_rnd/halfcheetah/medium_v2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
actor_beta: 0.3
actor_learning_rate: 0.001
alpha_learning_rate: 0.001
batch_size: 1024
critic_beta: 0.3
critic_layernorm: true
critic_learning_rate: 0.001
dataset_name: "halfcheetah-medium-v2"
eval_episodes: 10
eval_every: 50
eval_seed: 42
gamma: 0.99
group: "sac-rnd-halfcheetah-medium-v2-multiseed-v0"
hidden_dim: 256
name: "SAC-RND"
normalize_reward: false
num_critics: 2
num_epochs: 3000
num_updates_on_epoch: 1000
project: "CORL"
rnd_embedding_dim: 32
rnd_hidden_dim: 256
rnd_learning_rate: 0.0003
rnd_num_epochs: 1
tau: 0.005
train_seed: 10
26 changes: 26 additions & 0 deletions configs/sac_rnd/halfcheetah/random_v2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
actor_beta: 0.1
actor_learning_rate: 0.001
alpha_learning_rate: 0.001
batch_size: 1024
critic_beta: 0.1
critic_layernorm: true
critic_learning_rate: 0.001
dataset_name: "halfcheetah-random-v2"
eval_episodes: 10
eval_every: 50
eval_seed: 42
gamma: 0.99
group: "sac-rnd-halfcheetah-random-v2-multiseed-v0"
hidden_dim: 256
name: "SAC-RND"
normalize_reward: false
num_critics: 2
num_epochs: 3000
num_updates_on_epoch: 1000
project: "CORL"
rnd_embedding_dim: 32
rnd_hidden_dim: 256
rnd_learning_rate: 0.0003
rnd_num_epochs: 1
tau: 0.005
train_seed: 10
26 changes: 26 additions & 0 deletions configs/sac_rnd/hopper/expert_v2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
actor_beta: 20.0
actor_learning_rate: 0.001
alpha_learning_rate: 0.001
batch_size: 1024
critic_beta: 20.0
critic_layernorm: true
critic_learning_rate: 0.001
dataset_name: "hopper-expert-v2"
eval_episodes: 10
eval_every: 50
eval_seed: 42
gamma: 0.99
group: "sac-rnd-hopper-expert-v2-multiseed-v0"
hidden_dim: 256
name: "SAC-RND"
normalize_reward: false
num_critics: 2
num_epochs: 3000
num_updates_on_epoch: 1000
project: "CORL"
rnd_embedding_dim: 32
rnd_hidden_dim: 256
rnd_learning_rate: 0.0003
rnd_num_epochs: 1
tau: 0.005
train_seed: 10
26 changes: 26 additions & 0 deletions configs/sac_rnd/hopper/full_replay_v2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
actor_beta: 3.0
actor_learning_rate: 0.001
alpha_learning_rate: 0.001
batch_size: 1024
critic_beta: 3.0
critic_layernorm: true
critic_learning_rate: 0.001
dataset_name: "hopper-full-replay-v2"
eval_episodes: 10
eval_every: 50
eval_seed: 42
gamma: 0.99
group: "sac-rnd-hopper-full-replay-v2-multiseed-v0"
hidden_dim: 256
name: "SAC-RND"
normalize_reward: false
num_critics: 2
num_epochs: 3000
num_updates_on_epoch: 1000
project: "CORL"
rnd_embedding_dim: 32
rnd_hidden_dim: 256
rnd_learning_rate: 0.0003
rnd_num_epochs: 1
tau: 0.005
train_seed: 10
Loading