A neural network library using jax
Project description
Jynx
A straight forward neural network library written in jax. No hidden mechanisms, no black magic. Requires only jax and optax.
This library provides 3 components: (1) standard neural network layers, (2) a
fit
function to train models, and (3) a collection of basic callbacks for
checkpointing and logging. The fit
function doesn't know anything about the
way layers are implemented, and can be used with other frameworks if desired. It
only relies on optax.
TLDR;
import jax
import jax.numpy as jnp
import jax.random as rnd
import optax
import jynx
import jynx.callbacks as cb
import jynx.layers as nn
from jax import Array
def data_iter(key: Array):
from itertools import repeat
x = jnp.linspace(-1, 1, 100).reshape(-1, 1)
y = jnp.cos(x) + 0.05 * rnd.normal(key, x.shape)
return repeat((x, y))
def loss_fn(
net: nn.Module[[Array], Array],
batch: tuple[Array, Array], # any type you want
key: Array
) -> Array:
x, y = batch
y_pred = net(x, key=key)
return optax.l2_loss(y_pred, y).mean()
def make_model(key: Array) -> nn.Module[[Array], Array]:
k1, k2, k3 = rnd.split(key, 3)
net = nn.sequential(
nn.linear(1, 32, key=k1),
jax.nn.relu,
nn.linear(32, 32, key=k2),
jax.nn.relu,
nn.linear(32, 1, key=k3),
)
# or use:
# net = nn.mlp(
# [1, 32, 32, 1],
# activation=jax.nn.silu,
# key=key,
# )
return net
k1, k2, k3 = rnd.split(rnd.PRNGKey(0), 3)
res = jynx.fit(
make_model(k1),
loss_fn=loss_fn,
data_iter=data_iter(k2),
optimizer=optax.adam(1e-3),
key=k3,
callbacks=[
cb.ConsoleLogger(metrics=["loss"]),
cb.TensorBoardLogger("tboard"), # requires tensorboardx
cb.MlflowLogger(), # requires mlflow
cb.CheckPoint("latest.pk"), # requires cloudpickle
cb.EarlyStopping(
monitor="loss",
steps_without_improvement=500,
),
],
)
print("final loss", res.logs["loss"])
net = res.params
Layers
Currently implemented modules:
Sequential
Parallel
Recurrent
: likeSequential
but passes state, used for RNNsDenselyConnected
: DenseNetLinear
Conv
andConvTranspose
Embedding
Fn
: activation functionStarFn
: equivalent offn(*x)
Static
: wraps an object to be ignored by jaxReshape
Dropout
- Pooling layers:
AvgPoolng
,MaxPooling
,MinPooling
Norm
: layer norm, batch norm etc.SkipConnection
- RNN layers:
RNNCell
,GRUCell
,LSTMCell
Attention
,TransformerEncoderBlock
,TransformerDecoderBlock
Constructors for full networks:
mlp
transformer_encoder
transformer_decoder
rnn
,lstm
, andgru
- More to come...
How layers work
Layers are simple pytree containers with a __call__
method. To define new
modules easily, we provide a base PyTree
class. Using this is not at all a
requirement, it just makes most definitions simpler. Layers that
don't require any static data can just as easily be defined as NamedTuple
s.
class MyLinear(NamedTuple):
weight: Array
bias: Array
def __call__(self, x, *, key=None):
return x @ self.weight + self.bias
We provide initialization with factory functions instead of in the contructor. This makes flattening and unflattening pytrees much simpler. For example:
def my_linear(size_in, size_out, *, key):
w_init = jax.nn.initializers.kaiming_normal()
return MyLinear(
w_init(key, (size_in, size_out)),
jnp.zeros((size_out,)),
)
So layers are just 'dumb' containers.
The PyTree
base class converts the inheriting class to a dataclass and
registers the type as a jax
pytree
class MyDense(PyTree):
weight: Array
bias: Array
activation: Callable[[Array], Array] = static(default=jax.nn.relu)
def __call__(self, x, *, key=None):
return self.activation(x @ self.weight + self.bias)
The fit
function
TODO
Callbacks
TODO
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file jax_jynx-0.2.2.tar.gz
.
File metadata
- Download URL: jax_jynx-0.2.2.tar.gz
- Upload date:
- Size: 31.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.8.3 CPython/3.10.12 Linux/6.5.0-26-generic
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 25a32b311f4fda637bcda77af3dbe5804db6ae6e161cc04735fa1531012e3445 |
|
MD5 | 2977170975882a256f71a1af328ad1bc |
|
BLAKE2b-256 | 12100c6e817e31b1c033306e3fbca3370ad8ed617a08f5280b305e1f46b4325c |
File details
Details for the file jax_jynx-0.2.2-py3-none-any.whl
.
File metadata
- Download URL: jax_jynx-0.2.2-py3-none-any.whl
- Upload date:
- Size: 36.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.8.3 CPython/3.10.12 Linux/6.5.0-26-generic
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | f26b96d8ee42d35d4715da720107af2d856c01aca35b56452a4b0bfe3c9b54d9 |
|
MD5 | 384ebf8c38a19bfcbb1e011c498cacf0 |
|
BLAKE2b-256 | 6a852deedf7860f15de9f4ec3fe9bbf26831c0f6816b5134c375029d8f62f86d |