Skip to content

Commit

Permalink
Add small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertTLange committed Feb 16, 2024
1 parent 67b361b commit 5ccd16a
Show file tree
Hide file tree
Showing 15 changed files with 224 additions and 196 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ tboards/
dev_notes.md
configs/
experiments/
v2/
evotf_es.py

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
28 changes: 16 additions & 12 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
### Work-in-Progress

- Implement more strategies
- [ ] Large-scale CMA-ES variants
- [ ] [LM-CMA](https://www.researchgate.net/publication/282612269_LM-CMA_An_alternative_to_L-BFGS_for_large-scale_black_Box_optimization)
- [ ] [VkD-CMA](https://hal.inria.fr/hal-01306551v1/document), [Code](https://gist.github.com/youheiakimoto/2fb26c0ace43c22b8f19c7796e69e108)
- [ ] [RBO](http://proceedings.mlr.press/v100/choromanski20a/choromanski20a.pdf)

- Encoding methods - via special reshape wrappers
- [ ] Discrete Cosine Transform
- [ ] Wavelet Based Encoding (van Steenkiste, 2016)
- [ ] CNN Hypernetwork (Ha - start with simple MLP)
### [v0.1.6] - [TBD]

##### Added

- Implemented Hill Climbing strategy as a simple baseline.
- Adds `use_antithetic_sampling` option to OpenAI-ES.
- Added EvoTransformer strategy and trained checkpoint.

##### Fixed

- Gradientless Descent best member replacement.

##### Changed

- SNES import DES weights directly and reuses code
- Made Sep_CMA_ES and OpenAI-ES use vector sigmas for EvoTransformer data collection.

### [v0.1.5] - [11/2023]

Expand Down
6 changes: 6 additions & 0 deletions evosax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
LES,
LGA,
NoiseReuseES,
HillClimber,
# EvoTF_ES,
)
from .core import FitnessShaper, ParameterReshaper
from .utils import ESLog
Expand Down Expand Up @@ -76,6 +78,8 @@
"LES": LES,
"LGA": LGA,
"NoiseReuseES": NoiseReuseES,
"HillClimber": HillClimber,
# "EvoTF_ES": EvoTF_ES,
}

__all__ = [
Expand Down Expand Up @@ -122,4 +126,6 @@
"LES",
"LGA",
"NoiseReuseES",
"HillClimber",
# "EvoTF_ES",
]
6 changes: 3 additions & 3 deletions evosax/experimental/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .dist_open_es import DistributedOpenES
from .open_es import OpenES


DistributedStrategies = {"DistributedOpenES": DistributedOpenES}
DistributedStrategies = {"OpenES": OpenES}


__all__ = ["DistributedOpenES", "DistributedStrategies"]
__all__ = ["OpenES", "DistributedStrategies"]
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Distributed version of OpenAI-ES. Supports z-scoring fitness trafo only."""

import jax
import jax.numpy as jnp
import chex
Expand Down Expand Up @@ -30,7 +31,7 @@ class EvoParams:
clip_max: float = jnp.finfo(jnp.float32).max


class DistributedOpenES(Strategy):
class OpenES(Strategy):
def __init__(
self,
popsize: int,
Expand Down Expand Up @@ -94,9 +95,7 @@ def params_strategy(self) -> EvoParams:
es_params = es_params
return es_params

def initialize_strategy(
self, rng: chex.PRNGKey, params: EvoParams
) -> EvoState:
def initialize_strategy(self, rng: chex.PRNGKey, params: EvoParams) -> EvoState:
"""`initialize` the evolution strategy."""
if self.n_devices > 1:
mean, sigma, opt_state = self.multi_init(rng, params)
Expand Down Expand Up @@ -207,9 +206,7 @@ def multi_tell(
def calc_per_device_grad(x, fitness, mean, sigma):
# Reconstruct noise from last mean/std estimates
noise = (x - mean) / sigma
theta_grad = (
1.0 / (self.popsize * sigma) * jnp.dot(noise.T, fitness)
)
theta_grad = 1.0 / (self.popsize * sigma) * jnp.dot(noise.T, fitness)
return jax.lax.pmean(theta_grad, axis_name="p")

theta_grad = jax.pmap(calc_per_device_grad, axis_name="p")(
Expand All @@ -220,12 +217,8 @@ def calc_per_device_grad(x, fitness, mean, sigma):
mean, opt_state = jax.pmap(self.optimizer.step)(
state.mean, theta_grad, state.opt_state, params.opt_params
)
opt_state = jax.pmap(self.optimizer.update)(
opt_state, params.opt_params
)
sigma = jax.pmap(exp_decay)(
state.sigma, params.sigma_decay, params.sigma_limit
)
opt_state = jax.pmap(self.optimizer.update)(opt_state, params.opt_params)
sigma = jax.pmap(exp_decay)(state.sigma, params.sigma_decay, params.sigma_limit)
return mean, sigma, opt_state

def single_tell(
Expand All @@ -235,9 +228,7 @@ def single_tell(
fitness = (fitness - jnp.mean(fitness)) / (jnp.std(fitness) + 1e-10)
# Reconstruct noise from last mean/std estimates
noise = (x - state.mean) / state.sigma
theta_grad = (
1.0 / (self.popsize * state.sigma) * jnp.dot(noise.T, fitness)
)
theta_grad = 1.0 / (self.popsize * state.sigma) * jnp.dot(noise.T, fitness)

# Grad update using optimizer instance - decay lrate if desired
mean, opt_state = self.optimizer.step(
Expand All @@ -254,7 +245,7 @@ def pmap_zscore(fitness: chex.Array) -> chex.Array:
def zscore(fit: chex.Array) -> chex.Array:
all_mean = jax.lax.pmean(fit, axis_name="p").mean()
diff = fit - all_mean
std = jnp.sqrt(jax.lax.pmean(diff ** 2, axis_name="p").mean())
std = jnp.sqrt(jax.lax.pmean(diff**2, axis_name="p").mean())
return diff / (std + 1e-10)

out = jax.pmap(zscore, axis_name="p")(fitness)
Expand All @@ -271,7 +262,7 @@ def zscore(fit: chex.Array) -> chex.Array:

def run_es(lrate_decay=0.99, sigma_decay=0.99, opt_name="adam"):
rng = jax.random.PRNGKey(0)
strategy = DistributedOpenES(
strategy = OpenES(
popsize,
num_dims,
opt_name=opt_name,
Expand All @@ -289,7 +280,7 @@ def run_es(lrate_decay=0.99, sigma_decay=0.99, opt_name="adam"):
print("Solution shape", x.shape)

def sphere(x):
return jnp.sum(x ** 2)
return jnp.sum(x**2)

if n_devices > 1:
psphere = jax.pmap(jax.vmap(sphere))
Expand Down Expand Up @@ -321,9 +312,7 @@ def sphere(x):
plt.plot(all_fitness095, label="lrate decay = 0.95, sigma_decay = 0.95")
plt.xlabel("Generations")
plt.ylabel("Mean Population Fitness")
plt.title(
f"{num_dims}D Quadratic - {popsize} Pop - Lrate 0.05, Sigma 0.04 - Adam"
)
plt.title(f"{num_dims}D Quadratic - {popsize} Pop - Lrate 0.05, Sigma 0.04 - Adam")
plt.legend()
plt.ylim(0, 500)
plt.savefig("quadratic_adam.png", dpi=300)
4 changes: 4 additions & 0 deletions evosax/strategies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
from .les import LES
from .lga import LGA
from .noise_reuse_es import NoiseReuseES
from .hill_climber import HillClimber

# from .evotf_es import EvoTF_ES

__all__ = [
"SimpleGA",
Expand Down Expand Up @@ -69,4 +71,6 @@
"LES",
"LGA",
"NoiseReuseES",
"HillClimber",
# "EvoTF_ES",
]
31 changes: 8 additions & 23 deletions evosax/strategies/des.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def get_des_weights(popsize: int, temperature: float = 12.5):
ranks = ranks - 0.5
sigout = nn.sigmoid(temperature * ranks)
weights = nn.softmax(-20 * sigout)
return weights
return weights[:, None]


class DES(Strategy):
Expand All @@ -53,12 +53,7 @@ def __init__(
):
"""Discovered Evolution Strategy (Lange et al., 2023)"""
super().__init__(
popsize,
num_dims,
pholder_params,
mean_decay,
n_devices,
**fitness_kwargs
popsize, num_dims, pholder_params, mean_decay, n_devices, **fitness_kwargs
)
self.strategy_name = "DES"
self.temperature = temperature
Expand All @@ -67,13 +62,9 @@ def __init__(
@property
def params_strategy(self) -> EvoParams:
"""Return default parameters of evolution strategy."""
return EvoParams(
temperature=self.temperature, sigma_init=self.sigma_init
)
return EvoParams(temperature=self.temperature, sigma_init=self.sigma_init)

def initialize_strategy(
self, rng: chex.PRNGKey, params: EvoParams
) -> EvoState:
def initialize_strategy(self, rng: chex.PRNGKey, params: EvoParams) -> EvoState:
"""`initialize` the evolution strategy."""
# Get DES discovered recombination weights.
weights = get_des_weights(self.popsize, params.temperature)
Expand All @@ -86,7 +77,7 @@ def initialize_strategy(
state = EvoState(
mean=initialization,
sigma=params.sigma_init * jnp.ones(self.num_dims),
weights=weights.reshape(-1, 1),
weights=weights,
best_member=initialization,
)
return state
Expand All @@ -96,9 +87,7 @@ def ask_strategy(
) -> Tuple[chex.Array, EvoState]:
"""`ask` for new proposed candidates to evaluate next."""
z = jax.random.normal(rng, (self.popsize, self.num_dims)) # ~ N(0, I)
x = state.mean + z * state.sigma.reshape(
1, self.num_dims
) # ~ N(m, σ^2 I)
x = state.mean + z * state.sigma.reshape(1, self.num_dims) # ~ N(m, σ^2 I)
return x, state

def tell_strategy(
Expand All @@ -113,11 +102,7 @@ def tell_strategy(
x = x[fitness.argsort()]
# Weighted updates
weighted_mean = (weights * x).sum(axis=0)
weighted_sigma = jnp.sqrt(
(weights * (x - state.mean) ** 2).sum(axis=0) + 1e-06
)
weighted_sigma = jnp.sqrt((weights * (x - state.mean) ** 2).sum(axis=0) + 1e-06)
mean = state.mean + params.lrate_mean * (weighted_mean - state.mean)
sigma = state.sigma + params.lrate_sigma * (
weighted_sigma - state.sigma
)
sigma = state.sigma + params.lrate_sigma * (weighted_sigma - state.sigma)
return state.replace(mean=mean, sigma=sigma)
16 changes: 5 additions & 11 deletions evosax/strategies/gld.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import chex
from flax import struct
from ..strategy import Strategy
from ..utils import get_best_fitness_member


@struct.dataclass
Expand Down Expand Up @@ -38,12 +39,7 @@ def __init__(
"""Gradientless Descent (Golovin et al., 2019)
Reference: https://arxiv.org/pdf/1911.06317.pdf"""
super().__init__(
popsize,
num_dims,
pholder_params,
mean_decay,
n_devices,
**fitness_kwargs
popsize, num_dims, pholder_params, mean_decay, n_devices, **fitness_kwargs
)
self.strategy_name = "GLD"

Expand All @@ -52,9 +48,7 @@ def params_strategy(self) -> EvoParams:
"""Return default parameters of evolution strategy."""
return EvoParams()

def initialize_strategy(
self, rng: chex.PRNGKey, params: EvoParams
) -> EvoState:
def initialize_strategy(self, rng: chex.PRNGKey, params: EvoParams) -> EvoState:
"""`initialize` the evolution strategy."""
initialization = jax.random.uniform(
rng,
Expand Down Expand Up @@ -94,5 +88,5 @@ def tell_strategy(
params: EvoParams,
) -> EvoState:
"""`tell` update to ES state."""
# No state update needed - everything happens with best_member update
return state.replace(mean=state.best_member)
best_member, best_fitness = get_best_fitness_member(x, fitness, state, False)
return state.replace(mean=best_member)
85 changes: 85 additions & 0 deletions evosax/strategies/hill_climber.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from typing import Tuple, Optional, Union
import jax
import jax.numpy as jnp
import chex
from flax import struct
from ..strategy import Strategy
from ..utils import get_best_fitness_member


@struct.dataclass
class EvoState:
mean: chex.Array
sigma: chex.Array
best_member: chex.Array
best_fitness: float = jnp.finfo(jnp.float32).max
gen_counter: int = 0


@struct.dataclass
class EvoParams:
sigma_init: float = 1.0
init_min: float = 0.0
init_max: float = 0.0
clip_min: float = -jnp.finfo(jnp.float32).max
clip_max: float = jnp.finfo(jnp.float32).max


class HillClimber(Strategy):
def __init__(
self,
popsize: int,
num_dims: Optional[int] = None,
pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None,
mean_decay: float = 0.0,
n_devices: Optional[int] = None,
**fitness_kwargs: Union[bool, int, float]
):
"""Simple Gaussian Hill Climbing"""
super().__init__(
popsize, num_dims, pholder_params, mean_decay, n_devices, **fitness_kwargs
)
self.strategy_name = "HillClimber"

@property
def params_strategy(self) -> EvoParams:
"""Return default parameters of evolution strategy."""
return EvoParams()

def initialize_strategy(self, rng: chex.PRNGKey, params: EvoParams) -> EvoState:
"""`initialize` the evolution strategy."""
initialization = jax.random.uniform(
rng,
(self.num_dims,),
minval=params.init_min,
maxval=params.init_max,
)
state = EvoState(
mean=initialization,
sigma=params.sigma_init * jnp.ones((self.num_dims,)),
best_member=initialization,
)
return state

def ask_strategy(
self, rng: chex.PRNGKey, state: EvoState, params: EvoParams
) -> Tuple[chex.Array, EvoState]:
"""`ask` for new proposed candidates to evaluate next."""
# Sampling of N(0, 1) noise
z = jax.random.normal(
rng,
(self.popsize, self.num_dims),
)
x = state.best_member + state.sigma.reshape(1, self.num_dims) * z
return x, state

def tell_strategy(
self,
x: chex.Array,
fitness: chex.Array,
state: EvoState,
params: EvoParams,
) -> EvoState:
"""`tell` update to ES state."""
best_member, best_fitness = get_best_fitness_member(x, fitness, state, False)
return state.replace(mean=best_member)
Loading

0 comments on commit 5ccd16a

Please sign in to comment.