From d3779940623b9766d03b2be71df27ff2347d40f3 Mon Sep 17 00:00:00 2001 From: Cornelius Braun Date: Fri, 18 Oct 2024 15:06:07 +0200 Subject: [PATCH 1/4] added SV_CMA_ES --- evosax/__init__.py | 3 + evosax/strategies/__init__.py | 2 + evosax/strategies/sv_cma_es.py | 357 +++++++++++++++++++++++++++++++++ 3 files changed, 362 insertions(+) create mode 100644 evosax/strategies/sv_cma_es.py diff --git a/evosax/__init__.py b/evosax/__init__.py index db6fcb8..4dc3e32 100755 --- a/evosax/__init__.py +++ b/evosax/__init__.py @@ -37,6 +37,7 @@ HillClimber, EvoTF_ES, DiffusionEvolution, + SV_CMA_ES ) from .core import FitnessShaper, ParameterReshaper from .utils import ESLog @@ -82,6 +83,7 @@ "HillClimber": HillClimber, "EvoTF_ES": EvoTF_ES, "DiffusionEvolution": DiffusionEvolution, + "SV_CMA_ES": SV_CMA_ES, } __all__ = [ @@ -131,4 +133,5 @@ "HillClimber", "EvoTF_ES", "DiffusionEvolution", + "SV_CMA_ES" ] diff --git a/evosax/strategies/__init__.py b/evosax/strategies/__init__.py index 5c88eeb..dd15329 100755 --- a/evosax/strategies/__init__.py +++ b/evosax/strategies/__init__.py @@ -35,6 +35,7 @@ from .hill_climber import HillClimber from .evotf_es import EvoTF_ES from .diffusion import DiffusionEvolution +from .sv_cma_es import SV_CMA_ES __all__ = [ "SimpleGA", @@ -74,4 +75,5 @@ "HillClimber", "EvoTF_ES", "DiffusionEvolution", + "SV_CMA_ES" ] diff --git a/evosax/strategies/sv_cma_es.py b/evosax/strategies/sv_cma_es.py new file mode 100644 index 0000000..cadad64 --- /dev/null +++ b/evosax/strategies/sv_cma_es.py @@ -0,0 +1,357 @@ +from typing import Optional + +import jax +import jax.numpy as jnp +from chex import Array, ArrayTree, PRNGKey +from flax.struct import dataclass + +from evosax.strategies.cma_es import get_cma_elite_weights, update_p_c, update_p_sigma, sample, update_sigma, update_covariance, EvoParams, CMA_ES +from evosax.utils.eigen_decomp import full_eigen_decomp + + +class Kernel: + def __call__(self, x1: Array, x2: Array, bandwidth: float) -> Array: + pass + +class RBF(Kernel): + def __call__(self, x1: Array, x2: Array, bandwidth: float) -> Array: + return jnp.exp(-0.5 * jnp.sum((x1 - x2) ** 2) / bandwidth) + + +@dataclass +class EvoState: + p_sigma: Array + p_c: Array + C: Array + D: Optional[Array] + B: Optional[Array] + mean: Array + sigma: Array + weights: Array + weights_truncated: Array + best_member: Array + best_fitness: float = jnp.finfo(jnp.float32).max + gen_counter: int = 0 + bandwidth: float = 1. + alpha: float = 1. + + +class SV_CMA_ES(CMA_ES): + def __init__( + self, + npop: int, + subpopsize: int, + kernel: Kernel, + num_dims: Optional[int] = None, + pholder_params: Optional[ArrayTree | Array] = None, + elite_ratio: float = 0.5, + sigma_init: float = 1.0, + mean_decay: float = 0.0, + n_devices: Optional[int] = None, + **fitness_kwargs: bool | int | float + ): + """Stein Variational CMA-ES (Braun et al., 2024) + Reference: https://arxiv.org/abs/2410.10390""" + self.npop = npop + self.subpopsize = subpopsize + popsize = int(npop * subpopsize) + super().__init__( + popsize, + num_dims, + pholder_params, + elite_ratio, + sigma_init, + mean_decay, + n_devices, + **fitness_kwargs + ) + self.elite_popsize = max(1, int(self.subpopsize * self.elite_ratio)) + self.strategy_name = "SV_CMA_ES" + self.kernel = kernel + + def initialize_strategy( + self, rng: PRNGKey, params: EvoParams + ) -> EvoState: + """`initialize` the evolution strategy.""" + weights, weights_truncated, _, _, _ = get_cma_elite_weights( + self.subpopsize, self.elite_popsize, self.num_dims, self.max_dims_sq + ) + # Initialize evolution paths & covariance matrix + initialization = jax.random.uniform( + rng, + (self.npop, self.num_dims), + minval=params.init_min, + maxval=params.init_max, + ) + + state = EvoState( + p_sigma=jnp.zeros((self.npop, self.num_dims)), + p_c=jnp.zeros((self.npop, self.num_dims)), + sigma=jnp.ones(self.npop) * params.sigma_init, + mean=initialization, + C=jnp.tile(jnp.eye(self.num_dims), (self.npop, 1, 1)), + D=None, + B=None, + weights=weights, + weights_truncated=weights_truncated, + best_member=initialization[0], # Take any random member of the means + ) + return state + + def ask_strategy( + self, rng: PRNGKey, state: EvoState, params: EvoParams + ) -> [Array, EvoState]: + """`ask` for new parameter candidates to evaluate next.""" + Cs, Bs, Ds = jax.vmap(full_eigen_decomp, (0, 0, 0, None))( + state.C, state.B, state.D, state.gen_counter + ) + keys = jax.random.split(rng, num=self.npop) + x = jax.vmap(sample, (0, 0, 0, 0, 0, None, None))( + keys, + state.mean, + state.sigma, + Bs, + Ds, + self.num_dims, + self.subpopsize, + ) + + # Reshape for evaluation + x = x.reshape(self.popsize, self.num_dims) + + return x, state.replace(C=Cs, B=Bs, D=Ds) + + def tell_strategy( + self, + x: Array, + fitness: Array, + state: EvoState, + params: EvoParams, + ) -> EvoState: + """`tell` performance data for strategy state update.""" + x = x.reshape(self.npop, self.subpopsize, self.num_dims) + fitness = fitness.reshape(self.npop, self.subpopsize) + + # Compute grads + y_ks, y_ws = jax.vmap(cmaes_grad, (0, 0, 0, 0, None))( + x, + fitness, + state.mean, + state.sigma, + state.weights_truncated + ) + + # Compute kernel grads + bandwidth = state.bandwidth + kernel_grads = jax.vmap( + lambda xi: jnp.mean( + jax.vmap(lambda xj: jax.grad(self.kernel)(xj, xi, bandwidth))(state.mean), + axis=0 + ) + )(state.mean) + + # Update means using the kernel gradients + alpha = state.alpha + projected_steps = y_ws + alpha * kernel_grads / state.sigma[:, None] + means = state.mean + params.c_m * state.sigma[:, None] * projected_steps + + # Search distribution updates + p_sigmas, C_2s, Cs, Bs, Ds = jax.vmap(update_p_sigma, (0, 0, 0, 0, 0, None, None, None))( + state.C, + state.B, + state.D, + state.p_sigma, + projected_steps, + params.c_sigma, + params.mu_eff, + state.gen_counter, + ) + + p_cs, norms_p_sigma, h_sigmas = jax.vmap(update_p_c, (0, 0, 0, None, 0, None, None, None, None))( + means, + p_sigmas, + state.p_c, + state.gen_counter + 1, + projected_steps, + params.c_sigma, + params.chi_n, + params.c_c, + params.mu_eff, + ) + + Cs = jax.vmap(update_covariance, (0, 0, 0, 0, 0, 0, None, None, None, None))( + means, + p_cs, + Cs, + y_ks, + h_sigmas, + C_2s, + state.weights, + params.c_c, + params.c_1, + params.c_mu + ) + + sigmas = jax.vmap(update_sigma, (0, 0, None, None, None))( + state.sigma, + norms_p_sigma, + params.c_sigma, + params.d_sigma, + params.chi_n, + ) + + return state.replace( + mean=means, p_sigma=p_sigmas, C=Cs, B=Bs, D=Ds, p_c=p_cs, sigma=sigmas + ) + + +def cmaes_grad( + x: Array, + fitness: Array, + mean: Array, + sigma: float, + weights_truncated: Array, +) -> [Array, Array]: + """Approximate gradient using samples from a search distribution.""" + # get sorted solutions + concat_p_f = jnp.hstack([jnp.expand_dims(fitness, 1), x]) + sorted_solutions = concat_p_f[concat_p_f[:, 0].argsort()] + # get the scores + x_k = sorted_solutions[:, 1:] # ~ N(m, σ^2 C) + y_k = (x_k - mean) / sigma # ~ N(0, C) + grad = jnp.dot(weights_truncated.T, y_k) # y_w can be seen as score estimate of CMA-ES + + return y_k, grad + + +if __name__ == "__main__": + from typing import Optional + + import numpy as np + import matplotlib.pyplot as plt + + + def KDE(kernel: Kernel, modes: Array, weights: Optional[Array] = None, bandwidth: float = 1.): + """Kernel density estimation.""" + if weights is None: + weights = jnp.ones(modes.shape[0]) / modes.shape[0] + return lambda xi: jnp.sum( + jax.vmap(kernel, in_axes=(None, 0, None))(xi, modes, bandwidth) * weights + ) + + + def plot_pdf(pdf, ax: Optional = None, xmin: float = -3., xmax: float = 3.): + x = np.linspace(xmin, xmax, 50) + x_ = np.stack(np.meshgrid(x, x), axis=-1).reshape(-1, 2) + energies = pdf(x_) + + plt.figure(figsize=(7, 7)) + if ax is not None: + ax.contourf(x, x, energies.reshape(50, 50), levels=20) # , cmap="Greys" + else: + plt.contourf(x, x, energies.reshape(50, 50), levels=20) # , cmap="Greys" + + + def plot_particles_pdf(x, objective, score_fn, npop: int, xmin: float = -3., xmax: float = 3.): + plot_pdf(objective, score_fn, xmin, xmax) + plt.scatter(*x.T, color="salmon") + # for xi in x.reshape(npop, -1, 2): + # plt.scatter(*xi.T) + plt.xlim(xmin, xmax) + plt.ylim(xmin, xmax) + # plt.show() + + + class Benchmark: + def __init__(self, lb: Array | float, ub: Array | float, dim: int, fglob: float, name: str) -> None: + self.lower_bounds = lb + self.upper_bounds = ub + self.dim = dim + self.fglob = fglob + self.name = name + + def get_objective_derivative(self): + pass + + def get_objective(self): + return self.get_objective_derivative()[0] + + def plot(self, x: Array, lb: Array, ub: Array): + """Plotting code. Works for synthetic benchmarks.""" + plot_particles_pdf( + x[0], + lambda y: jnp.exp(-self.get_objective()(y)), + None, + 1, + lb, + ub + ) + + + class GMM(Benchmark): + def __init__( + self, + rng: PRNGKey, + lb: float = -6., + ub: float = 6., + kernel_rad: float = 1., + n_modes: int = 4, + dim: int = 2, + name: str = "GMM" + ) -> None: + fglob = -1 / (kernel_rad * (2 ** dim + 1)) # pdf = height * width * npeaks != 1 solves to this + super().__init__(lb, ub, dim, fglob, name) + self.kernel_rad = kernel_rad + + # Instantiate problem + rng_w, rng_m = jax.random.split(rng) + self.weights = jax.random.uniform(rng_w, (n_modes,), minval=0., maxval=10.) + self.weights /= jnp.sum(self.weights) + self.modes = jax.random.uniform(rng_m, (n_modes, dim), minval=lb + 2, + maxval=ub - 2) # Add some slack beyond the bounds so they do not all overlap + + def get_objective_derivative(self): + """Return the objective and its derivative functions.""" + eval_fn = jax.jit(lambda x: -jnp.log(KDE(RBF(), self.modes, self.weights, self.kernel_rad)(x))) + return jax.vmap(eval_fn), jax.vmap(jax.grad(lambda x: -eval_fn(x))) + + # Benchmark + rng = jax.random.PRNGKey(2) + rng, init_rng, sample_rng = jax.random.split(rng, 3) + dim = 2 + bench = GMM(init_rng, dim=dim, n_modes=4, lb=-6., ub=6.) + n_iter = 1_000 + npop = 100 + popsize = 4 + cb_freq = 50 + + def plot_cb(x): + bench.plot((x, None), bench.lower_bounds - 2, bench.upper_bounds + 2) + plt.show() + + + rng, rng_init, rng_sample = jax.random.split(rng, 3) + strategy = SV_CMA_ES(npop=npop, subpopsize=popsize, kernel=RBF(), num_dims=bench.dim, elite_ratio=0.5, sigma_init=.05) + es_params = strategy.default_params.replace( + init_min=bench.lower_bounds, + init_max=bench.upper_bounds, + clip_min=bench.lower_bounds - 2, + clip_max=bench.upper_bounds + 2 + ) + state = strategy.initialize(rng_init, es_params) + state = state.replace(alpha=1., bandwidth=.5) + + # Get objective + objective_fn, score_fn = bench.get_objective_derivative() + + samples = [] + for t in range(n_iter): + rng, rng_gen = jax.random.split(rng) + x, state = strategy.ask(rng_gen, state, es_params) + fitness = objective_fn(x) # Evaluate score for gradient-based SVGD + state = strategy.tell(x, fitness, state, es_params) + + if t % cb_freq == 0: + print(t + 1, fitness.min()) + if plot_cb: + plot_cb(state.mean) From 7b01152c4962de19c0e9484f965a3e45fb6ece1e Mon Sep 17 00:00:00 2001 From: Cornelius Braun Date: Fri, 18 Oct 2024 15:25:53 +0200 Subject: [PATCH 2/4] include sv-cma-es in tests --- tests/conftest.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/conftest.py b/tests/conftest.py index 38182b0..c62f1e5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -42,6 +42,7 @@ def pytest_generate_tests(metafunc): "HillClimber", "EvoTF_ES", "DiffusionEvolution", + "SV_CMA_ES", ], ) else: From 90d8411e73a202f5f44e0b8012ec4da509d563a8 Mon Sep 17 00:00:00 2001 From: Cornelius Braun Date: Fri, 18 Oct 2024 15:50:17 +0200 Subject: [PATCH 3/4] adapted tests for sv_cma_es --- evosax/strategies/sv_cma_es.py | 4 ++-- tests/test_strategy_api.py | 10 ++++++++-- tests/test_strategy_run.py | 10 ++++++++-- 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/evosax/strategies/sv_cma_es.py b/evosax/strategies/sv_cma_es.py index cadad64..ada0d4f 100644 --- a/evosax/strategies/sv_cma_es.py +++ b/evosax/strategies/sv_cma_es.py @@ -41,7 +41,7 @@ def __init__( self, npop: int, subpopsize: int, - kernel: Kernel, + kernel: Kernel = RBF, num_dims: Optional[int] = None, pholder_params: Optional[ArrayTree | Array] = None, elite_ratio: float = 0.5, @@ -67,7 +67,7 @@ def __init__( ) self.elite_popsize = max(1, int(self.subpopsize * self.elite_ratio)) self.strategy_name = "SV_CMA_ES" - self.kernel = kernel + self.kernel = kernel() def initialize_strategy( self, rng: PRNGKey, params: EvoParams diff --git a/tests/test_strategy_api.py b/tests/test_strategy_api.py index 3d7318e..eb7ee74 100644 --- a/tests/test_strategy_api.py +++ b/tests/test_strategy_api.py @@ -10,7 +10,10 @@ def test_strategy_ask(strategy_name): popsize = 21 else: popsize = 20 - strategy = Strategies[strategy_name](popsize=popsize, num_dims=2) + if strategy_name == "SV_CMA_ES": + strategy = Strategies[strategy_name](npop=1, subpopsize=popsize, num_dims=2) + else: + strategy = Strategies[strategy_name](popsize=popsize, num_dims=2) params = strategy.default_params state = strategy.initialize(rng, params) x, state = strategy.ask(rng, state, params) @@ -26,7 +29,10 @@ def test_strategy_ask_tell(strategy_name): popsize = 21 else: popsize = 20 - strategy = Strategies[strategy_name](popsize=popsize, num_dims=2) + if strategy_name == "SV_CMA_ES": + strategy = Strategies[strategy_name](npop=1, subpopsize=popsize, num_dims=2) + else: + strategy = Strategies[strategy_name](popsize=popsize, num_dims=2) params = strategy.default_params state = strategy.initialize(rng, params) x, state = strategy.ask(rng, state, params) diff --git a/tests/test_strategy_run.py b/tests/test_strategy_run.py index a3d016a..16549f5 100644 --- a/tests/test_strategy_run.py +++ b/tests/test_strategy_run.py @@ -17,11 +17,14 @@ def test_strategy_run(strategy_name): popsize = 21 else: popsize = 20 + if strategy_name == "SV_CMA_ES": + strategy = Strat(npop=1, subpopsize=popsize, num_dims=2) + else: + strategy = Strat(popsize=popsize, num_dims=2) evaluator = BBOBFitness("Sphere", 2) fitness_shaper = FitnessShaper() batch_eval = evaluator.rollout - strategy = Strat(popsize=popsize, num_dims=2) params = strategy.default_params state = strategy.initialize(rng, params) @@ -46,11 +49,14 @@ def test_strategy_scan(strategy_name): popsize = 21 else: popsize = 20 + if strategy_name == "SV_CMA_ES": + strategy = Strat(npop=1, subpopsize=popsize, num_dims=2) + else: + strategy = Strat(popsize=popsize, num_dims=2) evaluator = BBOBFitness("Sphere", 2) fitness_shaper = FitnessShaper() batch_eval = evaluator.rollout - strategy = Strat(popsize=popsize, num_dims=2) es_params = strategy.default_params @partial(jax.jit, static_argnums=(1,)) From 2fdd0d3bd68a8a7d063f3a0e119af2dfb79d3c29 Mon Sep 17 00:00:00 2001 From: Cornelius Braun Date: Mon, 21 Oct 2024 14:21:21 +0200 Subject: [PATCH 4/4] Add SVOpenES --- README.md | 78 +++++++-------- evosax/__init__.py | 7 +- evosax/strategies/__init__.py | 4 +- evosax/strategies/sv_cma_es.py | 147 +--------------------------- evosax/strategies/sv_open_es.py | 167 ++++++++++++++++++++++++++++++++ evosax/utils/__init__.py | 4 + evosax/utils/kernel.py | 16 +++ tests/test_strategy_run.py | 4 +- 8 files changed, 240 insertions(+), 187 deletions(-) create mode 100644 evosax/strategies/sv_open_es.py create mode 100644 evosax/utils/kernel.py diff --git a/README.md b/README.md index 09e5529..277794c 100755 --- a/README.md +++ b/README.md @@ -32,44 +32,46 @@ state.best_member, state.best_fitness ## Implemented Evolution Strategies 🦎 -| Strategy | Reference | Import | Example | -| --- | --- | --- | --- | -| OpenAI-ES | [Salimans et al. (2017)](https://arxiv.org/pdf/1703.03864.pdf) | [`OpenES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/open_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/03_cnn_mnist.ipynb) -| PGPE | [Sehnke et al. (2010)](https://citeseerx.ist.psu.edu/viewdoc/download;jsessionid=A64D1AE8313A364B814998E9E245B40A?doi=10.1.1.180.7104&rep=rep1&type=pdf) | [`PGPE`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/pgpe.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/02_mlp_control.ipynb) -| ARS | [Mania et al. (2018)](https://arxiv.org/pdf/1803.07055.pdf) | [`ARS`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/ars.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/00_getting_started.ipynb) -| ESMC | [Merchant et al. (2021)](https://proceedings.mlr.press/v139/merchant21a.html) | [`ESMC`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/esmc.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) -| Persistent ES | [Vicol et al. (2021)](http://proceedings.mlr.press/v139/vicol21a.html) | [`PersistentES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/persistent_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/04_lrate_pes.ipynb) -| Noise-Reuse ES | [Li et al. (2023)](https://arxiv.org/pdf/2304.12180.pdf) | [`NoiseReuseES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/noise_reuse_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/04_lrate_pes.ipynb) -| xNES | [Wierstra et al. (2014)](https://www.jmlr.org/papers/volume15/wierstra14a/wierstra14a.pdf) | [`XNES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/xnes.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) -| SNES | [Wierstra et al. (2014)](https://www.jmlr.org/papers/volume15/wierstra14a/wierstra14a.pdf) | [`SNES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/sxnes.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) -| CR-FM-NES | [Nomura & Ono (2022)](https://arxiv.org/abs/2201.11422) | [`CR_FM_NES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/cr_fm_nes.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) -| Guided ES | [Maheswaranathan et al. (2018)](https://arxiv.org/abs/1806.10230) | [`GuidedES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/guided_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) -| ASEBO | [Choromanski et al. (2019)](https://arxiv.org/abs/1903.04268) | [`ASEBO`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/asebo.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) -| CMA-ES | [Hansen & Ostermeier (2001)](http://www.cmap.polytechnique.fr/~nikolaus.hansen/cmaartic.pdf) | [`CMA_ES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/cma_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) -| Sep-CMA-ES | [Ros & Hansen (2008)](https://hal.inria.fr/inria-00287367/document) | [`Sep_CMA_ES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/sep_cma_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) -| BIPOP-CMA-ES | [Hansen (2009)](https://hal.inria.fr/inria-00382093/document) | [`BIPOP_CMA_ES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/bipop_cma_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/06_restart_es.ipynb) -| IPOP-CMA-ES | [Auer & Hansen (2005)](http://www.cmap.polytechnique.fr/~nikolaus.hansen/cec2005ipopcmaes.pdf) | [`IPOP_CMA_ES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/ipop_cma_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/06_restart_es.ipynb) -| Full-iAMaLGaM | [Bosman et al. (2013)](https://tinyurl.com/y9fcccx2) | [`Full_iAMaLGaM`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/full_iamalgam.py) |[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) -| Independent-iAMaLGaM | [Bosman et al. (2013)](https://tinyurl.com/y9fcccx2) | [`Indep_iAMaLGaM`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/indep_iamalgam.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) -| MA-ES | [Bayer & Sendhoff (2017)](https://www.honda-ri.de/pubs/pdf/3376.pdf) | [`MA_ES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/ma_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) -| LM-MA-ES | [Loshchilov et al. (2017)](https://arxiv.org/pdf/1705.06693.pdf) | [`LM_MA_ES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/lm_ma_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) -| RmES | [Li & Zhang (2017)](https://ieeexplore.ieee.org/document/8080257) | [`RmES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/rm_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) -| Simple Genetic | [Such et al. (2017)](https://arxiv.org/abs/1712.06567) | [`SimpleGA`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/simple_ga.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) -| SAMR-GA | [Clune et al. (2008)](https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1000187) | [`SAMR_GA`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/samr_ga.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) -| GESMR-GA | [Kumar et al. (2022)](https://arxiv.org/abs/2204.04817) | [`GESMR_GA`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/gesmr_ga.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) -| MR15-GA | [Rechenberg (1978)](https://link.springer.com/chapter/10.1007/978-3-642-81283-5_8) | [`MR15_GA`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/mr15_ga.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) -| LGA | [Lange et al. (2023b)](https://arxiv.org/abs/2304.03995) | [`LGA`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/lga.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) -| Simple Gaussian | [Rechenberg (1978)](https://link.springer.com/chapter/10.1007/978-3-642-81283-5_8) | [`SimpleES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/simple_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) -| DES | [Lange et al. (2023a)](https://arxiv.org/abs/2211.11260) | [`DES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/des.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) -| LES | [Lange et al. (2023a)](https://arxiv.org/abs/2211.11260) | [`LES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/les.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) -| EvoTF | [Lange et al. (2024)](https://arxiv.org/abs/2403.02985) | [`EvoTF_ES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/evotf_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) -| Diffusion Evolution | [Zhang et al. (2024)](https://arxiv.org/pdf/2410.02543) | [`DiffusionEvolution`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/diffusion.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) -| Particle Swarm Optimization | [Kennedy & Eberhart (1995)](https://ieeexplore.ieee.org/document/488968) | [`PSO`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/pso.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) -| Differential Evolution | [Storn & Price (1997)](https://www.metabolic-economics.de/pages/seminar_theoretische_biologie_2007/literatur/schaber/Storn1997JGlobOpt11.pdf) | [`DE`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/de.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) -| GLD | [Golovin et al. (2019)](https://arxiv.org/pdf/1911.06317.pdf) | [`GLD`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/gld.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) -| Simulated Annealing | [Rasdi Rere et al. (2015)](https://www.sciencedirect.com/science/article/pii/S1877050915035759) | [`SimAnneal`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/sim_anneal.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) -| Population-Based Training | [Jaderberg et al. (2017)](https://arxiv.org/abs/1711.09846) | [`PBT`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/pbt.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/05_quadratic_pbt.ipynb) -| Random Search | [Bergstra & Bengio (2012)](https://www.jmlr.org/papers/volume13/bergstra12a/bergstra12a.pdf) | [`RandomSearch`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/random.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) +| Strategy | Reference | Import | Example | +|-----------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------| --- | +| OpenAI-ES | [Salimans et al. (2017)](https://arxiv.org/pdf/1703.03864.pdf) | [`OpenES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/open_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/03_cnn_mnist.ipynb) +| PGPE | [Sehnke et al. (2010)](https://citeseerx.ist.psu.edu/viewdoc/download;jsessionid=A64D1AE8313A364B814998E9E245B40A?doi=10.1.1.180.7104&rep=rep1&type=pdf) | [`PGPE`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/pgpe.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/02_mlp_control.ipynb) +| ARS | [Mania et al. (2018)](https://arxiv.org/pdf/1803.07055.pdf) | [`ARS`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/ars.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/00_getting_started.ipynb) +| ESMC | [Merchant et al. (2021)](https://proceedings.mlr.press/v139/merchant21a.html) | [`ESMC`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/esmc.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) +| Persistent ES | [Vicol et al. (2021)](http://proceedings.mlr.press/v139/vicol21a.html) | [`PersistentES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/persistent_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/04_lrate_pes.ipynb) +| Noise-Reuse ES | [Li et al. (2023)](https://arxiv.org/pdf/2304.12180.pdf) | [`NoiseReuseES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/noise_reuse_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/04_lrate_pes.ipynb) +| xNES | [Wierstra et al. (2014)](https://www.jmlr.org/papers/volume15/wierstra14a/wierstra14a.pdf) | [`XNES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/xnes.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) +| SNES | [Wierstra et al. (2014)](https://www.jmlr.org/papers/volume15/wierstra14a/wierstra14a.pdf) | [`SNES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/sxnes.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) +| CR-FM-NES | [Nomura & Ono (2022)](https://arxiv.org/abs/2201.11422) | [`CR_FM_NES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/cr_fm_nes.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) +| Guided ES | [Maheswaranathan et al. (2018)](https://arxiv.org/abs/1806.10230) | [`GuidedES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/guided_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) +| ASEBO | [Choromanski et al. (2019)](https://arxiv.org/abs/1903.04268) | [`ASEBO`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/asebo.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) +| CMA-ES | [Hansen & Ostermeier (2001)](http://www.cmap.polytechnique.fr/~nikolaus.hansen/cmaartic.pdf) | [`CMA_ES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/cma_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) +| Sep-CMA-ES | [Ros & Hansen (2008)](https://hal.inria.fr/inria-00287367/document) | [`Sep_CMA_ES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/sep_cma_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) +| BIPOP-CMA-ES | [Hansen (2009)](https://hal.inria.fr/inria-00382093/document) | [`BIPOP_CMA_ES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/bipop_cma_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/06_restart_es.ipynb) +| IPOP-CMA-ES | [Auer & Hansen (2005)](http://www.cmap.polytechnique.fr/~nikolaus.hansen/cec2005ipopcmaes.pdf) | [`IPOP_CMA_ES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/ipop_cma_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/06_restart_es.ipynb) +| Full-iAMaLGaM | [Bosman et al. (2013)](https://tinyurl.com/y9fcccx2) | [`Full_iAMaLGaM`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/full_iamalgam.py) |[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) +| Independent-iAMaLGaM | [Bosman et al. (2013)](https://tinyurl.com/y9fcccx2) | [`Indep_iAMaLGaM`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/indep_iamalgam.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) +| MA-ES | [Bayer & Sendhoff (2017)](https://www.honda-ri.de/pubs/pdf/3376.pdf) | [`MA_ES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/ma_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) +| LM-MA-ES | [Loshchilov et al. (2017)](https://arxiv.org/pdf/1705.06693.pdf) | [`LM_MA_ES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/lm_ma_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) +| RmES | [Li & Zhang (2017)](https://ieeexplore.ieee.org/document/8080257) | [`RmES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/rm_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) +| Simple Genetic | [Such et al. (2017)](https://arxiv.org/abs/1712.06567) | [`SimpleGA`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/simple_ga.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) +| SAMR-GA | [Clune et al. (2008)](https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1000187) | [`SAMR_GA`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/samr_ga.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) +| GESMR-GA | [Kumar et al. (2022)](https://arxiv.org/abs/2204.04817) | [`GESMR_GA`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/gesmr_ga.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) +| MR15-GA | [Rechenberg (1978)](https://link.springer.com/chapter/10.1007/978-3-642-81283-5_8) | [`MR15_GA`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/mr15_ga.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) +| LGA | [Lange et al. (2023b)](https://arxiv.org/abs/2304.03995) | [`LGA`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/lga.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) +| Simple Gaussian | [Rechenberg (1978)](https://link.springer.com/chapter/10.1007/978-3-642-81283-5_8) | [`SimpleES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/simple_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) +| DES | [Lange et al. (2023a)](https://arxiv.org/abs/2211.11260) | [`DES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/des.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) +| LES | [Lange et al. (2023a)](https://arxiv.org/abs/2211.11260) | [`LES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/les.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) +| EvoTF | [Lange et al. (2024)](https://arxiv.org/abs/2403.02985) | [`EvoTF_ES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/evotf_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) +| Diffusion Evolution | [Zhang et al. (2024)](https://arxiv.org/pdf/2410.02543) | [`DiffusionEvolution`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/diffusion.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) +| SV-OpenAI-ES | [Liu et al. (2017)](https://arxiv.org/abs/1704.02399) | [`SV_OpenES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/sv_open_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) +| SV-CMA-ES | [Braun et al. (2024)](https://arxiv.org/abs/2410.10390) | [`SV_CMA_ES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/sv_cma_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) +| Particle Swarm Optimization | [Kennedy & Eberhart (1995)](https://ieeexplore.ieee.org/document/488968) | [`PSO`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/pso.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) +| Differential Evolution | [Storn & Price (1997)](https://www.metabolic-economics.de/pages/seminar_theoretische_biologie_2007/literatur/schaber/Storn1997JGlobOpt11.pdf) | [`DE`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/de.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) +| GLD | [Golovin et al. (2019)](https://arxiv.org/pdf/1911.06317.pdf) | [`GLD`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/gld.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) +| Simulated Annealing | [Rasdi Rere et al. (2015)](https://www.sciencedirect.com/science/article/pii/S1877050915035759) | [`SimAnneal`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/sim_anneal.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) +| Population-Based Training | [Jaderberg et al. (2017)](https://arxiv.org/abs/1711.09846) | [`PBT`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/pbt.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/05_quadratic_pbt.ipynb) +| Random Search | [Bergstra & Bengio (2012)](https://www.jmlr.org/papers/volume13/bergstra12a/bergstra12a.pdf) | [`RandomSearch`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/random.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) diff --git a/evosax/__init__.py b/evosax/__init__.py index 4dc3e32..fa900e7 100755 --- a/evosax/__init__.py +++ b/evosax/__init__.py @@ -37,7 +37,8 @@ HillClimber, EvoTF_ES, DiffusionEvolution, - SV_CMA_ES + SV_CMA_ES, + SV_OpenES ) from .core import FitnessShaper, ParameterReshaper from .utils import ESLog @@ -84,6 +85,7 @@ "EvoTF_ES": EvoTF_ES, "DiffusionEvolution": DiffusionEvolution, "SV_CMA_ES": SV_CMA_ES, + "SV_OpenES": SV_OpenES, } __all__ = [ @@ -133,5 +135,6 @@ "HillClimber", "EvoTF_ES", "DiffusionEvolution", - "SV_CMA_ES" + "SV_CMA_ES", + "SV_OpenES" ] diff --git a/evosax/strategies/__init__.py b/evosax/strategies/__init__.py index dd15329..f0326e8 100755 --- a/evosax/strategies/__init__.py +++ b/evosax/strategies/__init__.py @@ -36,6 +36,7 @@ from .evotf_es import EvoTF_ES from .diffusion import DiffusionEvolution from .sv_cma_es import SV_CMA_ES +from .sv_open_es import SV_OpenES __all__ = [ "SimpleGA", @@ -75,5 +76,6 @@ "HillClimber", "EvoTF_ES", "DiffusionEvolution", - "SV_CMA_ES" + "SV_CMA_ES", + "SV_OpenES" ] diff --git a/evosax/strategies/sv_cma_es.py b/evosax/strategies/sv_cma_es.py index ada0d4f..9843390 100644 --- a/evosax/strategies/sv_cma_es.py +++ b/evosax/strategies/sv_cma_es.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Type import jax import jax.numpy as jnp @@ -7,15 +7,7 @@ from evosax.strategies.cma_es import get_cma_elite_weights, update_p_c, update_p_sigma, sample, update_sigma, update_covariance, EvoParams, CMA_ES from evosax.utils.eigen_decomp import full_eigen_decomp - - -class Kernel: - def __call__(self, x1: Array, x2: Array, bandwidth: float) -> Array: - pass - -class RBF(Kernel): - def __call__(self, x1: Array, x2: Array, bandwidth: float) -> Array: - return jnp.exp(-0.5 * jnp.sum((x1 - x2) ** 2) / bandwidth) +from evosax.utils.kernel import Kernel, RBF @dataclass @@ -41,7 +33,7 @@ def __init__( self, npop: int, subpopsize: int, - kernel: Kernel = RBF, + kernel: Type[Kernel] = RBF, num_dims: Optional[int] = None, pholder_params: Optional[ArrayTree | Array] = None, elite_ratio: float = 0.5, @@ -222,136 +214,3 @@ def cmaes_grad( grad = jnp.dot(weights_truncated.T, y_k) # y_w can be seen as score estimate of CMA-ES return y_k, grad - - -if __name__ == "__main__": - from typing import Optional - - import numpy as np - import matplotlib.pyplot as plt - - - def KDE(kernel: Kernel, modes: Array, weights: Optional[Array] = None, bandwidth: float = 1.): - """Kernel density estimation.""" - if weights is None: - weights = jnp.ones(modes.shape[0]) / modes.shape[0] - return lambda xi: jnp.sum( - jax.vmap(kernel, in_axes=(None, 0, None))(xi, modes, bandwidth) * weights - ) - - - def plot_pdf(pdf, ax: Optional = None, xmin: float = -3., xmax: float = 3.): - x = np.linspace(xmin, xmax, 50) - x_ = np.stack(np.meshgrid(x, x), axis=-1).reshape(-1, 2) - energies = pdf(x_) - - plt.figure(figsize=(7, 7)) - if ax is not None: - ax.contourf(x, x, energies.reshape(50, 50), levels=20) # , cmap="Greys" - else: - plt.contourf(x, x, energies.reshape(50, 50), levels=20) # , cmap="Greys" - - - def plot_particles_pdf(x, objective, score_fn, npop: int, xmin: float = -3., xmax: float = 3.): - plot_pdf(objective, score_fn, xmin, xmax) - plt.scatter(*x.T, color="salmon") - # for xi in x.reshape(npop, -1, 2): - # plt.scatter(*xi.T) - plt.xlim(xmin, xmax) - plt.ylim(xmin, xmax) - # plt.show() - - - class Benchmark: - def __init__(self, lb: Array | float, ub: Array | float, dim: int, fglob: float, name: str) -> None: - self.lower_bounds = lb - self.upper_bounds = ub - self.dim = dim - self.fglob = fglob - self.name = name - - def get_objective_derivative(self): - pass - - def get_objective(self): - return self.get_objective_derivative()[0] - - def plot(self, x: Array, lb: Array, ub: Array): - """Plotting code. Works for synthetic benchmarks.""" - plot_particles_pdf( - x[0], - lambda y: jnp.exp(-self.get_objective()(y)), - None, - 1, - lb, - ub - ) - - - class GMM(Benchmark): - def __init__( - self, - rng: PRNGKey, - lb: float = -6., - ub: float = 6., - kernel_rad: float = 1., - n_modes: int = 4, - dim: int = 2, - name: str = "GMM" - ) -> None: - fglob = -1 / (kernel_rad * (2 ** dim + 1)) # pdf = height * width * npeaks != 1 solves to this - super().__init__(lb, ub, dim, fglob, name) - self.kernel_rad = kernel_rad - - # Instantiate problem - rng_w, rng_m = jax.random.split(rng) - self.weights = jax.random.uniform(rng_w, (n_modes,), minval=0., maxval=10.) - self.weights /= jnp.sum(self.weights) - self.modes = jax.random.uniform(rng_m, (n_modes, dim), minval=lb + 2, - maxval=ub - 2) # Add some slack beyond the bounds so they do not all overlap - - def get_objective_derivative(self): - """Return the objective and its derivative functions.""" - eval_fn = jax.jit(lambda x: -jnp.log(KDE(RBF(), self.modes, self.weights, self.kernel_rad)(x))) - return jax.vmap(eval_fn), jax.vmap(jax.grad(lambda x: -eval_fn(x))) - - # Benchmark - rng = jax.random.PRNGKey(2) - rng, init_rng, sample_rng = jax.random.split(rng, 3) - dim = 2 - bench = GMM(init_rng, dim=dim, n_modes=4, lb=-6., ub=6.) - n_iter = 1_000 - npop = 100 - popsize = 4 - cb_freq = 50 - - def plot_cb(x): - bench.plot((x, None), bench.lower_bounds - 2, bench.upper_bounds + 2) - plt.show() - - - rng, rng_init, rng_sample = jax.random.split(rng, 3) - strategy = SV_CMA_ES(npop=npop, subpopsize=popsize, kernel=RBF(), num_dims=bench.dim, elite_ratio=0.5, sigma_init=.05) - es_params = strategy.default_params.replace( - init_min=bench.lower_bounds, - init_max=bench.upper_bounds, - clip_min=bench.lower_bounds - 2, - clip_max=bench.upper_bounds + 2 - ) - state = strategy.initialize(rng_init, es_params) - state = state.replace(alpha=1., bandwidth=.5) - - # Get objective - objective_fn, score_fn = bench.get_objective_derivative() - - samples = [] - for t in range(n_iter): - rng, rng_gen = jax.random.split(rng) - x, state = strategy.ask(rng_gen, state, es_params) - fitness = objective_fn(x) # Evaluate score for gradient-based SVGD - state = strategy.tell(x, fitness, state, es_params) - - if t % cb_freq == 0: - print(t + 1, fitness.min()) - if plot_cb: - plot_cb(state.mean) diff --git a/evosax/strategies/sv_open_es.py b/evosax/strategies/sv_open_es.py new file mode 100644 index 0000000..5223c3f --- /dev/null +++ b/evosax/strategies/sv_open_es.py @@ -0,0 +1,167 @@ +from typing import Optional, Type, Union + +import jax +from flax import struct +from chex import Array, ArrayTree +import jax.numpy as jnp + +from evosax.core import OptState, exp_decay +from evosax.strategies.open_es import EvoParams, OpenES +from evosax.utils.kernel import RBF, Kernel + + +@struct.dataclass +class EvoState: + mean: Array + sigma: Array + opt_state: OptState + best_member: Array + best_fitness: float = jnp.finfo(jnp.float32).max + gen_counter: int = 0 + bandwidth: float = 1. + alpha: float = 1. + + +class SV_OpenES(OpenES): + def __init__( + self, + npop: int, + subpopsize: int, + kernel: Type[Kernel] = RBF, + num_dims: Optional[int] = None, + pholder_params: Optional[ArrayTree | Array] = None, + use_antithetic_sampling: bool = True, + opt_name: str = "adam", + lrate_init: float = 0.05, + lrate_decay: float = 1.0, + lrate_limit: float = 0.001, + sigma_init: float = 0.03, + sigma_decay: float = 1.0, + sigma_limit: float = 0.01, + mean_decay: float = 0.0, + n_devices: Optional[int] = None, + **fitness_kwargs: Union[bool, int, float] + ): + """Stein Variational OpenAI-ES (Liu et al., 2017) + Reference: https://arxiv.org/abs/1704.02399""" + super().__init__( + npop * subpopsize, + num_dims, + pholder_params, + use_antithetic_sampling, + opt_name, + lrate_init, + lrate_decay, + lrate_limit, + sigma_init, + sigma_decay, + sigma_limit, + mean_decay, + n_devices, + **fitness_kwargs + ) + assert not subpopsize & 1, "Sub-population size size must be even" + self.strategy_name = "SV_OpenAI_ES" + self.npop = npop + self.subpopsize = subpopsize + self.kernel = kernel() + + @property + def params_strategy(self) -> EvoParams: + """Return default parameters of evolution strategy.""" + opt_params = self.optimizer.default_params.replace( + lrate_init=self.lrate_init, + lrate_decay=self.lrate_decay, + lrate_limit=self.lrate_limit, + ) + return EvoParams( + opt_params=opt_params, + sigma_init=self.sigma_init, + sigma_decay=self.sigma_decay, + sigma_limit=self.sigma_limit, + ) + + def initialize_strategy( + self, rng: jax.random.PRNGKey, params: EvoParams + ) -> EvoState: + """`initialize` the evolution strategy.""" + x_init = jax.random.uniform( + rng, + (self.npop, self.num_dims), + minval=params.init_min, + maxval=params.init_max + ) + state = EvoState( + mean=x_init, + sigma=jnp.ones((self.npop, self.num_dims)) * params.sigma_init, + opt_state=jax.vmap(lambda _: self.optimizer.initialize(params.opt_params))(jnp.arange(self.npop)), + best_member=x_init[0], # pholder best + ) + + return state + + def ask_strategy( + self, rng: jax.random.PRNGKey, state: EvoState, params: EvoParams + ) -> [Array, EvoState]: + """`ask` for new parameter candidates to evaluate next.""" + # Antithetic sampling of noise + if self.use_antithetic_sampling: + z_plus = jax.random.normal( + rng, + (self.npop, int(self.subpopsize / 2), self.num_dims), + ) + z = jnp.concatenate([z_plus, -1.0 * z_plus], axis=1) + else: + z = jax.random.normal(rng, (self.npop, self.subpopsize, self.num_dims)) + + x = state.mean[:, None] + state.sigma[:, None] * z + x = x.reshape(self.popsize, self.num_dims) + + return x, state + + def tell_strategy( + self, + x: Array, + fitness: Array, + state: EvoState, + params: EvoParams, + ) -> EvoState: + """`tell` performance data for strategy state update.""" + x = x.reshape(self.npop, self.subpopsize, self.num_dims) + fitness = fitness.reshape(self.npop, self.subpopsize) + + # Compute MC gradients from fitness scores + noise = (state.mean[:, None] - x) / state.sigma[:, None] + scores = jnp.einsum("ijk,ij->ik", noise, fitness) / (self.subpopsize * state.sigma) + + # Compute SVGD steps + svgd_scores = svgd_grad(state.mean, scores, self.kernel, state.bandwidth) + svgd_kerns = svgd_kern(state.mean, scores, self.kernel, state.bandwidth) + gradients = -(svgd_scores + svgd_kerns * state.alpha) # flip the grads for minimization + + # Grad update using optimizer instance - decay lrate if desired + mean, opt_state = jax.vmap(self.optimizer.step, (0, 0, 0, None))( + state.mean, gradients, state.opt_state, params.opt_params + ) + opt_state = jax.vmap(self.optimizer.update, (0, None))(opt_state, params.opt_params) + sigma = jax.vmap(exp_decay, (0, None, None))(state.sigma, params.sigma_decay, params.sigma_limit) + + return state.replace(mean=mean, sigma=sigma, opt_state=opt_state) + + +def svgd_kern(x: Array, scores: Array, kernel: Kernel, bandwidth: float) -> Array: + """SVGD repulsive force.""" + phi = lambda xi: jnp.mean( + jax.vmap(lambda xj, scorej: jax.grad(kernel)(xj, xi, bandwidth))(x, scores), + axis=0 + ) + return jax.vmap(phi)(x) + + +def svgd_grad(x: Array, scores: Array, kernel: Kernel, bandwidth: float) -> Array: + """SVGD driving force.""" + phi = lambda xi: jnp.mean( + jax.vmap(lambda xj, scorej: kernel(xj, xi, bandwidth) * scorej)(x, scores), + axis=0 + ) + return jax.vmap(phi)(x) diff --git a/evosax/utils/__init__.py b/evosax/utils/__init__.py index 13851ae..d00c38e 100755 --- a/evosax/utils/__init__.py +++ b/evosax/utils/__init__.py @@ -7,9 +7,13 @@ # 2D Fitness visualization tools from .visualizer_2d import BBOBVisualizer +# Kernels +from .kernel import Kernel, RBF __all__ = [ "get_best_fitness_member", "ESLog", "BBOBVisualizer", + "Kernel", + "RBF" ] diff --git a/evosax/utils/kernel.py b/evosax/utils/kernel.py new file mode 100644 index 0000000..85dc603 --- /dev/null +++ b/evosax/utils/kernel.py @@ -0,0 +1,16 @@ +from abc import abstractmethod, ABC + +import jax.numpy as jnp +from chex import Array + + +class Kernel(ABC): + @abstractmethod + def __call__(self, x1: Array, x2: Array, bandwidth: float) -> Array: + """Compute the kernel function between two input arrays.""" + pass + +class RBF(Kernel): + """Radial Basis Function (RBF) kernel implementation.""" + def __call__(self, x1: Array, x2: Array, bandwidth: float) -> Array: + return jnp.exp(-0.5 * jnp.sum((x1 - x2) ** 2) / bandwidth) diff --git a/tests/test_strategy_run.py b/tests/test_strategy_run.py index 16549f5..601694f 100644 --- a/tests/test_strategy_run.py +++ b/tests/test_strategy_run.py @@ -17,7 +17,7 @@ def test_strategy_run(strategy_name): popsize = 21 else: popsize = 20 - if strategy_name == "SV_CMA_ES": + if strategy_name in ["SV_CMA_ES", "SV_OpenAI_ES"]: strategy = Strat(npop=1, subpopsize=popsize, num_dims=2) else: strategy = Strat(popsize=popsize, num_dims=2) @@ -49,7 +49,7 @@ def test_strategy_scan(strategy_name): popsize = 21 else: popsize = 20 - if strategy_name == "SV_CMA_ES": + if strategy_name in ["SV_CMA_ES", "SV_OpenAI_ES"]: strategy = Strat(npop=1, subpopsize=popsize, num_dims=2) else: strategy = Strat(popsize=popsize, num_dims=2)