Skip to main content

A neural network library using jax

Project description

Jynx

A straight forward neural network library written in jax. No hidden mechanisms, no black magic. Requires only jax and optax.

This library provides 3 components: (1) standard neural network layers, (2) a fit function to train models, and (3) a collection of basic callbacks for checkpointing and logging. The fit function doesn't know anything about the way layers are implemented, and can be used with other frameworks if desired. It only relies on optax.

TLDR;

import jax
import jax.numpy as jnp
import jax.random as rnd
import optax

import jynx
import jynx.callbacks as cb
import jynx.layers as nn

from jax import Array


def data_iter(key: Array):
    from itertools import repeat
    x = jnp.linspace(-1, 1, 100).reshape(-1, 1)
    y = jnp.cos(x) + 0.05 * rnd.normal(key, x.shape)
    return repeat((x, y))


def loss_fn(
    net: nn.Module[[Array], Array],
    batch: tuple[Array, Array],  # any type you want
    key: Array
) -> Array:
    x, y = batch
    y_pred = net(x, key=key)
    return optax.l2_loss(y_pred, y).mean()


def make_model(key: Array) -> nn.Module[[Array], Array]:
    k1, k2, k3 = rnd.split(key, 3)
    net = nn.sequential(
        nn.linear(1, 32, key=k1),
        jax.nn.relu,
        nn.linear(32, 32, key=k2),
        jax.nn.relu,
        nn.linear(32, 1, key=k3),
    )
    # or use:
    # net = nn.mlp(
    #     [1, 32, 32, 1],
    #     activation=jax.nn.silu,
    #     key=key,
    # )
    return net


k1, k2, k3 = rnd.split(rnd.PRNGKey(0), 3)
res = jynx.fit(
    make_model(k1),
    loss_fn=loss_fn,
    data_iter=data_iter(k2),
    optimizer=optax.adam(1e-3),
    key=k3,
    callbacks=[
        cb.ConsoleLogger(metrics=["loss"]),
        cb.TensorBoardLogger("tboard"),  # requires tensorboardx
        cb.MlflowLogger(),  # requires mlflow
        cb.CheckPoint("latest.pk"),  # requires cloudpickle
        cb.EarlyStopping(
            monitor="loss",
            steps_without_improvement=500,
        ),
    ],
)
print("final loss", res.logs["loss"])
net = res.params

Layers

Currently implemented modules:

  • Sequential
  • Parallel
  • Recurrent: like Sequential but passes state, used for RNNs
  • DenselyConnected: DenseNet
  • Linear
  • Conv and ConvTranspose
  • Embedding
  • Fn: activation function
  • StarFn: equivalent of fn(*x)
  • Static: wraps an object to be ignored by jax
  • Reshape
  • Dropout
  • Pooling layers: AvgPoolng, MaxPooling, MinPooling
  • Norm: layer norm, batch norm etc.
  • SkipConnection
  • RNN layers: RNNCell, GRUCell, LSTMCell
  • Attention, TransformerEncoderBlock, TransformerDecoderBlock

Constructors for full networks:

  • mlp
  • transformer_encoder
  • transformer_decoder
  • rnn, lstm, and gru
  • More to come...

How layers work

Layers are simple pytree containers with a __call__ method. To define new modules easily, we provide a base PyTree class. Using this is not at all a requirement, it just makes most definitions simpler. Layers that don't require any static data can just as easily be defined as NamedTuples.

class MyLinear(NamedTuple):
    weight: Array
    bias: Array

    def __call__(self, x, *, key=None):
        return x @ self.weight + self.bias

We provide initialization with factory functions instead of in the contructor. This makes flattening and unflattening pytrees much simpler. For example:

def my_linear(size_in, size_out, *, key):
    w_init = jax.nn.initializers.kaiming_normal()
    return MyLinear(
        w_init(key, (size_in, size_out)),
        jnp.zeros((size_out,)),
    )

So layers are just 'dumb' containers. The PyTree base class converts the inheriting class to a dataclass and registers the type as a jax pytree

class MyDense(PyTree):
    weight: Array
    bias: Array
    activation: Callable[[Array], Array] = static(default=jax.nn.relu)

    def __call__(self, x, *, key=None):
        return self.activation(x @ self.weight + self.bias)

The fit function

TODO

Callbacks

TODO

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_jynx-0.3.0.tar.gz (31.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_jynx-0.3.0-py3-none-any.whl (36.8 kB view details)

Uploaded Python 3

File details

Details for the file jax_jynx-0.3.0.tar.gz.

File metadata

  • Download URL: jax_jynx-0.3.0.tar.gz
  • Upload date:
  • Size: 31.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.3 CPython/3.10.12 Linux/6.5.0-26-generic

File hashes

Hashes for jax_jynx-0.3.0.tar.gz
Algorithm Hash digest
SHA256 d8abda95c005e49af5596dd3c24392884ad6ba479e8a6394bff24493f765c28e
MD5 bffa36f3fe8e4e1801979ec851f378e6
BLAKE2b-256 cfe48785b42424dc937e96869a8be4b25cdb4c06702e1a15699e113faee31713

See more details on using hashes here.

File details

Details for the file jax_jynx-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: jax_jynx-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 36.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.3 CPython/3.10.12 Linux/6.5.0-26-generic

File hashes

Hashes for jax_jynx-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 0842dff747ab04ea3eacfcd1fe5d88dd51c7ceee1c3e39a2a0ce9c6ff93458f4
MD5 5b1394f656414db3b4addbea8725ff51
BLAKE2b-256 839af2d038d0f3687ebafa54fabe7c6edacec0b733623807768273d297488174

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page