Skip to main content

export JAX to ONNX - focus on flax nnx

Project description

jax2onnx ๐ŸŒŸ

jax2onnx converts your JAX/Flax functions directly into the ONNX format.

img.png

โœจ Key Features

  • Simple API
    Convert any JAX/Flax model to ONNX using to_onnx(...)

  • Model structure preserved
    With @onnx_function, submodules appear as named functions in the ONNX graph (e.g. in Netron). Useful for readability and reuse.

  • Dynamic input support
    Use abstract dimensions like 'B' or pass scalars as runtime inputs. Models stay flexible without retracing.

  • Plugin-based extensibility
    Add support for new primitives by writing small, local plugins.

  • Netron-friendly outputs
    All generated ONNX graphs include shape/type annotations and are structured for clear visualization.


๐Ÿš€ Quickstart

Convert your JAX callable to ONNX in just a few lines:

import onnx
from flax import nnx
from jax2onnx import to_onnx

# Define a simple MLP (from Flax docs)
class MLP(nnx.Module):
    def __init__(self, din, dmid, dout, *, rngs): 
        self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
        self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
        self.bn = nnx.BatchNorm(dmid, rngs=rngs)
        self.linear2 = nnx.Linear(dmid, dout, rngs=rngs) 
    def __call__(self, x): 
        x = nnx.gelu(self.dropout(self.bn(self.linear1(x))))
        return self.linear2(x)

# Instantiate model
my_callable = MLP(din=30, dmid=20, dout=10, rngs=nnx.Rngs(0))

# Convert to ONNX
onnx_model = to_onnx(my_callable, [("B", 30)])

# Save the model
onnx.save_model(onnx_model, "my_callable.onnx")

๐Ÿ”Ž See it visualized: my_callable.onnx


๐Ÿง  ONNX Functions โ€” Minimal Example

ONNX functions help encapsulate reusable subgraphs. Simply use the @onnx_function decorator to make your callable an ONNX function. Just an @onnx_function decorator to make your callable an ONNX function

from onnx import save_model
from flax import nnx
from jax2onnx import onnx_function, to_onnx

# just an @onnx_function decorator to make your callable an ONNX function
@onnx_function
class MLPBlock(nnx.Module):
  def __init__(self, dim, *, rngs):
    self.linear1 = nnx.Linear(dim, dim, rngs=rngs)
    self.linear2 = nnx.Linear(dim, dim, rngs=rngs)
    self.batchnorm = nnx.BatchNorm(dim, rngs=rngs)
  def __call__(self, x):
    return nnx.gelu(self.linear2(self.batchnorm(nnx.gelu(self.linear1(x)))))

# Use it inside another module
class MyModel(nnx.Module):
  def __init__(self, dim, *, rngs):
    self.block1 = MLPBlock(dim, rngs=rngs)
    self.block2 = MLPBlock(dim, rngs=rngs)
  def __call__(self, x):
    return self.block2(self.block1(x))

callable = MyModel(256, rngs=nnx.Rngs(0))
model = to_onnx(callable, [(100, 256)])
save_model(model, "docs/onnx/model_with_function.onnx")

๐Ÿ”Ž See it visualized: model_with_function.onnx


๐Ÿ“… Roadmap and Releases

Planned Versions

  • In Progress: Expanding coverage of core JAX primitives.
  • Under Evaluation: Initial support for flax.linen models within jax2onnx.
  • Upcoming 0.7.x: Advanced ONNX function support โ€” including function reuse, optimized internal graph generation, and improved variable naming for better readability.
  • Planned 0.6.3: Additional fixes to improve support for physics simulations.

Current Productive Version

  • 0.6.2 (PyPI): Fixed bugs in nnx.conv and lax.reshape; added new primitive jnp.prod.

Past Versions

  • 0.6.1: Improved support for lax.cond and lax.select_n; added new primitives (lax.reduce_and, lax.reduce_or, lax.reduce_prod, lax.reduce_xor); and introduced new examples for jnp.select and jnp.sort.
  • 0.6.0: Introduced the enable_double_precision parameter (default: False) to support physics simulations, and enhanced handling of lax.scatter.
  • 0.5.2: Add support for additional primitives: jnp.where, jnp.arange, jnp.linspace.
  • 0.5.1: Add support for subgraph using primitives: lax.while_loop, lax.cond, lax.fori_loop, lax.scan.
  • 0.5.0: Improved dynamic batch dimension handling by leveraging shape polymorphism for more robust and flexible model export. Added support for jnp.sign, jnp.abs, jnp.iota primitives.
  • 0.4.4: Added support for lax.cos, lax.cosh, lax.sin, lax.sinh and lax.scatter primitives.
  • 0.4.3: Fixed a bug in the validation of JAX callable outputs against their ONNX counterparts. This fix exposed previously hidden failing tests, which are now fixed.
  • 0.4.2: Cleanup and fixes to the basic ONNX function release.
  • 0.4.1 (ONNX functions): Introducing simple ONNX function support. Making use of ONNX functions is easy for the user: just a @onnx_function decorator making a callable an ONNX function. Each @onnx_function decorator creates a new ONNX function instance on the call graph.
  • 0.3.2: relaxed the minimum Python version to 3.10.
  • 0.3.0: Streamlined the plugin system with automatic registration and simplified integration of custom primitives.
  • 0.2.0 (First PyPI Release): Rebased the implementation on jaxpr, improving usability and adding low-level lax components.
  • 0.1.0 (Initial Approach, Not Released to PyPI): Produced ONNX exports for some nnx components and nnx-based examples, including a VisualTransformer.

โ“ Troubleshooting

If conversion doesn't work out of the box, it could be due to:

  • Non-dynamic function references:
    JAXPR-based conversion requires function references to be resolved dynamically at call-time.
    Solution: Wrap your function call inside a lambda to enforce dynamic resolution:

    my_dynamic_callable_function = lambda x: original_function(x)
    
  • Unsupported primitives:
    The callable may use a primitive not yet or not fully supported by jax2onnx.
    Solution: Write a plugin to handle the unsupported function (this is straightforward!).


๐Ÿงฉ Supported JAX/ONNX Components

JAX Component ONNX Components Testcases Since
core.dim_as_value Cast
Gather
Reshape
Shape
dim_as_value_dynamic โœ…
dim_as_value_dynamic_f64 โœ…
dim_as_value โœ…
dim_as_value_f64 โœ…
v0.5.0
jnp.add Add add โœ…
add_f64 โœ…
v0.1.0
jnp.arange Range arange_stop_only_concrete_input_val โœ…
arange_stop_only_concrete_input_val_f64 โœ…
arange_start_stop_concrete_input_val โœ…
arange_start_stop_concrete_input_val_f64 โœ…
arange_start_stop_step_concrete_input_val โœ…
arange_start_stop_step_concrete_input_val_f64 โœ…
arange_float_concrete_input_val โœ…
arange_float_concrete_input_val_f64 โœ…
arange_static_stop_only_int โœ…
arange_static_stop_only_int_f64 โœ…
arange_static_stop_only_float โœ…
arange_static_stop_only_float_f64 โœ…
arange_static_start_stop_int โœ…
arange_static_start_stop_int_f64 โœ…
arange_static_start_stop_step_int โœ…
arange_static_start_stop_step_int_f64 โœ…
arange_static_empty_result_pos_step โœ…
arange_static_empty_result_pos_step_f64 โœ…
arange_static_empty_result_neg_step โœ…
arange_static_empty_result_neg_step_f64 โœ…
arange_static_negative_step โœ…
arange_static_negative_step_f64 โœ…
arange_static_float_step_explicit_dtype โœ…
arange_static_float_step_explicit_dtype_f64 โœ…
arange_static_float_step_inferred_dtype โœ…
arange_static_float_step_inferred_dtype_f64 โœ…
arange_static_stop_zero โœ…
arange_static_stop_zero_f64 โœ…
arange_static_start_equals_stop โœ…
arange_static_start_equals_stop_f64 โœ…
arange_static_large_numbers_int โœ…
arange_static_large_numbers_int_f64 โœ…
v0.5.2
jnp.concatenate Concat concatenate โœ…
concatenate_f64 โœ…
concatenate_abstract_middle_dim_dynamic โœ…
concatenate_abstract_middle_dim_dynamic_f64 โœ…
concatenate_abstract_middle_dim โœ…
concatenate_abstract_middle_dim_f64 โœ…
concatenate_tile_and_symbolic_dynamic โœ…
concatenate_tile_and_symbolic_dynamic_f64 โœ…
concatenate_tile_and_symbolic โœ…
concatenate_tile_and_symbolic_f64 โœ…
v0.1.0
jnp.einsum Einsum einsum_vector_dot โœ…
einsum_vector_dot_f64 โœ…
einsum_matrix_vector โœ…
einsum_matrix_vector_f64 โœ…
einsum_matrix_matrix_dynamic โœ…
einsum_matrix_matrix_dynamic_f64 โœ…
einsum_matrix_matrix โœ…
einsum_matrix_matrix_f64 โœ…
einsum_transpose โœ…
einsum_transpose_f64 โœ…
einsum_batch_transpose_dynamic โœ…
einsum_batch_transpose_dynamic_f64 โœ…
einsum_batch_transpose โœ…
einsum_batch_transpose_f64 โœ…
einsum_diag โœ…
einsum_diag_f64 โœ…
einsum_sum_reduce โœ…
einsum_sum_reduce_f64 โœ…
einsum_multi_operand โœ…
einsum_multi_operand_f64 โœ…
einsum_attention_logits_orig_dynamic โœ…
einsum_attention_logits_orig_dynamic_f64 โœ…
einsum_attention_logits_orig โœ…
einsum_attention_logits_orig_f64 โœ…
einsum_attention_output_orig_dynamic โœ…
einsum_attention_output_orig_dynamic_f64 โœ…
einsum_attention_output_orig โœ…
einsum_attention_output_orig_f64 โœ…
einsum_attention_logits_batched_dynamic โœ…
einsum_attention_logits_batched_dynamic_f64 โœ…
einsum_attention_logits_batched โœ…
einsum_attention_logits_batched_f64 โœ…
einsum_attention_output_batched_dynamic โœ…
einsum_attention_output_batched_dynamic_f64 โœ…
einsum_attention_output_batched โœ…
einsum_attention_output_batched_f64 โœ…
einsum_ellipsis_rank_mismatch โœ…
einsum_ellipsis_rank_mismatch_f64 โœ…
v0.1.0
jnp.linspace Constant linspace_static_basic โœ…
linspace_static_basic_f64 โœ…
linspace_static_endpoint_false โœ…
linspace_static_endpoint_false_f64 โœ…
linspace_static_num_1 โœ…
linspace_static_num_1_f64 โœ…
linspace_static_num_0 โœ…
linspace_static_num_0_f64 โœ…
linspace_static_int_inputs_default_dtype โœ…
linspace_static_int_inputs_default_dtype_f64 โœ…
v0.5.2
jnp.matmul MatMul matmul_2d โœ…
matmul_2d_f64 โœ…
matmul_1d_2d โœ…
matmul_1d_2d_f64 โœ…
matmul_2d_1d โœ…
matmul_2d_1d_f64 โœ…
matmul_dynamic_dynamic โœ…
matmul_dynamic_dynamic_f64 โœ…
matmul_dynamic โœ…
matmul_dynamic_f64 โœ…
matmul_dynamic_a_dynamic โœ…
matmul_dynamic_a_dynamic_f64 โœ…
matmul_dynamic_a โœ…
matmul_dynamic_a_f64 โœ…
matmul_1d โœ…
matmul_1d_f64 โœ…
matmul_3d โœ…
matmul_3d_f64 โœ…
v0.1.0
jnp.prod ReduceProd basic_prod โœ…
basic_prod_f64 โœ…
prod_with_axis โœ…
prod_with_axis_f64 โœ…
prod_with_keepdims โœ…
prod_with_keepdims_f64 โœ…
v0.6.2
jnp.reshape Reshape reshape_1 โœ…
reshape_1_f64 โœ…
reshape_2 โœ…
reshape_2_f64 โœ…
reshape_3 โœ…
reshape_3_f64 โœ…
reshape_4_dynamic โœ…
reshape_4_dynamic_f64 โœ…
reshape_4 โœ…
reshape_4_f64 โœ…
reshape_to_scalar โœ…
reshape_to_scalar_f64 โœ…
reshape_from_scalar โœ…
reshape_from_scalar_f64 โœ…
reshape_cnn_dynamic โœ…
reshape_cnn_dynamic_f64 โœ…
reshape_cnn โœ…
reshape_cnn_f64 โœ…
reshape_valid_flatten_trailing โœ…
reshape_valid_flatten_trailing_f64 โœ…
reshape_with_target_shape_from_symbolic_dim_computation โœ…
reshape_with_target_shape_from_symbolic_dim_computation_f64 โœ…
v0.1.0
jnp.shape Shape shape_basic โœ…
shape_basic_f64 โœ…
shape_dynamic_dynamic โœ…
shape_dynamic_dynamic_f64 โœ…
shape_dynamic โœ…
shape_dynamic_f64 โœ…
0.4.0
jnp.sort TopK sort_1d โœ…
sort_1d_f64 โœ…
sort_2d_axis0_dynamic โœ…
sort_2d_axis0_dynamic_f64 โœ…
sort_2d_axis0 โœ…
sort_2d_axis0_f64 โœ…
v0.5.2
jnp.squeeze Squeeze squeeze_single_dim โœ…
squeeze_single_dim_f64 โœ…
squeeze_multiple_dims โœ…
squeeze_multiple_dims_f64 โœ…
squeeze_vit_output โœ…
squeeze_vit_output_f64 โœ…
squeeze_dynamic_batch_dynamic โœ…
squeeze_dynamic_batch_dynamic_f64 โœ…
squeeze_dynamic_batch โœ…
squeeze_dynamic_batch_f64 โœ…
squeeze_all_dims โœ…
squeeze_all_dims_f64 โœ…
squeeze_negative_axis โœ…
squeeze_negative_axis_f64 โœ…
squeeze_negative_axis_tuple โœ…
squeeze_negative_axis_tuple_f64 โœ…
squeeze_dynamic_and_negative_axis_dynamic โœ…
squeeze_dynamic_and_negative_axis_dynamic_f64 โœ…
squeeze_dynamic_and_negative_axis โœ…
squeeze_dynamic_and_negative_axis_f64 โœ…
v0.1.0
jnp.tile Tile tile_repeats โœ…
tile_repeats_f64 โœ…
tile_a โœ…
tile_a_f64 โœ…
tile_b โœ…
tile_b_f64 โœ…
tile_c โœ…
tile_c_f64 โœ…
tile_d โœ…
tile_d_f64 โœ…
tile_dynamic_input_static โœ…
tile_dynamic_input_static_f64 โœ…
tile_dynamic_input_dynamic โœ…
tile_dynamic_input_dynamic_f64 โœ…
tile_dynamic_input โœ…
tile_dynamic_input_f64 โœ…
tile_pad โœ…
tile_pad_f64 โœ…
tile_with_symbolic_repeats_static โœ…
tile_with_symbolic_repeats_static_f64 โœ…
tile_with_symbolic_repeats_dynamic โœ…
tile_with_symbolic_repeats_dynamic_f64 โœ…
tile_with_symbolic_repeats โœ…
tile_with_symbolic_repeats_f64 โœ…
tile_param_symbolic_dynamic โœ…
tile_param_symbolic_dynamic_f64 โœ…
tile_param_symbolic โœ…
tile_param_symbolic_f64 โœ…
v0.1.0
jnp.transpose Transpose transpose_basic โœ…
transpose_basic_f64 โœ…
transpose_reverse โœ…
transpose_reverse_f64 โœ…
transpose_4d_dynamic โœ…
transpose_4d_dynamic_f64 โœ…
transpose_4d โœ…
transpose_4d_f64 โœ…
transpose_square_matrix โœ…
transpose_square_matrix_f64 โœ…
transpose_high_dim โœ…
transpose_high_dim_f64 โœ…
transpose_no_axes โœ…
transpose_no_axes_f64 โœ…
transpose_3d_dynamic โœ…
transpose_3d_dynamic_f64 โœ…
transpose_3d โœ…
transpose_3d_f64 โœ…
v0.1.0
jnp.where Where where_simple โœ…
where_simple_f64 โœ…
where_broadcast โœ…
where_broadcast_f64 โœ…
where_multidim_condition_scalar_branches_broadcast โœ…
where_multidim_condition_scalar_branches_broadcast_f64 โœ…
where_multidim_condition_scalar_branches_broadcast โœ…
where_multidim_condition_scalar_branches_broadcast_f64 โœ…
where_A โœ…
where_A_f64 โœ…
where_B โœ…
where_B_f64 โœ…
where_jax_int_literals_broadcast_f64_mode โœ…
v0.5.2
lax.abs Abs abs โœ…
abs_f64 โœ…
v0.5.0
lax.add Add add โœ…
add_f64 โœ…
v0.2.0
lax.argmax ArgMax argmax_float_axis0 โœ…
argmax_float_axis0_f64 โœ…
argmax_float_axis1 โœ…
argmax_float_axis1_f64 โœ…
argmax_boolean_input_axis0_specific_values โœ…
argmax_boolean_input_axis0_specific_values_f64 โœ…
argmax_boolean_input_axis1_specific_values โœ…
argmax_boolean_input_axis1_specific_values_f64 โœ…
argmax_boolean_random_input_axis0 โœ…
argmax_boolean_random_input_axis0_f64 โœ…
v0.2.0
lax.argmin ArgMin argmin_test1 โœ…
argmin_test1_f64 โœ…
argmin_test2 โœ…
argmin_test2_f64 โœ…
v0.2.0
lax.broadcast_in_dim Expand
Identity
Reshape
broadcast_in_dim โœ…
broadcast_in_dim_f64 โœ…
broadcast_in_dim_2d_to_3d โœ…
broadcast_in_dim_2d_to_3d_f64 โœ…
broadcast_in_dim_scalar โœ…
broadcast_in_dim_scalar_f64 โœ…
broadcast_in_dim_batch_dynamic โœ…
broadcast_in_dim_batch_dynamic_f64 โœ…
broadcast_in_dim_batch โœ…
broadcast_in_dim_batch_f64 โœ…
v0.2.0
lax.concatenate Concat concatenate โœ…
concatenate_f64 โœ…
concatenate_axis1_dynamic โœ…
concatenate_axis1_dynamic_f64 โœ…
concatenate_axis1 โœ…
concatenate_axis1_f64 โœ…
concatenate_axis0 โœ…
concatenate_axis0_f64 โœ…
concatenate_3d โœ…
concatenate_3d_f64 โœ…
v0.2.0
lax.cond If cond_scalar โœ…
cond_scalar_f64 โœ…
cond_multiple_operands_in_tuple โœ…
cond_multiple_operands_in_tuple_f64 โœ…
cond_my_new_complex_scenario โœ…
cond_my_new_complex_scenario_f64 โœ…
cond_nested_conditional โœ…
cond_nested_conditional_f64 โœ…
cond_variables โœ…
cond_variables_f64 โœ…
cond_internal_constant_f64 โœ…
v0.5.1
lax.conv Conv conv โœ…
conv_f64 โœ…
conv2 โœ…
conv2_f64 โœ…
v0.2.0
lax.convert_element_type Cast convert_element_type โœ…
convert_element_type_f64 โœ…
v0.2.0
[lax.copy](Handles the JAX primitive lax.copy_p. Note: jax.lax.copy API is removed.) Identity copy_float32_array โœ…
copy_int64_scalar โœ…
<your_current_version>
lax.cos Cos cos โœ…
cos_f64 โœ…
v0.4.4
lax.cosh Cosh cosh โœ…
cosh_f64 โœ…
v0.4.4
lax.device_put Identity device_put_array โœ…
device_put_array_f64 โœ…
device_put_scalar โœ…
device_put_scalar_f64 โœ…
v0.4.0
lax.div Div div โœ…
div_f64 โœ…
v0.2.0
lax.dot_general MatMul dot_general โœ…
dot_general_f64 โœ…
v0.2.0
lax.dynamic_slice Slice dynamic_slice_test1 โœ…
dynamic_slice_test1_f64 โœ…
dynamic_slice_2d โœ…
dynamic_slice_2d_f64 โœ…
dynamic_slice_3d โœ…
dynamic_slice_3d_f64 โœ…
dynamic_slice_vit_like โœ…
dynamic_slice_vit_like_f64 โœ…
dynamic_slice_vit_like_dynamic_dynamic โœ…
dynamic_slice_vit_like_dynamic_dynamic_f64 โœ…
dynamic_slice_vit_like_dynamic โœ…
dynamic_slice_vit_like_dynamic_f64 โœ…
v0.1.0
lax.eq Equal eq โœ…
eq_f64 โœ…
v0.2.0
lax.exp Exp exp โœ…
exp_f64 โœ…
v0.2.0
lax.fori_loop Loop fori_loop_counter โœ…
fori_loop_counter_f64 โœ…
fori_loop_zero โœ…
fori_loop_zero_f64 โœ…
fori_loop_vector โœ…
fori_loop_vector_f64 โœ…
fori_loop_example โœ…
fori_loop_example_f64 โœ…
v0.5.1
lax.gather GatherND gather_static โœ…
gather_static_f64 โœ…
gather_dynamic_batch_simple_index_dynamic โœ…
gather_dynamic_batch_simple_index_dynamic_f64 โœ…
gather_dynamic_batch_simple_index โœ…
gather_dynamic_batch_simple_index_f64 โœ…
v0.2.0
lax.gt Greater gt โœ…
gt_f64 โœ…
v0.2.0
lax.integer_pow Pow integer_pow โœ…
integer_pow_f64 โœ…
v0.2.0
lax.iota Range iota_int32 โœ…
iota_int32_f64 โœ…
iota_float32 โœ…
iota_float32_f64 โœ…
broadcasted_iota โœ…
broadcasted_iota_f64 โœ…
v0.5.0
lax.log Log log โœ…
log_f64 โœ…
v0.2.0
lax.lt Less lt โœ…
lt_f64 โœ…
v0.2.0
lax.max Max max โœ…
max_f64 โœ…
v0.2.0
lax.min Min min_test1 โœ…
min_test1_f64 โœ…
v0.1.0
lax.mul Mul mul_test1 โœ…
mul_test1_f64 โœ…
mul_test2 โœ…
mul_test2_f64 โœ…
v0.1.0
lax.ne Equal
Not
ne โœ…
ne_f64 โœ…
v0.2.0
lax.neg Neg neg โœ…
neg_f64 โœ…
v0.2.0
lax.reduce_and Cast
ReduceMin
reduce_and_all_true โœ…
reduce_and_all_true_f64 โœ…
reduce_and_one_false โœ…
reduce_and_one_false_f64 โœ…
reduce_and_keepdims โœ…
reduce_and_keepdims_f64 โœ…
v0.6.1
lax.reduce_max ReduceMax reduce_max โœ…
reduce_max_f64 โœ…
reduce_max_allaxes โœ…
reduce_max_allaxes_f64 โœ…
reduce_max_keepdims โœ…
reduce_max_keepdims_f64 โœ…
v0.2.0
lax.reduce_min ReduceMin reduce_min โœ…
reduce_min_f64 โœ…
reduce_min_allaxes โœ…
reduce_min_allaxes_f64 โœ…
reduce_min_keepdims โœ…
reduce_min_keepdims_f64 โœ…
v0.2.0
lax.reduce_or Cast
ReduceMax
reduce_or_all_false โœ…
reduce_or_all_false_f64 โœ…
reduce_or_one_true โœ…
reduce_or_one_true_f64 โœ…
reduce_or_keepdims โœ…
reduce_or_keepdims_f64 โœ…
v0.6.1
lax.reduce_prod ReduceProd reduce_prod โœ…
reduce_prod_f64 โœ…
reduce_prod_allaxes โœ…
reduce_prod_allaxes_f64 โœ…
reduce_prod_keepdims โœ…
reduce_prod_keepdims_f64 โœ…
reduce_prod_dtype_f64 โœ…
reduce_prod_dtype โœ…
v0.6.1
lax.reduce_sum ReduceSum reduce_sum โœ…
reduce_sum_f64 โœ…
reduce_sum_allaxes โœ…
reduce_sum_allaxes_f64 โœ…
reduce_sum_keepdims โœ…
reduce_sum_keepdims_f64 โœ…
reduce_sum_dtype_f64 โœ…
reduce_sum_dtype โœ…
v0.2.0
lax.reduce_xor Cast
Mod
ReduceSum
reduce_xor_all_false โœ…
reduce_xor_all_false_f64 โœ…
reduce_xor_one_true โœ…
reduce_xor_one_true_f64 โœ…
reduce_xor_two_true โœ…
reduce_xor_two_true_f64 โœ…
reduce_xor_keepdims โœ…
reduce_xor_keepdims_f64 โœ…
v0.6.1
lax.reshape Reshape reshape โœ…
reshape_f64 โœ…
reshape_valid_squeeze_middle_dim_from_problematic_source โœ…
reshape_valid_squeeze_middle_dim_from_problematic_source_f64 โœ…
reshape_valid_flatten_trailing โœ…
reshape_valid_flatten_trailing_f64 โœ…
reshape_with_target_shape_from_symbolic_dim_computation โœ…
reshape_with_target_shape_from_symbolic_dim_computation_f64 โœ…
reshape_with_inferred_dimension_from_input_dynamic_dynamic โœ…
reshape_with_inferred_dimension_from_input_dynamic_dynamic_f64 โœ…
reshape_with_inferred_dimension_from_input_dynamic โœ…
reshape_with_inferred_dimension_from_input_dynamic_f64 โœ…
reshape_with_inferred_dimension_from_input โœ…
reshape_with_inferred_dimension_from_input_f64 โœ…
v0.2.0
lax.scan Scan scan_cumsum โœ…
scan_cumsum_f64 โœ…
scan_carry_only โœ…
scan_carry_only_f64 โœ…
scan_multiple_sequences โœ…
scan_multiple_sequences_f64 โœ…
scan_multiple_carry โœ…
scan_multiple_carry_f64 โœ…
scan_matrix_carry_multidim_xs โœ…
scan_matrix_carry_multidim_xs_f64 โœ…
v0.5.1
lax.scatter ScatterND scatter_set_axis0 โœ…
scatter_set_axis0_f64 โœ…
scatter_set_middle โœ…
scatter_set_middle_f64 โœ…
scatter_correct_axis_determination โœ…
scatter_correct_axis_determination_f64 โœ…
scatter_updates_slice_needed_axis0 โœ…
scatter_updates_slice_needed_axis0_f64 โœ…
scatter_from_user_warning_shapes_valid_jax โœ…
scatter_from_user_warning_shapes_valid_jax_f64 โœ…
scatter_user_error_scenario_precise โœ…
scatter_user_error_scenario_precise_f64 โœ…
v0.4.4
lax.scatter_add ScatterND scatter_add_simple_1d โœ…
scatter_add_simple_1d_f64 โœ…
scatter_add_window_2d_operand_1d_indices โœ…
scatter_add_window_2d_operand_1d_indices_f64 โœ…
scatter_add_batch_updates_1d_operand โœ…
scatter_add_batch_updates_1d_operand_f64 โœ…
scatter_add_mismatched_window_dims_from_user_report โœ…
scatter_add_mismatched_window_dims_from_user_report2 โœ…
scatter_add_mismatched_window_dims_from_user_report3 โœ…
scatter_add_fluids_pattern_updates_5_4_1_1 โœ…
v0.5.3
lax.select_n Where select_n_bool_predicate_two_cases_float โœ…
select_n_bool_predicate_two_cases_float_f64 โœ…
select_n_bool_predicate_two_cases_int โœ…
select_n_bool_predicate_two_cases_int_f64 โœ…
select_n_bool_predicate_scalar_broadcast โœ…
select_n_bool_predicate_scalar_broadcast_f64 โœ…
select_n_int_indices_three_cases โœ…
select_n_int_indices_three_cases_f64 โœ…
select_n_int_indices_four_cases โœ…
select_n_int_indices_four_cases_f64 โœ…
v0.2.0
lax.sign Sign sign โœ…
sign_f64 โœ…
v0.5.0
lax.sin Sin sin โœ…
sin_f64 โœ…
v0.4.4
lax.sinh Sinh sinh โœ…
sinh_f64 โœ…
v0.4.4
lax.slice Slice slice_test1 โœ…
slice_test1_f64 โœ…
slice_3d_none_strides โœ…
slice_3d_none_strides_f64 โœ…
v0.1.0
lax.sort TopK sort_1d โœ…
sort_1d_f64 โœ…
sort_2d โœ…
sort_2d_f64 โœ…
v0.2.0
lax.sqrt Sqrt sqrt โœ…
sqrt_f64 โœ…
v0.2.0
lax.square Mul square โœ…
square_f64 โœ…
v0.2.0
lax.squeeze Squeeze lax_squeeze_specific_axis_0 โœ…
lax_squeeze_specific_axis_0_f64 โœ…
lax_squeeze_multiple_axes โœ…
lax_squeeze_multiple_axes_f64 โœ…
lax_squeeze_no_op_empty_dims โœ…
lax_squeeze_no_op_empty_dims_f64 โœ…
lax_squeeze_problem_case_input_squeeze_only_axis_0 โœ…
lax_squeeze_problem_case_input_squeeze_only_axis_0_f64 โœ…
lax_squeeze_problem_case_input_squeeze_axes_0_2 โœ…
lax_squeeze_problem_case_input_squeeze_axes_0_2_f64 โœ…
lax_squeeze_problem_case_input_squeeze_all_dims_explicitly โœ…
lax_squeeze_problem_case_input_squeeze_all_dims_explicitly_f64 โœ…
v0.2.0
lax.stop_gradient Identity stop_gradient โœ…
stop_gradient_f64 โœ…
v0.2.0
lax.sub Sub sub_test1 โœ…
sub_test1_f64 โœ…
sub_test2 โœ…
sub_test2_f64 โœ…
v0.1.0
lax.tanh Tanh tanh โœ…
tanh_f64 โœ…
v0.2.0
lax.transpose Transpose transpose_basic โœ…
transpose_basic_f64 โœ…
v0.2.0
lax.while_loop Loop while_loop_counter โœ…
while_loop_counter_f64 โœ…
while_loop_vector โœ…
while_loop_vector_f64 โœ…
v0.5.1
nn.dot_product_attention Cast
Div
Einsum
Gather
Shape
Softmax
Sqrt
dpa_basic โœ…
dpa_basic_f64 โœ…
dpa_diff_heads_embed โœ…
dpa_diff_heads_embed_f64 โœ…
dpa_batch4_seq16 โœ…
dpa_batch4_seq16_f64 โœ…
dpa_float64 โœ…
dpa_float64_f64 โœ…
dpa_heads1_embed4 โœ…
dpa_heads1_embed4_f64 โœ…
dpa_heads8_embed8 โœ…
dpa_heads8_embed8_f64 โœ…
dpa_batch1_seq2 โœ…
dpa_batch1_seq2_f64 โœ…
dpa_batch8_seq4 โœ…
dpa_batch8_seq4_f64 โœ…
dpa_axis1 โœ…
dpa_axis1_f64 โœ…
v0.1.0
nn.softmax Softmax softmax โœ…
softmax_f64 โœ…
softmax_2d โœ…
softmax_2d_f64 โœ…
softmax_3d โœ…
softmax_3d_f64 โœ…
v0.1.0
nnx.avg_pool AveragePool
Transpose
avg_pool_dynamic โœ…
avg_pool_dynamic_f64 โœ…
avg_pool โœ…
avg_pool_f64 โœ…
avg_pool_same_padding_dynamic โœ…
avg_pool_same_padding_dynamic_f64 โœ…
avg_pool_same_padding โœ…
avg_pool_same_padding_f64 โœ…
avg_pool_default_padding_dynamic โœ…
avg_pool_default_padding_dynamic_f64 โœ…
avg_pool_default_padding โœ…
avg_pool_default_padding_f64 โœ…
avg_pool_stride1_dynamic โœ…
avg_pool_stride1_dynamic_f64 โœ…
avg_pool_stride1 โœ…
avg_pool_stride1_f64 โœ…
avg_pool_win3x3_stride2_dynamic โœ…
avg_pool_win3x3_stride2_dynamic_f64 โœ…
avg_pool_win3x3_stride2 โœ…
avg_pool_win3x3_stride2_f64 โœ…
avg_pool_stride_none_dynamic โœ…
avg_pool_stride_none_dynamic_f64 โœ…
avg_pool_stride_none โœ…
avg_pool_stride_none_f64 โœ…
avg_pool_count_include_pad_false_dynamic โœ…
avg_pool_count_include_pad_false_dynamic_f64 โœ…
avg_pool_count_include_pad_false โœ…
avg_pool_count_include_pad_false_f64 โœ…
v0.1.0
nnx.batch_norm BatchNormalization batch_norm_simple_dynamic โœ…
batch_norm_simple_dynamic_f64 โœ…
batch_norm_simple โœ…
batch_norm_simple_f64 โœ…
batch_norm_2d_dynamic โœ…
batch_norm_2d_dynamic_f64 โœ…
batch_norm_2d โœ…
batch_norm_2d_f64 โœ…
batch_norm_2d_use_bias_false_dynamic โœ…
batch_norm_2d_use_bias_false_dynamic_f64 โœ…
batch_norm_2d_use_bias_false โœ…
batch_norm_2d_use_bias_false_f64 โœ…
batch_norm_2d_use_scale_false_dynamic โœ…
batch_norm_2d_use_scale_false_dynamic_f64 โœ…
batch_norm_2d_use_scale_false โœ…
batch_norm_2d_use_scale_false_f64 โœ…
batch_norm_4d_dynamic โœ…
batch_norm_4d_dynamic_f64 โœ…
batch_norm_4d โœ…
batch_norm_4d_f64 โœ…
batch_norm_4d_use_bias_false_dynamic โœ…
batch_norm_4d_use_bias_false_dynamic_f64 โœ…
batch_norm_4d_use_bias_false โœ…
batch_norm_4d_use_bias_false_f64 โœ…
batch_norm_4d_use_scale_false_dynamic โœ…
batch_norm_4d_use_scale_false_dynamic_f64 โœ…
batch_norm_4d_use_scale_false โœ…
batch_norm_4d_use_scale_false_f64 โœ…
batch_norm_minimal โœ…
batch_norm_minimal_f64 โœ…
v0.1.0
nnx.conv Conv
Transpose
conv_basic_bias_dynamic โœ…
conv_basic_bias_dynamic_f64 โœ…
conv_basic_bias โœ…
conv_basic_bias_f64 โœ…
conv_basic_bias_2 โœ…
conv_basic_bias_2_f64 โœ…
conv_basic_bias_3 โœ…
conv_basic_bias_3_f64 โœ…
conv_stride2_bias โœ…
conv_stride2_bias_f64 โœ…
conv_no_bias_dynamic โœ…
conv_no_bias_dynamic_f64 โœ…
conv_no_bias โœ…
conv_no_bias_f64 โœ…
conv_valid_padding โœ…
conv_valid_padding_f64 โœ…
conv_stride1 โœ…
conv_stride1_f64 โœ…
conv_stride2 โœ…
conv_stride2_f64 โœ…
conv_different_kernel โœ…
conv_different_kernel_f64 โœ…
conv_float64 โœ…
conv_float64_f64 โœ…
conv_single_batch โœ…
conv_single_batch_f64 โœ…
conv_large_batch โœ…
conv_large_batch_f64 โœ…
v0.1.0
nnx.dot_product_attention Cast
Div
Einsum
Gather
Shape
Softmax
Sqrt
dpa_basic โœ…
dpa_basic_f64 โœ…
dpa_diff_heads_embed โœ…
dpa_diff_heads_embed_f64 โœ…
dpa_batch4_seq16 โœ…
dpa_batch4_seq16_f64 โœ…
dpa_float64 โœ…
dpa_float64_f64 โœ…
dpa_heads1_embed4 โœ…
dpa_heads1_embed4_f64 โœ…
dpa_heads8_embed8 โœ…
dpa_heads8_embed8_f64 โœ…
dpa_batch1_seq2 โœ…
dpa_batch1_seq2_f64 โœ…
dpa_batch8_seq4 โœ…
dpa_batch8_seq4_f64 โœ…
dpa_axis1 โœ…
dpa_axis1_f64 โœ…
v0.1.0
nnx.dropout Dropout dropout_init_params_dynamic โœ…
dropout_init_params_dynamic_f64 โœ…
dropout_init_params โœ…
dropout_init_params_f64 โœ…
dropout_call_params_dynamic โœ…
dropout_call_params_dynamic_f64 โœ…
dropout_call_params โœ…
dropout_call_params_f64 โœ…
v0.1.0
nnx.einsum Add
Einsum
einsum_module_with_bias โœ…
einsum_module_with_bias_f64 โœ…
einsum_module_no_bias โœ…
einsum_module_no_bias_f64 โœ…
v0.4.2
nnx.elu Elu elu โœ…
elu_f64 โœ…
v0.1.0
nnx.gelu Gelu gelu โœ…
gelu_f64 โœ…
gelu_1 โœ…
gelu_1_f64 โœ…
gelu_2 โœ…
gelu_2_f64 โœ…
gelu_3_dynamic โœ…
gelu_3_dynamic_f64 โœ…
gelu_3 โœ…
gelu_3_f64 โœ…
v0.1.0
nnx.group_norm GroupNormalization group_norm โœ…
group_norm_f64 โœ…
group_norm_2 โœ…
group_norm_2_f64 โœ…
v0.3.0
nnx.layer_norm LayerNormalization layer_norm_dynamic โœ…
layer_norm_dynamic_f64 โœ…
layer_norm โœ…
layer_norm_f64 โœ…
layer_norm_multiaxis_dynamic โœ…
layer_norm_multiaxis_dynamic_f64 โœ…
layer_norm_multiaxis โœ…
layer_norm_multiaxis_f64 โœ…
v0.1.0
nnx.leaky_relu LeakyRelu leaky_relu โœ…
leaky_relu_f64 โœ…
v0.1.0
nnx.linear Gemm
Reshape
linear_symbolic_batch_dynamic โœ…
linear_symbolic_batch_dynamic_f64 โœ…
linear_symbolic_batch โœ…
linear_symbolic_batch_f64 โœ…
linear_high_rank โœ…
linear_high_rank_f64 โœ…
v0.1.0
nnx.linear_general Gemm
Reshape
linear_general_dynamic โœ…
linear_general_dynamic_f64 โœ…
linear_general โœ…
linear_general_f64 โœ…
linear_general_2 โœ…
linear_general_2_f64 โœ…
linear_general_3 โœ…
linear_general_3_f64 โœ…
linear_general_4 โœ…
linear_general_4_f64 โœ…
linear_general_abstract_eval_axes โœ…
linear_general_abstract_eval_axes_f64 โœ…
linear_general_abstract_eval_axes_pair โœ…
linear_general_abstract_eval_axes_pair_f64 โœ…
v0.1.0
nnx.log_softmax LogSoftmax log_softmax โœ…
log_softmax_f64 โœ…
v0.1.0
nnx.max_pool MaxPool
Transpose
max_pool โœ…
max_pool_f64 โœ…
max_pool_same_padding โœ…
max_pool_same_padding_f64 โœ…
v0.1.0
nnx.relu Relu relu_1d โœ…
relu_1d_f64 โœ…
relu_4d_dynamic โœ…
relu_4d_dynamic_f64 โœ…
relu_4d โœ…
relu_4d_f64 โœ…
v0.1.0
nnx.rms_norm RMSNormalization rms_norm โœ…
rms_norm_f64 โœ…
rms_norm_2 โœ…
rms_norm_2_f64 โœ…
v0.3.0
nnx.sigmoid Sigmoid sigmoid โœ…
sigmoid_f64 โœ…
v0.1.0
nnx.softmax Softmax softmax_dynamic โœ…
softmax_dynamic_f64 โœ…
softmax โœ…
softmax_f64 โœ…
v0.1.0
nnx.softplus Softplus softplus โœ…
softplus_f64 โœ…
v0.1.0
nnx.tanh Tanh tanh โœ…
tanh_f64 โœ…
v0.1.0

Legend:
โœ… = Passed
โŒ = Failed
โž– = No testcase yet


๐ŸŽฏ Examples

Component Description Testcases Since
select_test select_test: Demonstrates jnp.select with a dynamic condition based on an input array. select_test_all_options โž–
select_test_scalar_select_option_0 โž–
select_test_scalar_select_option_1 โž–
select_test_scalar_select_option_2 โž–
select_test_default_case โž–
v0.6.1
sort_test sort_test: Demonstrates jnp.sort on slices of an input array. sort_test_basic โž– v0.6.1
AutoEncoder A simple autoencoder example. simple_autoencoder โœ…
simple_autoencoder_f64 โœ…
v0.2.0
CNN A simple convolutional neural network (CNN). simple_cnn_explicit_dimensions โœ…
simple_cnn_explicit_dimensions_f64 โœ…
simple_cnn_dynamic โœ…
simple_cnn_dynamic_f64 โœ…
simple_cnn โœ…
simple_cnn_f64 โœ…
v0.1.0
ClassificationHead Classification head for Vision Transformer classification_head_dynamic โœ…
classification_head_dynamic_f64 โœ…
classification_head โœ…
classification_head_f64 โœ…
v0.4.0
ClassificationHeadFlatten Classification head for Vision Transformer classification_head_flat_dynamic โœ…
classification_head_flat_dynamic_f64 โœ…
classification_head_flat โœ…
classification_head_flat_f64 โœ…
v0.4.0
ConcatClsToken Concatenate CLS token to the input embedding concat_cls_token_dynamic โœ…
concat_cls_token_dynamic_f64 โœ…
concat_cls_token โœ…
concat_cls_token_f64 โœ…
v0.4.0
ConcatClsTokenFlatten Concatenate CLS token to the input embedding concat_cls_token_flat_dynamic โœ…
concat_cls_token_flat_dynamic_f64 โœ…
concat_cls_token_flat โœ…
concat_cls_token_flat_f64 โœ…
v0.4.0
ConvEmbedding Convolutional Token Embedding for MNIST with hierarchical downsampling. mnist_conv_embedding_dynamic โœ…
mnist_conv_embedding_dynamic_f64 โœ…
mnist_conv_embedding โœ…
mnist_conv_embedding_f64 โœ…
v0.1.0
ConvEmbeddingFlatten Convolutional Token Embedding for MNIST with hierarchical downsampling. mnist_conv_embedding_flat_dynamic โœ…
mnist_conv_embedding_flat_dynamic_f64 โœ…
mnist_conv_embedding_flat โœ…
mnist_conv_embedding_flat_f64 โœ…
v0.1.0
FeedForward MLP in Transformer feed_forward_dynamic โœ…
feed_forward_dynamic_f64 โœ…
feed_forward โœ…
feed_forward_f64 โœ…
v0.1.0
FeedForwardFlatten MLP in Transformer feed_forward_flat_dynamic โœ…
feed_forward_flat_dynamic_f64 โœ…
feed_forward_flat โœ…
feed_forward_flat_f64 โœ…
v0.1.0
ForiLoop fori_loop example fori_loop_counter โœ…
fori_loop_counter_f64 โœ…
v0.5.1
GetToken Get the CLS token from the input embedding get_token_dynamic โœ…
get_token_dynamic_f64 โœ…
get_token โœ…
get_token_f64 โœ…
v0.4.0
GetTokenFlatten Get the CLS token from the input embedding get_token_flat_dynamic โœ…
get_token_flat_dynamic_f64 โœ…
get_token_flat โœ…
get_token_flat_f64 โœ…
v0.4.0
MLP A simple Multi-Layer Perceptron (MLP) with BatchNorm, Dropout, and GELU activation. simple_mlp_dynamic โœ…
simple_mlp_dynamic_f64 โœ…
simple_mlp โœ…
simple_mlp_f64 โœ…
simple_mlp_with_call_params_dynamic โœ…
simple_mlp_with_call_params_dynamic_f64 โœ…
simple_mlp_with_call_params โœ…
simple_mlp_with_call_params_f64 โœ…
v0.1.0
MultiHeadAttention This is a multi-head attention module implemented by Flax/nnx that has no ONNX correspondent on the same granularity. multihead_attention_nn_dynamic โœ…
multihead_attention_nn_dynamic_f64 โœ…
multihead_attention_nn โœ…
multihead_attention_nn_f64 โœ…
multihead_attention_nnx_dynamic โœ…
multihead_attention_nnx_dynamic_f64 โœ…
multihead_attention_nnx โœ…
multihead_attention_nnx_f64 โœ…
v0.2.0
PatchEmbedding Cutting the image into patches and linearly embedding them. patch_embedding_dynamic โœ…
patch_embedding_dynamic_f64 โœ…
patch_embedding โœ…
patch_embedding_f64 โœ…
v0.1.0
PatchEmbeddingFlatten Cutting the image into patches and linearly embedding them. patch_embedding_flat_dynamic โœ…
patch_embedding_flat_dynamic_f64 โœ…
patch_embedding_flat โœ…
patch_embedding_flat_f64 โœ…
v0.1.0
PositionalEmbedding Add positional embedding to the input embedding positional_embedding_dynamic โœ…
positional_embedding_dynamic_f64 โœ…
positional_embedding โœ…
positional_embedding_f64 โœ…
v0.4.0
PositionalEmbeddingFlatten Add positional embedding to the input embedding positional_embedding_flat_dynamic โœ…
positional_embedding_flat_dynamic_f64 โœ…
positional_embedding_flat โœ…
positional_embedding_flat_f64 โœ…
v0.4.0
TransformerBlock Transformer from 'Attention Is All You Need.' transformer_block_dynamic โœ…
transformer_block_dynamic_f64 โœ…
transformer_block โœ…
transformer_block_f64 โœ…
v0.1.0
TransformerBlockFlatten Transformer from 'Attention Is All You Need.' transformer_block_flat_dynamic โœ…
transformer_block_flat_dynamic_f64 โœ…
transformer_block_flat โœ…
transformer_block_flat_f64 โœ…
v0.1.0
TransformerStack Stack of Transformer blocks transformer_stack_dynamic โœ…
transformer_stack_dynamic_f64 โœ…
transformer_stack โœ…
transformer_stack_f64 โœ…
v0.1.0
TransformerStackFlatten Stack of Transformer blocks transformer_stack_flat_dynamic โœ…
transformer_stack_flat_dynamic_f64 โœ…
transformer_stack_flat โœ…
transformer_stack_flat_f64 โœ…
v0.1.0
VisionTransformer A Vision Transformer (ViT) model for MNIST with configurable embedding type. vit_conv_embedding_dynamic โœ…
vit_conv_embedding_dynamic_f64 โœ…
vit_conv_embedding โœ…
vit_conv_embedding_f64 โœ…
vit_patch_embedding โœ…
vit_patch_embedding_f64 โœ…
v0.2.0
VisionTransformerFlatten A Vision Transformer (ViT) model for MNIST with configurable embedding type. vit_conv_embedding_flat_dynamic โœ…
vit_conv_embedding_flat_dynamic_f64 โœ…
vit_conv_embedding_flat โœ…
vit_conv_embedding_flat_f64 โœ…
vit_patch_embedding_flat_dynamic โœ…
vit_patch_embedding_flat_dynamic_f64 โœ…
vit_patch_embedding_flat โœ…
vit_patch_embedding_flat_f64 โœ…
v0.2.0
onnx_functions_000 one function on an outer layer. 000_one_function_on_outer_layer_dynamic โœ…
000_one_function_on_outer_layer_dynamic_f64 โœ…
000_one_function_on_outer_layer โœ…
000_one_function_on_outer_layer_f64 โœ…
v0.4.0
onnx_functions_001 one function on an inner layer. 001_one_function_inner_dynamic โœ…
001_one_function_inner_dynamic_f64 โœ…
001_one_function_inner โœ…
001_one_function_inner_f64 โœ…
v0.4.0
onnx_functions_002 two nested functions. 002_two_nested_functions_dynamic โœ…
002_two_nested_functions_dynamic_f64 โœ…
002_two_nested_functions โœ…
002_two_nested_functions_f64 โœ…
v0.4.0
onnx_functions_003 two nested functions. 003_two_simple_nested_functions_dynamic โœ…
003_two_simple_nested_functions_dynamic_f64 โœ…
003_two_simple_nested_functions โœ…
003_two_simple_nested_functions_f64 โœ…
v0.4.0
onnx_functions_004 nested function plus component 004_nested_function_plus_component_dynamic โœ…
004_nested_function_plus_component_dynamic_f64 โœ…
004_nested_function_plus_component โœ…
004_nested_function_plus_component_f64 โœ…
v0.4.0
onnx_functions_005 nested function plus more components 005_nested_function_plus_component_dynamic โœ…
005_nested_function_plus_component_dynamic_f64 โœ…
005_nested_function_plus_component โœ…
005_nested_function_plus_component_f64 โœ…
v0.4.0
onnx_functions_006 one function on an outer layer. 006_one_function_outer_dynamic โœ…
006_one_function_outer_dynamic_f64 โœ…
006_one_function_outer โœ…
006_one_function_outer_f64 โœ…
v0.4.0
onnx_functions_007 transformer block with nested mlp block with call parameter 007_transformer_block_dynamic โœ…
007_transformer_block_dynamic_f64 โœ…
007_transformer_block โœ…
007_transformer_block_f64 โœ…
v0.4.0
onnx_functions_008 transformer block with nested mlp block no call parameter 008_transformer_block_dynamic โœ…
008_transformer_block_dynamic_f64 โœ…
008_transformer_block โœ…
008_transformer_block_f64 โœ…
v0.4.0
onnx_functions_009 transformer block using decorator on class and function 009_transformer_block_dynamic โœ…
009_transformer_block_dynamic_f64 โœ…
009_transformer_block โœ…
009_transformer_block_f64 โœ…
v0.4.0
onnx_functions_010 transformer stack 010_transformer_stack_dynamic โœ…
010_transformer_stack_dynamic_f64 โœ…
010_transformer_stack โœ…
010_transformer_stack_f64 โœ…
v0.4.0
onnx_functions_012 Vision Transformer (ViT) 012_vit_conv_embedding_dynamic โœ…
012_vit_conv_embedding_dynamic_f64 โœ…
012_vit_conv_embedding โœ…
012_vit_conv_embedding_f64 โœ…
v0.4.0
onnx_functions_013 Vision Transformer (ViT) 013_vit_conv_embedding_with_call_params_dynamic โœ…
013_vit_conv_embedding_with_call_params_dynamic_f64 โœ…
013_vit_conv_embedding_with_call_params โœ…
013_vit_conv_embedding_with_call_params_f64 โœ…
013_vit_conv_embedding_with_internal_call_params_dynamic โœ…
013_vit_conv_embedding_with_internal_call_params_dynamic_f64 โœ…
013_vit_conv_embedding_with_internal_call_params โœ…
013_vit_conv_embedding_with_internal_call_params_f64 โœ…
v0.4.0
onnx_functions_014 one function on an outer layer. 014_one_function_with_input_param_with_default_value โœ…
014_one_function_with_input_param_with_default_value_f64 โœ…
014_one_function_without_input_param_with_default_value_dynamic โœ…
014_one_function_without_input_param_with_default_value_dynamic_f64 โœ…
014_one_function_without_input_param_with_default_value โœ…
014_one_function_without_input_param_with_default_value_f64 โœ…
v0.4.0
onnx_functions_015 one function on an outer layer. 015_one_function_with_input_param_without_default_value_dynamic โœ…
015_one_function_with_input_param_without_default_value_dynamic_f64 โœ…
015_one_function_with_input_param_without_default_value โœ…
015_one_function_with_input_param_without_default_value_f64 โœ…
v0.4.0
onnx_functions_016 nested function plus more components 016_internal_function_with_input_param_with_default_value_dynamic โœ…
016_internal_function_with_input_param_with_default_value_dynamic_f64 โœ…
016_internal_function_with_input_param_with_default_value โœ…
016_internal_function_with_input_param_with_default_value_f64 โœ…
v0.4.0

๐Ÿ“Œ Dependencies

Versions of Major Dependencies:

Library Versions
JAX 0.6.1
Flax 0.10.6
onnx 1.18.0
onnxruntime 1.22.0

Note: For more details, check pyproject.toml.


โš ๏ธ Limitations

  • Currently not all JAX/Flax components are supported (you can easily help expand this coverage!).
  • Function references need dynamic resolution at call-time.
  • ONNX graph composition is done in-memory before saving to disk, potentially causing memory issues with very large models.

๐Ÿค How to Contribute

We warmly welcome contributions!

How you can help:

  • Add a plugin: Extend jax2onnx by writing a simple Python file in jax2onnx/plugins. a custom primitive or an example.
  • Bug fixes & improvements: PRs and issues are always welcome.

๐Ÿ’พ Installation

Install from PyPI:

pip install jax2onnx  

๐Ÿ“œ License

This project is licensed under the Apache License, Version 2.0. See LICENSE for details.


๐ŸŒŸ Special Thanks

Special thanks for example contributions to @burakksen and @Cadynum

Special thanks for plugin contributions to @burakksen

Special thanks to tumaer/JAXFLUIDS for contributing valuable insights rooted in physics simulation use cases.

Special thanks to @lutzroeder for making shapes internal to ONNX function visible in his great Netron viewer.

Special thanks to the community members involved in:

A huge thanks especially to @limarta, whose elegant jaxpr-to-ONNX demonstration significantly inspired this project.


Happy converting! ๐ŸŽ‰

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax2onnx-0.6.2.tar.gz (236.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax2onnx-0.6.2-py3-none-any.whl (357.9 kB view details)

Uploaded Python 3

File details

Details for the file jax2onnx-0.6.2.tar.gz.

File metadata

  • Download URL: jax2onnx-0.6.2.tar.gz
  • Upload date:
  • Size: 236.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.1.3 CPython/3.12.5 Linux/6.6.87.1-microsoft-standard-WSL2

File hashes

Hashes for jax2onnx-0.6.2.tar.gz
Algorithm Hash digest
SHA256 bbdae81c7444a8bf73e0b8dfa03e8cc51ec33199d50d40a15d47d988cb1ecb35
MD5 e135281cd20cb73fbe03cb67767f7f46
BLAKE2b-256 71d662dd6c264eda4815be5728df210aed532089bcfc1ce314f9491906df4bde

See more details on using hashes here.

File details

Details for the file jax2onnx-0.6.2-py3-none-any.whl.

File metadata

  • Download URL: jax2onnx-0.6.2-py3-none-any.whl
  • Upload date:
  • Size: 357.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.1.3 CPython/3.12.5 Linux/6.6.87.1-microsoft-standard-WSL2

File hashes

Hashes for jax2onnx-0.6.2-py3-none-any.whl
Algorithm Hash digest
SHA256 c3e7f92c8ab9e1b6be042a39c20aa23e0158b7ff380be9b315a1a9bd4b2fcb79
MD5 3a61688a7fcd4d1e501efd05ae5196cc
BLAKE2b-256 f3339bbf1cf0c66ae6e9a00fbdb4671913d16645d1046f933991a71c7dfbc650

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page