Skip to main content

export JAX to ONNX

Project description

jax2onnx 🌟

CI PyPI version

jax2onnx converts your JAX, Flax NNX, Flax Linen, Equinox functions directly into the ONNX format.

jax2onnx.svg

✨ Key Features

  • simple API
    Easily convert JAX callables—including Flax NNX, Flax Linen and Equinox models—into ONNX format using to_onnx(...).

  • model structure preserved
    With @onnx_function, submodules appear as named functions in the ONNX graph (e.g. in Netron). Useful for readability and reuse.

  • dynamic input support
    Use abstract dimensions like 'B' or pass scalars as runtime inputs. Models stay flexible without retracing.

  • plugin-based extensibility
    Add support for new primitives by writing small, local plugins.

  • onnx-ir native pipeline
    Conversion, optimization, and post-processing all run on the typed onnx_ir toolkit—no protobuf juggling—and stay memory-lean before the final ONNX serialization.

  • Netron-friendly outputs
    Generated graphs carry shape/type annotations and a clean hierarchy, so tools like Netron stay easy to read.


🚀 Quickstart

Install and export your first model in minutes:

pip install jax2onnx

Convert your JAX callable to ONNX in just a few lines:

from flax import nnx
from jax2onnx import to_onnx

# Define a simple MLP (from Flax docs)
class MLP(nnx.Module):
    def __init__(self, din, dmid, dout, *, rngs): 
        self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
        self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
        self.bn = nnx.BatchNorm(dmid, rngs=rngs)
        self.linear2 = nnx.Linear(dmid, dout, rngs=rngs) 
    def __call__(self, x): 
        x = nnx.gelu(self.dropout(self.bn(self.linear1(x))))
        return self.linear2(x)

# Instantiate model
my_callable = MLP(din=30, dmid=20, dout=10, rngs=nnx.Rngs(0))

# Export straight to disk without keeping the proto in memory
to_onnx(
    my_callable,
    [("B", 30)],
    return_mode="file",
    output_path="my_callable.onnx",
)

🔎 See it visualized: my_callable.onnx


🧠 ONNX Functions — Minimal Example

ONNX functions help encapsulate reusable subgraphs. Simply use the @onnx_function decorator to make your callable an ONNX function. Just an @onnx_function decorator to make your callable an ONNX function

from flax import nnx
from jax2onnx import onnx_function, to_onnx

# just an @onnx_function decorator to make your callable an ONNX function
@onnx_function
class MLPBlock(nnx.Module):
  def __init__(self, dim, *, rngs):
    self.linear1 = nnx.Linear(dim, dim, rngs=rngs)
    self.linear2 = nnx.Linear(dim, dim, rngs=rngs)
    self.batchnorm = nnx.BatchNorm(dim, rngs=rngs)
  def __call__(self, x):
    return nnx.gelu(self.linear2(self.batchnorm(nnx.gelu(self.linear1(x)))))

# Use it inside another module
class MyModel(nnx.Module):
  def __init__(self, dim, *, rngs):
    self.block1 = MLPBlock(dim, rngs=rngs)
    self.block2 = MLPBlock(dim, rngs=rngs)
  def __call__(self, x):
    return self.block2(self.block1(x))

callable = MyModel(256, rngs=nnx.Rngs(0))
to_onnx(
    callable,
    [(100, 256)],
    return_mode="file",
    output_path="model_with_function.onnx",
)

🔎 See it visualized: model_with_function.onnx


SotA examples 🚀


🧩 Coverage & Examples (Interactive)

[!TIP] JAX · Flax · Equinox — explore everything that’s supported and see it in action.

  • Support matrix: status per component
  • 🧪 Exact regression testcase for each entry
  • 🔍 One-click Netron graph to inspect nodes, shapes, attributes
  • 🧩 Examples that compose multiple components (Conv→Norm→Activation→Pool, MLP w/ LayerNorm+Dropout, reshape/transpose/concat, scan/while_loop, gather/scatter, …)

Links: Open support matrix ↗ · Browse examples ↗


📅 Roadmap and Releases

Planned

  • Broaden coverage of JAX, Flax NNX/Linen, and Equinox components.
  • Expand SotA example support for vision and language models.
  • Improve support for physics-based simulations

Current Productive Version

  • 0.11.0:
    • Initial Flax Linen support: core layers (Dense/DenseGeneral, Conv/ConvTranspose/ConvLocal, pooling, BatchNorm/LayerNorm/GroupNorm/RMSNorm/InstanceNorm), Dropout, Einsum/Embed, spectral/weight norm wrappers, activation coverage (GELU plus glu/hard_/log_/relu6/silu-swish/tanh/normalize/one_hot), attention stack (dot_product_attention, dot_product_attention_weights, make_attention_mask/make_causal_mask, SelfAttention, MultiHeadDotProductAttention, MultiHeadAttention), recurrent stack (SimpleCell, GRUCell, MGUCell, LSTMCell, OptimizedLSTMCell, ConvLSTMCell, RNN, Bidirectional), and Linen examples (MLP/CNN/Sequential).
    • Modernized IR optimization pipeline: standard onnx_ir CSE pass adoption, removed legacy helpers/getattr patterns, and simplified tests with direct graph iteration.

Past Versions

See past_versions for the full release archive.


❓ Troubleshooting

If conversion doesn't work out of the box, it could be due to:

  • Non-dynamic function references:
    JAXPR-based conversion requires function references to be resolved dynamically at call-time.
    Solution: Wrap your function call inside a lambda to enforce dynamic resolution:

    my_dynamic_callable_function = lambda x: original_function(x)
    
  • Unsupported primitives:
    The callable may use a primitive not yet or not fully supported by jax2onnx.
    Solution: Write a plugin to handle the unsupported function (this is straightforward!).

Looking for provenance details while debugging? Check out the new Stacktrace Metadata guide.


🤝 How to Contribute

We warmly welcome contributions!

How you can help:

  • Add a plugin: Extend jax2onnx by writing a simple Python file in jax2onnx/plugins: a primitive or an example. The Plugin Quickstart walks through the process step-by-step.
  • Bug fixes & improvements: PRs and issues are always welcome.

📌 Dependencies

Latest supported version of major dependencies:

Library Versions
JAX 0.8.2
Flax 0.12.2
Equinox 0.13.2
onnx-ir 0.1.13
onnx 1.20.0
onnxruntime 1.23.2

For exact pins and extras, see pyproject.toml.


📜 License

This project is licensed under the Apache License, Version 2.0. See LICENSE for details.


🌟 Special Thanks

✨ Special thanks to @clementpoiret for initiating Equinox support and for Equimo, which brings modern vision models—such as DINOv3—to JAX/Equinox.

✨ Special thanks to @justinchuby for introducing onnx-ir as a scalable and more efficient way to handle ONNX model construction.

✨ Special thanks to @atveit for introducing us to gpt-oss-jax-vs-torch-numerical-comparison.

✨ Special thanks for example contributions to @burakssen, @Cadynum, @clementpoiret and @PVirie

✨ Special thanks for plugin contributions to @burakssen, @clementpoiret, @Clouder0, @rakadam and benmacadam64

✨ Special thanks to @benmacadam64 for championing the complex-number handling initiative.

✨ Special thanks to tumaer/JAXFLUIDS for contributing valuable insights rooted in physics simulation use cases.

✨ Special thanks to @lutzroeder for making shapes internal to ONNX function visible in his great Netron viewer.

✨ Special thanks to the community members involved in:

✨ Special thanks to @limarta, whose elegant jaxpr-to-ONNX demonstration significantly inspired this project.


Happy converting! 🎉

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax2onnx-0.11.0.tar.gz (480.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax2onnx-0.11.0-py3-none-any.whl (753.5 kB view details)

Uploaded Python 3

File details

Details for the file jax2onnx-0.11.0.tar.gz.

File metadata

  • Download URL: jax2onnx-0.11.0.tar.gz
  • Upload date:
  • Size: 480.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.1.3 CPython/3.12.5 Linux/6.6.87.2-microsoft-standard-WSL2

File hashes

Hashes for jax2onnx-0.11.0.tar.gz
Algorithm Hash digest
SHA256 ad155689d3de14629f7324fdf3a9bb8a11e2c0c9af7c775f5f0053f74033fa8d
MD5 501be9da2980d7c613752651a409af4e
BLAKE2b-256 502d39ffe12ec93ff9024818e734b447077fb165ae3da1c81279ee077c0de5f2

See more details on using hashes here.

File details

Details for the file jax2onnx-0.11.0-py3-none-any.whl.

File metadata

  • Download URL: jax2onnx-0.11.0-py3-none-any.whl
  • Upload date:
  • Size: 753.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.1.3 CPython/3.12.5 Linux/6.6.87.2-microsoft-standard-WSL2

File hashes

Hashes for jax2onnx-0.11.0-py3-none-any.whl
Algorithm Hash digest
SHA256 cef5bdd7388bb586e8f334b97b540a4bae6e227d0af7354d048b7bcc70ed0df3
MD5 fe8c5b2e2df64e5107fccfb6cb3fb55d
BLAKE2b-256 3615c80600c189b945c9e2aeb94b1d9cf62f9b758275a905fe1ca147c947f4db

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page