Skip to main content

Highly flexible input/output space agnostic NN models in JAX.

Project description

JaxZoo

jaxzoo is a package built on top of JAX and Flax (the library for neural network in JAX) that provides a zoo of neural network models. It is designed to be easy to use and to be easily extensible. The models of JaxZoo adapt to the input and output space of the model, and can be easily created with one or two lines of code.

model = JaxzooMLP(
    space_input=DictSpace({
        "figures" : TupleSpace([DiscreteSpace(10), DiscreteSpace(10)]),
        "embedding" : ContinuousSpace(64),
        "image" : ContinuousSpace((28, 28), low=0.0, high=1.0),
    }),
    space_output=ProbabilitySpace(10),
    hidden_dims=[32],
    name_activation_fn="swish",
    )
variables = model.get_initialized_variables(key_random=subkey)

print(f"Model table summary : {model.get_table_summary()}")

The 2 main feature of JaxZoo is its simplicity : you can easily create a simple model with one or two lines of code, and its flexibility : JaxZoo support any kind of input and output space, including discrete, continuous, hierarchical and automatically adapt to the input/output structure of the model.

Installation

You will need to install numpy, JAX and Flax before installing JaxZoo.

pip install jaxzoo

Quickstart

To create a simple MLP model that receives images of shape (32, 32, 3) and output a probability vector of shape (10,), you can do :

from jaxzoo.mlp import JaxzooMLP

model = JaxzooMLP(
    space_input=ContinuousSpace((28, 28, 3)),
    space_output=ProbabilitySpace(10),
    hidden_dims=[32],
    name_activation_fn="swish",
    )
variables = model.get_initialized_variables(key_random=subkey)
print(f"Model table summary : {model.get_table_summary()}")

Features

Simple basic models

JaxZoo provides a zoo of simple models that can be easily created with one or two lines of code. The models are :

  • JaxzooMLP : a simple MLP model
  • JaxzooCNN : a simple CNN model

Input space agnosticism

You can give any kind of input space to the model, including continuous, discrete, hierarchical, etc. The model will automatically adapt to the input space provided this one stays constant for the duration of the model use.

For example, if your input is a dictionnary containing one tuple of 2 figures, one embedding vector and one image between 0 and 1, you can do :

from jaxzoo.spaces import Space, ContinuousSpace, DiscreteSpace, TupleSpace, DictSpace

# Define the input space
space_input = DictSpace({
    "figures" : TupleSpace([DiscreteSpace(10), DiscreteSpace(10)]),
    "embedding" : ContinuousSpace(64),
    "image" : ContinuousSpace((28, 28), low=0.0, high=1.0),
})

How it work is that the model will treat hierarchically the input space, applying model-wise functions to each sub-input. For example, the JaxzooMLP model will flatten and concatenate each input components before applying the MLP layers, while the JaxzooCNN model will apply a CNN to the images and concatenate the embedding with non-image inputs before applying an MLP.

Output space agnosticism

Similarly to the input space, you can give any kind of output space to the model, except 2D spaces and above which are not yet supported.

For example, if your output is a dictionnary containing one probability vector of size 10 and one embedding of size 2, you can do :

from jaxzoo.spaces import Space, ContinuousSpace, DiscreteSpace, TupleSpace, DictSpace, ProbabilitySpace

# Define the output space
space_output = DictSpace({
    "probability" : ProbabilitySpace(10),
    "embedding" : ContinuousSpace(2, low=jnp.array([-1.0, jnp.nan]), high=jnp.inf),
})

How it work is that each model treat the input to produce an embedding vector. If this embedding is already adapted to the output space, the model will directly output the embedding. Otherwise, the model will apply for each sub-output space a space-wise operation to the embedding to produce the output. For example, shape constraints will lead to a dense layer, probability space will lead to applying a softmax, etc.

Stochastic models

JaxZoo support stochastic models, ie models that take a random key as input, as this may be the case for some models. The key is passed as an additional argument to the model, and is used to generate random numbers in the model.

key_random = jax.random.PRNGKey(0)
key_random, subkey = jax.random.split(key_random)
pred = model.apply(variables=variables, x=x, key_random=subkey)

Model summary

You can get a summary of the model by calling the get_table_summary method. This will give you a table with the input and output spaces, the number of parameters, the number of layers, etc.

print(f"Model table summary : {model.get_table_summary()}")

Inference for single data and batched data

You can apply the model to a single data point with method apply or to a batch of data points with method apply_batched.

# Single data point
pred = model.apply(variables=variables, x=x_batch[0], key_random=subkey)

# Batch of data points
pred = model.apply_batched(variables=variables, x=x_batch, key_random=subkey)

Project details


Release history Release notifications | RSS feed

This version

1.0

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxzoo-1.0.tar.gz (16.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jaxzoo-1.0-py3-none-any.whl (19.7 kB view details)

Uploaded Python 3

File details

Details for the file jaxzoo-1.0.tar.gz.

File metadata

  • Download URL: jaxzoo-1.0.tar.gz
  • Upload date:
  • Size: 16.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.10.12

File hashes

Hashes for jaxzoo-1.0.tar.gz
Algorithm Hash digest
SHA256 6c45f3865003c6f96a7baba654e96debacddc3d43ba353212064bd0185243ff1
MD5 940f2f3f72d722d1e72ae5212d5af3a7
BLAKE2b-256 08faba184b4348a008ebea30e5d5fb37c7088f2ddf083cd707189eaa2fae236c

See more details on using hashes here.

File details

Details for the file jaxzoo-1.0-py3-none-any.whl.

File metadata

  • Download URL: jaxzoo-1.0-py3-none-any.whl
  • Upload date:
  • Size: 19.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.10.12

File hashes

Hashes for jaxzoo-1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 21b9ed1546799d07bac0f11cdd8761109759ba885855e358ba7aaf65edcb77ac
MD5 c315fd3c92550bdc11c75cb99c8b7192
BLAKE2b-256 ebc750a1d2ae2a993fa44facc3573267a0e2d55dad7e92dfaceb05f93a45f3f9

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page