Skip to main content

Cellular Automata Accelerated in JAX

Project description

CAX: Cellular Automata Accelerated

CAX is a high-performance cellular automata library built on top of JAX/Flax that is designed for flexiblity.

Overview 🔎

CAX is a cutting-edge library designed to implement and accelerate various types of cellular automata using the JAX framework. Whether you're a researcher, a hobbyist, or just curious about the fascinating world of emergent and self-organizing systems, CAX has got you covered! 🧬

Despite their conceptual simplicity, cellular automata often demand significant computational resources. The parallel update of numerous cells, coupled with the need for backpropagation through time in neural cellular automata training, can render these models computationally intensive. CAX leverages hardware accelerators and massive parallelization to run cellular automata experiments in minutes. 🚀

The library works with discrete or continuous cellular automata of any spatial dimension, offering exceptional flexibility. From simulating one-dimensional elementary cellular automata to training three-dimensional self-autoencoding neural cellular automata, and even creating beautiful Lenia simulations, CAX provides a versatile platform for exploring the rich world of self-organizing systems. ✨

Implemented Cellular Automata 🦎

Cellular Automata Reference Example
Elementary Cellular Automata Wolfram, Stephen (2002) Colab
Conway's Game of Life Gardner, Martin (1970) Colab
Lenia Chan, Bert Wang-Chak (2020) Colab
Growing Neural Cellular Automata Mordvintsev, et al. (2020) Colab
Growing Conditional Neural Cellular Automata Faldor, et al. (2024) Colab
Growing Unsupervised Neural Cellular Automata Faldor, et al. (2024) Colab
Self-classifying MNIST Digits Randazzo, et al. (2020) Colab
Self-autoencoding MNIST Digits Faldor, et al. (2024) Colab
Diffusing Neural Cellular Automata Faldor, et al. (2024) Colab

Installation 🛠️

You will need Python 3.12 or later. Then, install CAX from PyPi:

pip install cax

To upgrade to the latest version of CAX, you can use:

pip install --upgrade git+https://github.com/maxencefaldor/cax.git

By default, CAX is installed on CPU, but you can install it on GPU or TPU with:

pip install jax[cuda12]
pip install jax[tpu]

Getting Started 🚦

import jax
from cax.core.ca import CA
from cax.core.perceive.dwconv_perceive import DWConvPerceive
from cax.core.update.nca_update import NCAUpdate
from flax import nnx

seed = 0

channel_size = 16
num_kernels = 3
hidden_size = 128
cell_dropout_rate = 0.5

key = jax.random.PRNGKey(seed)
rngs = nnx.Rngs(seed)

perceive = DWConvPerceive(channel_size, rngs)
update = NCAUpdate(
	channel_size,
	num_kernels*channel_size,
	(hidden_size,),
	rngs,
	cell_dropout_rate=cell_dropout_rate
)
ca = CA(perceive, update)

state = jax.random.normal(key, (64, 64, channel_size))
state = ca(state, num_steps=128)

Citing CAX 📝

To cite this repository:

@software{cax2024,
	author = {Faldor, Maxence and Cully, Antoine},
	title = {{CAX}: Cellular Automata Accelerated in {JAX}},
	url = {http://github.com/maxencefaldor/cax},
	version = {0.1.0},
	year = {2024},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cax-0.1.1.tar.gz (17.5 kB view hashes)

Uploaded Source

Built Distribution

cax-0.1.1-py3-none-any.whl (22.4 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page