Skip to main content

A simple neural network library for JAX.

Project description

Ion
Ion

A simple neural network library for JAX

Python PyPI Ruff CI codecov


Ion is a simple neural network library for JAX. The core is three concepts (Module, Param, Optimizer) in <1000 lines of code. Models are pytrees that always work directly with jax.grad, jax.jit, and jax.vmap. Ion also ships with standard neural network layers (linear, convolution, attention, normalization, recurrent, and more) built on the core.

pip install ion-nn

Core Concepts

Param

Param wraps an array and marks it as a model parameter, either trainable or frozen.

w = nn.Param(jax.random.normal(shape=(3, 16), key=key))      # trainable
b = nn.Param(jax.numpy.zeros(shape=(16,)), trainable=False)  # frozen

Params work directly in arithmetic (x @ w works without unwrapping). Frozen params produce zero gradients under jax.grad.

Module

Inherit from nn.Module to define a layer. Subclasses are registered as JAX pytrees and become immutable after __init__.

import ion.nn as nn

class Linear(nn.Module):
    w: nn.Param
    b: nn.Param

    def __init__(self, in_dim, out_dim, *, key):
        self.w = nn.Param(jax.random.normal(shape=(in_dim, out_dim), key=key))
        self.b = nn.Param(jax.numpy.zeros(shape=(out_dim,)))

    def __call__(self, x):
        return x @ self.w + self.b

Non-array fields (ints, strings, callables) are treated as static config. Store num_heads, use_bias, or activation functions directly on the module.

Optimizer

Wraps an optax optimizer with Param-aware updates. Frozen params are automatically partitioned out, so no manual filtering is needed.

optimizer = ion.Optimizer(optax.adam(3e-4), model)
model, optimizer = optimizer.update(model, grads)

That's the entire core. See Internals for design details and sharp edges.

Example

Putting it all together with a model built from Ion's standard layers:

import jax, optax, typing

import ion
import ion.nn as nn


class MLP(nn.Module):
    layer_1: nn.Linear
    layer_2: nn.Linear
    activation: typing.Callable

    def __init__(self, activation=jax.nn.relu, *, key):
        keys = jax.random.split(key, 2)
        self.layer_1 = nn.Linear(784, 128, key=keys[0])
        self.layer_2 = nn.Linear(128, 10, key=keys[1])
        self.activation = activation

    def __call__(self, x):
        return self.layer_2(self.activation(self.layer_1(x)))


def loss_fn(model, x, y):
    logits = model(x)
    return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()


@jax.jit
def train_step(model, optimizer, x, y):
    grads = jax.grad(loss_fn)(model, x, y)
    model, optimizer = optimizer.update(model, grads)
    return model, optimizer


model = MLP(key=jax.random.key(0))

optimizer = ion.Optimizer(optax.adam(3e-4), model)

for x, y in data:
    model, optimizer = train_step(model, optimizer, x, y)

Utilities

nn.Module provides convenience methods and properties for common operations. Methods return new instances, as modules are immutable.

model.replace(activation=jax.nn.tanh)    # create a modified copy
model.freeze()                           # freeze all params
model.unfreeze()                         # unfreeze all params
model.replace(base=model.base.freeze())  # freeze a sub-module
model.astype(jax.numpy.bfloat16)         # cast params to a different dtype
model.params                             # pytree of Param leaves
model.num_params                         # total parameter count

Layers

Ion ships with standard neural network layers. Each is a Module with trainable Param leaves.

Category Layers
Linear Linear, Identity, LoRALinear
Convolution Conv, ConvTranspose
Attention SelfAttention, CrossAttention
Normalization LayerNorm, RMSNorm, GroupNorm
Recurrent RNNCell, LSTMCell, GRUCell, RNN, LSTM, GRU
SSM LRUCell, S4DCell, S5Cell, LRU, S4D, S5
Pooling MaxPool, AvgPool
Embedding Embedding, LearnedPositionalEmbedding
Positional sinusoidal, rope, apply_rope, alibi
Regularization Dropout
Blocks Sequential, MLP, TransformerBlock, CrossTransformerBlock
GNN GCNConv, GATConv, GATv2Conv

See Layer Conventions for data format, weight init, spatial layer usage, and SSM conventions. See GNN Conventions for graph layer usage.

Pretty Printing

In notebooks, Treescope provides interactive, color-coded visualization of Ion Modules and Params. Treescope is enabled by default on import, and can be configured:

ion.enable_treescope()                 # Ion Modules and Params only (default)
ion.enable_treescope(everything=True)  # all types
ion.disable_treescope()                # turn off

Modules also have built-in text formatting for terminal output.

>>> model = MLP(key=jax.random.key(0))
>>> model

MLP(
  layer_1=Linear(
    w=Param(f32[784, 128], trainable=True),
    b=Param(f32[128], trainable=True),
  ),
  layer_2=Linear(
    w=Param(f32[128, 10], trainable=True),
    b=Param(f32[10], trainable=True),
  ),
  activation=relu,
)

Serialization

Save and load any pytree as .npz files. Works with models, optimizers, or any other pytree. load requires a reference pytree as a template to reconstruct the tree structure.

ion.save("model.npz", model)
model = ion.load("model.npz", model)

ion.save("snapshot.npz", (model, optimizer))
model, optimizer = ion.load("snapshot.npz", (model, optimizer))

Examples

FAQ

Why do I need a neural network library in JAX?

Building simple neural network models from scratch in JAX is straightforward. As they get more complex however, two things become painful: managing parameters (initializing them, tracking which are trainable, freezing some for fine-tuning) and composing modules (reusing layers, wiring them through JAX transforms, not reimplementing things like convolution padding from scratch for every project). A neural network library takes care of this so you can focus on model building and training.


Who is Ion for?

Ion is for JAX users who want a neural network library that is small, easy to learn, and easy to understand.

The core introduces three concepts, Module, Param, and Optimizer, and from there JAX does everything else. There are no custom transforms, no special contexts, no framework-specific calling conventions. If you already know JAX, you can learn Ion in an hour.

Because the core is <1000 lines with not much happening behind the scenes, it's straightforward to reason about what JAX is doing. This matters most in complex training setups like multi-stage fine-tuning or custom gradient flows.


How does Ion compare to Equinox and Flax?

Equinox is an excellent pytree-based library for scientific computing where neural networks are one of several possible use-cases. It provides filtered transforms, partition/combine utilities, and general pytree tools. Equinox treats all JAX arrays equally, so users must apply lax.stop_gradient or manually filter trainable parameters when computing gradients and applying optimizer updates. In Ion, Param tracks trainability so jax.grad returns zero gradients for frozen params automatically, and Optimizer handles the partition internally. Relative to Equinox, Ion trades off flexibility for simplicity and ease of use.

Flax NNX takes a different approach. NNX models are mutable graph objects with reference semantics, and custom transforms (nnx.jit, nnx.grad) bridge mutability with JAX's functional model. Ion leans into JAX's philosophy of functional programming and immutability, building on native JAX transforms rather than replacing them. The trade-off is transparent, simple machinery over powerful but opaque machinery.

Both Equinox and Flax are well battle-tested and have existing model hubs. If you need a broader pytree toolkit for scientific computing, Equinox is excellent. If you want PyTorch-like mutability, Flax NNX is a great choice.

License

Released under the Apache License 2.0.

Citation

To cite this repository:

@software{ion,
  title = {Ion: Simple Neural Networks in JAX},
  author = {Alex Goddard},
  url = {https://github.com/auxeno/ion},
  year = {2026}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ion_nn-0.5.2.tar.gz (1.6 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

ion_nn-0.5.2-py3-none-any.whl (44.8 kB view details)

Uploaded Python 3

File details

Details for the file ion_nn-0.5.2.tar.gz.

File metadata

  • Download URL: ion_nn-0.5.2.tar.gz
  • Upload date:
  • Size: 1.6 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for ion_nn-0.5.2.tar.gz
Algorithm Hash digest
SHA256 7b0d1fe831ed3e73d5b43196d3de41eca915394b600d5e3c8623e42c12b2d958
MD5 f8987d81fb3e12e246e86491173557bf
BLAKE2b-256 fcde5acccb259fe96523d1474e854bbfa3b7544bec6c0f1e663c68b748d893fd

See more details on using hashes here.

Provenance

The following attestation bundles were made for ion_nn-0.5.2.tar.gz:

Publisher: publish.yml on Auxeno/ion

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file ion_nn-0.5.2-py3-none-any.whl.

File metadata

  • Download URL: ion_nn-0.5.2-py3-none-any.whl
  • Upload date:
  • Size: 44.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for ion_nn-0.5.2-py3-none-any.whl
Algorithm Hash digest
SHA256 febc86b8f19372735823647b46762a5fb8a0306a8610b6f72dba1c02d42cfe8e
MD5 3d009d88dbb7df1ec4e27a47204614b8
BLAKE2b-256 49cd9c12b627c7e4bc7b503410bf9822d3053a8d96b4f94a1f0d7f52056d1723

See more details on using hashes here.

Provenance

The following attestation bundles were made for ion_nn-0.5.2-py3-none-any.whl:

Publisher: publish.yml on Auxeno/ion

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page