From 1308118ff02ad4c0dc3715ebf0a3a1351840d305 Mon Sep 17 00:00:00 2001 From: akbir Date: Wed, 15 Mar 2023 12:59:23 +0000 Subject: [PATCH 1/3] make opponents consistent across elites --- pax/runners/runner_evo.py | 22 ++++++++++++---------- pax/runners/runner_marl.py | 8 ++++---- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/pax/runners/runner_evo.py b/pax/runners/runner_evo.py index 18813d60..f9948f1e 100644 --- a/pax/runners/runner_evo.py +++ b/pax/runners/runner_evo.py @@ -308,13 +308,13 @@ def _rollout( _env_params: Any, ): # env reset - rngs = jnp.concatenate( + env_rngs = jnp.concatenate( [jax.random.split(_rng_run, args.num_envs)] * args.num_opps * args.popsize ).reshape((args.popsize, args.num_opps, args.num_envs, -1)) - obs, env_state = env.reset(rngs, _env_params) + obs, env_state = env.reset(env_rngs, _env_params) rewards = [ jnp.zeros((args.popsize, args.num_opps, args.num_envs)), jnp.zeros((args.popsize, args.num_opps, args.num_envs)), @@ -329,10 +329,11 @@ def _rollout( else: # meta-experiments - init 2nd agent per trial + a2_rng = jnp.concatenate( + [jax.random.split(_rng_run, args.num_opps)] * args.popsize + ).reshape(args.popsize, args.num_opps, -1) a2_state, a2_mem = agent2.batch_init( - jax.random.split( - _rng_run, args.popsize * args.num_opps - ).reshape(args.popsize, args.num_opps, -1), + a2_rng, agent2._mem.hidden, ) @@ -340,7 +341,7 @@ def _rollout( vals, stack = jax.lax.scan( _outer_rollout, ( - rngs, + env_rngs, *obs, *rewards, _a1_state, @@ -355,7 +356,7 @@ def _rollout( ) ( - rngs, + env_rngs, obs1, obs2, r1, @@ -480,18 +481,19 @@ def run_loop( agent1._mem.hidden, (popsize, num_opps, 1, 1), ) + a1_rng = jax.random.split(rng, popsize) agent1._state, agent1._mem = agent1.batch_init( - jax.random.split(agent1._state.random_key, popsize), + a1_rng, init_hidden, ) a1_state, a1_mem = agent1._state, agent1._mem for gen in range(num_gens): - rng, rng_run, rng_gen, rng_key = jax.random.split(rng, 4) + rng, rng_run, rng_evo, rng_key = jax.random.split(rng, 4) # Ask - x, evo_state = strategy.ask(rng_gen, evo_state, es_params) + x, evo_state = strategy.ask(rng_evo, evo_state, es_params) params = param_reshaper.reshape(x) if self.args.num_devices == 1: params = jax.tree_util.tree_map( diff --git a/pax/runners/runner_marl.py b/pax/runners/runner_marl.py index 7ba6bbe3..2eed3e1d 100644 --- a/pax/runners/runner_marl.py +++ b/pax/runners/runner_marl.py @@ -152,8 +152,9 @@ def _reshape_opp_dim(x): if args.agent2 != "NaiveEx": # NaiveEx requires env first step to init. init_hidden = jnp.tile(agent2._mem.hidden, (args.num_opps, 1, 1)) + a2_rng = jax.random.split(agent2._state.random_key, args.num_opps) agent2._state, agent2._mem = agent2.batch_init( - jax.random.split(agent2._state.random_key, args.num_opps), + a2_rng, init_hidden, ) @@ -321,9 +322,8 @@ def _rollout( elif self.args.env_type in ["meta"]: # meta-experiments - init 2nd agent per trial - _a2_state, _a2_mem = agent2.batch_init( - jax.random.split(_rng_run, self.num_opps), _a2_mem.hidden - ) + a2_rng = jax.random.split(_rng_run, self.num_opps) + _a2_state, _a2_mem = agent2.batch_init(a2_rng, _a2_mem.hidden) # run trials vals, stack = jax.lax.scan( _outer_rollout, From c7b190c642468382341ed77ca1d001f14e24dc4c Mon Sep 17 00:00:00 2001 From: akbir Date: Wed, 15 Mar 2023 13:08:54 +0000 Subject: [PATCH 2/3] fixed coplayer rng in init --- pax/runners/runner_evo.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pax/runners/runner_evo.py b/pax/runners/runner_evo.py index f9948f1e..9ce590b0 100644 --- a/pax/runners/runner_evo.py +++ b/pax/runners/runner_evo.py @@ -163,12 +163,13 @@ def __init__( # NaiveEx requires env first step to init. init_hidden = jnp.tile(agent2._mem.hidden, (args.num_opps, 1, 1)) - key = jax.random.split( - agent2._state.random_key, args.popsize * args.num_opps + a2_rng = jnp.concatenate( + [jax.random.split(agent2._state.random_key, args.num_opps)] + * args.popsize ).reshape(args.popsize, args.num_opps, -1) agent2._state, agent2._mem = agent2.batch_init( - key, + a2_rng, init_hidden, ) From b17225ef45bc09812044afcab6dd297c1e6c4348 Mon Sep 17 00:00:00 2001 From: akbir Date: Wed, 15 Mar 2023 18:12:11 +0000 Subject: [PATCH 3/3] run eval_x_play --- pax/agents/mfos_ppo/ppo_gru.py | 10 +- pax/conf/experiment/ipd/shaper_v_tabular.yaml | 2 +- pax/conf/experiment/ipditm/eval_gs_gs.yaml | 87 ++++++++++++ pax/conf/experiment/ipditm/eval_gs_mfos.yaml | 85 ++++++++++++ .../experiment/ipditm/eval_gs_shaper.yaml | 87 ++++++++++++ pax/conf/experiment/ipditm/eval_mfos_gs.yaml | 85 ++++++++++++ .../experiment/ipditm/eval_mfos_mfos.yaml | 86 ++++++++++++ .../experiment/ipditm/eval_mfos_shaper.yaml | 86 ++++++++++++ .../experiment/ipditm/eval_shaper_gs.yaml | 87 ++++++++++++ .../experiment/ipditm/eval_shaper_mfos.yaml | 87 ++++++++++++ .../experiment/ipditm/eval_shaper_shaper.yaml | 87 ++++++++++++ pax/experiment.py | 20 ++- pax/runners/runner_ipditm_eval.py | 125 ++++++------------ 13 files changed, 835 insertions(+), 99 deletions(-) create mode 100644 pax/conf/experiment/ipditm/eval_gs_gs.yaml create mode 100644 pax/conf/experiment/ipditm/eval_gs_mfos.yaml create mode 100644 pax/conf/experiment/ipditm/eval_gs_shaper.yaml create mode 100644 pax/conf/experiment/ipditm/eval_mfos_gs.yaml create mode 100644 pax/conf/experiment/ipditm/eval_mfos_mfos.yaml create mode 100644 pax/conf/experiment/ipditm/eval_mfos_shaper.yaml create mode 100644 pax/conf/experiment/ipditm/eval_shaper_gs.yaml create mode 100644 pax/conf/experiment/ipditm/eval_shaper_mfos.yaml create mode 100644 pax/conf/experiment/ipditm/eval_shaper_shaper.yaml diff --git a/pax/agents/mfos_ppo/ppo_gru.py b/pax/agents/mfos_ppo/ppo_gru.py index 2191043e..2091a6a2 100644 --- a/pax/agents/mfos_ppo/ppo_gru.py +++ b/pax/agents/mfos_ppo/ppo_gru.py @@ -61,9 +61,7 @@ def __init__( random_key: jnp.ndarray, gru_dim: int, obs_spec: Tuple, - batch_size: int = 2000, num_envs: int = 4, - num_steps: int = 500, num_minibatches: int = 16, num_epochs: int = 4, clip_value: bool = True, @@ -479,15 +477,12 @@ def prepare_batch( # Other useful hyperparameters self._num_envs = num_envs # number of environments - self._num_steps = num_steps # number of steps per environment - self._batch_size = int(num_envs * num_steps) # number in one batch self._num_minibatches = num_minibatches # number of minibatches self._num_epochs = num_epochs # number of epochs to use sample self._gru_dim = gru_dim def reset_memory(self, memory, eval=False) -> TrainingState: num_envs = 1 if eval else self._num_envs - memory = memory._replace( extras={ "values": jnp.zeros(num_envs), @@ -573,8 +568,7 @@ def make_mfos_agent( # Optimizer transition_steps = ( - num_iterations, - *agent_args.num_epochs * agent_args.num_minibatches, + num_iterations * agent_args.num_epochs * agent_args.num_minibatches, ) if agent_args.lr_scheduling: @@ -607,9 +601,7 @@ def make_mfos_agent( random_key=random_key, gru_dim=gru_dim, obs_spec=obs_spec, - batch_size=None, num_envs=args.num_envs, - num_steps=args.num_steps, num_minibatches=agent_args.num_minibatches, num_epochs=agent_args.num_epochs, clip_value=agent_args.clip_value, diff --git a/pax/conf/experiment/ipd/shaper_v_tabular.yaml b/pax/conf/experiment/ipd/shaper_v_tabular.yaml index 10ba189f..364a14d2 100644 --- a/pax/conf/experiment/ipd/shaper_v_tabular.yaml +++ b/pax/conf/experiment/ipd/shaper_v_tabular.yaml @@ -17,7 +17,7 @@ runner: evo top_k: 5 popsize: 1000 num_envs: 2 -num_opps: 1 +num_opps: 10 num_outer_steps: 100 num_inner_steps: 100 num_iters: 5000 diff --git a/pax/conf/experiment/ipditm/eval_gs_gs.yaml b/pax/conf/experiment/ipditm/eval_gs_gs.yaml new file mode 100644 index 00000000..c9927d5e --- /dev/null +++ b/pax/conf/experiment/ipditm/eval_gs_gs.yaml @@ -0,0 +1,87 @@ +# @package _global_ + +# Agents +agent1: 'PPO' +agent2: 'PPO' + +# Environment +env_id: InTheMatrix +env_type: meta +env_discount: 0.96 +payoff: [[[3, 0], [5, 1]], [[3, 5], [0, 1]]] +runner: ipditm_eval +freeze: 5 +fixed_coins: False + +# Training hyperparameters + +# env_batch_size = num_envs * num_opponents +num_envs: 50 +num_opps: 1 +num_outer_steps: 100 +num_inner_steps: 152 +save_interval: 100 +num_iters: 1 +save_gif: False +# Evaluation +# Shaper +run_path1: ucl-dark/ipditm/2wjr55mr +model_path1: exp/shaping-PPO-vs-PPO_memory/run-seed-0/2023-01-05_09.59.53.063797/generation_1000 +# GS +run_path2: ucl-dark/ipditm/2wjr55mr +model_path2: exp/shaping-PPO-vs-PPO_memory/run-seed-0/2023-01-05_09.59.53.063797/generation_1000 + +# PPO agent parameters +ppo1: + num_minibatches: 8 + num_epochs: 2 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.1 + entropy_coeff_horizon: 0.6e8 + entropy_coeff_end: 0.005 + lr_scheduling: False + learning_rate: 0.005 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: True + output_channels: 16 + kernel_shape: [3, 3] + separate: False # only works with CNN + hidden_size: 8 + +ppo2: + num_minibatches: 8 + num_epochs: 2 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.1 + entropy_coeff_horizon: 0.6e8 + entropy_coeff_end: 0.005 + lr_scheduling: False + learning_rate: 0.005 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: True + output_channels: 16 + kernel_shape: [3, 3] + separate: False # only works with CNN + hidden_size: 8 + +# Logging setup +wandb: + entity: "ucl-dark" + project: ipditm + group: 'xplay-eval-${agent1}-vs-${agent2}' + name: run-seed-${seed} + log: True \ No newline at end of file diff --git a/pax/conf/experiment/ipditm/eval_gs_mfos.yaml b/pax/conf/experiment/ipditm/eval_gs_mfos.yaml new file mode 100644 index 00000000..2d789430 --- /dev/null +++ b/pax/conf/experiment/ipditm/eval_gs_mfos.yaml @@ -0,0 +1,85 @@ +# @package _global_ + +# Agents +agent1: 'PPO' +agent2: 'MFOS' + +# Environment +env_id: InTheMatrix +env_type: meta +env_discount: 0.96 +payoff: [[[3, 0], [5, 1]], [[3, 5], [0, 1]]] +runner: ipditm_eval +freeze: 5 +fixed_coins: False + +# Training hyperparameters + +# env_batch_size = num_envs * num_opponents +num_envs: 50 +num_opps: 1 +num_outer_steps: 100 +num_inner_steps: 152 +save_interval: 100 +num_iters: 1 +save_gif: False +# Evaluation +# GS +run_path1: ucl-dark/ipditm/2wjr55mr +model_path1: exp/shaping-PPO-vs-PPO_memory/run-seed-0/2023-01-05_09.59.53.063797/generation_1000 +run_path2: ucl-dark/ipditm/226zwu1v +model_path2: exp/shaping-MFOS-vs-PPO_memory/run-seed-0/2023-01-09_10.27.04.619601/generation_1000 +# PPO agent parameters +ppo1: + num_minibatches: 8 + num_epochs: 2 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.1 + entropy_coeff_horizon: 0.6e8 + entropy_coeff_end: 0.005 + lr_scheduling: False + learning_rate: 0.005 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: True + output_channels: 16 + kernel_shape: [3, 3] + separate: False # only works with CNN + hidden_size: 8 + +ppo2: + num_minibatches: 8 + num_epochs: 2 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.1 + entropy_coeff_horizon: 0.6e8 + entropy_coeff_end: 0.05 + lr_scheduling: False + learning_rate: 0.005 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: True + output_channels: 16 + kernel_shape: [3, 3] + separate: False # only works with CNN + hidden_size: 16 + +# Logging setup +wandb: + entity: "ucl-dark" + project: ipditm + group: 'xplay-eval-${agent1}-vs-${agent2}' + name: run-seed-${seed} + log: True \ No newline at end of file diff --git a/pax/conf/experiment/ipditm/eval_gs_shaper.yaml b/pax/conf/experiment/ipditm/eval_gs_shaper.yaml new file mode 100644 index 00000000..e4c695a2 --- /dev/null +++ b/pax/conf/experiment/ipditm/eval_gs_shaper.yaml @@ -0,0 +1,87 @@ +# @package _global_ + +# Agents +agent2: 'PPO_memory' +agent1: 'PPO' + +# Environment +env_id: InTheMatrix +env_type: meta +env_discount: 0.96 +payoff: [[[3, 0], [5, 1]], [[3, 5], [0, 1]]] +runner: ipditm_eval +freeze: 5 +fixed_coins: False + +# Training hyperparameters + +# env_batch_size = num_envs * num_opponents +num_envs: 50 +num_opps: 1 +num_outer_steps: 100 +num_inner_steps: 152 +save_interval: 100 +num_iters: 1 +save_gif: False +# Evaluation +# Shaper +run_path2: ucl-dark/ipditm/1vpl5161 +model_path2: exp/shaping-PPO_memory-vs-PPO_memory/run-seed-0/2023-01-05_14.13.25.169599/generation_1000 +# GS +run_path1: ucl-dark/ipditm/2wjr55mr +model_path1: exp/shaping-PPO-vs-PPO_memory/run-seed-0/2023-01-05_09.59.53.063797/generation_1000 + +# PPO agent parameters +ppo2: + num_minibatches: 8 + num_epochs: 2 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.1 + entropy_coeff_horizon: 0.6e8 + entropy_coeff_end: 0.005 + lr_scheduling: False + learning_rate: 0.005 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: True + output_channels: 16 + kernel_shape: [3, 3] + separate: False # only works with CNN + hidden_size: 32 + +ppo1: + num_minibatches: 8 + num_epochs: 2 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.1 + entropy_coeff_horizon: 0.6e8 + entropy_coeff_end: 0.005 + lr_scheduling: False + learning_rate: 0.005 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: True + output_channels: 16 + kernel_shape: [3, 3] + separate: False # only works with CNN + hidden_size: 8 + +# Logging setup +wandb: + entity: "ucl-dark" + project: ipditm + group: 'xplay-eval-${agent1}-vs-${agent2}' + name: run-seed-${seed} + log: True \ No newline at end of file diff --git a/pax/conf/experiment/ipditm/eval_mfos_gs.yaml b/pax/conf/experiment/ipditm/eval_mfos_gs.yaml new file mode 100644 index 00000000..f319ff35 --- /dev/null +++ b/pax/conf/experiment/ipditm/eval_mfos_gs.yaml @@ -0,0 +1,85 @@ +# @package _global_ + +# Agents +agent1: 'MFOS' +agent2: 'PPO' + +# Environment +env_id: InTheMatrix +env_type: meta +env_discount: 0.96 +payoff: [[[3, 0], [5, 1]], [[3, 5], [0, 1]]] +runner: ipditm_eval +freeze: 5 +fixed_coins: False + +# Training hyperparameters + +# env_batch_size = num_envs * num_opponents +num_envs: 50 +num_opps: 1 +num_outer_steps: 100 +num_inner_steps: 152 +save_interval: 100 +num_iters: 1 +save_gif: False +# Evaluation +run_path1: ucl-dark/ipditm/226zwu1v +model_path1: exp/shaping-MFOS-vs-PPO_memory/run-seed-0/2023-01-09_10.27.04.619601/generation_1000 +run_path2: ucl-dark/ipditm/2wjr55mr +model_path2: exp/shaping-PPO-vs-PPO_memory/run-seed-0/2023-01-05_09.59.53.063797/generation_1000 +# PPO agent parameters + +ppo1: + num_minibatches: 8 + num_epochs: 2 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.1 + entropy_coeff_horizon: 0.6e8 + entropy_coeff_end: 0.05 + lr_scheduling: False + learning_rate: 0.005 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: True + output_channels: 16 + kernel_shape: [3, 3] + separate: False # only works with CNN + hidden_size: 16 + +ppo2: + num_minibatches: 8 + num_epochs: 2 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.1 + entropy_coeff_horizon: 0.6e8 + entropy_coeff_end: 0.005 + lr_scheduling: False + learning_rate: 0.005 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: True + output_channels: 16 + kernel_shape: [3, 3] + separate: False # only works with CNN + hidden_size: 8 + +# Logging setup +wandb: + entity: "ucl-dark" + project: ipditm + group: 'xplay-eval-${agent1}-vs-${agent2}' + name: run-seed-${seed} + log: True \ No newline at end of file diff --git a/pax/conf/experiment/ipditm/eval_mfos_mfos.yaml b/pax/conf/experiment/ipditm/eval_mfos_mfos.yaml new file mode 100644 index 00000000..719a5ed2 --- /dev/null +++ b/pax/conf/experiment/ipditm/eval_mfos_mfos.yaml @@ -0,0 +1,86 @@ +# @package _global_ + +# Agents +agent1: 'MFOS' +agent2: 'MFOS' + + +# Environment +env_id: InTheMatrix +env_type: meta +env_discount: 0.96 +payoff: [[[3, 0], [5, 1]], [[3, 5], [0, 1]]] +runner: ipditm_eval +freeze: 5 +fixed_coins: False + +# Training hyperparameters + +# env_batch_size = num_envs * num_opponents +num_envs: 50 +num_opps: 1 +num_outer_steps: 100 +num_inner_steps: 152 +save_interval: 100 +num_iters: 1 +save_gif: False +# Evaluation +run_path1: ucl-dark/ipditm/226zwu1v +model_path1: exp/shaping-MFOS-vs-PPO_memory/run-seed-0/2023-01-09_10.27.04.619601/generation_1000 +run_path2: ucl-dark/ipditm/226zwu1v +model_path2: exp/shaping-MFOS-vs-PPO_memory/run-seed-0/2023-01-09_10.27.04.619601/generation_1000 +# PPO agent parameters +ppo1: + num_minibatches: 8 + num_epochs: 2 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.1 + entropy_coeff_horizon: 0.6e8 + entropy_coeff_end: 0.05 + lr_scheduling: False + learning_rate: 0.005 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: True + output_channels: 16 + kernel_shape: [3, 3] + separate: False # only works with CNN + hidden_size: 16 + + +ppo2: + num_minibatches: 8 + num_epochs: 2 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.1 + entropy_coeff_horizon: 0.6e8 + entropy_coeff_end: 0.05 + lr_scheduling: False + learning_rate: 0.005 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: True + output_channels: 16 + kernel_shape: [3, 3] + separate: False # only works with CNN + hidden_size: 16 + +# Logging setup +wandb: + entity: "ucl-dark" + project: ipditm + group: 'xplay-eval-${agent1}-vs-${agent2}' + name: run-seed-${seed} + log: True \ No newline at end of file diff --git a/pax/conf/experiment/ipditm/eval_mfos_shaper.yaml b/pax/conf/experiment/ipditm/eval_mfos_shaper.yaml new file mode 100644 index 00000000..52ceb1c3 --- /dev/null +++ b/pax/conf/experiment/ipditm/eval_mfos_shaper.yaml @@ -0,0 +1,86 @@ +# @package _global_ + +# Agents +agent1: 'MFOS' +agent2: 'PPO_memory' + + +# Environment +env_id: InTheMatrix +env_type: meta +env_discount: 0.96 +payoff: [[[3, 0], [5, 1]], [[3, 5], [0, 1]]] +runner: ipditm_eval +freeze: 5 +fixed_coins: False + +# Training hyperparameters + +# env_batch_size = num_envs * num_opponents +num_envs: 50 +num_opps: 1 +num_outer_steps: 100 +num_inner_steps: 152 +save_interval: 100 +num_iters: 1 +save_gif: False +# Evaluation +run_path1: ucl-dark/ipditm/226zwu1v +model_path1: exp/shaping-MFOS-vs-PPO_memory/run-seed-0/2023-01-09_10.27.04.619601/generation_1000 +run_path2: ucl-dark/ipditm/1vpl5161 +model_path2: exp/shaping-PPO_memory-vs-PPO_memory/run-seed-0/2023-01-05_14.13.25.169599/generation_1000 +# PPO agent parameters +ppo1: + num_minibatches: 8 + num_epochs: 2 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.1 + entropy_coeff_horizon: 0.6e8 + entropy_coeff_end: 0.05 + lr_scheduling: False + learning_rate: 0.005 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: True + output_channels: 16 + kernel_shape: [3, 3] + separate: False # only works with CNN + hidden_size: 16 + + +ppo2: + num_minibatches: 8 + num_epochs: 2 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.1 + entropy_coeff_horizon: 0.6e8 + entropy_coeff_end: 0.005 + lr_scheduling: False + learning_rate: 0.005 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: True + output_channels: 16 + kernel_shape: [3, 3] + separate: False # only works with CNN + hidden_size: 32 + +# Logging setup +wandb: + entity: "ucl-dark" + project: ipditm + group: 'xplay-eval-${agent1}-vs-${agent2}' + name: run-seed-${seed} + log: True \ No newline at end of file diff --git a/pax/conf/experiment/ipditm/eval_shaper_gs.yaml b/pax/conf/experiment/ipditm/eval_shaper_gs.yaml new file mode 100644 index 00000000..3e5c4a1f --- /dev/null +++ b/pax/conf/experiment/ipditm/eval_shaper_gs.yaml @@ -0,0 +1,87 @@ +# @package _global_ + +# Agents +agent1: 'PPO_memory' +agent2: 'PPO' + +# Environment +env_id: InTheMatrix +env_type: meta +env_discount: 0.96 +payoff: [[[3, 0], [5, 1]], [[3, 5], [0, 1]]] +runner: ipditm_eval +freeze: 5 +fixed_coins: False + +# Training hyperparameters + +# env_batch_size = num_envs * num_opponents +num_envs: 50 +num_opps: 1 +num_outer_steps: 100 +num_inner_steps: 152 +save_interval: 100 +num_iters: 1 +save_gif: False +# Evaluation +# Shaper +run_path1: ucl-dark/ipditm/1vpl5161 +model_path1: exp/shaping-PPO_memory-vs-PPO_memory/run-seed-0/2023-01-05_14.13.25.169599/generation_1000 +# GS +run_path2: ucl-dark/ipditm/2wjr55mr +model_path2: exp/shaping-PPO-vs-PPO_memory/run-seed-0/2023-01-05_09.59.53.063797/generation_1000 + +# PPO agent parameters +ppo1: + num_minibatches: 8 + num_epochs: 2 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.1 + entropy_coeff_horizon: 0.6e8 + entropy_coeff_end: 0.005 + lr_scheduling: False + learning_rate: 0.005 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: True + output_channels: 16 + kernel_shape: [3, 3] + separate: False # only works with CNN + hidden_size: 32 + +ppo2: + num_minibatches: 8 + num_epochs: 2 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.1 + entropy_coeff_horizon: 0.6e8 + entropy_coeff_end: 0.005 + lr_scheduling: False + learning_rate: 0.005 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: True + output_channels: 16 + kernel_shape: [3, 3] + separate: False # only works with CNN + hidden_size: 8 + +# Logging setup +wandb: + entity: "ucl-dark" + project: ipditm + group: 'xplay-eval-${agent1}-vs-${agent2}' + name: run-seed-${seed} + log: True \ No newline at end of file diff --git a/pax/conf/experiment/ipditm/eval_shaper_mfos.yaml b/pax/conf/experiment/ipditm/eval_shaper_mfos.yaml new file mode 100644 index 00000000..35e129fa --- /dev/null +++ b/pax/conf/experiment/ipditm/eval_shaper_mfos.yaml @@ -0,0 +1,87 @@ +# @package _global_ + +# Agents +agent1: 'PPO_memory' +agent2: 'MFOS' + +# Environment +env_id: InTheMatrix +env_type: meta +env_discount: 0.96 +payoff: [[[3, 0], [5, 1]], [[3, 5], [0, 1]]] +runner: ipditm_eval +freeze: 5 +fixed_coins: False + +# Training hyperparameters + +# env_batch_size = num_envs * num_opponents +num_envs: 50 +num_opps: 1 +num_outer_steps: 100 +num_inner_steps: 152 +save_interval: 100 +num_iters: 1 +save_gif: False +# Evaluation +# Shaper +run_path1: ucl-dark/ipditm/1vpl5161 +model_path1: exp/shaping-PPO_memory-vs-PPO_memory/run-seed-0/2023-01-05_14.13.25.169599/generation_1000 +# MFOS +run_path2: ucl-dark/ipditm/226zwu1v +model_path2: exp/shaping-MFOS-vs-PPO_memory/run-seed-0/2023-01-09_10.27.04.619601/generation_1000 + +# PPO agent parameters +ppo1: + num_minibatches: 8 + num_epochs: 2 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.1 + entropy_coeff_horizon: 0.6e8 + entropy_coeff_end: 0.005 + lr_scheduling: False + learning_rate: 0.005 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: True + output_channels: 16 + kernel_shape: [3, 3] + separate: False # only works with CNN + hidden_size: 32 + +ppo2: + num_minibatches: 8 + num_epochs: 2 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.1 + entropy_coeff_horizon: 0.6e8 + entropy_coeff_end: 0.05 + lr_scheduling: False + learning_rate: 0.005 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: True + output_channels: 16 + kernel_shape: [3, 3] + separate: False # only works with CNN + hidden_size: 16 + +# Logging setup +wandb: + entity: "ucl-dark" + project: ipditm + group: 'xplay-eval-${agent1}-vs-${agent2}' + name: run-seed-${seed} + log: True \ No newline at end of file diff --git a/pax/conf/experiment/ipditm/eval_shaper_shaper.yaml b/pax/conf/experiment/ipditm/eval_shaper_shaper.yaml new file mode 100644 index 00000000..d0d95888 --- /dev/null +++ b/pax/conf/experiment/ipditm/eval_shaper_shaper.yaml @@ -0,0 +1,87 @@ +# @package _global_ + +# Agents +agent1: 'PPO_memory' +agent2: 'PPO_memory' + +# Environment +env_id: InTheMatrix +env_type: meta +env_discount: 0.96 +payoff: [[[3, 0], [5, 1]], [[3, 5], [0, 1]]] +runner: ipditm_eval +freeze: 5 +fixed_coins: False + +# Training hyperparameters + +# env_batch_size = num_envs * num_opponents +num_envs: 50 +num_opps: 1 +num_outer_steps: 100 +num_inner_steps: 152 +save_interval: 100 +num_iters: 1 +save_gif: False +# Evaluation +# Shaper +run_path1: ucl-dark/ipditm/1vpl5161 +model_path1: exp/shaping-PPO_memory-vs-PPO_memory/run-seed-0/2023-01-05_14.13.25.169599/generation_1000 +# GS +run_path2: ucl-dark/ipditm/1vpl5161 +model_path2: exp/shaping-PPO_memory-vs-PPO_memory/run-seed-0/2023-01-05_14.13.25.169599/generation_1000 + +# PPO agent parameters +ppo1: + num_minibatches: 8 + num_epochs: 2 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.1 + entropy_coeff_horizon: 0.6e8 + entropy_coeff_end: 0.005 + lr_scheduling: False + learning_rate: 0.005 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: True + output_channels: 16 + kernel_shape: [3, 3] + separate: False # only works with CNN + hidden_size: 32 + +ppo2: + num_minibatches: 8 + num_epochs: 2 + gamma: 0.96 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: False + entropy_coeff_start: 0.1 + entropy_coeff_horizon: 0.6e8 + entropy_coeff_end: 0.005 + lr_scheduling: False + learning_rate: 0.005 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: True + output_channels: 16 + kernel_shape: [3, 3] + separate: False # only works with CNN + hidden_size: 32 + +# Logging setup +wandb: + entity: "ucl-dark" + project: ipditm + group: 'xplay-eval-${agent1}-vs-${agent2}' + name: run-seed-${seed} + log: True \ No newline at end of file diff --git a/pax/experiment.py b/pax/experiment.py index 46f3e327..39b7d739 100644 --- a/pax/experiment.py +++ b/pax/experiment.py @@ -285,9 +285,10 @@ def agent_setup(args, env, env_params, logger): def get_PPO_memory_agent(seed, player_id): player_args = args.ppo1 if player_id == 1 else args.ppo2 - num_iterations = args.num_iters if player_id == 1 and args.env_type == "meta": num_iterations = args.num_outer_steps + else: + num_iterations = args.num_iters return make_gru_agent( args, player_args, @@ -332,10 +333,11 @@ def get_PPO_tabular_agent(seed, player_id): return ppo_agent def get_mfos_agent(seed, player_id): - agent_args = args.ppo1 - num_iterations = args.num_iters + agent_args = args.ppo1 if player_id == 1 else args.ppo2 if player_id == 1 and args.env_type == "meta": num_iterations = args.num_outer_steps + else: + num_iterations = args.num_iters ppo_agent = make_mfos_agent( args, agent_args, @@ -453,7 +455,11 @@ def watcher_setup(args, logger): def ppo_memory_log(agent): losses = losses_ppo(agent) - if args.env_id not in ["coin_game", "InTheMatrix", "iterated_matrix_game"]: + if args.env_id not in [ + "coin_game", + "InTheMatrix", + "iterated_matrix_game", + ]: policy = policy_logger_ppo_with_memory(agent) losses.update(policy) if args.wandb.log: @@ -465,7 +471,11 @@ def ppo_memory_log(agent): def ppo_log(agent): losses = losses_ppo(agent) - if args.env_id not in ["coin_game", "InTheMatrix", "iterated_matrix_game"]: + if args.env_id not in [ + "coin_game", + "InTheMatrix", + "iterated_matrix_game", + ]: policy = policy_logger_ppo(agent) value = value_logger_ppo(agent) losses.update(value) diff --git a/pax/runners/runner_ipditm_eval.py b/pax/runners/runner_ipditm_eval.py index 7eedaac7..a8171ad3 100644 --- a/pax/runners/runner_ipditm_eval.py +++ b/pax/runners/runner_ipditm_eval.py @@ -76,32 +76,17 @@ def __init__(self, agents, env, save_dir, args): self.random_key = jax.random.PRNGKey(args.seed) self.save_dir = save_dir - def _reshape_opp_dim(x): - # x: [num_opps, num_envs ...] - # x: [batch_size, ...] - batch_size = args.num_envs * args.num_opps - return jax.tree_util.tree_map( - lambda x: x.reshape((batch_size,) + x.shape[2:]), x - ) - - self.reduce_opp_dim = jax.jit(_reshape_opp_dim) self.ipd_stats = jax.jit(ipd_visitation) self.cg_stats = jax.jit(cg_visitation) + # VMAP for num envs: we vmap over the rng but not params env.reset = jax.vmap(env.reset, (0, None), 0) env.step = jax.vmap( env.step, (0, 0, 0, None), 0 # rng, state, actions, params ) - self.ipditm_stats = jax.jit(ipditm_stats) - # VMAP for num opps: we vmap over the rng but not params - env.reset = jax.jit(jax.vmap(env.reset, (0, None), 0)) - env.step = jax.jit( - jax.vmap( - env.step, (0, 0, 0, None), 0 # rng, state, actions, params - ) - ) - self.split = jax.vmap(jax.vmap(jax.random.split, (0, None)), (0, None)) + self.ipditm_stats = jax.jit(ipditm_stats) + self.split = jax.vmap(jax.random.split, (0, None)) self.num_outer_steps = self.args.num_outer_steps agent1, agent2 = agents @@ -112,18 +97,10 @@ def _reshape_opp_dim(x): agent1.batch_init = jax.jit(jax.vmap(agent1.make_initial_state)) else: # batch MemoryState not TrainingState - agent1.batch_init = jax.vmap( - agent1.make_initial_state, - (None, 0), - (None, 0), - ) - agent1.batch_reset = jax.jit( - jax.vmap(agent1.reset_memory, (0, None), 0), static_argnums=1 - ) + agent1.batch_init = jax.jit(agent1.make_initial_state) - agent1.batch_policy = jax.jit( - jax.vmap(agent1._policy, (None, 0, 0), (0, None, 0)) - ) + agent1.batch_reset = jax.jit(agent1.reset_memory, static_argnums=1) + agent1.batch_policy = jax.jit(agent1._policy) # batch all for Agent2 if args.agent2 == "NaiveEx": @@ -131,26 +108,25 @@ def _reshape_opp_dim(x): agent2.batch_init = jax.jit(jax.vmap(agent2.make_initial_state)) else: agent2.batch_init = jax.vmap( - agent2.make_initial_state, (0, None), 0 + agent2.make_initial_state, + (None, 0), + (None, 0), ) - agent2.batch_policy = jax.jit(jax.vmap(agent2._policy)) - agent2.batch_reset = jax.jit( - jax.vmap(agent2.reset_memory, (0, None), 0), static_argnums=1 - ) - agent2.batch_update = jax.jit(jax.vmap(agent2.update, (1, 0, 0, 0), 0)) + agent2.batch_reset = jax.jit(agent2.reset_memory, static_argnums=1) + agent2.batch_policy = jax.jit(agent2._policy) if args.agent1 != "NaiveEx": # NaiveEx requires env first step to init. - init_hidden = jnp.tile(agent1._mem.hidden, (args.num_opps, 1, 1)) + init_hidden = agent1._mem.hidden agent1._state, agent1._mem = agent1.batch_init( agent1._state.random_key, init_hidden ) if args.agent2 != "NaiveEx": # NaiveEx requires env first step to init. - init_hidden = jnp.tile(agent2._mem.hidden, (args.num_opps, 1, 1)) + init_hidden = jnp.tile(agent2._mem.hidden, (1, 1)) agent2._state, agent2._mem = agent2.batch_init( - jax.random.split(agent2._state.random_key, args.num_opps), + agent2._state.random_key, init_hidden, ) @@ -172,11 +148,10 @@ def _inner_rollout(carry, unused): # unpack rngs rngs = self.split(rngs, 4) - env_rng = rngs[:, :, 0, :] + env_rng = rngs[:, 0, :] # a1_rng = rngs[:, :, 1, :] # a2_rng = rngs[:, :, 2, :] - rngs = rngs[:, :, 3, :] - + rngs = rngs[:, 3, :] a1, a1_state, new_a1_mem = agent1.batch_policy( a1_state, obs1, @@ -188,6 +163,7 @@ def _inner_rollout(carry, unused): obs2, a2_mem, ) + (next_obs1, next_obs2), env_state, rewards, done, info = env.step( env_rng, env_state, @@ -258,13 +234,8 @@ def _outer_rollout(carry, unused): if args.agent1 == "MFOS": a1_mem = agent1.meta_policy(a1_mem) - # update second agent - a2_state, a2_mem, a2_metrics = agent2.batch_update( - trajectories[1], - obs2, - a2_state, - a2_mem, - ) + if args.agent2 == "MFOS": + a2_mem = agent2.meta_policy(a2_mem) return ( rngs, obs1, @@ -277,7 +248,7 @@ def _outer_rollout(carry, unused): a2_mem, env_state, env_params, - ), (*trajectories, a2_metrics) + ), trajectories def _rollout( _rng_run: jnp.ndarray, @@ -288,30 +259,18 @@ def _rollout( _env_params: Any, ): # env reset - rngs = jnp.concatenate( - [jax.random.split(_rng_run, args.num_envs)] * args.num_opps - ).reshape((args.num_opps, args.num_envs, -1)) - + rngs = jax.random.split(_rng_run, args.num_envs) obs, env_state = env.reset(rngs, _env_params) rewards = [ - jnp.zeros((args.num_opps, args.num_envs)), - jnp.zeros((args.num_opps, args.num_envs)), + jnp.zeros(args.num_envs), + jnp.zeros(args.num_envs), ] # Player 1 _a1_mem = agent1.batch_reset(_a1_mem, False) # Player 2 - if args.agent1 == "NaiveEx": - _a1_state, _a1_mem = agent1.batch_init(obs[0]) - - if args.agent2 == "NaiveEx": - _a2_state, _a2_mem = agent2.batch_init(obs[1]) + _a2_mem = agent2.batch_reset(_a2_mem, False) - elif self.args.env_type in ["meta"]: - # meta-experiments - init 2nd agent per trial - _a2_state, _a2_mem = agent2.batch_init( - jax.random.split(_rng_run, self.num_opps), _a2_mem.hidden - ) # run trials vals, stack = jax.lax.scan( _outer_rollout, @@ -327,7 +286,7 @@ def _rollout( _env_params, ), None, - length=num_outer_steps, + length=self.args.num_outer_steps, ) ( @@ -343,7 +302,7 @@ def _rollout( env_state, env_params, ) = vals - traj_1, traj_2, a2_metrics = stack + traj_1, traj_2 = stack # reset memory a1_mem = agent1.batch_reset(a1_mem, False) @@ -384,7 +343,6 @@ def _rollout( a1_mem, a2_state, a2_mem, - a2_metrics, traj_1, traj_2, ) @@ -407,6 +365,7 @@ def run_loop(self, env_params, agents, watchers): root=os.getcwd(), ) + print("Loading First Agent") pretrained_params = load(self.args.model_path1) a1_state = a1_state._replace(params=pretrained_params) @@ -431,7 +390,6 @@ def run_loop(self, env_params, agents, watchers): a1_mem, a2_state, a2_mem, - a2_metrics, traj, other_traj, ) = self.rollout( @@ -447,13 +405,13 @@ def run_loop(self, env_params, agents, watchers): if watchers: # metrics [outer_timesteps, num_opps] - flattened_metrics_2 = jax.tree_util.tree_map( - lambda x: x.flatten(), a2_metrics - ) - list_of_metrics = [ - {k: v[i] for k, v in flattened_metrics_2.items()} - for i in range(len(list(flattened_metrics_2.values())[0])) - ] + # flattened_metrics_2 = jax.tree_util.tree_map( + # lambda x: x.flatten(), a2_metrics + # ) + # list_of_metrics = [ + # {k: v[i] for k, v in flattened_metrics_2.items()} + # for i in range(len(list(flattened_metrics_2.values())[0])) + # ] env_state = traj.env_state list_of_env_states = [ EnvState( @@ -522,11 +480,11 @@ def run_loop(self, env_params, agents, watchers): watchers[0](agents[0]) # log the inner episodes - for i, metric in enumerate(list_of_metrics): - agents[1]._logger.metrics = metric - agents[1]._logger.metrics["sgd_steps"] = i - watchers[1](agents[1]) - wandb.log({"train_iteration": i} | list_of_env_stats[i]) + # for i, metric in enumerate(list_of_metrics): + # agents[1]._logger.metrics = metric + # agents[1]._logger.metrics["sgd_steps"] = i + # watchers[1](agents[1]) + # wandb.log({"train_iteration": i} | list_of_env_stats[i]) wandb.log( { @@ -561,11 +519,10 @@ def run_loop(self, env_params, agents, watchers): lambda x: x.reshape((x.shape[0] * x.shape[1], *x.shape[2:])), env_state, ) - env_idx = jax.random.choice(rng, env_state.red_pos.shape[2]) - opp_idx = jax.random.choice(rng, env_state.red_pos.shape[1]) + env_idx = jax.random.choice(rng, env_state.red_pos.shape[1]) env_state = jax.tree_util.tree_map( - lambda x: x[:, opp_idx, env_idx, ...], env_state + lambda x: x[:, env_idx, ...], env_state ) env_states = [ EnvState(