export JAX to ONNX - focus on flax nnx
Project description
jax2onnx ๐
jax2onnx converts your JAX/Flax(nnx) functions directly into the ONNX format.
โจ Key Features
-
Simple API
Easily convert JAX callablesโincluding Flax (NNX) modelsโinto ONNX format usingto_onnx(...). -
Model structure preserved
With@onnx_function, submodules appear as named functions in the ONNX graph (e.g. in Netron). Useful for readability and reuse. -
Dynamic input support
Use abstract dimensions like'B'or pass scalars as runtime inputs. Models stay flexible without retracing. -
Plugin-based extensibility
Add support for new primitives by writing small, local plugins. -
Netron-friendly outputs
All generated ONNX graphs include shape/type annotations and are structured for clear visualization.
๐ Quickstart
Convert your JAX callable to ONNX in just a few lines:
import onnx
from flax import nnx
from jax2onnx import to_onnx
# Define a simple MLP (from Flax docs)
class MLP(nnx.Module):
def __init__(self, din, dmid, dout, *, rngs):
self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
self.linear2 = nnx.Linear(dmid, dout, rngs=rngs)
def __call__(self, x):
x = nnx.gelu(self.dropout(self.bn(self.linear1(x))))
return self.linear2(x)
# Instantiate model
my_callable = MLP(din=30, dmid=20, dout=10, rngs=nnx.Rngs(0))
# Convert to ONNX
onnx_model = to_onnx(my_callable, [("B", 30)])
# Save the model
onnx.save_model(onnx_model, "my_callable.onnx")
๐ See it visualized: my_callable.onnx
๐ง ONNX Functions โ Minimal Example
ONNX functions help encapsulate reusable subgraphs. Simply use the @onnx_function decorator to make your callable an ONNX function.
Just an @onnx_function decorator to make your callable an ONNX function
from onnx import save_model
from flax import nnx
from jax2onnx import onnx_function, to_onnx
# just an @onnx_function decorator to make your callable an ONNX function
@onnx_function
class MLPBlock(nnx.Module):
def __init__(self, dim, *, rngs):
self.linear1 = nnx.Linear(dim, dim, rngs=rngs)
self.linear2 = nnx.Linear(dim, dim, rngs=rngs)
self.batchnorm = nnx.BatchNorm(dim, rngs=rngs)
def __call__(self, x):
return nnx.gelu(self.linear2(self.batchnorm(nnx.gelu(self.linear1(x)))))
# Use it inside another module
class MyModel(nnx.Module):
def __init__(self, dim, *, rngs):
self.block1 = MLPBlock(dim, rngs=rngs)
self.block2 = MLPBlock(dim, rngs=rngs)
def __call__(self, x):
return self.block2(self.block1(x))
callable = MyModel(256, rngs=nnx.Rngs(0))
model = to_onnx(callable, [(100, 256)])
save_model(model, "docs/onnx/model_with_function.onnx")
๐ See it visualized: model_with_function.onnx
๐ Roadmap and Releases
Planned Versions
-
Ongoing
- Expanding coverage of JAX and Flax (NNX) components
- Enhancing support for physics-based simulations
-
Upcoming
- Advanced ONNX function support, including function reuse, optimized internal graph structure, and improved variable naming for clarity and readability
- Support for
equinoxmodels - Integrating
onnx-iras a backend to improve ONNX model construction, memory efficiency, and performance
Current Productive Version
- 0.7.5 (PyPI):
- Fixed tests for functions without arguments.
- Added support for new primitives:
lax.bitwise_not,lax.clamp,lax.ge,jnp.clip,lax.rev. - Enhanced support for
nnx.dot_product_attention,nnx.conv,nnx.batch_norm,lax.mul,lax.reduce_max,lax.scan,lax.slice,lax.while_loop,nn.gelu,jnp.arange,jnp.cumsum,jnp.select,jnp.where, andjnp.concatenate.
Past Versions
- 0.7.4 adds support for
lax.cumsumandjnp.cumsum, and improveslax.scatter. - 0.7.3 improves polymorphism handling for transformers.
- 0.7.2 adds support for
jnp.split,lax.split,lax.logistic, includes an example fornnx.GRUCell, and improveslax.scatterandlax.while_loop. - 0.7.1 fixes a numeric equivalence bug in the test system, and adds support for
core.custom_jvp_generic,eqx.identity,jnp.select,jnp.stack,jnp.unstack,lax.select, plus multiplenn.*activations (identity,celu,elu,gelu,relu,leaky_relu,mish,selu,sigmoid,soft-sign,softmax,truncated_normal). - 0.7.0 introduces a GPT-2 example based on nanoGPT with ONNX function support and attention masking, adds support for
jnp.concatenate,jnp.take,nnx.Embed, and starts hosting ONNX models on Hugging Face. - 0.6.5 improves support for
nnx.batch_norm,nnx.group_norm,nnx.layer_norm,nnx.rms_norm,lax.broadcast_in_dim,lax.cond,lax.fori_loop,lax.integer_pow,lax.scan,lax.scatter,lax.scatter_add,lax.scatter_mul, andlax.while_loop; and adds support forlax.and,lax.rem, andlax.remat2. - 0.6.4: Improved support for
lax.scatter_mul. - 0.6.3: Double precision fixes for
lax.fori_loopandlax.while_loop. Fixed bugs inlax.scanandjnp.where. - 0.6.2: Fixed bugs in
nnx.convandlax.reshape; added new primitivejnp.prod. - 0.6.1: Improved support for
lax.condandlax.select_n; added new primitives (lax.reduce_and,lax.reduce_or,lax.reduce_prod,lax.reduce_xor); and introduced new examples forjnp.selectandjnp.sort. - 0.6.0: Introduced the
enable_double_precisionparameter (default:False) to support physics simulations, and enhanced handling oflax.scatter. - 0.5.2: Add support for additional primitives:
jnp.where,jnp.arange,jnp.linspace. - 0.5.1: Add support for subgraph using primitives:
lax.while_loop,lax.cond,lax.fori_loop,lax.scan. - 0.5.0: Improved dynamic batch dimension handling by leveraging shape polymorphism for more robust and flexible model export.
Added support for
jnp.sign,jnp.abs,jnp.iotaprimitives. - 0.4.4: Added support for
lax.cos,lax.cosh,lax.sin,lax.sinhandlax.scatterprimitives. - 0.4.3: Fixed a bug in the validation of JAX callable outputs against their ONNX counterparts. This fix exposed previously hidden failing tests, which are now fixed.
- 0.4.2: Cleanup and fixes to the basic ONNX function release.
- 0.4.1 (ONNX functions): Introducing simple ONNX function support. Making use of ONNX functions is easy for the user: just a
@onnx_functiondecorator making a callable an ONNX function. Each@onnx_functiondecorator creates a new ONNX function instance on the call graph. - 0.3.2: relaxed the minimum Python version to 3.10.
- 0.3.0: Streamlined the plugin system with automatic registration and simplified integration of custom primitives.
- 0.2.0 (First PyPI Release): Rebased the implementation on
jaxpr, improving usability and adding low-levellaxcomponents. - 0.1.0 (Initial Approach, Not Released to PyPI): Produced ONNX exports for some
nnxcomponents andnnx-based examples, including a VisualTransformer.
โ Troubleshooting
If conversion doesn't work out of the box, it could be due to:
-
Non-dynamic function references:
JAXPR-based conversion requires function references to be resolved dynamically at call-time.
Solution: Wrap your function call inside a lambda to enforce dynamic resolution:my_dynamic_callable_function = lambda x: original_function(x)
-
Unsupported primitives:
The callable may use a primitive not yet or not fully supported byjax2onnx.
Solution: Write a plugin to handle the unsupported function (this is straightforward!).
๐งฉ Supported JAX/ONNX Components
Legend:
โ
= Passed
โ = Failed
โ = No testcase yet
๐ฏ Examples
| Component | Description | Testcases | Since |
|---|---|---|---|
| SimpleLinearExample | A simple linear layer example using Equinox. | simple_linear_dynamic โ
simple_linear_dynamic_f64 โ
simple_linear โ
simple_linear_f64 โ
nn_linear_dynamic โ
nn_linear_dynamic_f64 โ
nn_linear โ
nn_linear_f64 โ
|
v0.7.1 |
| GPT | A simple GPT model that reuses nnx.MultiHeadAttention. | gpt_dynamic โ
gpt โ
|
v0.7.0 |
| GPT_Attention | A multi-head attention layer. | gpt_attention โ
|
v0.7.1 |
| GPT_CausalSelfAttention | A causal self-attention module. | causal_self_attention_dynamic โ
causal_self_attention โ
|
v0.7.0 |
| GPT_Embeddings | Combines token and position embeddings with dropout. | gpt_embeddings_dynamic โ
gpt_embeddings โ
|
v0.7.0 |
| GPT_Head | The head of the GPT model. | gpt_head_dynamic โ
gpt_head โ
|
v0.7.0 |
| GPT_MLP | An MLP block with GELU activation from nanoGPT. | gpt_mlp_dynamic โ
gpt_mlp โ
|
v0.7.0 |
| GPT_PositionEmbedding | A positional embedding layer using nnx.Embed. | position_embedding โ
|
v0.7.0 |
| GPT_TokenEmbedding | A token embedding layer using nnx.Embed. | token_embedding_dynamic โ
token_embedding โ
|
v0.7.0 |
| GPT_TransformerBlock | A transformer block combining attention and MLP. | gpt_block_dynamic โ
gpt_block โ
|
v0.7.0 |
| GPT_TransformerStack | A stack of transformer blocks. | transformer_stack_dynamic โ
transformer_stack โ
|
v0.7.0 |
| broadcast_add | Simple dynamic broadcast + add | broadcast_add_dynamic_dynamic โ
broadcast_add_dynamic_dynamic_f64 โ
broadcast_add_dynamic โ
broadcast_add_dynamic_f64 โ
|
v0.7.0 |
| cfl_timestep | Tests the CFL condition timestep calculation. | cfl_timestep_f64 โ
|
v0.6.5 |
| weno_reconstruction | Tests the complex arithmetic pattern found in WENO schemes. | weno_reconstruction_f64 โ
|
v0.6.5 |
| fori_loop_test | fori_loop_test: Demonstrates jax.lax.fori_loop with a simple loop. | fori_loop_test โ
fori_loop_test_f64 โ |
v0.6.3 |
| issue18_abs | Test jnp.abs from issue 18 | abs_fn โ
abs_fn_f64 โ
|
v0.6.3 |
| issue18_arange | Test arange from issue 18 | arange_fn โ
arange_fn_f64 โ
|
v0.6.3 |
| issue18_fori_loop | Test fori_loop from issue 18 | fori_loop_fn โ
fori_loop_fn_f64 โ
|
v0.6.3 |
| issue18_linspace | Test linspace from issue 18 | linspace_fn โ
linspace_fn_f64 โ
|
v0.6.3 |
| issue18_scan | Test scan from issue 18 (no xs) | scan_fn โ
scan_fn_f64 โ
|
v0.6.3 |
| issue18_sign | Test jnp.sign from issue 18 | sign_fn โ
sign_fn_f64 โ
|
v0.6.3 |
| issue18_where | Test where from issue 18 | where_fn โ
where_fn_f64 โ
|
v0.6.3 |
| issue18_while_loop | Test while_loop from issue 18 | while_loop_fn โ
|
v0.6.3 |
| select_test | select_test: Demonstrates jnp.select with a dynamic condition based on an input array. | select_test_all_options โselect_test_scalar_select_option_0 โselect_test_scalar_select_option_1 โselect_test_scalar_select_option_2 โselect_test_default_case โ |
v0.6.1 |
| sort_test | sort_test: Demonstrates jnp.sort on slices of an input array. | sort_test_basic โ |
v0.6.1 |
| cond_scatter_add_mul | Tests scatter_add/mul inside jnp.where branches | cond_scatter_add_mul_f64_a โ
cond_scatter_add_mul_f64_b โ
|
v0.6.4 |
| cond_scatter_repro | Reproduces a bug where lax.cond subgraphs do not inherit parent initializers. | cond_scatter_repro_f64 โ
|
v0.6.4 |
| remat2 | Tests a simple case of jax.checkpoint (also known as jax.remat2). |
checkpoint_scalar_f32 โ
checkpoint_scalar_f32_f64 โ
|
v0.6.5 |
| scatter_window | Window-scatter (HรW patch) with implicit batch (depth-3 path). Exercises GatherScatterMode.FILL_OR_DROP and double precision. Regression of a prior conversion failure. | scatter_window_update_f64_example โ
|
v0.7.4 |
| AutoEncoder | A simple autoencoder example. | simple_autoencoder โ
simple_autoencoder_f64 โ
|
v0.2.0 |
| CNN | A simple convolutional neural network (CNN). | simple_cnn_explicit_dimensions โ
simple_cnn_dynamic โ
simple_cnn โ
|
v0.1.0 |
| ForiLoop | fori_loop example | fori_loop_counter โ
fori_loop_counter_f64 โ
|
v0.5.1 |
| GRUCell | Vanilla gated-recurrent-unit cell from Flax/nnx. There is no 1-to-1 ONNX operator, so the converter decomposes it into MatMul, Add, Sigmoid, Tanh, etc. | gru_cell_basic โ
|
v0.7.2 |
| MLP | A simple Multi-Layer Perceptron (MLP) with BatchNorm, Dropout, and GELU activation. | simple_mlp_dynamic โ
simple_mlp_dynamic_f64 โ
simple_mlp โ
simple_mlp_f64 โ
simple_mlp_with_call_params_dynamic โ
simple_mlp_with_call_params_dynamic_f64 โ
simple_mlp_with_call_params โ
simple_mlp_with_call_params_f64 โ
|
v0.1.0 |
| MultiHeadAttention | This is a multi-head attention module implemented by Flax/nnx that has no ONNX correspondent on the same granularity. | multihead_attention_nn_dynamic โ
multihead_attention_nn โ
multihead_attention_nnx_dynamic โ
multihead_attention_nnx โ
multihead_attention_2_nnx_dynamic โ
multihead_attention_2_nnx โ
|
v0.2.0 |
| SequentialReLU | Two ReLU activations chained with nnx.Sequential (no parameters). | sequential_double_relu โ
sequential_double_relu_f64 โ
|
v0.7.1 |
| SequentialWithResidual | Tests nnx.Sequential nested inside a module with a residual connection. | sequential_nested_with_residual โ
|
v0.7.1 |
| TransformerDecoderWithSequential | A single-layer Transformer decoder built with nnx primitives (MHA, LayerNorm, Feed-Forward, Dropout). | tiny_decoder_with_sequential_dynamic โ
tiny_decoder_with_sequential โ
tiny_decoder_with_sequential_and_full_dynamic_shapes_dynamic โ
|
v0.7.1 |
| TransformerDecoderWithoutSequential | A single-layer Transformer decoder built with nnx primitives (MHA, LayerNorm, Feed-Forward, Dropout). | tiny_decoder_without_sequential_dynamic โ
tiny_decoder_without_sequential โ
|
v0.7.1 |
| onnx_functions_000 | one function on an outer layer. | 000_one_function_on_outer_layer_dynamic โ
000_one_function_on_outer_layer โ
|
v0.4.0 |
| onnx_functions_001 | one function on an inner layer. | 001_one_function_inner_dynamic โ
001_one_function_inner โ
|
v0.4.0 |
| onnx_functions_002 | two nested functions. | 002_two_nested_functions_dynamic โ
002_two_nested_functions โ
|
v0.4.0 |
| onnx_functions_003 | two nested functions. | 003_two_simple_nested_functions_dynamic โ
003_two_simple_nested_functions โ
|
v0.4.0 |
| onnx_functions_004 | nested function plus component | 004_nested_function_plus_component_dynamic โ
004_nested_function_plus_component โ
|
v0.4.0 |
| onnx_functions_005 | nested function plus more components | 005_nested_function_plus_component_dynamic โ
005_nested_function_plus_component โ
|
v0.4.0 |
| onnx_functions_006 | one function on an outer layer. | 006_one_function_outer_dynamic โ
006_one_function_outer โ
|
v0.4.0 |
| onnx_functions_007 | transformer block with nested mlp block with call parameter | 007_transformer_block_dynamic โ
007_transformer_block โ
|
v0.4.0 |
| onnx_functions_008 | transformer block with nested mlp block no call parameter | 008_transformer_block_dynamic โ
008_transformer_block โ
|
v0.4.0 |
| onnx_functions_009 | transformer block using decorator on class and function | 009_transformer_block_dynamic โ
009_transformer_block โ
|
v0.4.0 |
| onnx_functions_010 | transformer stack | 010_transformer_stack_dynamic โ
010_transformer_stack โ
|
v0.4.0 |
| onnx_functions_012 | Vision Transformer (ViT) | 012_vit_conv_embedding_dynamic โ
012_vit_conv_embedding โ
|
v0.4.0 |
| onnx_functions_013 | Vision Transformer (ViT) | 013_vit_conv_embedding_with_call_params_dynamic โ
013_vit_conv_embedding_with_call_params โ
013_vit_conv_embedding_with_internal_call_params_dynamic โ
013_vit_conv_embedding_with_internal_call_params โ
|
v0.4.0 |
| onnx_functions_014 | one function on an outer layer. | 014_one_function_with_input_param_with_default_value โ
014_one_function_without_input_param_with_default_value_dynamic โ
014_one_function_without_input_param_with_default_value โ
|
v0.4.0 |
| onnx_functions_015 | one function on an outer layer. | 015_one_function_with_input_param_without_default_value_dynamic โ
015_one_function_with_input_param_without_default_value โ
|
v0.4.0 |
| onnx_functions_016 | nested function plus more components | 016_internal_function_with_input_param_with_default_value_dynamic โ
016_internal_function_with_input_param_with_default_value โ
|
v0.4.0 |
| ClassificationHead | Classification head for Vision Transformer | classification_head_dynamic โ
classification_head โ
|
v0.4.0 |
| ClassificationHeadFlatten | Classification head for Vision Transformer | classification_head_flat_dynamic โ
classification_head_flat โ
|
v0.4.0 |
| ConcatClsToken | Concatenate CLS token to the input embedding | concat_cls_token_dynamic โ
concat_cls_token โ
|
v0.4.0 |
| ConcatClsTokenFlatten | Concatenate CLS token to the input embedding | concat_cls_token_flat_dynamic โ
concat_cls_token_flat โ
|
v0.4.0 |
| ConvEmbedding | Convolutional Token Embedding for MNIST with hierarchical downsampling. | mnist_conv_embedding_dynamic โ
mnist_conv_embedding โ
|
v0.1.0 |
| ConvEmbeddingFlatten | Convolutional Token Embedding for MNIST with hierarchical downsampling. | mnist_conv_embedding_flat_dynamic โ
mnist_conv_embedding_flat โ
|
v0.1.0 |
| FeedForward | MLP in Transformer | feed_forward_dynamic โ
feed_forward โ
|
v0.1.0 |
| FeedForwardFlatten | MLP in Transformer | feed_forward_flat_dynamic โ
feed_forward_flat โ
|
v0.1.0 |
| GetToken | Get the CLS token from the input embedding | get_token_dynamic โ
get_token โ
|
v0.4.0 |
| GetTokenFlatten | Get the CLS token from the input embedding | get_token_flat_dynamic โ
get_token_flat โ
|
v0.4.0 |
| PatchEmbedding | Cutting the image into patches and linearly embedding them. | patch_embedding_dynamic โ
patch_embedding โ
|
v0.1.0 |
| PatchEmbeddingFlatten | Cutting the image into patches and linearly embedding them. | patch_embedding_flat_dynamic โ
patch_embedding_flat โ
|
v0.1.0 |
| PositionalEmbedding | Add positional embedding to the input embedding | positional_embedding_dynamic โ
positional_embedding โ
|
v0.4.0 |
| PositionalEmbeddingFlatten | Add positional embedding to the input embedding | positional_embedding_flat_dynamic โ
positional_embedding_flat โ
|
v0.4.0 |
| TransformerBlock | Transformer from 'Attention Is All You Need.' | transformer_block_dynamic โ
transformer_block โ
|
v0.1.0 |
| TransformerBlockFlatten | Transformer from 'Attention Is All You Need.' | transformer_block_flat_dynamic โ
transformer_block_flat โ
|
v0.1.0 |
| TransformerStack | Stack of Transformer blocks | transformer_stack_dynamic โ
transformer_stack โ
|
v0.1.0 |
| TransformerStackFlatten | Stack of Transformer blocks | transformer_stack_flat_dynamic โ
transformer_stack_flat โ
|
v0.1.0 |
| VisionTransformer | A Vision Transformer (ViT) model for MNIST with configurable embedding type. | vit_conv_embedding_dynamic โ
vit_conv_embedding โ
vit_patch_embedding โ
|
v0.2.0 |
| VisionTransformerFlatten | A Vision Transformer (ViT) model for MNIST with configurable embedding type. | vit_conv_embedding_flat_dynamic โ
vit_conv_embedding_flat โ
vit_patch_embedding_flat_dynamic โ
vit_patch_embedding_flat โ
|
v0.2.0 |
๐ Dependencies
Versions of Major Dependencies:
| Library | Versions |
|---|---|
JAX |
0.6.2 |
Flax |
0.11.1 |
onnx |
1.18.0 |
onnxruntime |
1.22.1 |
Note: For more details, check pyproject.toml.
โ ๏ธ Limitations
- Currently not all JAX/Flax components are supported (you can easily help expand this coverage!).
- Function references need dynamic resolution at call-time.
- ONNX graph composition is done in-memory before saving to disk, potentially causing memory issues with very large models.
๐ค How to Contribute
We warmly welcome contributions!
How you can help:
- Add a plugin: Extend
jax2onnxby writing a simple Python file injax2onnx/plugins. a custom primitive or an example. - Bug fixes & improvements: PRs and issues are always welcome.
๐พ Installation
Install from PyPI:
pip install jax2onnx
๐ License
This project is licensed under the Apache License, Version 2.0. See LICENSE for details.
๐ Special Thanks
Special thanks for example contributions to @burakssen, @Cadynum, @clementpoiret and @PVirie
Special thanks for plugin contributions to @burakssen, @clementpoiret and @Clouder0
Special thanks to tumaer/JAXFLUIDS for contributing valuable insights rooted in physics simulation use cases.
Special thanks to @lutzroeder for making shapes internal to ONNX function visible in his great Netron viewer.
Special thanks to the community members involved in:
A huge thanks especially to @limarta, whose elegant jaxpr-to-ONNX demonstration significantly inspired this project.
Happy converting! ๐
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jax2onnx-0.7.5.tar.gz.
File metadata
- Download URL: jax2onnx-0.7.5.tar.gz
- Upload date:
- Size: 356.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/2.1.3 CPython/3.12.5 Linux/6.6.87.2-microsoft-standard-WSL2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
56809a0977a201087a484ae35c919d4806fa4d61ab531296769a98ad6332def0
|
|
| MD5 |
638555db3ca325680dd911e741394949
|
|
| BLAKE2b-256 |
37b0a2680be818f5ea75167c8a075f89e73f3dd6d8041b5d4ba4015e2a112dcc
|
File details
Details for the file jax2onnx-0.7.5-py3-none-any.whl.
File metadata
- Download URL: jax2onnx-0.7.5-py3-none-any.whl
- Upload date:
- Size: 525.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/2.1.3 CPython/3.12.5 Linux/6.6.87.2-microsoft-standard-WSL2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c9a56dd3b3ab30a3ad02625cb25dc05840e9c8ec5ec1e4c6ac408240f4d9cc96
|
|
| MD5 |
cf77efcc1bc90cbf49f91ff84e54c5ac
|
|
| BLAKE2b-256 |
36d944b7b419d2572d89bd0ce2e79602e68ead060644c1be3b2797a4b8ef7588
|