export JAX to ONNX - focus on flax nnx
Project description
jax2onnx ๐
jax2onnx converts your JAX/Flax functions directly into the ONNX format.
โจ Key Features
-
Simple API
Convert any JAX/Flax model to ONNX usingto_onnx(...)orsave_onnx(...) -
Model structure preserved
With@onnx_function, submodules appear as named functions in the ONNX graph (e.g. in Netron). Useful for readability and reuse. -
Dynamic input support
Use abstract dimensions like'B'or pass scalars as runtime inputs. Models stay flexible without retracing. -
Plugin-based extensibility
Add support for new primitives by writing small, local plugins. -
Netron-friendly outputs
All generated ONNX graphs include shape/type annotations and are structured for clear visualization.
๐ Quickstart
Convert your JAX callable to ONNX in just a few lines:
import onnx
from flax import nnx
from jax2onnx import to_onnx
# Define a simple MLP (from Flax docs)
class MLP(nnx.Module):
def __init__(self, din, dmid, dout, *, rngs):
self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
self.linear2 = nnx.Linear(dmid, dout, rngs=rngs)
def __call__(self, x):
x = nnx.gelu(self.dropout(self.bn(self.linear1(x))))
return self.linear2(x)
# Instantiate model
my_callable = MLP(din=30, dmid=20, dout=10, rngs=nnx.Rngs(0))
# Convert to ONNX
onnx_model = to_onnx(my_callable, [("B", 30)])
# Save the model
onnx.save_model(onnx_model, "my_callable.onnx")
๐ See it visualized: jax_callable.onnx
๐ง ONNX Functions โ Minimal Example
ONNX functions help encapsulate reusable subgraphs. Simply use the @onnx_function decorator to make your callable an ONNX function.
Just an @onnx_function decorator to make your callable an ONNX function
from onnx import save_model
from flax import nnx
from jax2onnx import onnx_function, to_onnx
# just an @onnx_function decorator to make your callable an ONNX function
@onnx_function
class MLPBlock(nnx.Module):
def __init__(self, dim, *, rngs):
self.linear1 = nnx.Linear(dim, dim, rngs=rngs)
self.linear2 = nnx.Linear(dim, dim, rngs=rngs)
self.batchnorm = nnx.BatchNorm(dim, rngs=rngs)
def __call__(self, x):
return nnx.gelu(self.linear2(self.batchnorm(nnx.gelu(self.linear1(x)))))
# Use it inside another module
class MyModel(nnx.Module):
def __init__(self, dim, *, rngs):
self.block1 = MLPBlock(dim, rngs=rngs)
self.block2 = MLPBlock(dim, rngs=rngs)
def __call__(self, x):
return self.block2(self.block1(x))
callable = MyModel(256, rngs=nnx.Rngs(0))
model = to_onnx(callable, [(100, 256)])
save_model(model, "docs/onnx/model_with_function.onnx")
๐ See it visualized: model_with_function.onnx
๐ Roadmap and Releases
Planned Versions
- Ongoing: Expanding JAX component coverage.
- 0.5.0: Some more ONNX function support ... batch dims, function reuse, make graph optimizer work within functions, allow user friendly var names
Current Productive Version
- 0.4.2 (PyPI): Cleanup and fixes to the basic ONNX function release.
Past Versions
- 0.4.1 (ONNX functions): Introducing simple ONNX function support. Making use of ONNX functions is easy for the user: just a
@onnx_functiondecorator making a callable an ONNX function. Each@onnx_functiondecorator creates a new ONNX function instance on the call graph. - 0.3.2: relaxed the minimum Python version to 3.10.
- 0.3.0: Streamlined the plugin system with automatic registration and simplified integration of custom primitives.
- 0.2.0 (First PyPI Release): Rebased the implementation on
jaxpr, improving usability and adding low-levellaxcomponents. - 0.1.0 (Initial Approach, Not Released to PyPI): Produced ONNX exports for some
nnxcomponents andnnx-based examples, including a VisualTransformer.
โ Troubleshooting
If conversion doesn't work out of the box, it could be due to:
-
Non-dynamic function references:
JAXPR-based conversion requires function references to be resolved dynamically at call-time.
Solution: Wrap your function call inside a lambda to enforce dynamic resolution:my_dynamic_callable_function = lambda x: original_function(x)
-
Unsupported primitives:
The callable may use a primitive not yet or not fully supported byjax2onnx.
Solution: Write a plugin to handle the unsupported function (this is straightforward!).
๐งฉ Supported JAX/ONNX Components
Legend:
โ
= Passed
โ = Failed
โ = No testcase yet
๐ฏ Examples
| Component | Description | Testcases | Since |
|---|---|---|---|
| AutoEncoder | A simple autoencoder example. | simple_autoencoder โ
|
v0.2.0 |
| CNN | A simple convolutional neural network (CNN). | simple_cnn โ
|
v0.1.0 |
| ClassificationHead | Classification head for Vision Transformer | classification_head โ
|
v0.4.0 |
| ConcatClsToken | Concatenate CLS token to the input embedding | concat_cls_token โ
|
v0.4.0 |
| ConvEmbedding | Convolutional Token Embedding for MNIST with hierarchical downsampling. | mnist_conv_embedding โ
|
v0.1.0 |
| FeedForward | MLP in Transformer | feed_forward โ
|
v0.1.0 |
| MLP | A simple Multi-Layer Perceptron (MLP) with BatchNorm, Dropout, and GELU activation. | simple_mlp_dynamic โ
simple_mlp โ
simple_mlp_with_call_params_dynamic โ
simple_mlp_with_call_params โ
|
v0.1.0 |
| MultiHeadAttention | This is a multi-head attention module implemented by Flax/nnx that has no ONNX correspondent on the same granularity. | multihead_attention โ
|
v0.2.0 |
| PatchEmbedding | Cutting the image into patches and linearly embedding them. | patch_embedding โ
|
v0.1.0 |
| PositionalEmbedding | Add positional embedding to the input embedding | positional_embedding โ
|
v0.4.0 |
| TransformerBlock | Transformer from 'Attention Is All You Need.' | transformer_block โ
|
v0.1.0 |
| TransformerStack | Stack of Transformer blocks | transformer_stack โ
|
v0.1.0 |
| VisionTransformer | A Vision Transformer (ViT) model for MNIST with configurable embedding type. | vit_conv_embedding โ
vit_patch_embedding โ
|
v0.2.0 |
| onnx_functions_000 | one function on an outer layer. | 000_one_function_on_outer_layer โ
|
v0.4.0 |
| onnx_functions_001 | one function on an inner layer. | 001_one_function_inner โ
|
v0.4.0 |
| onnx_functions_002 | two nested functions. | 002_two_nested_functions โ
|
v0.4.0 |
| onnx_functions_003 | two nested functions. | 003_two_simple_nested_functions โ
|
v0.4.0 |
| onnx_functions_004 | nested function plus component | 004_nested_function_plus_component โ
|
v0.4.0 |
| onnx_functions_005 | nested function plus more components | 005_nested_function_plus_component โ
|
v0.4.0 |
| onnx_functions_006 | one function on an outer layer. | 006_one_function_outer โ
|
v0.4.0 |
| onnx_functions_008 | transformer block with nested mlp block no call parameter | 008_transformer_block โ
|
v0.4.0 |
| onnx_functions_009 | transformer block using decorator on class and function | 009_transformer_block โ
|
v0.4.0 |
| onnx_functions_010 | transformer stack | 010_transformer_stack โ
|
v0.4.0 |
| onnx_functions_012 | Vision Transformer (ViT) | 012_vit_conv_embedding โ
|
v0.4.0 |
| onnx_functions_013 | Vision Transformer (ViT) | 013_vit_conv_embedding_with_call_params โ
013_vit_conv_embedding_with_internal_call_params โ
|
v0.4.0 |
| onnx_functions_014 | one function on an outer layer. | 014_one_function_with_input_param_with_default_value โ
014_one_function_without_input_param_with_default_value โ
|
v0.4.0 |
| onnx_functions_015 | one function on an outer layer. | 015_one_function_with_input_param_without_default_value โ
|
v0.4.0 |
| onnx_functions_016 | nested function plus more components | 016_internal_function_with_input_param_with_default_value โ
|
v0.4.0 |
๐ Dependencies
Versions of Major Dependencies:
| Library | Versions |
|---|---|
JAX |
0.5.3 |
Flax |
0.10.5 |
onnx |
1.17.0 |
onnxruntime |
1.21.0 |
Note: For more details, check pyproject.toml.
โ ๏ธ Limitations
- Currently not all JAX/Flax components are supported (you can easily help expand this coverage!).
- Function references need dynamic resolution at call-time.
- ONNX graph composition is done in-memory before saving to disk, potentially causing memory issues with very large models.
๐ค How to Contribute
We warmly welcome contributions!
How you can help:
- Add a plugin: Extend
jax2onnxby writing a simple Python file injax2onnx/plugins. a custom primitive or an example. - Bug fixes & improvements: PRs and issues are always welcome.
๐พ Installation
Install from PyPI:
pip install jax2onnx
Or get the latest development version from TestPyPI:
pip install -i https://test.pypi.org/simple/ jax2onnx
๐ License
This project is licensed under the Apache License, Version 2.0. See LICENSE for details.
๐ Special Thanks
Special thanks to @lutzroeder for making shapes internal to ONNX function visible in his great Netron viewer.
Special thanks to the community members involved in:
A huge thanks especially to @limarta, whose elegant jaxpr-to-ONNX demonstration significantly inspired this project.
Happy converting! ๐
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jax2onnx-0.4.2.tar.gz.
File metadata
- Download URL: jax2onnx-0.4.2.tar.gz
- Upload date:
- Size: 113.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/2.0.1 CPython/3.12.5 Linux/5.15.167.4-microsoft-standard-WSL2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
733a191ecbf7e4c90b6e80b7d1656c0cbe828341151d84774eb43923eabf06a2
|
|
| MD5 |
6303bda068377d8faac64881a3cb2491
|
|
| BLAKE2b-256 |
92c9d82406283575a939a791cf48aa76ff8c42a76d36e7eda62b6b1efe0290a4
|
File details
Details for the file jax2onnx-0.4.2-py3-none-any.whl.
File metadata
- Download URL: jax2onnx-0.4.2-py3-none-any.whl
- Upload date:
- Size: 202.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/2.0.1 CPython/3.12.5 Linux/5.15.167.4-microsoft-standard-WSL2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a88dab3f7dd151de76e637643fdd0b6c617d7629b1c804095f84c006ec73b82a
|
|
| MD5 |
29916f47966fe8b691019b27bc870439
|
|
| BLAKE2b-256 |
316a59819a433c03ab5dcb8e0f9682f1bae297b639cd8558ae0111bf1bdb1303
|