Skip to content

Commit

Permalink
Add EvoTF_ES
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertTLange committed Mar 5, 2024
1 parent 9bfe4b4 commit 1f3a8bc
Show file tree
Hide file tree
Showing 20 changed files with 1,561 additions and 13 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ dev_notes.md
configs/
experiments/
v2/
evotf_es.py

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
### [v0.1.6] - [TBD]
### [v0.1.6] - [03/2024]

##### Added

- Implemented Hill Climbing strategy as a simple baseline.
- Adds `use_antithetic_sampling` option to OpenAI-ES.
- Added EvoTransformer strategy and trained checkpoint.
- Added `EvoTransformer` and `EvoTF_ES` strategy with example trained checkpoint.

##### Fixed

Expand Down
3 changes: 2 additions & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
include evosax/strategies/ckpt/les/*.pkl
include evosax/strategies/ckpt/lga/*.pkl
include evosax/strategies/ckpt/lga/*.pkl
include evosax/strategies/ckpt/evotf/*.pkl
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ state.best_member, state.best_fitness
| Simple Gaussian | [Rechenberg (1978)](https://link.springer.com/chapter/10.1007/978-3-642-81283-5_8) | [`SimpleES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/simple_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb)
| DES | [Lange et al. (2023a)](https://arxiv.org/abs/2211.11260) | [`DES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/des.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb)
| LES | [Lange et al. (2023a)](https://arxiv.org/abs/2211.11260) | [`LES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/les.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb)
| EvoTF | [Lange et al. (2024)]() | [`EvoTF_ES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/evotf_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb)
| Particle Swarm Optimization | [Kennedy & Eberhart (1995)](https://ieeexplore.ieee.org/document/488968) | [`PSO`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/pso.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb)
| Differential Evolution | [Storn & Price (1997)](https://www.metabolic-economics.de/pages/seminar_theoretische_biologie_2007/literatur/schaber/Storn1997JGlobOpt11.pdf) | [`DE`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/de.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb)
| GLD | [Golovin et al. (2019)](https://arxiv.org/pdf/1911.06317.pdf) | [`GLD`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/gld.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb)
Expand Down
6 changes: 3 additions & 3 deletions evosax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
LGA,
NoiseReuseES,
HillClimber,
# EvoTF_ES,
EvoTF_ES,
)
from .core import FitnessShaper, ParameterReshaper
from .utils import ESLog
Expand Down Expand Up @@ -79,7 +79,7 @@
"LGA": LGA,
"NoiseReuseES": NoiseReuseES,
"HillClimber": HillClimber,
# "EvoTF_ES": EvoTF_ES,
"EvoTF_ES": EvoTF_ES,
}

__all__ = [
Expand Down Expand Up @@ -127,5 +127,5 @@
"LGA",
"NoiseReuseES",
"HillClimber",
# "EvoTF_ES",
"EvoTF_ES",
]
19 changes: 19 additions & 0 deletions evosax/learned_eo/evotf_tools/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from .evo_transformer import EvoTransformer
from .features import (
FitnessFeaturizer,
FitnessFeaturesState,
SolutionFeaturizer,
SolutionFeaturesState,
DistributionFeaturizer,
DistributionFeaturesState,
)

__all__ = [
"EvoTransformer",
"FitnessFeaturizer",
"FitnessFeaturesState",
"SolutionFeaturizer",
"SolutionFeaturesState",
"DistributionFeaturizer",
"DistributionFeaturesState",
]
129 changes: 129 additions & 0 deletions evosax/learned_eo/evotf_tools/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from typing import Tuple, Optional, List
from flax import linen as nn
import jax.numpy as jnp
import chex
from .shared import scaled_dot_product, expand_mask, MLP, PositionalEncoding


class MultiheadAttention(nn.Module):
embed_dim: int
num_heads: int
dropout_prob: float = 0.0
use_bias: bool = False
out_att_maps: bool = False

def setup(self):
self.qkv_proj = nn.Dense(
features=3 * self.embed_dim,
kernel_init=nn.initializers.xavier_uniform(),
bias_init=nn.initializers.zeros,
use_bias=self.use_bias,
)
self.out_proj = nn.Dense(
features=self.embed_dim,
kernel_init=nn.initializers.xavier_uniform(),
bias_init=nn.initializers.zeros,
use_bias=self.use_bias,
)
self.attn_dropout = nn.Dropout(self.dropout_prob)
self.resid_dropout = nn.Dropout(self.dropout_prob)

def __call__(
self,
x: chex.Array,
mask: Optional[chex.Array] = None,
train: bool = True,
) -> Tuple[chex.Array, chex.Array]:
batch_size, seq_length, embed_dim = x.shape
if mask is not None:
mask = expand_mask(mask)
qkv = self.qkv_proj(x)
qkv = qkv.reshape(batch_size, seq_length, self.num_heads, -1)
qkv = qkv.transpose(0, 2, 1, 3)
q, k, v = jnp.array_split(qkv, 3, axis=-1)

attention = scaled_dot_product(q, k, mask)
attention = self.attn_dropout(attention, deterministic=not train)
values = jnp.matmul(attention, v)
values = values.transpose(0, 2, 1, 3)
values = values.reshape(batch_size, seq_length, embed_dim)
out = self.out_proj(values)
out = self.resid_dropout(out, deterministic=not train)
if self.out_att_maps:
return out, attention
else:
return out, None


class AttentionBlock(nn.Module):
num_heads: int
embed_dim: int
dropout_prob: float
use_bias: bool
out_att_maps: bool

def setup(self):
self.ln_1 = nn.LayerNorm(use_bias=self.use_bias)
self.attn = MultiheadAttention(
self.embed_dim,
self.num_heads,
self.dropout_prob,
self.use_bias,
self.out_att_maps,
)
self.ln_2 = nn.LayerNorm(use_bias=self.use_bias)
self.mlp = MLP(self.embed_dim, self.dropout_prob, self.use_bias)

def __call__(
self,
x: chex.Array,
mask: Optional[chex.Array] = None,
train: bool = True,
) -> Tuple[chex.Array, chex.Array]:
attn_out, attn = self.attn(self.ln_1(x), mask, train)
x = x + attn_out
x = x + self.mlp(self.ln_2(x), train)
return x, attn


class AttentionEncoder(nn.Module):
embed_dim: int
num_heads: int
num_layers: int
dropout_prob: float = 0.0
input_dropout_prob: float = 0.0
use_bias: bool = False
out_att_maps: bool = False

def setup(self):
self.input_dropout = nn.Dropout(self.input_dropout_prob)
self.input_layer = nn.Dense(self.embed_dim, use_bias=self.use_bias)
self.positional_encoding = PositionalEncoding(self.embed_dim)
self.transformer = [
AttentionBlock(
num_heads=self.num_heads,
embed_dim=self.embed_dim,
dropout_prob=self.dropout_prob,
use_bias=self.use_bias,
out_att_maps=self.out_att_maps,
)
for _ in range(self.num_layers)
]

def __call__(
self,
x: chex.Array,
mask: Optional[chex.Array] = None,
add_positional_encoding: bool = True,
train=True,
) -> Tuple[chex.Array, List[chex.Array]]:
x = self.input_layer(x)
if add_positional_encoding:
x = self.positional_encoding(x)
x = self.input_dropout(x, deterministic=not train)
# Loop over transformer blocks and collect attention maps
attn_maps = []
for layer in self.transformer:
x, attn = layer(x, mask, train)
attn_maps.append(attn)
return x, attn_maps
180 changes: 180 additions & 0 deletions evosax/learned_eo/evotf_tools/evo_modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
from typing import List, Tuple
from functools import partial
import jax.numpy as jnp
from flax import linen as nn
import chex
from .attention import AttentionEncoder
from .perceiver import PerceiverEncoder


class CompressionPerceiver(nn.Module):
num_latents: int
latent_dim: int
embed_dim: int
num_heads: int
num_layers: int = 1
dropout_prob: float = 0.0
input_dropout_prob: float = 0.0
use_bias: bool = False
out_att_maps: bool = False

def setup(self):
self.cross_attn_population = partial(
PerceiverEncoder,
num_latents=self.num_latents,
latent_dim=self.latent_dim,
embed_dim=self.embed_dim,
num_heads=self.num_heads,
num_layers=self.num_layers,
dropout_prob=self.dropout_prob,
input_dropout_prob=self.input_dropout_prob,
use_bias=self.use_bias,
out_att_maps=self.out_att_maps,
)
self.lift_cross = nn.vmap(
self.cross_attn_population,
variable_axes={"params": None},
split_rngs={"params": False, "dropout": True},
in_axes=(0, None, None, None),
out_axes=0,
)

@nn.compact
def __call__(
self, x: chex.Array, train: bool = False
) -> Tuple[chex.Array, List[chex.Array]]:
x = x.transpose(1, 0, 2, 3)
out, att = self.lift_cross(name="CompressionPerceiver")(
x,
None,
False,
train,
)
out = out.transpose(1, 0, 2, 3)
if self.out_att_maps:
att = [jnp.array(a).transpose(1, 0, 2, 3, 4) for a in att]
return out, att


class SolutionPerceiver(nn.Module):
num_latents: int
latent_dim: int
embed_dim: int
num_heads: int
num_layers: int = 1
dropout_prob: float = 0.0
input_dropout_prob: float = 0.0
use_bias: bool = False
out_att_maps: bool = False

def setup(self):
self.cross_attn_population = partial(
CompressionPerceiver,
num_latents=self.num_latents,
latent_dim=self.latent_dim,
embed_dim=self.embed_dim,
num_heads=self.num_heads,
num_layers=self.num_layers,
dropout_prob=self.dropout_prob,
input_dropout_prob=self.input_dropout_prob,
use_bias=self.use_bias,
out_att_maps=self.out_att_maps,
)
self.lift_cross = nn.vmap(
self.cross_attn_population,
variable_axes={"params": None},
split_rngs={"params": False, "dropout": False},
in_axes=(0, None),
out_axes=0,
)

@nn.compact
def __call__(
self, x: chex.Array, train: bool = False
) -> Tuple[chex.Array, List[chex.Array]]:
x = x.transpose(3, 0, 1, 2, 4)
out, att = self.lift_cross(name="SolutionPerceiver")(
x,
train,
)
out = out.transpose(1, 2, 3, 0, 4)
if self.out_att_maps:
att = [jnp.array(a).transpose(1, 2, 0, 3, 4, 5) for a in att]
return out, att


class DistributionAttention(nn.Module):
embed_dim: int
num_heads: int
num_layers: int = 1
dropout_prob: float = 0.0
input_dropout_prob: float = 0.0
use_bias: bool = False
out_att_maps: bool = False

def setup(self):
self.transformer = partial(
AttentionEncoder,
num_heads=self.num_heads,
embed_dim=self.embed_dim,
num_layers=self.num_layers,
dropout_prob=self.dropout_prob,
input_dropout_prob=self.input_dropout_prob,
use_bias=self.use_bias,
out_att_maps=self.out_att_maps,
)

self.lift_att = nn.vmap(
self.transformer,
variable_axes={"params": None},
split_rngs={"params": False, "dropout": False},
in_axes=(0, None, None, None),
out_axes=0,
)

@nn.compact
def __call__(
self, x: chex.Array, train: bool = True
) -> Tuple[chex.Array, List[chex.Array]]:
x = x.transpose(1, 0, 2, 3)
out, att = self.lift_att(name="DistributionAttention")(x, None, False, train)
out = out.transpose(1, 0, 2, 3)
if self.out_att_maps:
att = jnp.array(att).transpose(2, 1, 0, 3, 4, 5)
return out, att


class DistributionUpdateNetwork(nn.Module):
embed_dim: int
dropout_prob: float = 0.0
use_bias: bool = False

def setup(self):
self.output_net = [
nn.Dense(
features=self.embed_dim,
kernel_init=nn.initializers.xavier_uniform(),
bias_init=nn.initializers.zeros,
use_bias=self.use_bias,
),
nn.LayerNorm(self.use_bias),
nn.relu,
nn.Dropout(self.dropout_prob),
nn.Dense(
features=2,
kernel_init=nn.initializers.xavier_uniform(),
bias_init=nn.initializers.zeros,
use_bias=self.use_bias,
),
]

@nn.compact
def __call__(self, x: chex.Array, train: bool = False) -> chex.Array:
out = x
for l in self.output_net:
out = (
l(out)
if not isinstance(l, nn.Dropout)
else l(out, deterministic=not train)
)
return out
Loading

0 comments on commit 1f3a8bc

Please sign in to comment.