export JAX to ONNX - focus on flax nnx
Project description
jax2onnx ๐
jax2onnx converts your JAX/Flax(nnx) functions directly into the ONNX format.
โจ Key Features
-
Simple API
Easily convert JAX callablesโincluding Flax (NNX) modelsโinto ONNX format usingto_onnx(...). -
Model structure preserved
With@onnx_function, submodules appear as named functions in the ONNX graph (e.g. in Netron). Useful for readability and reuse. -
Dynamic input support
Use abstract dimensions like'B'or pass scalars as runtime inputs. Models stay flexible without retracing. -
Plugin-based extensibility
Add support for new primitives by writing small, local plugins. -
Netron-friendly outputs
All generated ONNX graphs include shape/type annotations and are structured for clear visualization.
๐ Quickstart
Convert your JAX callable to ONNX in just a few lines:
import onnx
from flax import nnx
from jax2onnx import to_onnx
# Define a simple MLP (from Flax docs)
class MLP(nnx.Module):
def __init__(self, din, dmid, dout, *, rngs):
self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
self.linear2 = nnx.Linear(dmid, dout, rngs=rngs)
def __call__(self, x):
x = nnx.gelu(self.dropout(self.bn(self.linear1(x))))
return self.linear2(x)
# Instantiate model
my_callable = MLP(din=30, dmid=20, dout=10, rngs=nnx.Rngs(0))
# Convert to ONNX
onnx_model = to_onnx(my_callable, [("B", 30)])
# Save the model
onnx.save_model(onnx_model, "my_callable.onnx")
๐ See it visualized: my_callable.onnx
๐ง ONNX Functions โ Minimal Example
ONNX functions help encapsulate reusable subgraphs. Simply use the @onnx_function decorator to make your callable an ONNX function.
Just an @onnx_function decorator to make your callable an ONNX function
from onnx import save_model
from flax import nnx
from jax2onnx import onnx_function, to_onnx
# just an @onnx_function decorator to make your callable an ONNX function
@onnx_function
class MLPBlock(nnx.Module):
def __init__(self, dim, *, rngs):
self.linear1 = nnx.Linear(dim, dim, rngs=rngs)
self.linear2 = nnx.Linear(dim, dim, rngs=rngs)
self.batchnorm = nnx.BatchNorm(dim, rngs=rngs)
def __call__(self, x):
return nnx.gelu(self.linear2(self.batchnorm(nnx.gelu(self.linear1(x)))))
# Use it inside another module
class MyModel(nnx.Module):
def __init__(self, dim, *, rngs):
self.block1 = MLPBlock(dim, rngs=rngs)
self.block2 = MLPBlock(dim, rngs=rngs)
def __call__(self, x):
return self.block2(self.block1(x))
callable = MyModel(256, rngs=nnx.Rngs(0))
model = to_onnx(callable, [(100, 256)])
save_model(model, "docs/onnx/model_with_function.onnx")
๐ See it visualized: model_with_function.onnx
๐ Roadmap and Releases
Planned Versions
Hereโs a polished version:
Planned Versions
-
Ongoing
- Expanding coverage of JAX and Flax (NNX) components
- Enhancing support for physics-based simulations
-
Under Evaluation
- Integrating
onnx-iras a backend to improve ONNX model construction, memory efficiency, and performance
- Integrating
-
Upcoming
- Advanced ONNX function support, including function reuse, optimized internal graph structure, and improved variable naming for clarity and readability
- Support for
equinoxmodels
Current Productive Version
- 0.7.0 (PyPI):
- Added a GPT-2 model example based on nanoGPT, featuring ONNX function support and attention masking
- New support for
jnp.concatenate,jnp.take,nnx.Embed - ONNX models are now hosted on Hugging Face
Past Versions
- 0.6.5: Improved support for
nnx.batch_norm,nnx.group_norm,nnx.layer_norm,nnx.rms_norm,lax.broadcast_in_dim,lax.cond,lax.fori_loop,lax.integer_pow,lax.scan,lax.scatter,lax.scatter_add,lax.scatter_mulandlax.while_loop; added support forlax.and,lax.remandlax.remat2. - 0.6.4: Improved support for
lax.scatter_mul. - 0.6.3: Double precision fixes for
lax.fori_loopandlax.while_loop. Fixed bugs inlax.scanandjnp.where. - 0.6.2: Fixed bugs in
nnx.convandlax.reshape; added new primitivejnp.prod. - 0.6.1: Improved support for
lax.condandlax.select_n; added new primitives (lax.reduce_and,lax.reduce_or,lax.reduce_prod,lax.reduce_xor); and introduced new examples forjnp.selectandjnp.sort. - 0.6.0: Introduced the
enable_double_precisionparameter (default:False) to support physics simulations, and enhanced handling oflax.scatter. - 0.5.2: Add support for additional primitives:
jnp.where,jnp.arange,jnp.linspace. - 0.5.1: Add support for subgraph using primitives:
lax.while_loop,lax.cond,lax.fori_loop,lax.scan. - 0.5.0: Improved dynamic batch dimension handling by leveraging shape polymorphism for more robust and flexible model export.
Added support for
jnp.sign,jnp.abs,jnp.iotaprimitives. - 0.4.4: Added support for
lax.cos,lax.cosh,lax.sin,lax.sinhandlax.scatterprimitives. - 0.4.3: Fixed a bug in the validation of JAX callable outputs against their ONNX counterparts. This fix exposed previously hidden failing tests, which are now fixed.
- 0.4.2: Cleanup and fixes to the basic ONNX function release.
- 0.4.1 (ONNX functions): Introducing simple ONNX function support. Making use of ONNX functions is easy for the user: just a
@onnx_functiondecorator making a callable an ONNX function. Each@onnx_functiondecorator creates a new ONNX function instance on the call graph. - 0.3.2: relaxed the minimum Python version to 3.10.
- 0.3.0: Streamlined the plugin system with automatic registration and simplified integration of custom primitives.
- 0.2.0 (First PyPI Release): Rebased the implementation on
jaxpr, improving usability and adding low-levellaxcomponents. - 0.1.0 (Initial Approach, Not Released to PyPI): Produced ONNX exports for some
nnxcomponents andnnx-based examples, including a VisualTransformer.
โ Troubleshooting
If conversion doesn't work out of the box, it could be due to:
-
Non-dynamic function references:
JAXPR-based conversion requires function references to be resolved dynamically at call-time.
Solution: Wrap your function call inside a lambda to enforce dynamic resolution:my_dynamic_callable_function = lambda x: original_function(x)
-
Unsupported primitives:
The callable may use a primitive not yet or not fully supported byjax2onnx.
Solution: Write a plugin to handle the unsupported function (this is straightforward!).
๐งฉ Supported JAX/ONNX Components
Legend:
โ
= Passed
โ = Failed
โ = No testcase yet
๐ฏ Examples
๐ Dependencies
Versions of Major Dependencies:
| Library | Versions |
|---|---|
JAX |
0.6.2 |
Flax |
0.10.6 |
onnx |
1.18.0 |
onnxruntime |
1.22.0 |
Note: For more details, check pyproject.toml.
โ ๏ธ Limitations
- Currently not all JAX/Flax components are supported (you can easily help expand this coverage!).
- Function references need dynamic resolution at call-time.
- ONNX graph composition is done in-memory before saving to disk, potentially causing memory issues with very large models.
๐ค How to Contribute
We warmly welcome contributions!
How you can help:
- Add a plugin: Extend
jax2onnxby writing a simple Python file injax2onnx/plugins. a custom primitive or an example. - Bug fixes & improvements: PRs and issues are always welcome.
๐พ Installation
Install from PyPI:
pip install jax2onnx
๐ License
This project is licensed under the Apache License, Version 2.0. See LICENSE for details.
๐ Special Thanks
Special thanks for example contributions to @burakssen, @Cadynum and @clementpoiret
Special thanks for plugin contributions to @burakssen
Special thanks to tumaer/JAXFLUIDS for contributing valuable insights rooted in physics simulation use cases.
Special thanks to @lutzroeder for making shapes internal to ONNX function visible in his great Netron viewer.
Special thanks to the community members involved in:
A huge thanks especially to @limarta, whose elegant jaxpr-to-ONNX demonstration significantly inspired this project.
Happy converting! ๐
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jax2onnx-0.7.0.tar.gz.
File metadata
- Download URL: jax2onnx-0.7.0.tar.gz
- Upload date:
- Size: 259.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/2.1.3 CPython/3.12.5 Linux/6.6.87.2-microsoft-standard-WSL2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
139cf58ca148a0bd028c2f382b91b31ab51db8af16529720aeb8b45460eaf2c9
|
|
| MD5 |
4464e901a79b52cb6da2cc638f729d88
|
|
| BLAKE2b-256 |
bfff204618e1f23b64b612eefbf3b6c3954d461fa35a83ac02b28c6285cf9ca3
|
File details
Details for the file jax2onnx-0.7.0-py3-none-any.whl.
File metadata
- Download URL: jax2onnx-0.7.0-py3-none-any.whl
- Upload date:
- Size: 389.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/2.1.3 CPython/3.12.5 Linux/6.6.87.2-microsoft-standard-WSL2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4877aaf489b6cea4cc94eb011855e55797ade4bde529c2b49df81a5fd63d3500
|
|
| MD5 |
acf0e5a030f08361bcd181313e285fdf
|
|
| BLAKE2b-256 |
1c6b1e7a527c909ebc72dbddecc135c4cd96749e43436fa20b04c57cbd7d8eb6
|