Skip to main content

Neural Nets for JAX

Project description

JAXnet Build Status PyPI

JAXnet is a neural net library for JAX. Different from popular neural net libraries, it is completely functional:

  • No mutable weights in modules
  • No global compute graph
  • No global random key

This is an early version. Expect breaking changes! Install with

pip install jaxnet

If you want to run networks on GPU/TPU, first install the right version of jaxlib.

Overview

Defining networks looks similar to the TensorFlow2 / Keras functional API:

from jax import numpy as np, jit
from jax.random import PRNGKey
from jaxnet import *

net = Sequential([Conv(2, (3, 3)), relu, flatten, Dense(4), softmax])

To initialize parameter values for a network, call init_params on any module (with example inputs and a random key):

inputs = np.zeros((3, 5, 5, 1))
params = net.init_params(PRNGKey(0), inputs)

It initializes and returns all parameters, accessible via attributes:

print(params.layers[3].bias) # [0.00212132 0.01169001 0.00331698 0.00460713]

Invoke the network with:

output = net(params, inputs)

For acceleration use jit:

output = jit(net)(params, inputs)

See JAXnet in action in these demos: Mnist VAE and OCR with RNNs. Alternative design ideas are discussed here.

Defining modules

Modules are functions decorated with @parametrized, with parameters defined through default values:

def Dense(out_dim, kernel_init=glorot(), bias_init=randn()):
    @parametrized
    def dense(inputs,
              kernel=Param(lambda inputs: (inputs.shape[-1], out_dim), kernel_init),
              bias=Param(lambda _: (out_dim,), bias_init)):
        return np.dot(inputs, kernel) + bias

    return dense

Param specifies parameter shape and initialization function. @parametrized transforms this function to allow usage as above.

Nesting modules

Modules can be used in other modules through default arguments:

@parametrized
def net(inputs, layer1=Dense(10), layer2=Dense(20))
    inputs = layer1(inputs)
    return layer2(inputs)

Use many modules at once with collections:

def Sequential(layers):
    @parametrized
    def sequential(inputs, layers=layers):
        for module in layers:
            inputs = module(inputs)
        return inputs

    return sequential

Nested tuples/list/dicts of modules work. The same is true for Params.

Using parameter-free functions is seamless:

def relu(input):
    return np.maximum(input, 0)

layer = Sequential([Dense(10), relu])

Parameter sharing

Parameters can be shared by using module or parameter objects multiple times (not yet implemented):

shared_net=Sequential([layer, layer])

This is equivalent to (already implemented):

@parametrized
def shared_net(input, layer=layer):
    return layer(layer(input))

Parameter reuse

You can reuse parameters of submodules:

inputs = np.zeros((1, 2))

layer = Dense(5)
net1 = Sequential([layer, Dense(2)])
net2 = Sequential([layer, Dense(3)])

layer_params = layer.init_params(PRNGKey(0), inputs)
net1_params = net1.init_params(PRNGKey(0), inputs, reuse={layer: layer_params})
net2_params = net2.init_params(PRNGKey(1), inputs, reuse={layer: layer_params})

# Now net1_params.layers[0] equals net2_params.layers[0] equals layer_params

If all parameters are reused, you can use join_params instead of init_params:

inputs = np.zeros((1, 2))

net = Dense(5)
prediction = Sequential([net, softmax])

net_params = net.init_params(PRNGKey(0), inputs)
prediction_params = prediction.join_params({net: layer_params})

# net_params.layers[0] is now equal to net_params

output = jit(prediction)(prediction_params, inputs)

If you just want to call the network with these joined parameters, you can use the shorthand:

output = prediction.apply_joined({net: layer_params}, inputs, jit=True)

What about stax?

JAXnet is independent of stax. The main motivation over stax is to simplify nesting modules. Find details and porting instructions here.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxnet-0.1.2.tar.gz (8.5 kB view hashes)

Uploaded Source

Built Distribution

jaxnet-0.1.2-py3-none-any.whl (12.4 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page