A simple neural network library for JAX.
Project description
Ion is a simple neural network library for JAX. The core is three concepts (Module, Param, Optimizer) in <1000 lines of code. Models are pytrees that always work directly with jax.grad, jax.jit, and jax.vmap. Ion also ships with standard neural network layers (linear, convolution, attention, normalization, recurrent, and more) built on the core.
pip install ion-nn
Core Concepts
Param
Param wraps an array and marks it as a model parameter, either trainable or frozen.
w = nn.Param(jax.random.normal(shape=(3, 16), key=key)) # trainable
b = nn.Param(jax.numpy.zeros(shape=(16,)), trainable=False) # frozen
Params work directly in arithmetic (x @ w works without unwrapping). Frozen params produce zero gradients under jax.grad.
Module
Inherit from nn.Module to define a layer. Subclasses are registered as JAX pytrees and become immutable after __init__.
import ion.nn as nn
class Linear(nn.Module):
w: nn.Param
b: nn.Param
def __init__(self, in_dim, out_dim, *, key):
self.w = nn.Param(jax.random.normal(shape=(in_dim, out_dim), key=key))
self.b = nn.Param(jax.numpy.zeros(shape=(out_dim,)))
def __call__(self, x):
return x @ self.w + self.b
Non-array fields (ints, strings, callables) are treated as static config. Store num_heads, use_bias, or activation functions directly on the module.
Optimizer
Wraps an optax optimizer with Param-aware updates. Frozen params are automatically partitioned out, so no manual filtering is needed.
optimizer = ion.Optimizer(optax.adam(3e-4), model)
model, optimizer = optimizer.update(model, grads)
That's the entire core. See Internals for design details and sharp edges.
Example
Putting it all together with a model built from Ion's standard layers:
import jax, optax, typing
import ion
import ion.nn as nn
class MLP(nn.Module):
layer_1: nn.Linear
layer_2: nn.Linear
activation: typing.Callable
def __init__(self, activation=jax.nn.relu, *, key):
keys = jax.random.split(key, 2)
self.layer_1 = nn.Linear(784, 128, key=keys[0])
self.layer_2 = nn.Linear(128, 10, key=keys[1])
self.activation = activation
def __call__(self, x):
return self.layer_2(self.activation(self.layer_1(x)))
def loss_fn(model, x, y):
logits = model(x)
return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()
@jax.jit
def train_step(model, optimizer, x, y):
grads = jax.grad(loss_fn)(model, x, y)
model, optimizer = optimizer.update(model, grads)
return model, optimizer
model = MLP(key=jax.random.key(0))
optimizer = ion.Optimizer(optax.adam(3e-4), model)
for x, y in data:
model, optimizer = train_step(model, optimizer, x, y)
Utilities
nn.Module provides convenience methods and properties for common operations. Methods return new instances, as modules are immutable.
model.replace(activation=jax.nn.tanh) # create a modified copy
model.freeze() # freeze all params
model.unfreeze() # unfreeze all params
model.replace(base=model.base.freeze()) # freeze a sub-module
model.astype(jax.numpy.bfloat16) # cast params to a different dtype
model.params # pytree of Param leaves
model.num_params # total parameter count
Layers
Ion ships with standard neural network layers. Each is a Module with trainable Param leaves.
| Category | Layers |
|---|---|
| Linear | Linear, Identity, LoRALinear |
| Convolution | Conv, ConvTranspose |
| Attention | SelfAttention, CrossAttention |
| Normalization | LayerNorm, RMSNorm, GroupNorm |
| Recurrent | RNNCell, LSTMCell, GRUCell, RNN, LSTM, GRU |
| SSM | LRUCell, S4DCell, S5Cell, LRU, S4D, S5 |
| Pooling | MaxPool, AvgPool |
| Embedding | Embedding, LearnedPositionalEmbedding |
| Positional | sinusoidal, rope, apply_rope, alibi |
| Regularization | Dropout |
| Blocks | Sequential, MLP, TransformerBlock, CrossTransformerBlock |
| GNN | GCNConv, GATConv, GATv2Conv |
See Layer Conventions for data format, weight init, spatial layer usage, and SSM conventions. See GNN Conventions for graph layer usage.
Pretty Printing
In notebooks, Treescope provides interactive, color-coded visualization of Ion Modules and Params. Treescope is enabled by default on import, and can be configured:
ion.enable_treescope() # Ion Modules and Params only (default)
ion.enable_treescope(everything=True) # all types
ion.disable_treescope() # turn off
Modules also have built-in text formatting for terminal output.
>>> model = MLP(key=jax.random.key(0))
>>> model
MLP(
layer_1=Linear(
w=Param(f32[784, 128], trainable=True),
b=Param(f32[128], trainable=True),
),
layer_2=Linear(
w=Param(f32[128, 10], trainable=True),
b=Param(f32[10], trainable=True),
),
activation=relu,
)
Serialization
Save and load any pytree as .npz files. Works with models, optimizers, or any other pytree. load requires a reference pytree as a template to reconstruct the tree structure.
ion.save("model.npz", model)
model = ion.load("model.npz", model)
ion.save("snapshot.npz", (model, optimizer))
model, optimizer = ion.load("snapshot.npz", (model, optimizer))
Examples
- Ion Tour: Hands-on walkthrough of the core API
- CNN Demo: Image classification with convolutional networks
- RNN Demo: Sequence classification with recurrent networks
- GPT Demo Notebook: Character-level GPT on TinyStories
- VAE Demo Notebook: Variational autoencoder for image generation
- GNN Node Classification: Node classification on Cora with graph neural networks
- GNN Molecular Property Prediction: Blood-brain barrier prediction with graph attention networks
- SSM Pathfinder: Exploring state space models on the Pathfinder task
- PPO Demo: Reinforcement learning with Gymnax
- PQN Demo: Parallelized Q-Networks with Gymnax
- DQN Atari: Deep Q-networks on Atari with replay buffers and target networks
FAQ
Why do I need a neural network library in JAX?
Building simple neural network models from scratch in JAX is straightforward. As they get more complex however, two things become painful: managing parameters (initializing them, tracking which are trainable, freezing some for fine-tuning) and composing modules (reusing layers, wiring them through JAX transforms, not reimplementing things like convolution padding from scratch for every project). A neural network library takes care of this so you can focus on model building and training.
Who is Ion for?
Ion is for JAX users who want a neural network library that is small, easy to learn, and easy to understand.
The core introduces three concepts,
Module,Param, andOptimizer, and from there JAX does everything else. There are no custom transforms, no special contexts, no framework-specific calling conventions. If you already know JAX, you can learn Ion in an hour.Because the core is <1000 lines with not much happening behind the scenes, it's straightforward to reason about what JAX is doing. This matters most in complex training setups like multi-stage fine-tuning or custom gradient flows.
How does Ion compare to Equinox and Flax?
Equinox is an excellent pytree-based library for scientific computing where neural networks are one of several possible use-cases. It provides filtered transforms, partition/combine utilities, and general pytree tools. Equinox treats all JAX arrays equally, so users must apply
lax.stop_gradientor manually filter trainable parameters when computing gradients and applying optimizer updates. In Ion,Paramtracks trainability sojax.gradreturns zero gradients for frozen params automatically, andOptimizerhandles the partition internally. Relative to Equinox, Ion trades off flexibility for simplicity and ease of use.Flax NNX takes a different approach. NNX models are mutable graph objects with reference semantics, and custom transforms (
nnx.jit,nnx.grad) bridge mutability with JAX's functional model. Ion leans into JAX's philosophy of functional programming and immutability, building on native JAX transforms rather than replacing them. The trade-off is transparent, simple machinery over powerful but opaque machinery.Both Equinox and Flax are well battle-tested and have existing model hubs. If you need a broader pytree toolkit for scientific computing, Equinox is excellent. If you want PyTorch-like mutability, Flax NNX is a great choice.
License
Released under the Apache License 2.0.
Citation
To cite this repository:
@software{ion,
title = {Ion: Simple Neural Networks in JAX},
author = {Alex Goddard},
url = {https://github.com/auxeno/ion},
year = {2026}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file ion_nn-0.5.2.tar.gz.
File metadata
- Download URL: ion_nn-0.5.2.tar.gz
- Upload date:
- Size: 1.6 MB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7b0d1fe831ed3e73d5b43196d3de41eca915394b600d5e3c8623e42c12b2d958
|
|
| MD5 |
f8987d81fb3e12e246e86491173557bf
|
|
| BLAKE2b-256 |
fcde5acccb259fe96523d1474e854bbfa3b7544bec6c0f1e663c68b748d893fd
|
Provenance
The following attestation bundles were made for ion_nn-0.5.2.tar.gz:
Publisher:
publish.yml on Auxeno/ion
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
ion_nn-0.5.2.tar.gz -
Subject digest:
7b0d1fe831ed3e73d5b43196d3de41eca915394b600d5e3c8623e42c12b2d958 - Sigstore transparency entry: 1229896194
- Sigstore integration time:
-
Permalink:
Auxeno/ion@6307ecbf1865ffc5e0931d14a3369282f1e08454 -
Branch / Tag:
refs/tags/v0.5.2 - Owner: https://github.com/Auxeno
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@6307ecbf1865ffc5e0931d14a3369282f1e08454 -
Trigger Event:
release
-
Statement type:
File details
Details for the file ion_nn-0.5.2-py3-none-any.whl.
File metadata
- Download URL: ion_nn-0.5.2-py3-none-any.whl
- Upload date:
- Size: 44.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
febc86b8f19372735823647b46762a5fb8a0306a8610b6f72dba1c02d42cfe8e
|
|
| MD5 |
3d009d88dbb7df1ec4e27a47204614b8
|
|
| BLAKE2b-256 |
49cd9c12b627c7e4bc7b503410bf9822d3053a8d96b4f94a1f0d7f52056d1723
|
Provenance
The following attestation bundles were made for ion_nn-0.5.2-py3-none-any.whl:
Publisher:
publish.yml on Auxeno/ion
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
ion_nn-0.5.2-py3-none-any.whl -
Subject digest:
febc86b8f19372735823647b46762a5fb8a0306a8610b6f72dba1c02d42cfe8e - Sigstore transparency entry: 1229896256
- Sigstore integration time:
-
Permalink:
Auxeno/ion@6307ecbf1865ffc5e0931d14a3369282f1e08454 -
Branch / Tag:
refs/tags/v0.5.2 - Owner: https://github.com/Auxeno
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@6307ecbf1865ffc5e0931d14a3369282f1e08454 -
Trigger Event:
release
-
Statement type: