Skip to main content

Neural Cellular Automata (https://distill.pub/2020/growing-ca/ -- Mordvintsev, et al., "Growing Neural Cellular Automata", Distill, 2020) implemented in JAX

Project description

Neural Cellular Automata (Based on https://distill.pub/2020/growing-ca/) implemented in Jax (Flax)

Gecko gif


Installation

from source:

git clone git@github.com:shyamsn97/jax-nca.git
cd jax-nca
python setup.py install

from PYPI

pip install jax-nca

How do NCAs work?

For more information, view the awesome article https://distill.pub/2020/growing-ca/ -- Mordvintsev, et al., "Growing Neural Cellular Automata", Distill, 2020

Image below describes a single update step: https://github.com/distillpub/post--growing-ca/blob/master/public/figures/model.svg

NCA update


Why Jax?

Note: This project served as a nice introduction to jax, so its performance can probably be improved

NCAs are autoregressive models like RNNs, where new states are calculated from previous ones. With jax, we can make these operations a lot more performant with jax.lax.scan and jax.jit (https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html)

Instead of writing the nca growth process as:

def multi_step(params, nca, current_state, num_steps):
    # params: parameters for NCA
    # nca: Flax Module describing NCA
    # current_state: Current NCA state
    # num_steps: number of steps to run

    for i in range(num_steps):
        current_state = nca.apply(params, current_state)
    return current_state

We can write this with jax.lax.scan

def multi_step(params, nca, current_state, num_steps):
    # params: parameters for NCA
    # nca: Flax Module describing NCA
    # current_state: Current NCA state
    # num_steps: number of steps to run

    def forward(carry, inp):
        carry = nca.apply({"params": params}, carry)
        return carry, carry

    final_state, nca_states = jax.lax.scan(forward, current_state, None, length=num_steps)
    return final_state

The actual multi_step implementation can be found here: https://github.com/shyamsn97/jax-nca/blob/main/jax_nca/nca.py#L103


Usage

See notebooks/Gecko.ipynb for a full example

Currently there's a bug with the stochastic update, so only cell_fire_rate = 1.0 works at the moment

Creating and using NCA:

class NCA(nn.Module):
    num_hidden_channels: int
    num_target_channels: int = 3
    alpha_living_threshold: float = 0.1
    cell_fire_rate: float = 1.0
    trainable_perception: bool = False
    alpha: float = 1.0

    """
        num_hidden_channels: Number of hidden channels for each cell to use
        num_target_channels: Number of target channels to be used
        alpha_living_threshold: threshold to determine whether a cell lives or dies
        cell_fire_rate: probability that a cell receives an update per step
        trainable_perception: if true, instead of using sobel filters use a trainable conv net
        alpha: scalar value to be multiplied to updates
    """
    ...

from jax_nca.nca import NCA

# usage
nca = NCA(
    num_hidden_channels = 16, 
    num_target_channels = 3,
    trainable_perception = False,
    cell_fire_rate = 1.0,
    alpha_living_threshold = 0.1
)

nca_seed = nca.create_seed(
    nca.num_hidden_channels, nca.num_target_channels, shape=(64,64), batch_size=1
)
rng = jax.random.PRNGKey(0)
params = = nca.init(rng, nca_seed, rng)["params"]
update = nca.apply({"params":params}, nca_seed, jax.random.PRNGKey(10))

# multi step

final_state, nca_states = nca.multi_step(poarams, nca_seed, jax.random.PRNGKey(10), num_steps=32)

To train the NCA:

from jax_nca.dataset import ImageDataset
from jax_nca.trainer import EmojiTrainer


dataset = ImageDataset(emoji='🦎', img_size=64)


nca = NCA(
    num_hidden_channels = 16, 
    num_target_channels = 3,
    trainable_perception = False,
    cell_fire_rate = 1.0,
    alpha_living_threshold = 0.1
)

trainer = EmojiTrainer(dataset, nca, n_damage=0)

trainer.train(100000, batch_size=8, seed=10, lr=2e-4, min_steps=64, max_steps=96)

# to access train state:

state = trainer.state

# save
nca.save(state.params, "saved_params")

# load params
loaded_params = nca.load("saved_params")

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_nca-0.1.2.tar.gz (9.8 kB view details)

Uploaded Source

Built Distributions

jax_nca-0.1.2-py3.9.egg (18.4 kB view details)

Uploaded Source

jax_nca-0.1.2-py3-none-any.whl (9.8 kB view details)

Uploaded Python 3

File details

Details for the file jax_nca-0.1.2.tar.gz.

File metadata

  • Download URL: jax_nca-0.1.2.tar.gz
  • Upload date:
  • Size: 9.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.0 CPython/3.9.12

File hashes

Hashes for jax_nca-0.1.2.tar.gz
Algorithm Hash digest
SHA256 11b3336d19b08a02906a2da96a9fa42901ba61673d277f8bf2ca8fc9fb115b47
MD5 772b54f5e9963cdc8966dfb595e5676c
BLAKE2b-256 5c4ed28158252c4a1a6937fbba1d0adc22d7970369429a659a3cfc3ccd0366fd

See more details on using hashes here.

File details

Details for the file jax_nca-0.1.2-py3.9.egg.

File metadata

  • Download URL: jax_nca-0.1.2-py3.9.egg
  • Upload date:
  • Size: 18.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.0 CPython/3.9.12

File hashes

Hashes for jax_nca-0.1.2-py3.9.egg
Algorithm Hash digest
SHA256 90815f020aceafec880b2f8845be9e56e3cc9ed5bf21563cecd812df9cdc22d7
MD5 46ae1df55e682c1c4252b820e70cd732
BLAKE2b-256 4bee6811b27e32474ab545f80715200341c3e46d908809a045f022fd0837123e

See more details on using hashes here.

File details

Details for the file jax_nca-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: jax_nca-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 9.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.0 CPython/3.9.12

File hashes

Hashes for jax_nca-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 2b912a0aab06507fe73f4c7b1a3caac8b17fd81a511510e17ac0cd0b5ce89495
MD5 7a256ce22f624135bbee21c5dbc81dcb
BLAKE2b-256 8df994ead5c7832ce3ac17889c9520c02b077c9f11712b588451b5f679d5c19a

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page