export JAX to ONNX
Project description
jax2onnx 🌟
jax2onnx converts your JAX, Flax NNX, Flax Linen, Equinox functions directly into the ONNX format.
✨ Key Features
-
simple API
Easily convert JAX callables—including Flax NNX, Flax Linen and Equinox models—into ONNX format usingto_onnx(...). -
model structure preserved
With@onnx_function, submodules appear as named functions in the ONNX graph (e.g. in Netron). Useful for readability and reuse. -
dynamic input support
Use abstract dimensions like'B'or pass scalars as runtime inputs. Models stay flexible without retracing. -
plugin-based extensibility
Add support for new primitives by writing small, local plugins. -
onnx-ir native pipeline
Conversion, optimization, and post-processing all run on the typedonnx_irtoolkit—no protobuf juggling—and stay memory-lean before the final ONNX serialization. -
Netron-friendly outputs
Generated graphs carry shape/type annotations and a clean hierarchy, so tools like Netron stay easy to read.
🚀 Quickstart
Install and export your first model in minutes:
pip install jax2onnx
Convert your JAX callable to ONNX in just a few lines:
from flax import nnx
from jax2onnx import to_onnx
# Define a simple MLP (from Flax docs)
class MLP(nnx.Module):
def __init__(self, din, dmid, dout, *, rngs):
self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
self.linear2 = nnx.Linear(dmid, dout, rngs=rngs)
def __call__(self, x):
x = nnx.gelu(self.dropout(self.bn(self.linear1(x))))
return self.linear2(x)
# Instantiate model
my_callable = MLP(din=30, dmid=20, dout=10, rngs=nnx.Rngs(0))
# Export straight to disk without keeping the proto in memory
to_onnx(
my_callable,
[("B", 30)],
return_mode="file",
output_path="my_callable.onnx",
)
🔎 See it visualized: my_callable.onnx
🧠 ONNX Functions — Minimal Example
ONNX functions help encapsulate reusable subgraphs. Simply use the @onnx_function decorator to make your callable an ONNX function.
Just an @onnx_function decorator to make your callable an ONNX function
from flax import nnx
from jax2onnx import onnx_function, to_onnx
# just an @onnx_function decorator to make your callable an ONNX function
@onnx_function
class MLPBlock(nnx.Module):
def __init__(self, dim, *, rngs):
self.linear1 = nnx.Linear(dim, dim, rngs=rngs)
self.linear2 = nnx.Linear(dim, dim, rngs=rngs)
self.batchnorm = nnx.BatchNorm(dim, rngs=rngs)
def __call__(self, x):
return nnx.gelu(self.linear2(self.batchnorm(nnx.gelu(self.linear1(x)))))
# Use it inside another module
class MyModel(nnx.Module):
def __init__(self, dim, *, rngs):
self.block1 = MLPBlock(dim, rngs=rngs)
self.block2 = MLPBlock(dim, rngs=rngs)
def __call__(self, x):
return self.block2(self.block1(x))
callable = MyModel(256, rngs=nnx.Rngs(0))
to_onnx(
callable,
[(100, 256)],
return_mode="file",
output_path="model_with_function.onnx",
)
🔎 See it visualized: model_with_function.onnx
SotA examples 🚀
-
Language: GPT-OSS (open-source MoE Transformer)
- Architecture: Flax/NNX + Equinox reference stacks with gating/routing capture, MoE MLP rebuilds, and deterministic ONNX exporters (see
jax2onnx/plugins/examples/nnx/gpt_oss_flax.pyandjax2onnx/plugins/examples/eqx/gpt_oss.py). - Structural graph:
- How-to: Getting GPT-OSS weights into jax2onnx
- Equivalence check: Routing parity harness · Flax parity tests · Equinox parity tests
- Optional pretrained weights: openai/gpt-oss-20b · openai/gpt-oss-120b (weights and model cards list
license: apache-2.0)
- Architecture: Flax/NNX + Equinox reference stacks with gating/routing capture, MoE MLP rebuilds, and deterministic ONNX exporters (see
-
Vision: DINOv3
- Architecture: Equimo’s clean-room Equinox/JAX implementation, following Meta AI’s DINOv3 paper. Flax/NNX parity modules now live under
jax2onnx/plugins/examples/nnx/dinov3.py(randomly initialised example stack for IR-only exports). - Structural graphs (selected examples):
- How-to: Getting Meta weights into jax2onnx
- Equivalence check: Comparing Meta vs jax2onnx ONNX
- Optional pretrained weights (Meta AI): facebook/dinov3-vitb16-pretrain-lvd1689m (other variants live under the same namespace) — DINOv3 license applies; review before downloading or redistributing.
- Architecture: Equimo’s clean-room Equinox/JAX implementation, following Meta AI’s DINOv3 paper. Flax/NNX parity modules now live under
🧩 Coverage & Examples (Interactive)
[!TIP] JAX · Flax · Equinox — explore everything that’s supported and see it in action.
- ✅ Support matrix: status per component
- 🧪 Exact regression testcase for each entry
- 🔍 One-click Netron graph to inspect nodes, shapes, attributes
- 🧩 Examples that compose multiple components (Conv→Norm→Activation→Pool, MLP w/ LayerNorm+Dropout,
reshape/transpose/concat,scan/while_loop,gather/scatter, …)Links: Open support matrix ↗ · Browse examples ↗
📅 Roadmap and Releases
Planned
- Broaden coverage of JAX, Flax NNX/Linen, and Equinox components.
- Expand SotA example support for vision and language models.
- Improve support for physics-based simulations
Current Productive Version
- 0.11.0:
- Initial Flax Linen support: core layers (Dense/DenseGeneral, Conv/ConvTranspose/ConvLocal, pooling, BatchNorm/LayerNorm/GroupNorm/RMSNorm/InstanceNorm), Dropout, Einsum/Embed, spectral/weight norm wrappers, activation coverage (GELU plus glu/hard_/log_/relu6/silu-swish/tanh/normalize/one_hot), attention stack (dot_product_attention, dot_product_attention_weights, make_attention_mask/make_causal_mask, SelfAttention, MultiHeadDotProductAttention, MultiHeadAttention), recurrent stack (SimpleCell, GRUCell, MGUCell, LSTMCell, OptimizedLSTMCell, ConvLSTMCell, RNN, Bidirectional), and Linen examples (MLP/CNN/Sequential).
- Modernized IR optimization pipeline: standard onnx_ir CSE pass adoption, removed legacy helpers/getattr patterns, and simplified tests with direct graph iteration.
Past Versions
See past_versions for the full release archive.
❓ Troubleshooting
If conversion doesn't work out of the box, it could be due to:
-
Non-dynamic function references:
JAXPR-based conversion requires function references to be resolved dynamically at call-time.
Solution: Wrap your function call inside a lambda to enforce dynamic resolution:my_dynamic_callable_function = lambda x: original_function(x)
-
Unsupported primitives:
The callable may use a primitive not yet or not fully supported byjax2onnx.
Solution: Write a plugin to handle the unsupported function (this is straightforward!).
Looking for provenance details while debugging? Check out the new Stacktrace Metadata guide.
🤝 How to Contribute
We warmly welcome contributions!
How you can help:
- Add a plugin: Extend
jax2onnxby writing a simple Python file injax2onnx/plugins: a primitive or an example. The Plugin Quickstart walks through the process step-by-step. - Bug fixes & improvements: PRs and issues are always welcome.
📌 Dependencies
Latest supported version of major dependencies:
| Library | Versions |
|---|---|
JAX |
0.8.2 |
Flax |
0.12.2 |
Equinox |
0.13.2 |
onnx-ir |
0.1.13 |
onnx |
1.20.0 |
onnxruntime |
1.23.2 |
For exact pins and extras, see pyproject.toml.
📜 License
This project is licensed under the Apache License, Version 2.0. See LICENSE for details.
🌟 Special Thanks
✨ Special thanks to @clementpoiret for initiating Equinox support and for Equimo, which brings modern vision models—such as DINOv3—to JAX/Equinox.
✨ Special thanks to @justinchuby for introducing onnx-ir as a scalable and more efficient way to handle ONNX model construction.
✨ Special thanks to @atveit for introducing us to gpt-oss-jax-vs-torch-numerical-comparison.
✨ Special thanks for example contributions to @burakssen, @Cadynum, @clementpoiret and @PVirie
✨ Special thanks for plugin contributions to @burakssen, @clementpoiret, @Clouder0, @rakadam and benmacadam64
✨ Special thanks to @benmacadam64 for championing the complex-number handling initiative.
✨ Special thanks to tumaer/JAXFLUIDS for contributing valuable insights rooted in physics simulation use cases.
✨ Special thanks to @lutzroeder for making shapes internal to ONNX function visible in his great Netron viewer.
✨ Special thanks to the community members involved in:
✨ Special thanks to @limarta, whose elegant jaxpr-to-ONNX demonstration significantly inspired this project.
Happy converting! 🎉
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jax2onnx-0.11.0.tar.gz.
File metadata
- Download URL: jax2onnx-0.11.0.tar.gz
- Upload date:
- Size: 480.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/2.1.3 CPython/3.12.5 Linux/6.6.87.2-microsoft-standard-WSL2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ad155689d3de14629f7324fdf3a9bb8a11e2c0c9af7c775f5f0053f74033fa8d
|
|
| MD5 |
501be9da2980d7c613752651a409af4e
|
|
| BLAKE2b-256 |
502d39ffe12ec93ff9024818e734b447077fb165ae3da1c81279ee077c0de5f2
|
File details
Details for the file jax2onnx-0.11.0-py3-none-any.whl.
File metadata
- Download URL: jax2onnx-0.11.0-py3-none-any.whl
- Upload date:
- Size: 753.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/2.1.3 CPython/3.12.5 Linux/6.6.87.2-microsoft-standard-WSL2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
cef5bdd7388bb586e8f334b97b540a4bae6e227d0af7354d048b7bcc70ed0df3
|
|
| MD5 |
fe8c5b2e2df64e5107fccfb6cb3fb55d
|
|
| BLAKE2b-256 |
3615c80600c189b945c9e2aeb94b1d9cf62f9b758275a905fe1ca147c947f4db
|