A mini machine learning framework a la JAX.
Project description
minijax
A mini deep learning framework à la Jax.
Like Jax, minijax is based on function transformations.
It implements vmap, grad, jit (well, about one third) in less than 500 lines of code.
from minijax.core import relu
from minijax.eval import Array
from minijax.grad import grad
from minijax.jit import jit
from minijax.vmap import vmap
def model(x, params):
for w, b in params:
x = relu(w @ x + b)
return x
# x = Array(...); define params
y = jit(grad(vmap(model, (0, None))))(x, params)
Each of minijax.core, minijax.grad, minijax.vmap and minijax.jit is less than 100 lines of code.
A good place to get started is minijax/core.py, the demo.ipynb notebook, or the train_mnist.ipynb notebooks.
demo.ipynbtrains an small neural network classifier for a 2d dataset using stochastic gradient descent.train_mnist.ipynbtrains a multi-layer-perceptron on MNIST using Adam.- a little extra:
derivatives.ipynbdemonstrates computing higher-order derivatives usingminijax.
To get started, clone this repository, run
pip install -e .[demos]
and have fun with the code!
Acknowledgements
This repo is inspired by the awesome micrograd repository that implements a jet smaller PyTorch-style deep learning framework. Unlike, micrograd, the purpose of minijax is to demonstrate composable function transformations. I also wanted something that scales at least to MNIST, so minijax uses numpy instead of pure Python scalars.
Autodidax from the Jax docs taught me about the core ideas behind Jax. This repo is essentially Autodidax but in the micrograd format instead of a tutorial. I simplified some parts some more and tried to use less jargon.
I wrote minijax over the Christmas break in 2025 and thank my family for still having me around and doing 95% of the cooking.
From minijax to Jax
This repo uses slightly different terminology than Jax:
- Nested containers (
minijax.nested_containers) are PyTrees in Jax. - Compute graphs (
minijax.compute_grad) are Jaxprs in Jax.
Besides lots of edge cases that are not handled in minijax, minijax also lacks most of Jax's jitting magic: minijax does not run code on GPUs or TPUs and can not split computation across multiple devices. Actually, minijax can only use a single CPU core. Besides that, minijax supports far fewer computation primitives (no convolutions, ...), has very unhelpful error messages, leaves broadcasting to numpy, has no dtypes (imagine that!), and and and...
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file minijax-0.1.1.tar.gz.
File metadata
- Download URL: minijax-0.1.1.tar.gz
- Upload date:
- Size: 8.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.9.18 {"installer":{"name":"uv","version":"0.9.18","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
fee7bcbf1de9454a4814e2db39a94e602d145cc4a610ec8e97d2a62298445f20
|
|
| MD5 |
42674f112b91ca41297794d4c823d4b4
|
|
| BLAKE2b-256 |
6a88ebb758743373b80acc5a905fa713514309f2b49d66f9f74794f025becdfc
|
File details
Details for the file minijax-0.1.1-py3-none-any.whl.
File metadata
- Download URL: minijax-0.1.1-py3-none-any.whl
- Upload date:
- Size: 13.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.9.18 {"installer":{"name":"uv","version":"0.9.18","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a957e2d581a6e4374d0b7f97420109469eb949b48f49e22c772d9fba167709cd
|
|
| MD5 |
e3c16b6532f4fb18ac42abefb3046dfe
|
|
| BLAKE2b-256 |
ce60650d6aae5bf69af8e9aac8fbac5a738db54bb0d70df5b44d7ab4cf9acbfb
|