Skip to content
/ jax-nca Public
forked from shyamsn97/jax-nca

Neural Cellular Automata implemented with Jax

License

Notifications You must be signed in to change notification settings

ej159/jax-nca

 
 

Repository files navigation

Neural Cellular Automata (Based on https://distill.pub/2020/growing-ca/) implemented in Jax (Flax)

Gecko gif


Installation

from source:

git clone [email protected]:shyamsn97/jax-nca.git
cd jax-nca
python setup.py install

from PYPI

pip install jax-nca

How do NCAs work?

For more information, view the awesome article https://distill.pub/2020/growing-ca/ -- Mordvintsev, et al., "Growing Neural Cellular Automata", Distill, 2020

Image below describes a single update step: https://github.com/distillpub/post--growing-ca/blob/master/public/figures/model.svg

NCA update


Why Jax?

Note: This project served as a nice introduction to jax, so its performance can probably be improved

NCAs are autoregressive models like RNNs, where new states are calculated from previous ones. With jax, we can make these operations a lot more performant with jax.lax.scan and jax.jit (https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html)

Instead of writing the nca growth process as:

def multi_step(params, nca, current_state, num_steps):
    # params: parameters for NCA
    # nca: Flax Module describing NCA
    # current_state: Current NCA state
    # num_steps: number of steps to run

    for i in range(num_steps):
        current_state = nca.apply(params, current_state)
    return current_state

We can write this with jax.lax.scan

def multi_step(params, nca, current_state, num_steps):
    # params: parameters for NCA
    # nca: Flax Module describing NCA
    # current_state: Current NCA state
    # num_steps: number of steps to run

    def forward(carry, inp):
        carry = nca.apply({"params": params}, carry)
        return carry, carry

    final_state, nca_states = jax.lax.scan(forward, current_state, None, length=num_steps)
    return final_state

The actual multi_step implementation can be found here: https://github.com/shyamsn97/jax-nca/blob/main/jax_nca/nca.py#L103


Usage

See notebooks/Gecko.ipynb for a full example

Currently there's a bug with the stochastic update, so only cell_fire_rate = 1.0 works at the moment

Creating and using NCA:

class NCA(nn.Module):
    num_hidden_channels: int
    num_target_channels: int = 3
    alpha_living_threshold: float = 0.1
    cell_fire_rate: float = 1.0
    trainable_perception: bool = False
    alpha: float = 1.0

    """
        num_hidden_channels: Number of hidden channels for each cell to use
        num_target_channels: Number of target channels to be used
        alpha_living_threshold: threshold to determine whether a cell lives or dies
        cell_fire_rate: probability that a cell receives an update per step
        trainable_perception: if true, instead of using sobel filters use a trainable conv net
        alpha: scalar value to be multiplied to updates
    """
    ...

from jax_nca.nca import NCA

# usage
nca = NCA(
    num_hidden_channels = 16, 
    num_target_channels = 3,
    trainable_perception = False,
    cell_fire_rate = 1.0,
    alpha_living_threshold = 0.1
)

nca_seed = nca.create_seed(
    nca.num_hidden_channels, nca.num_target_channels, shape=(64,64), batch_size=1
)
rng = jax.random.PRNGKey(0)
params = = nca.init(rng, nca_seed, rng)["params"]
update = nca.apply({"params":params}, nca_seed, jax.random.PRNGKey(10))

# multi step

final_state, nca_states = nca.multi_step(poarams, nca_seed, jax.random.PRNGKey(10), num_steps=32)

To train the NCA:

from jax_nca.dataset import ImageDataset
from jax_nca.trainer import EmojiTrainer


dataset = ImageDataset(emoji='🦎', img_size=64)


nca = NCA(
    num_hidden_channels = 16, 
    num_target_channels = 3,
    trainable_perception = False,
    cell_fire_rate = 1.0,
    alpha_living_threshold = 0.1
)

trainer = EmojiTrainer(dataset, nca, n_damage=0)

trainer.train(100000, batch_size=8, seed=10, lr=2e-4, min_steps=64, max_steps=96)

# to access train state:

state = trainer.state

# save
nca.save(state.params, "saved_params")

# load params
loaded_params = nca.load("saved_params")

About

Neural Cellular Automata implemented with Jax

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Jupyter Notebook 66.1%
  • Python 32.5%
  • Makefile 1.4%