Skip to main content

Neural Nets for JAX

Project description

JAXnet Build Status PyPI

JAXnet is a neural net library built with JAX. Different from popular alternatives, its API is purely functional:

  • Immutable weights
  • No global compute graph
  • No global random key

This allows code to be more concise, robust and optimized (motivation below).

This is a preview. Expect breaking changes! Install with

pip3 install jaxnet

To use GPU/TPU, first install the right version of jaxlib.

API overview

from jaxnet import *

net = Sequential(Conv(2, (3, 3)), relu, flatten, Dense(4), softmax)

creates a neural net model. To initialize parameters, call init_params with a random key and example inputs:

from jax import numpy as np, jit
from jax.random import PRNGKey

inputs = np.zeros((3, 5, 5, 1))
params = net.init_params(PRNGKey(0), inputs)

print(params.dense.bias) # [-0.0178184   0.02460396 -0.00353479  0.00492503]

Invoke the network with:

output = net.apply(params, inputs) # use "jit(net.apply)(params, inputs)" for acceleration

Modules are defined as @parametrized functions that can use other modules:

@parametrized
def encode(images):
    hidden = Sequential(Dense(512), relu, Dense(512), relu)(images)
    means = Dense(10)(hidden)
    variances = Sequential(Dense(10), softplus)(hidden)
    return means, variances

All modules are composed in this way. Find more details on the API here.

JAXnet allows step-by-step debugging with concrete values like any plain Python function (when jit compilation is not used).

See JAXnet in action in these demos: Mnist Classifier, Mnist VAE, OCR with RNNs (to be fixed), ResNet and WaveNet.

Why JAXnet?

Side effects and mutable state come at a cost. Deep learning is no exception.

Functional parameter handling allows concise regularization and reparametrization.

JAXnet makes things like L2 regularization and variational inference for models concise (see API). It also allows regularizing or reparametrizing any custom modules without changing their code.

In contrast, TensorFlow 2 requires:

  • Regularization arguments on layer level, with custom implementations for each layer type.
  • Reparametrization arguments on layer level, and separate implementations for every layer.

Functional code allows new ways of optimization.

JAX allows functional numpy code to be accelerated with jit and run on GPU. Here are two use cases:

  • In JAXnet, weights are explicitly initialized into an object controlled by the user. Optimization returns a new version of weights instead of mutating them inline. This allows whole training loops to be compiled / run on GPU (demo).
  • If you use functional numpy/scipy for pre-/postprocessing, replacing numpy with jax.numpy in your import allows you to compile it / run it on GPU. (demo).

Reusing code relying on a global compute graph can be a hassle.

This is particularly true for more advanced use cases, say: You want to use existing TensorFlow code that manipulates variables by using their global name. You need to instantiate this network with two different sets of weights, and combine their output. Since you want your code to be fast, you'd like run the combined network to GPU. While solutions exist, code like this is typically brittle and hard to maintain.

JAXnet has no global compute graph. All network definitions and weights are contained in (read-only) objects. This encourages code that is easy to reuse.

Global random state is inflexible.

Example: While trained a VAE, you might want to see how reconstructions for a fixed latent variable sample improve over time. In popular frameworks, the easiest solution is typically to sample a latent variable and resupply it to the network, requiring some extra code.

In JAXnet you can fix the sampling random seed for this specific part of the network. (demo)

What about existing libraries?

Here is a crude comparison with popular deep learning libraries:

TensorFlow2/Keras PyTorch JAXnet
Immutable weights
No global compute graph
No global random key

JAXnet is independent of stax. The main motivation over stax is to simplify nesting modules. Find details and porting instructions here.

Questions

Feel free to create an issue on GitHub.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxnet-0.2.0.tar.gz (14.1 kB view hashes)

Uploaded Source

Built Distribution

jaxnet-0.2.0-py3-none-any.whl (12.6 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page