Skip to content

Commit

Permalink
Undo antithetic sampling option to PGPE
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertTLange committed Feb 21, 2024
1 parent b207f56 commit 9bfe4b4
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 12 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
##### Added

- Implemented Hill Climbing strategy as a simple baseline.
- Adds `use_antithetic_sampling` option to OpenAI-ES & PGPE.
- Adds `use_antithetic_sampling` option to OpenAI-ES.
- Added EvoTransformer strategy and trained checkpoint.

##### Fixed
Expand Down
16 changes: 5 additions & 11 deletions evosax/strategies/pgpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,11 @@ class EvoParams:


class PGPE(Strategy):

def __init__(
self,
popsize: int,
num_dims: Optional[int] = None,
pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None,
use_antithetic_sampling: bool = True,
elite_ratio: float = 1.0,
opt_name: str = "adam",
lrate_init: float = 0.15,
Expand All @@ -65,7 +63,6 @@ def __init__(
assert opt_name in ["sgd", "adam", "rmsprop", "clipup", "adan"]
self.optimizer = GradientOptimizer[opt_name](self.num_dims)
self.strategy_name = "PGPE"
self.use_antithetic_sampling = use_antithetic_sampling

# Set core kwargs es_params (lrate/sigma schedules)
self.lrate_init = lrate_init
Expand Down Expand Up @@ -111,14 +108,11 @@ def ask_strategy(
) -> Tuple[chex.Array, EvoState]:
"""`ask` for new parameter candidates to evaluate next."""
# Antithetic sampling of noise
if self.use_antithetic_sampling:
z_plus = jax.random.normal(
rng,
(int(self.popsize / 2), self.num_dims),
)
z = jnp.concatenate([z_plus, -1.0 * z_plus])
else:
z = jax.random.normal(rng, (self.popsize, self.num_dims))
z_plus = jax.random.normal(
rng,
(int(self.popsize / 2), self.num_dims),
)
z = jnp.hstack([z_plus, -1.0 * z_plus]).reshape(-1, self.num_dims)
x = state.mean + state.sigma * z
return x, state

Expand Down

0 comments on commit 9bfe4b4

Please sign in to comment.