Skip to main content

export JAX to ONNX - focus on flax nnx

Project description

jax2onnx ๐ŸŒŸ

jax2onnx converts your JAX/Flax functions directly into the ONNX format.

img.png

โœจ Key Features

  • Simple API
    Convert any JAX/Flax model to ONNX using to_onnx(...) or save_onnx(...)

  • Model structure preserved
    With @onnx_function, submodules appear as named functions in the ONNX graph (e.g. in Netron). Useful for readability and reuse.

  • Dynamic input support
    Use abstract dimensions like 'B' or pass scalars as runtime inputs. Models stay flexible without retracing.

  • Plugin-based extensibility
    Add support for new primitives by writing small, local plugins.

  • Netron-friendly outputs
    All generated ONNX graphs include shape/type annotations and are structured for clear visualization.


๐Ÿš€ Quickstart

Convert your JAX callable to ONNX in just a few lines:

import onnx
from flax import nnx
from jax2onnx import to_onnx

# Define a simple MLP (from Flax docs)
class MLP(nnx.Module):
    def __init__(self, din, dmid, dout, *, rngs): 
        self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
        self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
        self.bn = nnx.BatchNorm(dmid, rngs=rngs)
        self.linear2 = nnx.Linear(dmid, dout, rngs=rngs) 
    def __call__(self, x): 
        x = nnx.gelu(self.dropout(self.bn(self.linear1(x))))
        return self.linear2(x)

# Instantiate model
my_callable = MLP(din=30, dmid=20, dout=10, rngs=nnx.Rngs(0))

# Convert to ONNX
onnx_model = to_onnx(my_callable, [("B", 30)])

# Save the model
onnx.save_model(onnx_model, "my_callable.onnx")

๐Ÿ”Ž See it visualized: jax_callable.onnx


๐Ÿง  ONNX Functions โ€” Minimal Example

ONNX functions help encapsulate reusable subgraphs. Simply use the @onnx_function decorator to make your callable an ONNX function. Just an @onnx_function decorator to make your callable an ONNX function

from onnx import save_model
from flax import nnx
from jax2onnx import onnx_function, to_onnx

# just an @onnx_function decorator to make your callable an ONNX function
@onnx_function
class MLPBlock(nnx.Module):
  def __init__(self, dim, *, rngs):
    self.linear1 = nnx.Linear(dim, dim, rngs=rngs)
    self.linear2 = nnx.Linear(dim, dim, rngs=rngs)
    self.batchnorm = nnx.BatchNorm(dim, rngs=rngs)
  def __call__(self, x):
    return nnx.gelu(self.linear2(self.batchnorm(nnx.gelu(self.linear1(x)))))

# Use it inside another module
class MyModel(nnx.Module):
  def __init__(self, dim, *, rngs):
    self.block1 = MLPBlock(dim, rngs=rngs)
    self.block2 = MLPBlock(dim, rngs=rngs)
  def __call__(self, x):
    return self.block2(self.block1(x))

callable = MyModel(256, rngs=nnx.Rngs(0))
model = to_onnx(callable, [(100, 256)])
save_model(model, "docs/onnx/model_with_function.onnx")

๐Ÿ”Ž See it visualized: model_with_function.onnx


๐Ÿ“… Roadmap and Releases

Planned Versions

  • Ongoing: Expanding JAX component coverage.
  • 0.5.0: Some more ONNX function support ... batch dims, function reuse, make graph optimizer work within functions, allow user friendly var names

Current Productive Version

  • 0.4.2 (PyPI): Cleanup and fixes to the basic ONNX function release.

Past Versions

  • 0.4.1 (ONNX functions): Introducing simple ONNX function support. Making use of ONNX functions is easy for the user: just a @onnx_function decorator making a callable an ONNX function. Each @onnx_function decorator creates a new ONNX function instance on the call graph.
  • 0.3.2: relaxed the minimum Python version to 3.10.
  • 0.3.0: Streamlined the plugin system with automatic registration and simplified integration of custom primitives.
  • 0.2.0 (First PyPI Release): Rebased the implementation on jaxpr, improving usability and adding low-level lax components.
  • 0.1.0 (Initial Approach, Not Released to PyPI): Produced ONNX exports for some nnx components and nnx-based examples, including a VisualTransformer.

โ“ Troubleshooting

If conversion doesn't work out of the box, it could be due to:

  • Non-dynamic function references:
    JAXPR-based conversion requires function references to be resolved dynamically at call-time.
    Solution: Wrap your function call inside a lambda to enforce dynamic resolution:

    my_dynamic_callable_function = lambda x: original_function(x)
    
  • Unsupported primitives:
    The callable may use a primitive not yet or not fully supported by jax2onnx.
    Solution: Write a plugin to handle the unsupported function (this is straightforward!).


๐Ÿงฉ Supported JAX/ONNX Components

JAX Component ONNX Components Testcases Since
add Add add โœ… v0.1.0
concat Concat concat โœ…
concat_abstract_middle_dim_dynamic โœ…
concat_abstract_middle_dim โœ…
v0.1.0
einsum Einsum einsum โœ…
einsum_preferred_element_type โœ…
einsum_matmul โœ…
einsum_dynamic_dynamic โœ…
einsum_dynamic โœ…
einsum_dynamic_matmul_dynamic โœ…
einsum_dynamic_matmul โœ…
einsum_transpose โœ…
einsum_dynamic_transpose_dynamic โœ…
einsum_dynamic_transpose โœ…
einsum_dynamic_matmul2_dynamic โœ…
einsum_dynamic_matmul2 โœ…
einsum_dynamic_matmul3_dynamic โœ…
einsum_dynamic_matmul3 โœ…
einsum_outer_product โœ…
einsum_trace โœ…
einsum_sum โœ…
einsum_broadcast โœ…
einsum_reduce โœ…
einsum_permute โœ…
einsum_dynamic_outer_dynamic โœ…
einsum_dynamic_outer โœ…
einsum_dynamic_reduce_dynamic โœ…
einsum_dynamic_reduce โœ…
v0.1.0
matmul MatMul matmul_2d โœ…
matmul_1d_2d โœ…
matmul_2d_1d โœ…
matmul_dynamic_dynamic โœ…
matmul_dynamic โœ…
matmul_dynamic_a_dynamic โœ…
matmul_dynamic_a โœ…
matmul_dynamic_b_dynamic โœ…
matmul_dynamic_b โœ…
matmul_1d โœ…
matmul_3d โœ…
v0.1.0
reshape Reshape reshape_1 โœ…
reshape_2 โœ…
reshape_3 โœ…
reshape_4_dynamic โœ…
reshape_4 โœ…
reshape_to_scalar โœ…
reshape_from_scalar โœ…
v0.1.0
shape Shape shape_basic โœ…
shape_dynamic_dynamic โœ…
shape_dynamic โœ…
0.4.0
squeeze Squeeze squeeze_single_dim โœ…
squeeze_multiple_dims โœ…
squeeze_vit_output โœ…
squeeze_dynamic_batch_dynamic โœ…
squeeze_dynamic_batch โœ…
squeeze_all_dims โœ…
squeeze_negative_axis โœ…
squeeze_negative_axis_tuple โœ…
squeeze_dynamic_and_negative_axis_dynamic โœ…
squeeze_dynamic_and_negative_axis โœ…
v0.1.0
tile Tile tile_repeats_tensor โœ…
tile_a โœ…
tile_b โœ…
tile_c โœ…
tile_d โœ…
tile_dynamic_dynamic โœ…
tile_dynamic โœ…
tile_pad โœ…
v0.1.0
transpose Transpose transpose_basic โœ…
transpose_reverse โœ…
transpose_4d โœ…
transpose_square_matrix โœ…
transpose_high_dim โœ…
transpose_no_axes โœ…
transpose_dynamic_dynamic โœ…
transpose_dynamic โœ…
v0.1.0
add Add add โœ… v0.2.0
argmax ArgMax argmax_test1 โœ…
argmax_test2 โœ…
v0.2.0
argmin ArgMin argmin_test1 โœ…
argmin_test2 โœ…
v0.2.0
broadcast_in_dim Expand broadcast_in_dim โœ…
broadcast_in_dim_2d_to_3d โœ…
broadcast_in_dim_scalar โœ…
v0.2.0
concatenate Concat concatenate โœ…
concatenate_axis1 โœ…
concatenate_dynamic_dynamic โœ…
concatenate_dynamic โœ…
concatenate_3d โœ…
v0.2.0
conv Conv conv โœ…
conv2 โœ…
v0.2.0
convert_element_type Cast convert_element_type โœ… v0.2.0
device_put Identity device_put_array โœ…
device_put_scalar โœ…
v0.4.0
div Div div โœ… v0.2.0
dot_general MatMul dot_general โœ… v0.2.0
dynamic_slice Slice dynamic_slice_test1 โœ…
dynamic_slice_2d โœ…
dynamic_slice_3d โœ…
v0.1.0
eq Equal eq โœ… v0.2.0
exp Exp exp โœ… v0.2.0
gather Gather gather โœ… v0.2.0
gt Greater gt โœ… v0.2.0
integer_pow Pow integer_pow โœ… v0.2.0
log Log log โœ… v0.2.0
lt Less lt โœ… v0.2.0
max Max max โœ… v0.2.0
min Min min_test1 โœ… v0.1.0
mul Mul mul_test1 โœ…
mul_test2 โœ…
v0.1.0
ne Equal
Not
ne โœ… v0.2.0
neg Neg neg โœ… v0.2.0
reduce_max ReduceMax reduce_max โœ… v0.2.0
reduce_min ReduceMin reduce_min โœ… v0.2.0
reduce_sum ReduceSum reduce_sum โœ… v0.2.0
reshape Reshape reshape โœ… v0.2.0
slice Slice slice_test1 โœ… v0.1.0
sort TopK sort_1d โœ…
sort_1d_empty โœ…
sort_1d_single โœ…
sort_1d_larger โœ…
sort_1d_specific_values โœ…
v0.2.0
sqrt Sqrt sqrt โœ… v0.2.0
square Mul square โœ… v0.2.0
squeeze Squeeze squeeze โœ… v0.2.0
stop_gradient Identity stop_gradient โœ… v0.2.0
sub Sub sub_test1 โœ…
sub_test2 โœ…
v0.1.0
tanh Tanh tanh โœ… v0.2.0
transpose Transpose transpose_basic โœ… v0.2.0
softmax Softmax softmax โœ…
softmax_2d โœ…
softmax_3d โœ…
v0.1.0
avg_pool AveragePool
Transpose
avg_pool โœ…
avg_pool_same_padding โœ…
avg_pool_default_padding โœ…
avg_pool_stride1 โœ…
avg_pool_large_window โœ…
avg_pool_single_batch โœ…
avg_pool_dynamic_batch_dynamic โœ…
avg_pool_dynamic_batch โœ…
avg_pool_stride_none โœ…
avg_pool_count_include_pad_false โœ…
v0.1.0
batch_norm BatchNormalization batch_norm โœ…
batch_norm_2 โœ…
batch_norm_3d โœ…
batch_norm_float64 โœ…
batch_norm_single_batch โœ…
batch_norm_2d_train โœ…
batch_norm_4d_use_bias โœ…
batch_norm_4d_use_scale โœ…
batch_norm_momentum โœ…
batch_norm_epsilon โœ…
batch_norm_float32 โœ…
batch_norm_3d_train โœ…
batch_norm_single_batch_train โœ…
v0.1.0
conv Conv
Transpose
conv_basic_bias_dynamic โœ…
conv_basic_bias โœ…
conv_basic_bias_2 โœ…
conv_basic_bias_3 โœ…
conv_stride2_bias โœ…
conv_no_bias_dynamic โœ…
conv_no_bias โœ…
conv_valid_padding โœ…
conv_stride1 โœ…
conv_stride2 โœ…
conv_different_kernel โœ…
conv_float64 โœ…
conv_single_batch โœ…
conv_large_batch โœ…
v0.1.0
conv_transpose ConvTranspose conv_transpose_valid_padding โœ…
conv_transpose_circular_padding โœ…
v0.3.0
dot_product_attention Cast
Div
Einsum
Gather
Shape
Softmax
Sqrt
dpa_basic โœ…
dpa_diff_heads_embed โœ…
dpa_batch4_seq16 โœ…
dpa_float64 โœ…
dpa_heads1_embed4 โœ…
dpa_heads8_embed8 โœ…
dpa_batch1_seq2 โœ…
dpa_batch8_seq4 โœ…
dpa_axis1 โœ…
v0.1.0
dropout Dropout dropout_init_params โœ…
dropout_call_params โœ…
v0.1.0
einsum Add
Einsum
einsum_module_with_bias โœ…
einsum_module_no_bias โœ…
v0.4.2
elu Elu elu โœ… v0.1.0
gelu Gelu gelu โœ…
gelu_1 โœ…
gelu_2 โœ…
gelu_3 โœ…
v0.1.0
group_norm GroupNormalization group_norm โœ…
group_norm_2 โœ…
v0.3.0
layer_norm LayerNormalization layer_norm โœ…
layer_norm_multiaxis โœ…
v0.1.0
leaky_relu LeakyRelu leaky_relu โœ… v0.1.0
linear Gemm
Reshape
linear_2d โœ…
linear_dynamic โœ…
linear โœ…
v0.1.0
linear_general Gemm
Reshape
linear_general_dynamic โœ…
linear_general โœ…
linear_general_2 โœ…
linear_general_3 โœ…
linear_general_4 โœ…
v0.1.0
log_softmax LogSoftmax log_softmax โœ… v0.1.0
max_pool MaxPool
Transpose
max_pool โœ…
max_pool_same_padding โœ…
v0.1.0
relu Relu relu โœ…
relu_2 โœ…
v0.1.0
rms_norm RMSNormalization rms_norm โœ…
rms_norm_2 โœ…
v0.3.0
sigmoid Sigmoid sigmoid โœ… v0.1.0
softmax Softmax softmax โœ… v0.1.0
softplus Softplus softplus โœ… v0.1.0
tanh Tanh tanh โœ… v0.1.0

Legend:
โœ… = Passed
โŒ = Failed
โž– = No testcase yet


๐ŸŽฏ Examples

Component Description Testcases Since
AutoEncoder A simple autoencoder example. simple_autoencoder โœ… v0.2.0
CNN A simple convolutional neural network (CNN). simple_cnn โœ… v0.1.0
ClassificationHead Classification head for Vision Transformer classification_head โœ… v0.4.0
ConcatClsToken Concatenate CLS token to the input embedding concat_cls_token โœ… v0.4.0
ConvEmbedding Convolutional Token Embedding for MNIST with hierarchical downsampling. mnist_conv_embedding โœ… v0.1.0
FeedForward MLP in Transformer feed_forward โœ… v0.1.0
MLP A simple Multi-Layer Perceptron (MLP) with BatchNorm, Dropout, and GELU activation. simple_mlp_dynamic โœ…
simple_mlp โœ…
simple_mlp_with_call_params_dynamic โœ…
simple_mlp_with_call_params โœ…
v0.1.0
MultiHeadAttention This is a multi-head attention module implemented by Flax/nnx that has no ONNX correspondent on the same granularity. multihead_attention โœ… v0.2.0
PatchEmbedding Cutting the image into patches and linearly embedding them. patch_embedding โœ… v0.1.0
PositionalEmbedding Add positional embedding to the input embedding positional_embedding โœ… v0.4.0
TransformerBlock Transformer from 'Attention Is All You Need.' transformer_block โœ… v0.1.0
TransformerStack Stack of Transformer blocks transformer_stack โœ… v0.1.0
VisionTransformer A Vision Transformer (ViT) model for MNIST with configurable embedding type. vit_conv_embedding โœ…
vit_patch_embedding โœ…
v0.2.0
onnx_functions_000 one function on an outer layer. 000_one_function_on_outer_layer โœ… v0.4.0
onnx_functions_001 one function on an inner layer. 001_one_function_inner โœ… v0.4.0
onnx_functions_002 two nested functions. 002_two_nested_functions โœ… v0.4.0
onnx_functions_003 two nested functions. 003_two_simple_nested_functions โœ… v0.4.0
onnx_functions_004 nested function plus component 004_nested_function_plus_component โœ… v0.4.0
onnx_functions_005 nested function plus more components 005_nested_function_plus_component โœ… v0.4.0
onnx_functions_006 one function on an outer layer. 006_one_function_outer โœ… v0.4.0
onnx_functions_008 transformer block with nested mlp block no call parameter 008_transformer_block โœ… v0.4.0
onnx_functions_009 transformer block using decorator on class and function 009_transformer_block โœ… v0.4.0
onnx_functions_010 transformer stack 010_transformer_stack โœ… v0.4.0
onnx_functions_012 Vision Transformer (ViT) 012_vit_conv_embedding โœ… v0.4.0
onnx_functions_013 Vision Transformer (ViT) 013_vit_conv_embedding_with_call_params โœ…
013_vit_conv_embedding_with_internal_call_params โœ…
v0.4.0
onnx_functions_014 one function on an outer layer. 014_one_function_with_input_param_with_default_value โœ…
014_one_function_without_input_param_with_default_value โœ…
v0.4.0
onnx_functions_015 one function on an outer layer. 015_one_function_with_input_param_without_default_value โœ… v0.4.0
onnx_functions_016 nested function plus more components 016_internal_function_with_input_param_with_default_value โœ… v0.4.0

๐Ÿ“Œ Dependencies

Versions of Major Dependencies:

Library Versions
JAX 0.5.3
Flax 0.10.5
onnx 1.17.0
onnxruntime 1.21.0

Note: For more details, check pyproject.toml.


โš ๏ธ Limitations

  • Currently not all JAX/Flax components are supported (you can easily help expand this coverage!).
  • Function references need dynamic resolution at call-time.
  • ONNX graph composition is done in-memory before saving to disk, potentially causing memory issues with very large models.

๐Ÿค How to Contribute

We warmly welcome contributions!

How you can help:

  • Add a plugin: Extend jax2onnx by writing a simple Python file in jax2onnx/plugins. a custom primitive or an example.
  • Bug fixes & improvements: PRs and issues are always welcome.

๐Ÿ’พ Installation

Install from PyPI:

pip install jax2onnx  

Or get the latest development version from TestPyPI:

pip install -i https://test.pypi.org/simple/ jax2onnx

๐Ÿ“œ License

This project is licensed under the Apache License, Version 2.0. See LICENSE for details.


๐ŸŒŸ Special Thanks

Special thanks to @lutzroeder for making shapes internal to ONNX function visible in his great Netron viewer.

Special thanks to the community members involved in:

A huge thanks especially to @limarta, whose elegant jaxpr-to-ONNX demonstration significantly inspired this project.


Happy converting! ๐ŸŽ‰

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax2onnx-0.4.2.tar.gz (113.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax2onnx-0.4.2-py3-none-any.whl (202.1 kB view details)

Uploaded Python 3

File details

Details for the file jax2onnx-0.4.2.tar.gz.

File metadata

  • Download URL: jax2onnx-0.4.2.tar.gz
  • Upload date:
  • Size: 113.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.0.1 CPython/3.12.5 Linux/5.15.167.4-microsoft-standard-WSL2

File hashes

Hashes for jax2onnx-0.4.2.tar.gz
Algorithm Hash digest
SHA256 733a191ecbf7e4c90b6e80b7d1656c0cbe828341151d84774eb43923eabf06a2
MD5 6303bda068377d8faac64881a3cb2491
BLAKE2b-256 92c9d82406283575a939a791cf48aa76ff8c42a76d36e7eda62b6b1efe0290a4

See more details on using hashes here.

File details

Details for the file jax2onnx-0.4.2-py3-none-any.whl.

File metadata

  • Download URL: jax2onnx-0.4.2-py3-none-any.whl
  • Upload date:
  • Size: 202.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.0.1 CPython/3.12.5 Linux/5.15.167.4-microsoft-standard-WSL2

File hashes

Hashes for jax2onnx-0.4.2-py3-none-any.whl
Algorithm Hash digest
SHA256 a88dab3f7dd151de76e637643fdd0b6c617d7629b1c804095f84c006ec73b82a
MD5 29916f47966fe8b691019b27bc870439
BLAKE2b-256 316a59819a433c03ab5dcb8e0f9682f1bae297b639cd8558ae0111bf1bdb1303

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page