Skip to main content

Neural Nets for JAX

Project description

JAXnet Build Status PyPI

JAXnet is a neural net library for JAX. Different from popular neural net libraries, it is completely functional:

  • No mutable weights in modules
  • No global compute graph
  • No global random key

This is an early version. Expect bugs, sharp edges and breaking changes! Install with

pip install jaxnet

Overview

Defining networks looks similar to the TensorFlow2 / Keras functional API:

from jax import numpy as np, random, jit
from jaxnet import *

net = Sequential([Dense(10), relu, Dense(4)])

Sequential, Dense, Conv and RNN (with GRUCell) are already supported.

To initialize parameter values for a network, call init_params on any module (with example inputs and a random key):

batch = np.zeros((3, 2))
params = net.init_params(random.PRNGKey(0), batch)

It initializes and returns all parameters, accessible via attributes:

print(params.layers[0].bias) # [0.00212132 0.01169001 0.00331698 0.00460713]

Invoke the network with:

output = net(params, batch)

For acceleration use jit:

output = jit(net)(params, batch)

Full examples are given in these interactive demos: Mnist VAE and OCR with RNNs.

Defining modules

Modules are functions decorated with @parameterized, with parameters defined through default values:

def Dense(out_dim, kernel_init=glorot(), bias_init=randn()):
    @parameterized
    def dense(inputs,
              kernel=Param(lambda inputs: (inputs.shape[-1], out_dim), kernel_init),
              bias=Param(lambda _: (out_dim,), bias_init)):
        return np.dot(inputs, kernel) + bias

    return dense

Param specifies parameter shape and initialization function. @parameterized transforms this function to allow usage as above.

Nesting modules

Modules can be used in other modules through default arguments:

@parameterized
def net(inputs, layer1=Dense(10), layer2=Dense(20))
    inputs = layer1(inputs)
    return layer2(inputs)

Use many modules at once with collections:

def Sequential(layers):
    @parameterized
    def sequential(inputs, layers=layers):
        for module in layers:
            inputs = module(inputs)
        return inputs

    return sequential

Nested tuples/list/dicts of modules work. The same is true for Params.

Using parameter-free functions is seamless:

def relu(input):
    return np.maximum(input, 0)

layer = Sequential([Dense(10), relu])

Parameter sharing

Parameters can be shared by using module or parameter objects multiple times (not yet implemented):

shared_net=Sequential([layer, layer])

This is equivalent to (already implemented):

@parameterized
def shared_net(input, layer=layer):
    return layer(layer(input))

Parameter reuse

You can reuse parameters of submodules:

inputs = np.zeros((1, 2))

layer = Dense(5)
net1 = Sequential([layer, Dense(2)])
net2 = Sequential([layer, Dense(3)])

layer_params = layer.init_params(random.PRNGKey(0), inputs)
net1_params = net1.init_params(random.PRNGKey(0), inputs, reuse={layer: layer_params})
net2_params = net2.init_params(random.PRNGKey(1), inputs, reuse={layer: layer_params})

# Now net1_params.layers[0] equals net2_params.layers[0] equals layer_params

If all parameters are reused, you can use join_params instead of init_params:

inputs = np.zeros((1, 2))

net = Dense(5)
prediction = Sequential([net, softmax])

net_params = net.init_params(random.PRNGKey(0), inputs)
prediction_params = prediction.join_params({net: layer_params})

# net_params.layers[0] is now equal to net_params

output = jit(prediction)(prediction_params, inputs)

If you just want to call the network with these joined parameters, you can use the shorthand:

output = prediction.apply_joined({net: layer_params}, inputs, jit=True)

What about stax?

JAXnet is independent of stax. The main motivation over stax is to simplify nesting modules:

  • Automating init_params: delegation to submodules, output_shape inference, rng passing
  • Seamless use of parameter-free functions as modules
  • Allowing streamlined module/parameter-sharing

Alternative design ideas are discussed here.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxnet-0.1.1.tar.gz (6.9 kB view hashes)

Uploaded Source

Built Distribution

jaxnet-0.1.1-py3-none-any.whl (11.4 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page