Skip to content

Commit

Permalink
Use shaped fitness instead of weights to avoid expensive population s…
Browse files Browse the repository at this point in the history
…orting
  • Loading branch information
maxencefaldor committed Mar 6, 2025
1 parent 49b9c7f commit b8bb602
Show file tree
Hide file tree
Showing 62 changed files with 990 additions and 632 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ jobs:
flake8 ./evosax --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Run unit/integration tests
run: |
python -m pytest --all -vv --durations=0 --cov=./ --cov-report=term-missing --cov-report=xml
python -m pytest tests/ -vv --durations=0 --cov=./ --cov-report=term-missing --cov-report=xml
- name: "Upload coverage to Codecov"
uses: codecov/codecov-action@v2
with:
Expand Down
14 changes: 7 additions & 7 deletions evosax/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .distribution_based.bipop_cma_es import BIPOP_CMA_ES
from .distribution_based.cma_es import CMA_ES
from .distribution_based.cr_fm_nes import CR_FM_NES
from .distribution_based.des import DES
from .distribution_based.discovered_es import DiscoveredES
from .distribution_based.esmc import ESMC
from .distribution_based.evotf_es import EvoTF_ES
from .distribution_based.gradientless_descent import GradientlessDescent
Expand All @@ -29,7 +29,7 @@
from .distribution_based.iamalgam_full import iAMaLGaM_Full
from .distribution_based.iamalgam_univariate import iAMaLGaM_Univariate
from .distribution_based.ipop_cma_es import IPOP_CMA_ES
from .distribution_based.les import LES
from .distribution_based.learned_es import LearnedES
from .distribution_based.lm_ma_es import LM_MA_ES
from .distribution_based.ma_es import MA_ES
from .distribution_based.noise_reuse_es import NoiseReuseES
Expand All @@ -42,16 +42,16 @@
from .distribution_based.simple_es import SimpleES
from .distribution_based.simulated_annealing import SimulatedAnnealing
from .distribution_based.snes import SNES
from .distribution_based.sv_cma_es import SV_CMA_ES
from .distribution_based.sv_open_es import SV_Open_ES
from .distribution_based.sv.sv_cma_es import SV_CMA_ES
from .distribution_based.sv.sv_open_es import SV_Open_ES
from .distribution_based.xnes import xNES

# Population-based algorithms
from .population_based import population_based_algorithms
from .population_based.de import DE
from .population_based.diffusion import DiffusionEvolution
from .population_based.differential_evolution import DifferentialEvolution
from .population_based.diffusion_evolution import DiffusionEvolution
from .population_based.gesmr_ga import GESMR_GA
from .population_based.lga import LGA
from .population_based.learned_ga import LearnedGA
from .population_based.mr15_ga import MR15_GA
from .population_based.pso import PSO
from .population_based.samr_ga import SAMR_GA
Expand Down
4 changes: 2 additions & 2 deletions evosax/algorithms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def init(
state = self._init(key, params)
return state

# @partial(jax.jit, static_argnames=("self",))
@partial(jax.jit, static_argnames=("self",))
def ask(
self,
key: jax.Array,
Expand All @@ -111,7 +111,7 @@ def ask(

return population, state

# @partial(jax.jit, static_argnames=("self",))
@partial(jax.jit, static_argnames=("self",))
def tell(
self,
key: jax.Array,
Expand Down
12 changes: 6 additions & 6 deletions evosax/algorithms/distribution_based/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .bipop_cma_es import BIPOP_CMA_ES
from .cma_es import CMA_ES
from .cr_fm_nes import CR_FM_NES
from .des import DES
from .discovered_es import DiscoveredES
from .esmc import ESMC
from .evotf_es import EvoTF_ES
from .gradientless_descent import GradientlessDescent
Expand All @@ -14,7 +14,7 @@
from .iamalgam_full import iAMaLGaM_Full
from .iamalgam_univariate import iAMaLGaM_Univariate
from .ipop_cma_es import IPOP_CMA_ES
from .les import LES
from .learned_es import LearnedES
from .lm_ma_es import LM_MA_ES
from .ma_es import MA_ES
from .noise_reuse_es import NoiseReuseES
Expand All @@ -27,8 +27,8 @@
from .simple_es import SimpleES
from .simulated_annealing import SimulatedAnnealing
from .snes import SNES
from .sv_cma_es import SV_CMA_ES
from .sv_open_es import SV_Open_ES
from .sv.sv_cma_es import SV_CMA_ES
from .sv.sv_open_es import SV_Open_ES
from .xnes import xNES

distribution_based_algorithms = {
Expand All @@ -37,7 +37,7 @@
"BIPOP_CMA_ES": BIPOP_CMA_ES,
"CMA_ES": CMA_ES,
"CR_FM_NES": CR_FM_NES,
"DES": DES,
"DES": DiscoveredES,
"ESMC": ESMC,
"EvoTF_ES": EvoTF_ES,
"GradientlessDescent": GradientlessDescent,
Expand All @@ -46,7 +46,7 @@
"iAMaLGaM_Full": iAMaLGaM_Full,
"iAMaLGaM_Univariate": iAMaLGaM_Univariate,
"IPOP_CMA_ES": IPOP_CMA_ES,
"LES": LES,
"LES": LearnedES,
"LM_MA_ES": LM_MA_ES,
"MA_ES": MA_ES,
"NoiseReuseES": NoiseReuseES,
Expand Down
21 changes: 13 additions & 8 deletions evosax/algorithms/distribution_based/ars.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Augmented Random Search (Mania et al., 2018).
Reference: https://arxiv.org/abs/1803.07055
[1] https://arxiv.org/abs/1803.07055
"""

from collections.abc import Callable
Expand All @@ -10,8 +10,9 @@
import optax
from flax import struct

from ...core.fitness_shaping import identity_fitness_shaping_fn
from ...types import Fitness, Population, Solution
from evosax.core.fitness_shaping import identity_fitness_shaping_fn
from evosax.types import Fitness, Population, Solution

from .base import DistributionBasedAlgorithm, Params, State, metrics_fn


Expand All @@ -20,7 +21,6 @@ class State(State):
mean: jax.Array
std: float
opt_state: optax.OptState
z: jax.Array


@struct.dataclass
Expand Down Expand Up @@ -48,6 +48,11 @@ def __init__(
# Optimizer
self.optimizer = optimizer

@property
def num_elites(self):
"""Set the elite ratio and update num_elites."""
return max(1, int(self.elite_ratio * self.population_size // 2))

@property
def _default_params(self) -> Params:
return Params(std_init=1.0)
Expand All @@ -57,7 +62,6 @@ def _init(self, key: jax.Array, params: Params) -> State:
mean=jnp.full((self.num_dims,), jnp.nan),
std=params.std_init,
opt_state=self.optimizer.init(jnp.zeros(self.num_dims)),
z=jnp.zeros((self.population_size, self.num_dims)),
best_solution=jnp.full((self.num_dims,), jnp.nan),
best_fitness=jnp.inf,
generation_counter=0,
Expand All @@ -73,8 +77,9 @@ def _ask(
# Antithetic sampling
z_plus = jax.random.normal(key, (self.population_size // 2, self.num_dims))
z = jnp.concatenate([z_plus, -z_plus])

population = state.mean + state.std * z
return population, state.replace(z=z)
return population, state

def _tell(
self,
Expand All @@ -98,9 +103,9 @@ def _tell(
fitness_std = jnp.clip(jnp.std(fitness_elite), min=1e-8)

# Compute grad
z = state.z[: self.population_size // 2]
z = (population[: self.population_size // 2] - state.mean) / state.std
delta = fitness_plus[elite_idx] - fitness_minus[elite_idx]
grad = jnp.dot(z[elite_idx].T, delta) / (self.num_elites * fitness_std)
grad = jnp.dot(delta, z[elite_idx]) / (self.num_elites * fitness_std)

# Update mean
updates, opt_state = self.optimizer.update(grad, state.opt_state)
Expand Down
31 changes: 16 additions & 15 deletions evosax/algorithms/distribution_based/asebo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Adaptive ES-Active Subspaces for Blackbox Optimization (Choromanski et al., 2019).
Reference: https://arxiv.org/abs/1903.04268
[1] https://arxiv.org/abs/1903.04268
Note that there are a couple of adaptations:
1. We always sample a fixed population size per generation
Expand All @@ -14,8 +14,9 @@
import optax
from flax import struct

from ...core.fitness_shaping import identity_fitness_shaping_fn
from ...types import Fitness, Population, Solution
from evosax.core.fitness_shaping import identity_fitness_shaping_fn
from evosax.types import Fitness, Population, Solution

from .base import DistributionBasedAlgorithm, Params, State, metrics_fn


Expand Down Expand Up @@ -135,25 +136,23 @@ def _tell(
state: State,
params: Params,
) -> State:
# Reconstruct noise from last mean/std estimates
noise = (population - state.mean) / state.std
noise_1 = noise[: int(self.population_size / 2)]
fit_1 = fitness[: int(self.population_size / 2)]
fit_2 = fitness[int(self.population_size / 2) :]
fit_diff_noise = jnp.dot(noise_1.T, fit_1 - fit_2)
grad = 1.0 / 2.0 * fit_diff_noise
# Compute grad
fitness_plus = fitness[: self.population_size // 2]
fitness_minus = fitness[self.population_size // 2 :]
grad = 0.5 * jnp.dot(
fitness_plus - fitness_minus,
(population[: self.population_size // 2] - state.mean) / state.std,
)

alpha = jnp.linalg.norm(jnp.dot(grad, state.UUT_ort)) / jnp.linalg.norm(
jnp.dot(grad, state.UUT)
)
subspace_ready = state.generation_counter > self.subspace_dims
alpha = jax.lax.select(subspace_ready, alpha, 1.0)

# Add grad FIFO-style to subspace archive (only if provided else FD)
grad_subspace = jnp.zeros((self.subspace_dims, self.num_dims))
grad_subspace = grad_subspace.at[:-1, :].set(state.grad_subspace[1:, :])
# FIFO grad subspace (same as in guided_es.py)
grad_subspace = jnp.roll(state.grad_subspace, shift=-1, axis=0)
grad_subspace = grad_subspace.at[-1, :].set(grad)
state = state.replace(grad_subspace=grad_subspace)

# Normalize gradients by norm / num_dims
grad /= jnp.linalg.norm(grad) / self.num_dims + 1e-8
Expand All @@ -162,4 +161,6 @@ def _tell(
updates, opt_state = self.optimizer.update(grad, state.opt_state)
mean = optax.apply_updates(state.mean, updates)

return state.replace(mean=mean, opt_state=opt_state, alpha=alpha)
return state.replace(
mean=mean, opt_state=opt_state, grad_subspace=grad_subspace, alpha=alpha
)
11 changes: 7 additions & 4 deletions evosax/algorithms/distribution_based/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
import jax.numpy as jnp
from flax import struct

from ...core.fitness_shaping import identity_fitness_shaping_fn
from ...types import Fitness, Metrics, Population, Solution
from evosax.core.fitness_shaping import identity_fitness_shaping_fn
from evosax.types import Fitness, Metrics, Population, Solution

from ..base import EvolutionaryAlgorithm, Params, State
from ..base import metrics_fn as base_metrics_fn

Expand All @@ -32,7 +33,10 @@ def metrics_fn(
) -> Metrics:
"""Compute metrics for distribution-based algorithm."""
metrics = base_metrics_fn(key, population, fitness, state, params)
return metrics | {"mean": state.mean, "mean_norm": jnp.linalg.norm(state.mean)}
return metrics | {
"mean": state.mean,
"mean_norm": jnp.linalg.norm(state.mean, axis=-1),
}


class DistributionBasedAlgorithm(EvolutionaryAlgorithm):
Expand All @@ -57,7 +61,6 @@ def init(
) -> State:
"""Initialize distribution-based algorithm."""
state = self._init(key, params)

state = state.replace(mean=self._ravel_solution(mean))
return state

Expand Down
9 changes: 5 additions & 4 deletions evosax/algorithms/distribution_based/bipop_cma_es.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""BIPOP-CMA-ES (Hansen, 2009).
Reference: https://hal.inria.fr/inria-00382093/document
Inspired by: https://tinyurl.com/44y3ryhf
[1] https://hal.inria.fr/inria-00382093/document
[2] https://tinyurl.com/44y3ryhf
"""

from collections.abc import Callable
Expand All @@ -10,14 +10,15 @@
import jax
from flax import struct

from ...core.fitness_shaping import identity_fitness_shaping_fn
from evosax.core.fitness_shaping import identity_fitness_shaping_fn
from evosax.types import Fitness, Population, Solution

from ...restarts.restarter import (
WrapperParams,
WrapperState,
cma_criterion,
spread_criterion,
)
from ...types import Fitness, Population, Solution
from .base import metrics_fn
from .cma_es import CMA_ES

Expand Down
34 changes: 16 additions & 18 deletions evosax/algorithms/distribution_based/cma_es.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Covariance Matrix Adaptation Evolution Strategy (Hansen et al., 2001).
Reference: https://arxiv.org/abs/1604.00772
Inspired by: https://github.com/CyberAgentAILab/cmaes
[1] https://arxiv.org/abs/1604.00772
[2] https://github.com/CyberAgentAILab/cmaes
"""

from collections.abc import Callable
Expand All @@ -10,8 +10,9 @@
import jax.numpy as jnp
from flax import struct

from ...core.fitness_shaping import identity_fitness_shaping_fn
from ...types import Fitness, Population, Solution
from evosax.core.fitness_shaping import weights_fitness_shaping_fn
from evosax.types import Fitness, Population, Solution

from .base import DistributionBasedAlgorithm, Params, State, metrics_fn


Expand Down Expand Up @@ -49,7 +50,7 @@ def __init__(
self,
population_size: int,
solution: Solution,
fitness_shaping_fn: Callable = identity_fitness_shaping_fn,
fitness_shaping_fn: Callable = weights_fitness_shaping_fn,
metrics_fn: Callable = metrics_fn,
):
"""Initialize CMA-ES."""
Expand All @@ -64,10 +65,10 @@ def _default_params(self) -> Params:
jnp.arange(1, self.population_size + 1)
) # Eq. (48)

mu_eff = (jnp.sum(weights_prime[: self.num_elites]) ** 2) / jnp.sum(
mu_eff = jnp.sum(weights_prime[: self.num_elites]) ** 2 / jnp.sum(
weights_prime[: self.num_elites] ** 2
) # Eq. (8)
mu_eff_minus = (jnp.sum(weights_prime[self.num_elites :]) ** 2) / jnp.sum(
mu_eff_minus = jnp.sum(weights_prime[self.num_elites :]) ** 2 / jnp.sum(
weights_prime[self.num_elites :] ** 2
) # Table 1

Expand Down Expand Up @@ -197,7 +198,9 @@ def _tell(

delta_h_std = self.delta_h_std(h_std, params)
rank_one = self.rank_one(p_c)
rank_mu = self.rank_mu(y_k, (y_k @ state.B) * (1 / state.D) @ state.B.T, params)
rank_mu = self.rank_mu(
fitness, y_k, (y_k @ state.B) * (1 / state.D) @ state.B.T
)
C = self.update_C(state.C, delta_h_std, rank_one, rank_mu, params)

return state.replace(mean=mean, std=std, p_std=p_std, p_c=p_c, C=C)
Expand All @@ -211,13 +214,8 @@ def update_mean(
params: Params,
) -> tuple[jax.Array, jax.Array, jax.Array]:
"""Update the mean of the distribution."""
# Sort
idx = jnp.argsort(fitness)

y_k = (population[idx] - mean) / std # ~ N(0, C)
y_w = jnp.dot(
params.weights[: self.num_elites], y_k[: self.num_elites]
) # Eq. (41)
y_k = (population - mean) / std # ~ N(0, C)
y_w = jnp.dot(jnp.where(fitness < 0.0, 0.0, fitness), y_k) # Eq. (41)
return mean + params.c_mean * std * y_w, y_k, y_w # Eq. (42)

def update_p_std(
Expand Down Expand Up @@ -260,11 +258,11 @@ def rank_one(self, p_c: jax.Array) -> jax.Array:
return jnp.outer(p_c, p_c)

def rank_mu(
self, y_k: jax.Array, C_inv_sqrt_y_k: jax.Array, params: Params
self, fitness: Fitness, y_k: jax.Array, C_inv_sqrt_y_k: jax.Array
) -> jax.Array:
"""Compute the rank-mu update term for the covariance matrix."""
w_o = params.weights * jnp.where(
params.weights >= 0,
w_o = fitness * jnp.where(
fitness >= 0,
1,
self.num_dims
/ jnp.clip(jnp.sum(jnp.square(C_inv_sqrt_y_k), axis=-1), min=1e-8),
Expand Down
Loading

0 comments on commit b8bb602

Please sign in to comment.