Skip to content

Commit

Permalink
Add antithetic sampling option to PGPE
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertTLange committed Feb 21, 2024
1 parent 5ccd16a commit b207f56
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 19 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
##### Added

- Implemented Hill Climbing strategy as a simple baseline.
- Adds `use_antithetic_sampling` option to OpenAI-ES.
- Adds `use_antithetic_sampling` option to OpenAI-ES & PGPE.
- Added EvoTransformer strategy and trained checkpoint.

##### Fixed
Expand Down
33 changes: 15 additions & 18 deletions evosax/strategies/pgpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@ class EvoParams:


class PGPE(Strategy):

def __init__(
self,
popsize: int,
num_dims: Optional[int] = None,
pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None,
use_antithetic_sampling: bool = True,
elite_ratio: float = 1.0,
opt_name: str = "adam",
lrate_init: float = 0.15,
Expand All @@ -53,12 +55,7 @@ def __init__(
Reference: https://tinyurl.com/2p8bn956
Inspired by: https://github.com/hardmaru/estool/blob/master/es.py"""
super().__init__(
popsize,
num_dims,
pholder_params,
mean_decay,
n_devices,
**fitness_kwargs
popsize, num_dims, pholder_params, mean_decay, n_devices, **fitness_kwargs
)
assert 0 <= elite_ratio <= 1
self.elite_ratio = elite_ratio
Expand All @@ -68,6 +65,7 @@ def __init__(
assert opt_name in ["sgd", "adam", "rmsprop", "clipup", "adan"]
self.optimizer = GradientOptimizer[opt_name](self.num_dims)
self.strategy_name = "PGPE"
self.use_antithetic_sampling = use_antithetic_sampling

# Set core kwargs es_params (lrate/sigma schedules)
self.lrate_init = lrate_init
Expand All @@ -92,9 +90,7 @@ def params_strategy(self) -> EvoParams:
sigma_limit=self.sigma_limit,
)

def initialize_strategy(
self, rng: chex.PRNGKey, params: EvoParams
) -> EvoState:
def initialize_strategy(self, rng: chex.PRNGKey, params: EvoParams) -> EvoState:
"""`initialize` the evolution strategy."""
initialization = jax.random.uniform(
rng,
Expand All @@ -115,11 +111,14 @@ def ask_strategy(
) -> Tuple[chex.Array, EvoState]:
"""`ask` for new parameter candidates to evaluate next."""
# Antithetic sampling of noise
z_plus = jax.random.normal(
rng,
(int(self.popsize / 2), self.num_dims),
)
z = jnp.hstack([z_plus, -1.0 * z_plus]).reshape(-1, self.num_dims)
if self.use_antithetic_sampling:
z_plus = jax.random.normal(
rng,
(int(self.popsize / 2), self.num_dims),
)
z = jnp.concatenate([z_plus, -1.0 * z_plus])
else:
z = jax.random.normal(rng, (self.popsize, self.num_dims))
x = state.mean + state.sigma * z
return x, state

Expand Down Expand Up @@ -150,14 +149,12 @@ def tell_strategy(
opt_state = self.optimizer.update(opt_state, params.opt_params)

baseline = jnp.mean(fitness_elite)
all_avg_scores = (
jnp.stack([fit_1[elite_idx], fit_2[elite_idx]]).sum(axis=0) / 2
)
all_avg_scores = jnp.stack([fit_1[elite_idx], fit_2[elite_idx]]).sum(axis=0) / 2

# Update sigma vector
delta_sigma = (
(jnp.expand_dims(all_avg_scores, axis=1) - baseline)
* (noise_1 ** 2 - jnp.expand_dims(state.sigma ** 2, axis=0))
* (noise_1**2 - jnp.expand_dims(state.sigma**2, axis=0))
/ state.sigma
).mean(axis=0)

Expand Down

0 comments on commit b207f56

Please sign in to comment.