Skip to main content

A mini machine learning framework a la JAX.

Project description

minijax

Jax to minijax

A mini deep learning framework à la Jax. Like Jax, minijax is based on function transformations. It implements vmap, grad, jit (well, about one third) in less than 500 lines of code.

from minijax.core import relu
from minijax.eval import Array
from minijax.grad import grad
from minijax.jit import jit
from minijax.vmap import vmap

def model(x, params):
    for w, b in params:
        x = relu(w @ x + b)
    return x

# x = Array(...); define params
y = jit(grad(vmap(model, (0, None))))(x, params)

Each of minijax.core, minijax.grad, minijax.vmap and minijax.jit is less than 100 lines of code. A good place to get started is minijax/core.py, the demo.ipynb notebook, or the train_mnist.ipynb notebooks.

  • demo.ipynb trains an small neural network classifier for a 2d dataset using stochastic gradient descent.
  • train_mnist.ipynb trains a multi-layer-perceptron on MNIST using Adam.
  • a little extra: derivatives.ipynb demonstrates computing higher-order derivatives using minijax.

To get started, clone this repository, run

pip install -e .[demos]

and have fun with the code!

Acknowledgements

This repo is inspired by the awesome micrograd repository that implements a jet smaller PyTorch-style deep learning framework. Unlike, micrograd, the purpose of minijax is to demonstrate composable function transformations. I also wanted something that scales at least to MNIST, so minijax uses numpy instead of pure Python scalars.

Autodidax from the Jax docs taught me about the core ideas behind Jax. This repo is essentially Autodidax but in the micrograd format instead of a tutorial. I simplified some parts some more and tried to use less jargon.

I wrote minijax over the Christmas break in 2025 and thank my family for still having me around and doing 95% of the cooking.


From minijax to Jax

This repo uses slightly different terminology than Jax:

  • Nested containers (minijax.nested_containers) are PyTrees in Jax.
  • Compute graphs (minijax.compute_grad) are Jaxprs in Jax.

Besides lots of edge cases that are not handled in minijax, minijax also lacks most of Jax's jitting magic: minijax does not run code on GPUs or TPUs and can not split computation across multiple devices. Actually, minijax can only use a single CPU core. Besides that, minijax supports far fewer computation primitives (no convolutions, ...), has very unhelpful error messages, leaves broadcasting to numpy, has no dtypes (imagine that!), and and and...

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

minijax-0.1.1.tar.gz (8.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

minijax-0.1.1-py3-none-any.whl (13.6 kB view details)

Uploaded Python 3

File details

Details for the file minijax-0.1.1.tar.gz.

File metadata

  • Download URL: minijax-0.1.1.tar.gz
  • Upload date:
  • Size: 8.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.18 {"installer":{"name":"uv","version":"0.9.18","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for minijax-0.1.1.tar.gz
Algorithm Hash digest
SHA256 fee7bcbf1de9454a4814e2db39a94e602d145cc4a610ec8e97d2a62298445f20
MD5 42674f112b91ca41297794d4c823d4b4
BLAKE2b-256 6a88ebb758743373b80acc5a905fa713514309f2b49d66f9f74794f025becdfc

See more details on using hashes here.

File details

Details for the file minijax-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: minijax-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 13.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.18 {"installer":{"name":"uv","version":"0.9.18","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for minijax-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 a957e2d581a6e4374d0b7f97420109469eb949b48f49e22c772d9fba167709cd
MD5 e3c16b6532f4fb18ac42abefb3046dfe
BLAKE2b-256 ce60650d6aae5bf69af8e9aac8fbac5a738db54bb0d70df5b44d7ab4cf9acbfb

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page