Skip to main content

A neural network library using jax

Project description

Jynx

A straight forward neural network library written in jax. No hidden mechanisms, no black magic. Requires only jax and optax.

This library provides 3 components: (1) standard neural network layers, (2) a fit function to train models, and (3) a collection of basic callbacks for checkpointing and logging. The fit function doesn't know anything about the way layers are implemented, and can be used with other frameworks if desired. It only relies on optax.

TLDR;

import jax
import jax.numpy as jnp
import jax.random as rnd
import optax

import jynx
import jynx.callbacks as cb
import jynx.layers as nn

from jax import Array


def data_iter(key: Array):
    from itertools import repeat
    x = jnp.linspace(-1, 1, 100).reshape(-1, 1)
    y = jnp.cos(x) + 0.05 * rnd.normal(key, x.shape)
    return repeat((x, y))


def loss_fn(
    net: nn.Module[[Array], Array],
    batch: tuple[Array, Array],  # any type you want
    key: Array
) -> Array:
    x, y = batch
    y_pred = net(x, key=key)
    return optax.l2_loss(y_pred, y).mean()


def make_model(key: Array) -> nn.Module[[Array], Array]:
    k1, k2, k3 = rnd.split(key, 3)
    net = nn.sequential(
        nn.linear(1, 32, key=k1),
        jax.nn.relu,
        nn.linear(32, 32, key=k2),
        jax.nn.relu,
        nn.linear(32, 1, key=k3),
    )
    # or use:
    # net = nn.mlp(
    #     [1, 32, 32, 1],
    #     activation=jax.nn.silu,
    #     key=key,
    # )
    return net


k1, k2, k3 = rnd.split(rnd.PRNGKey(0), 3)
res = jynx.fit(
    make_model(k1),
    loss_fn=loss_fn,
    data_iter=data_iter(k2),
    optimizer=optax.adam(1e-3),
    key=k3,
    callbacks=[
        cb.ConsoleLogger(metrics=["loss"]),
        cb.TensorBoardLogger("tboard"),  # requires tensorboardx
        cb.MlflowLogger(),  # requires mlflow
        cb.CheckPoint("latest.pk"),  # requires cloudpickle
        cb.EarlyStopping(
            monitor="loss",
            steps_without_improvement=500,
        ),
    ],
)
print("final loss", res.logs["loss"])
net = res.params

Layers

Currently implemented modules:

  • Sequential
  • Parallel
  • Recurrent: like Sequential but passes state, used for RNNs
  • DenselyConnected: DenseNet
  • Linear
  • Conv and ConvTranspose
  • Embedding
  • Fn: activation function
  • StarFn: equivalent of fn(*x)
  • Static: wraps an object to be ignored by jax
  • Reshape
  • Dropout
  • Pooling layers: AvgPoolng, MaxPooling, MinPooling
  • Norm: layer norm, batch norm etc.
  • SkipConnection
  • RNN layers: RNNCell, GRUCell, LSTMCell
  • Attention, TransformerEncoderBlock, TransformerDecoderBlock

Constructors for full networks:

  • mlp
  • transformer_encoder
  • transformer_decoder
  • rnn, lstm, and gru
  • More to come...

How layers work

Layers are simple pytree containers with a __call__ method. To define new modules easily, we provide a base PyTree class. Using this is not at all a requirement, it just makes most definitions simpler. Layers that don't require any static data can just as easily be defined as NamedTuples.

class MyLinear(NamedTuple):
    weight: Array
    bias: Array

    def __call__(self, x, *, key=None):
        return x @ self.weight + self.bias

We provide initialization with factory functions instead of in the contructor. This makes flattening and unflattening pytrees much simpler. For example:

def my_linear(size_in, size_out, *, key):
    w_init = jax.nn.initializers.kaiming_normal()
    return MyLinear(
        w_init(key, (size_in, size_out)),
        jnp.zeros((size_out,)),
    )

So layers are just 'dumb' containers. The PyTree base class converts the inheriting class to a dataclass and registers the type as a jax pytree

class MyDense(PyTree):
    weight: Array
    bias: Array
    activation: Callable[[Array], Array] = static(default=jax.nn.relu)

    def __call__(self, x, *, key=None):
        return self.activation(x @ self.weight + self.bias)

The fit function

TODO

Callbacks

TODO

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_jynx-0.3.1.tar.gz (32.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_jynx-0.3.1-py3-none-any.whl (36.9 kB view details)

Uploaded Python 3

File details

Details for the file jax_jynx-0.3.1.tar.gz.

File metadata

  • Download URL: jax_jynx-0.3.1.tar.gz
  • Upload date:
  • Size: 32.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.3 CPython/3.10.12 Linux/6.5.0-26-generic

File hashes

Hashes for jax_jynx-0.3.1.tar.gz
Algorithm Hash digest
SHA256 c5873352670b3408aea816f90a9801b9e0fc5bcfd6542052e405038c13f64605
MD5 77c35546ce4c79027a730653dbe8f0a5
BLAKE2b-256 c970850a3299d81cbfc65179ef3a39df82f961ee7a628fa8967f6b46820c449d

See more details on using hashes here.

File details

Details for the file jax_jynx-0.3.1-py3-none-any.whl.

File metadata

  • Download URL: jax_jynx-0.3.1-py3-none-any.whl
  • Upload date:
  • Size: 36.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.3 CPython/3.10.12 Linux/6.5.0-26-generic

File hashes

Hashes for jax_jynx-0.3.1-py3-none-any.whl
Algorithm Hash digest
SHA256 dd5b07efba70d41f52531973b5bd962ad2b1240e0db90a4387b013d13f1c10a4
MD5 41d4b49f99d4c5c0afd9d04f07518c70
BLAKE2b-256 ec0acc8d1737b01509566a2362141e78dc92c30f65dbed50b459626a33933d63

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page