Skip to main content

A neural network library using jax

Project description

Jynx

A straight forward neural network library written in jax. No hidden mechanisms, no black magic. Requires only jax and optax.

This library provides 3 components: (1) standard neural network layers, (2) a fit function to train models, and (3) a collection of basic callbacks for checkpointing and logging. The fit function doesn't know anything about the way layers are implemented, and can be used with other frameworks if desired. It only relies on optax.

TLDR;

import jax
import jax.numpy as jnp
import jax.random as rnd
import optax

import jynx
import jynx.callbacks as cb
import jynx.layers as nn

from jax import Array


def data_iter(key: Array):
    from itertools import repeat
    x = jnp.linspace(-1, 1, 100).reshape(-1, 1)
    y = jnp.cos(x) + 0.05 * rnd.normal(key, x.shape)
    return repeat((x, y))


def loss_fn(
    net: nn.Module[[Array], Array],
    batch: tuple[Array, Array],  # any type you want
    key: Array
) -> Array:
    x, y = batch
    y_pred = net(x, key=key)
    return optax.l2_loss(y_pred, y).mean()


def make_model(key: Array) -> nn.Module[[Array], Array]:
    k1, k2, k3 = rnd.split(key, 3)
    net = nn.sequential(
        nn.linear(1, 32, key=k1),
        jax.nn.relu,
        nn.linear(32, 32, key=k2),
        jax.nn.relu,
        nn.linear(32, 1, key=k3),
    )
    # or use:
    # net = nn.mlp(
    #     [1, 32, 32, 1],
    #     activation=jax.nn.silu,
    #     key=key,
    # )
    return net


k1, k2, k3 = rnd.split(rnd.PRNGKey(0), 3)
res = jynx.fit(
    make_model(k1),
    loss_fn=loss_fn,
    data_iter=data_iter(k2),
    optimizer=optax.adam(1e-3),
    key=k3,
    callbacks=[
        cb.ConsoleLogger(metrics=["loss"]),
        cb.TensorBoardLogger("tboard"),  # requires tensorboardx
        cb.MlflowLogger(),  # requires mlflow
        cb.CheckPoint("latest.pk"),  # requires cloudpickle
        cb.EarlyStopping(
            monitor="loss",
            steps_without_improvement=500,
        ),
    ],
)
print("final loss", res.logs["loss"])
net = res.params

Layers

Currently implemented modules:

  • Sequential
  • Parallel
  • Recurrent: like Sequential but passes state, used for RNNs
  • DenselyConnected: DenseNet
  • Linear
  • Conv and ConvTranspose
  • Embedding
  • Fn: activation function
  • StarFn: equivalent of fn(*x)
  • Static: wraps an object to be ignored by jax
  • Reshape
  • Dropout
  • Pooling layers: AvgPoolng, MaxPooling, MinPooling
  • Norm: layer norm, batch norm etc.
  • SkipConnection
  • RNN layers: RNNCell, GRUCell, LSTMCell
  • Attention, TransformerEncoderBlock, TransformerDecoderBlock

Constructors for full networks:

  • mlp
  • transformer_encoder
  • transformer_decoder
  • rnn, lstm, and gru
  • More to come...

How layers work

Layers are simple pytree containers with a __call__ method. To define new modules easily, we provide a base PyTree class. Using this is not at all a requirement, it just makes most definitions simpler. Layers that don't require any static data can just as easily be defined as NamedTuples.

class MyLinear(NamedTuple):
    weight: Array
    bias: Array

    def __call__(self, x, *, key=None):
        return x @ self.weight + self.bias

We provide initialization with factory functions instead of in the contructor. This makes flattening and unflattening pytrees much simpler. For example:

def my_linear(size_in, size_out, *, key):
    w_init = jax.nn.initializers.kaiming_normal()
    return MyLinear(
        w_init(key, (size_in, size_out)),
        jnp.zeros((size_out,)),
    )

So layers are just 'dumb' containers. The PyTree base class converts the inheriting class to a dataclass and registers the type as a jax pytree

class MyDense(PyTree):
    weight: Array
    bias: Array
    activation: Callable[[Array], Array] = static(default=jax.nn.relu)

    def __call__(self, x, *, key=None):
        return self.activation(x @ self.weight + self.bias)

The fit function

TODO

Callbacks

TODO

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_jynx-0.2.2.tar.gz (31.7 kB view details)

Uploaded Source

Built Distribution

jax_jynx-0.2.2-py3-none-any.whl (36.5 kB view details)

Uploaded Python 3

File details

Details for the file jax_jynx-0.2.2.tar.gz.

File metadata

  • Download URL: jax_jynx-0.2.2.tar.gz
  • Upload date:
  • Size: 31.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.3 CPython/3.10.12 Linux/6.5.0-26-generic

File hashes

Hashes for jax_jynx-0.2.2.tar.gz
Algorithm Hash digest
SHA256 25a32b311f4fda637bcda77af3dbe5804db6ae6e161cc04735fa1531012e3445
MD5 2977170975882a256f71a1af328ad1bc
BLAKE2b-256 12100c6e817e31b1c033306e3fbca3370ad8ed617a08f5280b305e1f46b4325c

See more details on using hashes here.

File details

Details for the file jax_jynx-0.2.2-py3-none-any.whl.

File metadata

  • Download URL: jax_jynx-0.2.2-py3-none-any.whl
  • Upload date:
  • Size: 36.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.3 CPython/3.10.12 Linux/6.5.0-26-generic

File hashes

Hashes for jax_jynx-0.2.2-py3-none-any.whl
Algorithm Hash digest
SHA256 f26b96d8ee42d35d4715da720107af2d856c01aca35b56452a4b0bfe3c9b54d9
MD5 384ebf8c38a19bfcbb1e011c498cacf0
BLAKE2b-256 6a852deedf7860f15de9f4ec3fe9bbf26831c0f6816b5134c375029d8f62f86d

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page