Highly flexible input/output space agnostic NN models in JAX.
Project description
JaxZoo
jaxzoo is a package built on top of JAX and Flax (the library for neural network in JAX) that provides a zoo of neural network models. It is designed to be easy to use and to be easily extensible. The models of JaxZoo adapt to the input and output space of the model, and can be easily created with one or two lines of code.
model = JaxzooMLP(
space_input=DictSpace({
"figures" : TupleSpace([DiscreteSpace(10), DiscreteSpace(10)]),
"embedding" : ContinuousSpace(64),
"image" : ContinuousSpace((28, 28), low=0.0, high=1.0),
}),
space_output=ProbabilitySpace(10),
hidden_dims=[32],
name_activation_fn="swish",
)
variables = model.get_initialized_variables(key_random=subkey)
print(f"Model table summary : {model.get_table_summary()}")
The 2 main feature of JaxZoo is its simplicity : you can easily create a simple model with one or two lines of code, and its flexibility : JaxZoo support any kind of input and output space, including discrete, continuous, hierarchical and automatically adapt to the input/output structure of the model.
Installation
You will need to install numpy, JAX and Flax before installing JaxZoo.
pip install jaxzoo
Quickstart
To create a simple MLP model that receives images of shape (32, 32, 3) and output a probability vector of shape (10,), you can do :
from jaxzoo.mlp import JaxzooMLP
model = JaxzooMLP(
space_input=ContinuousSpace((28, 28, 3)),
space_output=ProbabilitySpace(10),
hidden_dims=[32],
name_activation_fn="swish",
)
variables = model.get_initialized_variables(key_random=subkey)
print(f"Model table summary : {model.get_table_summary()}")
Features
Simple basic models
JaxZoo provides a zoo of simple models that can be easily created with one or two lines of code. The models are :
JaxzooMLP: a simple MLP modelJaxzooCNN: a simple CNN model
Input space agnosticism
You can give any kind of input space to the model, including continuous, discrete, hierarchical, etc. The model will automatically adapt to the input space provided this one stays constant for the duration of the model use.
For example, if your input is a dictionnary containing one tuple of 2 figures, one embedding vector and one image between 0 and 1, you can do :
from jaxzoo.spaces import Space, ContinuousSpace, DiscreteSpace, TupleSpace, DictSpace
# Define the input space
space_input = DictSpace({
"figures" : TupleSpace([DiscreteSpace(10), DiscreteSpace(10)]),
"embedding" : ContinuousSpace(64),
"image" : ContinuousSpace((28, 28), low=0.0, high=1.0),
})
How it work is that the model will treat hierarchically the input space, applying model-wise functions to each sub-input. For example, the JaxzooMLP model will flatten and concatenate each input components before applying the MLP layers, while the JaxzooCNN model will apply a CNN to the images and concatenate the embedding with non-image inputs before applying an MLP.
Output space agnosticism
Similarly to the input space, you can give any kind of output space to the model, except 2D spaces and above which are not yet supported.
For example, if your output is a dictionnary containing one probability vector of size 10 and one embedding of size 2, you can do :
from jaxzoo.spaces import Space, ContinuousSpace, DiscreteSpace, TupleSpace, DictSpace, ProbabilitySpace
# Define the output space
space_output = DictSpace({
"probability" : ProbabilitySpace(10),
"embedding" : ContinuousSpace(2, low=jnp.array([-1.0, jnp.nan]), high=jnp.inf),
})
How it work is that each model treat the input to produce an embedding vector. If this embedding is already adapted to the output space, the model will directly output the embedding. Otherwise, the model will apply for each sub-output space a space-wise operation to the embedding to produce the output. For example, shape constraints will lead to a dense layer, probability space will lead to applying a softmax, etc.
Stochastic models
JaxZoo support stochastic models, ie models that take a random key as input, as this may be the case for some models. The key is passed as an additional argument to the model, and is used to generate random numbers in the model.
key_random = jax.random.PRNGKey(0)
key_random, subkey = jax.random.split(key_random)
pred = model.apply(variables=variables, x=x, key_random=subkey)
Model summary
You can get a summary of the model by calling the get_table_summary method. This will give you a table with the input and output spaces, the number of parameters, the number of layers, etc.
print(f"Model table summary : {model.get_table_summary()}")
Inference for single data and batched data
You can apply the model to a single data point with method apply or to a batch of data points with method apply_batched.
# Single data point
pred = model.apply(variables=variables, x=x_batch[0], key_random=subkey)
# Batch of data points
pred = model.apply_batched(variables=variables, x=x_batch, key_random=subkey)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jaxzoo-1.0.tar.gz.
File metadata
- Download URL: jaxzoo-1.0.tar.gz
- Upload date:
- Size: 16.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.0.1 CPython/3.10.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6c45f3865003c6f96a7baba654e96debacddc3d43ba353212064bd0185243ff1
|
|
| MD5 |
940f2f3f72d722d1e72ae5212d5af3a7
|
|
| BLAKE2b-256 |
08faba184b4348a008ebea30e5d5fb37c7088f2ddf083cd707189eaa2fae236c
|
File details
Details for the file jaxzoo-1.0-py3-none-any.whl.
File metadata
- Download URL: jaxzoo-1.0-py3-none-any.whl
- Upload date:
- Size: 19.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.0.1 CPython/3.10.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
21b9ed1546799d07bac0f11cdd8761109759ba885855e358ba7aaf65edcb77ac
|
|
| MD5 |
c315fd3c92550bdc11c75cb99c8b7192
|
|
| BLAKE2b-256 |
ebc750a1d2ae2a993fa44facc3573267a0e2d55dad7e92dfaceb05f93a45f3f9
|