Skip to main content

export JAX to ONNX

Project description

jax2onnx ๐ŸŒŸ

jax2onnx converts your JAX, Flax(nnx), Equinox functions directly into the ONNX format.

img.png

โœจ Key Features

  • Simple API
    Easily convert JAX callablesโ€”including Flax (NNX) and Equinox modelsโ€”into ONNX format using to_onnx(...).

  • Model structure preserved
    With @onnx_function, submodules appear as named functions in the ONNX graph (e.g. in Netron). Useful for readability and reuse.

  • Dynamic input support
    Use abstract dimensions like 'B' or pass scalars as runtime inputs. Models stay flexible without retracing.

  • Plugin-based extensibility
    Add support for new primitives by writing small, local plugins.

  • Netron-friendly outputs
    All generated ONNX graphs include shape/type annotations and are structured for clear visualization.


๐Ÿš€ Quickstart

Convert your JAX callable to ONNX in just a few lines:

import onnx
from flax import nnx
from jax2onnx import to_onnx

# Define a simple MLP (from Flax docs)
class MLP(nnx.Module):
    def __init__(self, din, dmid, dout, *, rngs): 
        self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
        self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
        self.bn = nnx.BatchNorm(dmid, rngs=rngs)
        self.linear2 = nnx.Linear(dmid, dout, rngs=rngs) 
    def __call__(self, x): 
        x = nnx.gelu(self.dropout(self.bn(self.linear1(x))))
        return self.linear2(x)

# Instantiate model
my_callable = MLP(din=30, dmid=20, dout=10, rngs=nnx.Rngs(0))

# Convert to ONNX
onnx_model = to_onnx(my_callable, [("B", 30)])

# Save the model
onnx.save_model(onnx_model, "my_callable.onnx")

๐Ÿ”Ž See it visualized: my_callable.onnx


๐Ÿง  ONNX Functions โ€” Minimal Example

ONNX functions help encapsulate reusable subgraphs. Simply use the @onnx_function decorator to make your callable an ONNX function. Just an @onnx_function decorator to make your callable an ONNX function

from onnx import save_model
from flax import nnx
from jax2onnx import onnx_function, to_onnx

# just an @onnx_function decorator to make your callable an ONNX function
@onnx_function
class MLPBlock(nnx.Module):
  def __init__(self, dim, *, rngs):
    self.linear1 = nnx.Linear(dim, dim, rngs=rngs)
    self.linear2 = nnx.Linear(dim, dim, rngs=rngs)
    self.batchnorm = nnx.BatchNorm(dim, rngs=rngs)
  def __call__(self, x):
    return nnx.gelu(self.linear2(self.batchnorm(nnx.gelu(self.linear1(x)))))

# Use it inside another module
class MyModel(nnx.Module):
  def __init__(self, dim, *, rngs):
    self.block1 = MLPBlock(dim, rngs=rngs)
    self.block2 = MLPBlock(dim, rngs=rngs)
  def __call__(self, x):
    return self.block2(self.block1(x))

callable = MyModel(256, rngs=nnx.Rngs(0))
model = to_onnx(callable, [(100, 256)])
save_model(model, "docs/onnx/model_with_function.onnx")

๐Ÿ”Ž See it visualized: model_with_function.onnx


๐Ÿ“… Roadmap and Releases

Planned Versions

  • Ongoing

    • Expanding coverage of JAX, Flax (NNX) and Equinox components
    • Enhancing support for physics-based simulations
  • Upcoming

    • Advanced ONNX function support, including function reuse, optimized internal graph structure, and improved variable naming for clarity and readability
    • Integrating onnx-ir as a backend to improve ONNX model construction, memory efficiency, and performance

Current Productive Version

  • 0.8.0 (initial Equinox support, PyPI):

    • Added initial support for Equinox components: eqx.dropout, eqx.layer_norm, and eqx.linear. Introduced an Equinox MlpExample showcasing Linear + Dropout composition.
    • Stabilized SSA and shape handling across lax.scan and lax.fori_loop bodies, preventing dtype leaks.
    • Improved dtype propagation in lax.gather, lax.concatenate.
    • Added plugin support for lax.pad.

Past Versions

  • 0.7.5 fixes tests for functions without arguments, adds support for lax.bitwise_not, lax.clamp, lax.ge, jnp.clip, lax.rev, and enhances support for nnx.dot_product_attention, nnx.conv, nnx.batch_norm, lax.mul, lax.reduce_max, lax.scan, lax.slice, lax.while_loop, nn.gelu, jnp.arange, jnp.cumsum, jnp.select, jnp.where, and jnp.concatenate.
  • 0.7.4 adds support for lax.cumsum and jnp.cumsum, and improves lax.scatter.
  • 0.7.3 improves polymorphism handling for transformers.
  • 0.7.2 adds support for jnp.split, lax.split, lax.logistic, includes an example for nnx.GRUCell, and improves lax.scatter and lax.while_loop.
  • 0.7.1 fixes a numeric equivalence bug in the test system, and adds support for core.custom_jvp_generic, eqx.identity, jnp.select, jnp.stack, jnp.unstack, lax.select, plus multiple nn.* activations (identity, celu, elu, gelu, relu, leaky_relu, mish, selu, sigmoid, soft-sign, softmax, truncated_normal).
  • 0.7.0 introduces a GPT-2 example based on nanoGPT with ONNX function support and attention masking, adds support for jnp.concatenate, jnp.take, nnx.Embed, and starts hosting ONNX models on Hugging Face.
  • 0.6.5 improves support for nnx.batch_norm, nnx.group_norm, nnx.layer_norm, nnx.rms_norm, lax.broadcast_in_dim, lax.cond, lax.fori_loop, lax.integer_pow, lax.scan, lax.scatter, lax.scatter_add, lax.scatter_mul, and lax.while_loop; and adds support for lax.and, lax.rem, and lax.remat2.
  • 0.6.4: Improved support for lax.scatter_mul.
  • 0.6.3: Double precision fixes for lax.fori_loop and lax.while_loop. Fixed bugs in lax.scan and jnp.where.
  • 0.6.2: Fixed bugs in nnx.conv and lax.reshape; added new primitive jnp.prod.
  • 0.6.1: Improved support for lax.cond and lax.select_n; added new primitives (lax.reduce_and, lax.reduce_or, lax.reduce_prod, lax.reduce_xor); and introduced new examples for jnp.select and jnp.sort.
  • 0.6.0: Introduced the enable_double_precision parameter (default: False) to support physics simulations, and enhanced handling of lax.scatter.
  • 0.5.2: Add support for additional primitives: jnp.where, jnp.arange, jnp.linspace.
  • 0.5.1: Add support for subgraph using primitives: lax.while_loop, lax.cond, lax.fori_loop, lax.scan.
  • 0.5.0: Improved dynamic batch dimension handling by leveraging shape polymorphism for more robust and flexible model export. Added support for jnp.sign, jnp.abs, jnp.iota primitives.
  • 0.4.4: Added support for lax.cos, lax.cosh, lax.sin, lax.sinh and lax.scatter primitives.
  • 0.4.3: Fixed a bug in the validation of JAX callable outputs against their ONNX counterparts. This fix exposed previously hidden failing tests, which are now fixed.
  • 0.4.2: Cleanup and fixes to the basic ONNX function release.
  • 0.4.1 (ONNX functions): Introducing simple ONNX function support. Making use of ONNX functions is easy for the user: just a @onnx_function decorator making a callable an ONNX function. Each @onnx_function decorator creates a new ONNX function instance on the call graph.
  • 0.3.2: relaxed the minimum Python version to 3.10.
  • 0.3.0: Streamlined the plugin system with automatic registration and simplified integration of custom primitives.
  • 0.2.0 (First PyPI Release): Rebased the implementation on jaxpr, improving usability and adding low-level lax components.
  • 0.1.0 (Initial Approach, Not Released to PyPI): Produced ONNX exports for some nnx components and nnx-based examples, including a VisualTransformer.

โ“ Troubleshooting

If conversion doesn't work out of the box, it could be due to:

  • Non-dynamic function references:
    JAXPR-based conversion requires function references to be resolved dynamically at call-time.
    Solution: Wrap your function call inside a lambda to enforce dynamic resolution:

    my_dynamic_callable_function = lambda x: original_function(x)
    
  • Unsupported primitives:
    The callable may use a primitive not yet or not fully supported by jax2onnx.
    Solution: Write a plugin to handle the unsupported function (this is straightforward!).


๐Ÿงฉ Supported JAX/ONNX Components

JAX Component ONNX Components Testcases Since
[core.custom_jvp_generic](Generic passthrough for custom JVP calls) CustomJvp custom_jvp_square โœ…
custom_jvp_square_f64 โœ…
v0.7.1
core.dim_as_value Cast
Gather
Reshape
Shape
dim_as_value_dynamic โœ…
dim_as_value โœ…
v0.5.0
eqx.dropout Dropout
Not
eqx_dropout_inference_mode โœ…
eqx_dropout_inference_mode_f64 โœ…
eqx_dropout_training_mode โœ…
eqx_dropout_training_mode_f64 โœ…
eqx_dropout_dynamic_inference โœ…
eqx_dropout_dynamic_inference_f64 โœ…
eqx_dropout_batched_inference_dynamic โœ…
eqx_dropout_batched_inference_dynamic_f64 โœ…
eqx_dropout_batched_inference โœ…
eqx_dropout_batched_inference_f64 โœ…
v0.8.0
eqx.identity Identity eqx_identity_static โœ…
eqx_identity_static_f64 โœ…
eqx_identity_symbolic_batch_dynamic โœ…
eqx_identity_symbolic_batch_dynamic_f64 โœ…
eqx_identity_symbolic_batch โœ…
eqx_identity_symbolic_batch_f64 โœ…
v0.7.1
eqx.layer_norm LayerNormalization layer_norm โœ…
layer_norm_f64 โœ…
layer_norm_multiaxis โœ…
layer_norm_multiaxis_f64 โœ…
batched_layer_norm_dynamic โœ…
batched_layer_norm_dynamic_f64 โœ…
batched_layer_norm โœ…
batched_layer_norm_f64 โœ…
layer_norm_no_bias_no_scale โœ…
layer_norm_no_bias_no_scale_f64 โœ…
v0.8.0
eqx.linear Gemm
Reshape
eqx_linear_symbolic_batch_dynamic โœ…
eqx_linear_symbolic_batch_dynamic_f64 โœ…
eqx_linear_symbolic_batch โœ…
eqx_linear_symbolic_batch_f64 โœ…
eqx_linear_high_rank โœ…
eqx_linear_high_rank_f64 โœ…
eqx_linear_vector โœ…
eqx_linear_vector_f64 โœ…
v0.8.0
jnp.add Add add โœ…
add_f64 โœ…
v0.1.0
jnp.arange Range arange_data_dependent_indices โœ…
arange_stop_only_concrete_input_val โœ…
arange_stop_only_concrete_input_val_f64 โœ…
arange_start_stop_concrete_input_val โœ…
arange_start_stop_concrete_input_val_f64 โœ…
arange_start_stop_step_concrete_input_val โœ…
arange_start_stop_step_concrete_input_val_f64 โœ…
arange_float_concrete_input_val โœ…
arange_float_concrete_input_val_f64 โœ…
arange_static_stop_only_int โœ…
arange_static_stop_only_int_f64 โœ…
arange_static_stop_only_float โœ…
arange_static_stop_only_float_f64 โœ…
arange_static_start_stop_int โœ…
arange_static_start_stop_int_f64 โœ…
arange_static_start_stop_step_int โœ…
arange_static_start_stop_step_int_f64 โœ…
arange_static_empty_result_pos_step โœ…
arange_static_empty_result_pos_step_f64 โœ…
arange_static_empty_result_neg_step โœ…
arange_static_empty_result_neg_step_f64 โœ…
arange_static_negative_step โœ…
arange_static_negative_step_f64 โœ…
arange_static_float_step_explicit_dtype โœ…
arange_static_float_step_explicit_dtype_f64 โœ…
arange_static_float_step_inferred_dtype โœ…
arange_static_float_step_inferred_dtype_f64 โœ…
arange_static_stop_zero โœ…
arange_static_stop_zero_f64 โœ…
arange_static_start_equals_stop โœ…
arange_static_start_equals_stop_f64 โœ…
arange_static_large_numbers_int โœ…
arange_static_large_numbers_int_f64 โœ…
v0.5.2
jnp.clip Max
Min
clip_i32_scalar_bounds โœ…
clip_i32_scalar_bounds_f64 โœ…
clip_f32_scalar_bounds_no_upcast_f64_mode โœ…
clip_only_upper โœ…
clip_only_upper_f64 โœ…
clip_only_lower โœ…
clip_only_lower_f64 โœ…
clip_broadcast_bounds โœ…
v0.7.5
jnp.concatenate Cast
Concat
concatenate_basic โœ…
concatenate_basic_f64 โœ…
concatenate_mixed_dtypes โœ…
concatenate_mixed_dtypes_f64 โœ…
concatenate_with_explicit_dtype โœ…
concatenate_with_explicit_dtype_f64 โœ…
concatenate_with_explicit_dtype_casts_inputs โœ…
concatenate_abstract_middle_dim_dynamic โœ…
concatenate_abstract_middle_dim_dynamic_f64 โœ…
concatenate_abstract_middle_dim โœ…
concatenate_abstract_middle_dim_f64 โœ…
concatenate_tile_and_symbolic_dynamic โœ…
concatenate_tile_and_symbolic_dynamic_f64 โœ…
concatenate_tile_and_symbolic โœ…
concatenate_tile_and_symbolic_f64 โœ…
v0.2.0
jnp.cumsum CumSum cumsum_axis2_i32 โœ…
cumsum_axis2_i32_f64 โœ…
cumsum_axis2_reverse_i32 โœ…
cumsum_axis2_reverse_i32_f64 โœ…
v0.7.4
jnp.einsum Einsum einsum_vector_dot โœ…
einsum_vector_dot_f64 โœ…
einsum_matrix_vector โœ…
einsum_matrix_vector_f64 โœ…
einsum_matrix_matrix_dynamic โœ…
einsum_matrix_matrix_dynamic_f64 โœ…
einsum_matrix_matrix โœ…
einsum_matrix_matrix_f64 โœ…
einsum_transpose โœ…
einsum_transpose_f64 โœ…
einsum_batch_transpose_dynamic โœ…
einsum_batch_transpose_dynamic_f64 โœ…
einsum_batch_transpose โœ…
einsum_batch_transpose_f64 โœ…
einsum_diag โœ…
einsum_diag_f64 โœ…
einsum_sum_reduce โœ…
einsum_sum_reduce_f64 โœ…
einsum_multi_operand โœ…
einsum_multi_operand_f64 โœ…
einsum_attention_logits_orig_dynamic โœ…
einsum_attention_logits_orig_dynamic_f64 โœ…
einsum_attention_logits_orig โœ…
einsum_attention_logits_orig_f64 โœ…
einsum_attention_output_orig_dynamic โœ…
einsum_attention_output_orig_dynamic_f64 โœ…
einsum_attention_output_orig โœ…
einsum_attention_output_orig_f64 โœ…
einsum_attention_logits_batched_dynamic โœ…
einsum_attention_logits_batched_dynamic_f64 โœ…
einsum_attention_logits_batched โœ…
einsum_attention_logits_batched_f64 โœ…
einsum_attention_output_batched_dynamic โœ…
einsum_attention_output_batched_dynamic_f64 โœ…
einsum_attention_output_batched โœ…
einsum_attention_output_batched_f64 โœ…
einsum_ellipsis_rank_mismatch โœ…
einsum_ellipsis_rank_mismatch_f64 โœ…
einsum_attention_logits_batched_rank_mismatch โœ…
einsum_attention_logits_batched_rank_mismatch_f64 โœ…
v0.1.0
jnp.linspace Constant linspace_static_basic โœ…
linspace_static_basic_f64 โœ…
linspace_static_endpoint_false โœ…
linspace_static_endpoint_false_f64 โœ…
linspace_static_num_1 โœ…
linspace_static_num_1_f64 โœ…
linspace_static_num_0 โœ…
linspace_static_num_0_f64 โœ…
linspace_static_int_inputs_default_dtype โœ…
linspace_static_int_inputs_default_dtype_f64 โœ…
v0.5.2
jnp.matmul MatMul matmul_2d โœ…
matmul_2d_f64 โœ…
matmul_1d_2d โœ…
matmul_1d_2d_f64 โœ…
matmul_2d_1d โœ…
matmul_2d_1d_f64 โœ…
matmul_dynamic_dynamic โœ…
matmul_dynamic_dynamic_f64 โœ…
matmul_dynamic โœ…
matmul_dynamic_f64 โœ…
matmul_dynamic_a_dynamic โœ…
matmul_dynamic_a_dynamic_f64 โœ…
matmul_dynamic_a โœ…
matmul_dynamic_a_f64 โœ…
matmul_1d โœ…
matmul_1d_f64 โœ…
matmul_3d โœ…
matmul_3d_f64 โœ…
v0.1.0
jnp.prod ReduceProd basic_prod โœ…
basic_prod_f64 โœ…
prod_with_axis โœ…
prod_with_axis_f64 โœ…
prod_with_keepdims โœ…
prod_with_keepdims_f64 โœ…
v0.6.2
jnp.reshape Reshape reshape_1 โœ…
reshape_1_f64 โœ…
reshape_2 โœ…
reshape_2_f64 โœ…
reshape_3 โœ…
reshape_3_f64 โœ…
reshape_4_dynamic โœ…
reshape_4_dynamic_f64 โœ…
reshape_4 โœ…
reshape_4_f64 โœ…
reshape_to_scalar โœ…
reshape_to_scalar_f64 โœ…
reshape_from_scalar โœ…
reshape_from_scalar_f64 โœ…
reshape_cnn_dynamic โœ…
reshape_cnn_dynamic_f64 โœ…
reshape_cnn โœ…
reshape_cnn_f64 โœ…
reshape_valid_flatten_trailing โœ…
reshape_valid_flatten_trailing_f64 โœ…
reshape_with_target_shape_from_symbolic_dim_computation โœ…
reshape_with_target_shape_from_symbolic_dim_computation_f64 โœ…
v0.1.0
jnp.select Where select_simple โœ…
select_simple_f64 โœ…
select_broadcast โœ…
select_broadcast_f64 โœ…
select_gpt2_attention_mask_dynamic โœ…
select_gpt2_attention_mask_dynamic_f64 โœ…
select_gpt2_attention_mask โœ…
select_gpt2_attention_mask_f64 โœ…
v0.7.1
jnp.shape Shape shape_basic โœ…
shape_basic_f64 โœ…
shape_dynamic_dynamic โœ…
shape_dynamic_dynamic_f64 โœ…
shape_dynamic โœ…
shape_dynamic_f64 โœ…
0.4.0
jnp.sort TopK sort_1d โœ…
sort_1d_f64 โœ…
sort_2d_axis0_dynamic โœ…
sort_2d_axis0_dynamic_f64 โœ…
sort_2d_axis0 โœ…
sort_2d_axis0_f64 โœ…
v0.5.2
jnp.split Split split_by_sections โœ…
split_by_sections_f64 โœ…
split_by_indices โœ…
split_by_indices_f64 โœ…
split_by_indices_symbolic_dynamic โœ…
split_by_indices_symbolic_dynamic_f64 โœ…
split_by_indices_symbolic โœ…
split_by_indices_symbolic_f64 โœ…
v0.7.2
jnp.squeeze Squeeze squeeze_single_dim โœ…
squeeze_single_dim_f64 โœ…
squeeze_multiple_dims โœ…
squeeze_multiple_dims_f64 โœ…
squeeze_vit_output โœ…
squeeze_vit_output_f64 โœ…
squeeze_dynamic_batch_dynamic โœ…
squeeze_dynamic_batch_dynamic_f64 โœ…
squeeze_dynamic_batch โœ…
squeeze_dynamic_batch_f64 โœ…
squeeze_all_dims โœ…
squeeze_all_dims_f64 โœ…
squeeze_negative_axis โœ…
squeeze_negative_axis_f64 โœ…
squeeze_negative_axis_tuple โœ…
squeeze_negative_axis_tuple_f64 โœ…
squeeze_dynamic_and_negative_axis_dynamic โœ…
squeeze_dynamic_and_negative_axis_dynamic_f64 โœ…
squeeze_dynamic_and_negative_axis โœ…
squeeze_dynamic_and_negative_axis_f64 โœ…
v0.1.0
jnp.stack Concat
Unsqueeze
stack_axis_0 โœ…
stack_axis_0_f64 โœ…
stack_axis_1 โœ…
stack_axis_1_f64 โœ…
stack_negative_axis โœ…
stack_negative_axis_f64 โœ…
stack_scalars โœ…
stack_scalars_f64 โœ…
v0.7.1
jnp.take Gather take_data_dependent_indices โœ… v0.7.0
jnp.tile Tile tile_repeats โœ…
tile_repeats_f64 โœ…
tile_a โœ…
tile_a_f64 โœ…
tile_b โœ…
tile_b_f64 โœ…
tile_c โœ…
tile_c_f64 โœ…
tile_d โœ…
tile_d_f64 โœ…
tile_dynamic_input_static โœ…
tile_dynamic_input_static_f64 โœ…
tile_dynamic_input_dynamic โœ…
tile_dynamic_input_dynamic_f64 โœ…
tile_dynamic_input โœ…
tile_dynamic_input_f64 โœ…
tile_pad โœ…
tile_pad_f64 โœ…
tile_with_symbolic_repeats_static โœ…
tile_with_symbolic_repeats_static_f64 โœ…
tile_with_symbolic_repeats_dynamic โœ…
tile_with_symbolic_repeats_dynamic_f64 โœ…
tile_with_symbolic_repeats โœ…
tile_with_symbolic_repeats_f64 โœ…
tile_param_symbolic_dynamic โœ…
tile_param_symbolic_dynamic_f64 โœ…
tile_param_symbolic โœ…
tile_param_symbolic_f64 โœ…
v0.1.0
jnp.transpose Transpose transpose_basic โœ…
transpose_basic_f64 โœ…
transpose_reverse โœ…
transpose_reverse_f64 โœ…
transpose_4d_dynamic โœ…
transpose_4d_dynamic_f64 โœ…
transpose_4d โœ…
transpose_4d_f64 โœ…
transpose_square_matrix โœ…
transpose_square_matrix_f64 โœ…
transpose_high_dim โœ…
transpose_high_dim_f64 โœ…
transpose_no_axes โœ…
transpose_no_axes_f64 โœ…
transpose_3d_dynamic โœ…
transpose_3d_dynamic_f64 โœ…
transpose_3d โœ…
transpose_3d_f64 โœ…
v0.1.0
jnp.unstack Split
Squeeze
unstack_axis_0 โœ…
unstack_axis_0_f64 โœ…
unstack_axis_1 โœ…
unstack_axis_1_f64 โœ…
unstack_negative_axis โœ…
unstack_negative_axis_f64 โœ…
v0.7.1
jnp.where Where where_gpt_mask_scores_literal_else_dynamic โœ…
where_gpt_mask_scores_literal_else_dynamic_f64 โœ…
where_gpt_mask_scores_literal_else โœ…
where_gpt_mask_scores_literal_else_f64 โœ…
where_simple โœ…
where_simple_f64 โœ…
where_broadcast โœ…
where_broadcast_f64 โœ…
where_multidim_condition_scalar_branches_broadcast โœ…
where_multidim_condition_scalar_branches_broadcast_f64 โœ…
where_multidim_condition_scalar_branches_broadcast โœ…
where_multidim_condition_scalar_branches_broadcast_f64 โœ…
where_A โœ…
where_A_f64 โœ…
where_B โœ…
where_B_f64 โœ…
where_gpt_mask_scores_scalar_else_dynamic โœ…
where_gpt_mask_scores_scalar_else_dynamic_f64 โœ…
where_gpt_mask_scores_scalar_else โœ…
where_gpt_mask_scores_scalar_else_f64 โœ…
where_int_condition_cast โœ…
where_int_condition_cast_f64 โœ…
where_literal_else_pyfloat โœ…
where_literal_else_pyfloat_f64 โœ…
where_jax_int_literals_broadcast_f64_mode โœ…
where_dtype_mismatch_f64_vs_i32_promote โœ…
where_simple โœ…
where_simple_f64 โœ…
v0.5.2
lax.abs Abs abs โœ…
abs_f64 โœ…
v0.5.0
lax.add Add add โœ…
add_f64 โœ…
v0.2.0
lax.and And
BitwiseAnd
and_bool โœ…
and_bool_f64 โœ…
and_int โœ…
and_int_f64 โœ…
v0.6.5
lax.argmax ArgMax argmax_float_axis0 โœ…
argmax_float_axis0_f64 โœ…
argmax_float_axis1 โœ…
argmax_float_axis1_f64 โœ…
argmax_boolean_input_axis0_specific_values โœ…
argmax_boolean_input_axis0_specific_values_f64 โœ…
argmax_boolean_input_axis1_specific_values โœ…
argmax_boolean_input_axis1_specific_values_f64 โœ…
argmax_boolean_random_input_axis0 โœ…
argmax_boolean_random_input_axis0_f64 โœ…
v0.2.0
lax.argmin ArgMin argmin_test1 โœ…
argmin_test1_f64 โœ…
argmin_test2 โœ…
argmin_test2_f64 โœ…
v0.2.0
lax.bitwise_not BitwiseNot bitwise_not_bool โœ…
bitwise_not_bool_f64 โœ…
bitwise_not_i32 โœ…
bitwise_not_i32_f64 โœ…
v0.7.5
lax.broadcast_in_dim Expand
Identity
Reshape
broadcast_in_dim โœ…
broadcast_in_dim_f64 โœ…
broadcast_in_dim_2d_to_3d โœ…
broadcast_in_dim_2d_to_3d_f64 โœ…
broadcast_in_dim_scalar โœ…
broadcast_in_dim_scalar_f64 โœ…
broadcast_in_dim_batch_dynamic โœ…
broadcast_in_dim_batch_dynamic_f64 โœ…
broadcast_in_dim_batch โœ…
broadcast_in_dim_batch_f64 โœ…
broadcast_in_dim_dynamic_B_dynamic โœ…
broadcast_in_dim_dynamic_B_dynamic_f64 โœ…
broadcast_in_dim_dynamic_B โœ…
broadcast_in_dim_dynamic_B_f64 โœ…
v0.2.0
lax.clamp Max
Min
clamp_i32_scalar_bounds โœ…
clamp_i32_scalar_bounds_f64 โœ…
clamp_scalar_float_bounds_match_x โœ…
clamp_scalar_float_bounds_match_x_f64 โœ…
clamp_vector_bounds_match โœ…
clamp_pyint_bounds_promote_to_x_dtype โœ…
clamp_pyint_bounds_promote_to_x_dtype_f64 โœ…
v0.7.5
lax.concatenate Cast
Concat
concatenate โœ…
concatenate_f64 โœ…
concatenate_axis1_dynamic โœ…
concatenate_axis1_dynamic_f64 โœ…
concatenate_axis1 โœ…
concatenate_axis1_f64 โœ…
concatenate_axis0 โœ…
concatenate_axis0_f64 โœ…
concatenate_3d โœ…
concatenate_3d_f64 โœ…
concatenate_internal_int32_then_cast_to_f32_zeroarg โœ…
v0.2.0
lax.cond If cond_scalar โœ…
cond_scalar_f64 โœ…
cond_multiple_operands_in_tuple โœ…
cond_multiple_operands_in_tuple_f64 โœ…
cond_my_new_complex_scenario โœ…
cond_my_new_complex_scenario_f64 โœ…
cond_nested_conditional โœ…
cond_nested_conditional_f64 โœ…
cond_variables โœ…
cond_variables_f64 โœ…
cond_internal_constant_f64 โœ…
cond_passthrough_identity โœ…
cond_passthrough_identity_f64 โœ…
cond_with_scatter โœ…
cond_with_scatter_f64 โœ…
v0.5.1
lax.conv Conv conv โœ…
conv2 โœ…
conv_general_dilated_nhwc_output โœ…
v0.2.0
lax.convert_element_type Cast convert_element_type โœ…
convert_element_type_f64 โœ…
v0.2.0
[lax.copy](Handles the JAX primitive lax.copy_p. Note: jax.lax.copy API is removed.) Identity copy_float32_array โœ…
copy_int64_scalar โœ…
<your_current_version>
lax.cos Cos cos โœ…
cos_f64 โœ…
v0.4.4
lax.cosh Cosh cosh โœ…
cosh_f64 โœ…
v0.4.4
lax.cumsum CumSum cumsum_i32_axis2 โœ…
cumsum_i32_axis2_f64 โœ…
cumsum_f32_axism1_reverse โœ…
cumsum_f32_axism1_reverse_f64 โœ…
v0.7.4
lax.device_put Identity device_put_array โœ…
device_put_array_f64 โœ…
device_put_scalar โœ…
device_put_scalar_f64 โœ…
v0.4.0
lax.div Div div โœ…
div_f64 โœ…
v0.2.0
lax.dot_general MatMul dot_general โœ…
dot_general_f64 โœ…
dot_general_lhs1_rhs1 โœ…
dot_general_lhs1_rhs1_f64 โœ…
v0.2.0
lax.dynamic_slice Slice dynamic_slice_test1 โœ…
dynamic_slice_test1_f64 โœ…
dynamic_slice_2d โœ…
dynamic_slice_2d_f64 โœ…
dynamic_slice_3d โœ…
dynamic_slice_3d_f64 โœ…
dynamic_slice_vit_like_dynamic โœ…
dynamic_slice_vit_like_dynamic_f64 โœ…
dynamic_slice_vit_like โœ…
dynamic_slice_vit_like_f64 โœ…
v0.1.0
lax.eq Equal eq โœ…
eq_f64 โœ…
v0.2.0
lax.exp Exp exp โœ…
exp_f64 โœ…
v0.2.0
lax.fori_loop Loop fori_loop_counter โœ…
fori_loop_counter_f64 โœ…
fori_loop_zero โœ…
fori_loop_zero_f64 โœ…
fori_loop_vector โœ…
fori_loop_vector_f64 โœ…
fori_loop_example โœ…
fori_loop_example_f64 โœ…
fori_loop_test โœ…
fori_loop_test_f64 โž–
v0.5.1
lax.gather GatherND gather_trig_where_pipeline_f64_indices_i64 โœ…
gather_trig_where_pipeline_f64_indices_i32 โœ…
gather_f64_data_i64_indices_output_is_f64 โœ…
gather_f64_data_i32_indices_cast_and_output_is_f64 โœ…
gather_static โœ…
gather_static_f64 โœ…
gather_dynamic_batch_simple_index_dynamic โœ…
gather_dynamic_batch_simple_index_dynamic_f64 โœ…
gather_dynamic_batch_simple_index โœ…
gather_dynamic_batch_simple_index_f64 โœ…
v0.2.0
lax.greater_equal GreaterOrEqual greater_equal โœ…
greater_equal_f64 โœ…
v0.7.5
lax.gt Greater gt โœ…
gt_f64 โœ…
v0.2.0
lax.integer_pow Pow integer_pow โœ…
integer_pow_f64 โœ…
v0.2.0
lax.iota Range iota_int32 โœ…
iota_int32_f64 โœ…
iota_float32 โœ…
iota_float32_f64 โœ…
broadcasted_iota โœ…
broadcasted_iota_f64 โœ…
v0.5.0
lax.log Log log โœ…
log_f64 โœ…
v0.2.0
lax.logistic Sigmoid lax_logistic_basic โœ…
lax_logistic_basic_f64 โœ…
v0.7.2
lax.lt Less lt โœ…
lt_f64 โœ…
v0.2.0
lax.max Max max โœ…
max_f64 โœ…
v0.2.0
lax.min Min min_test1 โœ…
min_test1_f64 โœ…
v0.1.0
lax.mul Mul mul_test1 โœ…
mul_test1_f64 โœ…
mul_test2 โœ…
mul_test2_f64 โœ…
mul_pyfloat_promotes_to_array_dtype_f64 โœ…
mul_scalar_broadcast_promote_to_f64 โœ…
v0.1.0
lax.ne Equal
Not
ne โœ…
ne_f64 โœ…
v0.2.0
lax.neg Neg neg โœ…
neg_f64 โœ…
v0.2.0
lax.pad Pad pad_const_1d โœ…
pad_const_1d_f64 โœ…
pad_const_2d โœ…
pad_const_2d_f64 โœ…
pad_const_2d_cval โœ…
pad_const_2d_cval_f64 โœ…
pad_inside_scan_smoke_f64 โœ…
pad_inside_nested_scan_smoke_f64 โœ…
v0.8.0
lax.reduce_and Cast
ReduceMin
reduce_and_all_true โœ…
reduce_and_all_true_f64 โœ…
reduce_and_one_false โœ…
reduce_and_one_false_f64 โœ…
reduce_and_keepdims โœ…
reduce_and_keepdims_f64 โœ…
v0.6.1
lax.reduce_max ReduceMax reduce_max โœ…
reduce_max_f64 โœ…
reduce_max_allaxes โœ…
reduce_max_allaxes_f64 โœ…
reduce_max_keepdims โœ…
reduce_max_keepdims_f64 โœ…
reduce_max_axes_input โœ…
reduce_max_axes_input_f64 โœ…
v0.2.0
lax.reduce_min ReduceMin reduce_min โœ…
reduce_min_f64 โœ…
reduce_min_allaxes โœ…
reduce_min_allaxes_f64 โœ…
reduce_min_keepdims โœ…
reduce_min_keepdims_f64 โœ…
v0.2.0
lax.reduce_or Cast
ReduceMax
reduce_or_all_false โœ…
reduce_or_all_false_f64 โœ…
reduce_or_one_true โœ…
reduce_or_one_true_f64 โœ…
reduce_or_keepdims โœ…
reduce_or_keepdims_f64 โœ…
v0.6.1
lax.reduce_prod ReduceProd reduce_prod โœ…
reduce_prod_f64 โœ…
reduce_prod_allaxes โœ…
reduce_prod_allaxes_f64 โœ…
reduce_prod_keepdims โœ…
reduce_prod_keepdims_f64 โœ…
reduce_prod_dtype_f64 โœ…
reduce_prod_dtype โœ…
v0.6.1
lax.reduce_sum ReduceSum reduce_sum โœ…
reduce_sum_f64 โœ…
reduce_sum_allaxes โœ…
reduce_sum_allaxes_f64 โœ…
reduce_sum_keepdims โœ…
reduce_sum_keepdims_f64 โœ…
reduce_sum_dtype_f64 โœ…
reduce_sum_dtype โœ…
v0.2.0
lax.reduce_xor Cast
Mod
ReduceSum
reduce_xor_all_false โœ…
reduce_xor_all_false_f64 โœ…
reduce_xor_one_true โœ…
reduce_xor_one_true_f64 โœ…
reduce_xor_two_true โœ…
reduce_xor_two_true_f64 โœ…
reduce_xor_keepdims โœ…
reduce_xor_keepdims_f64 โœ…
v0.6.1
lax.rem Div
Mod
rem_int โœ…
rem_int_f64 โœ…
rem_float โœ…
rem_float_f64 โœ…
rem_int_neg โœ…
rem_int_neg_f64 โœ…
rem_float_neg โœ…
rem_float_neg_f64 โœ…
v0.6.5
lax.reshape Reshape reshape โœ…
reshape_f64 โœ…
reshape_valid_squeeze_middle_dim_from_problematic_source โœ…
reshape_valid_squeeze_middle_dim_from_problematic_source_f64 โœ…
reshape_valid_flatten_trailing โœ…
reshape_valid_flatten_trailing_f64 โœ…
reshape_with_target_shape_from_symbolic_dim_computation โœ…
reshape_with_target_shape_from_symbolic_dim_computation_f64 โœ…
reshape_with_inferred_dimension_from_input_dynamic_dynamic โœ…
reshape_with_inferred_dimension_from_input_dynamic_dynamic_f64 โœ…
reshape_with_inferred_dimension_from_input_dynamic โœ…
reshape_with_inferred_dimension_from_input_dynamic_f64 โœ…
reshape_with_inferred_dimension_from_input โœ…
reshape_with_inferred_dimension_from_input_f64 โœ…
reshape_merge_symbolic_with_static_and_check_name_dynamic โœ…
reshape_merge_symbolic_with_static_and_check_name โœ…
v0.2.0
lax.rev Flip rev_vector โœ…
rev_vector_f64 โœ…
rev_matrix_axes01 โœ…
rev_matrix_axes01_f64 โœ…
v0.7.5
lax.scan Scan scan_cumsum โœ…
scan_cumsum_f64 โœ…
scan_carry_only โœ…
scan_carry_only_f64 โœ…
scan_multiple_sequences โœ…
scan_multiple_sequences_f64 โœ…
scan_multiple_carry โœ…
scan_multiple_carry_f64 โœ…
scan_matrix_carry_multidim_xs โœ…
scan_matrix_carry_multidim_xs_f64 โœ…
scan_no_xs โœ…
scan_no_xs_f64 โœ…
scan_fn โœ…
scan_fn_f64 โœ…
scan_jit_no_xs โœ…
scan_jit_no_xs_f64 โœ…
scan_captured_scalar โœ…
scan_captured_scalar_f64 โœ…
scan_rank0_sequence_vectorized โœ…
scan_rank0_sequence_vectorized_f64 โœ…
scan_two_diff_lengths โœ…
scan_two_diff_lengths_f64 โœ…
scan_two_diff_lengths โœ…
scan_two_diff_lengths_f64 โœ…
scan_nested_len_mismatch โœ…
scan_nested_len_mismatch_f64 โœ…
scan_two_diff_lengths_broadcast โœ…
scan_two_diff_lengths_broadcast_f64 โœ…
scan_two_diff_lengths_with_broadcast โœ…
scan_two_diff_lengths_f64 โœ…
scan_captured_scalar_with_xs โœ…
scan_captured_vector_with_xs_f64 โœ…
v0.5.1
lax.scatter ScatterND scatter_set_axis0 โœ…
scatter_set_axis0_f64 โœ…
scatter_set_middle โœ…
scatter_set_middle_f64 โœ…
scatter_correct_axis_determination โœ…
scatter_correct_axis_determination_f64 โœ…
scatter_updates_slice_needed_axis0 โœ…
scatter_updates_slice_needed_axis0_f64 โœ…
scatter_from_user_warning_shapes_valid_jax โœ…
scatter_from_user_warning_shapes_valid_jax_f64 โœ…
scatter_user_error_scenario_precise โœ…
scatter_user_error_scenario_precise_f64 โœ…
scatter_window_update_f64 โœ…
scatter_window_update_depth3_shapes_ok โœ…
scatter_static_slice_set_f64 โœ…
scatter_depth2_fp64_type_mismatch โœ…
scatter_simple_2d_window_out_of_bounds โœ…
scatter_clip_2d_window_at_edge โœ…
scatter_depth2_mixed_dtypes_fp_mismatch_f64 โœ…
scatter_depth2_mixed_dtypes_fp_mismatch โœ…
v0.4.4
lax.scatter_add ScatterND scatter_add_simple_1d โœ…
scatter_add_simple_1d_f64 โœ…
scatter_add_window_2d_operand_1d_indices โœ…
scatter_add_window_2d_operand_1d_indices_f64 โœ…
scatter_add_batch_updates_1d_operand โœ…
scatter_add_batch_updates_1d_operand_f64 โœ…
scatter_add_mismatched_window_dims_from_user_report โœ…
scatter_add_mismatched_window_dims_from_user_report2 โœ…
scatter_add_mismatched_window_dims_from_user_report3 โœ…
scatter_add_fluids_pattern_updates_5_4_1_1 โœ…
scatter_add_in_cond_float64 โœ…
scatter_add_fp64_dtype_mismatch โœ…
scatter_add_depth2_depth2_helper_regression โœ…
scatter_depth2_fp64_type_mismatch โœ…
v0.5.3
lax.scatter_max ScatterND scatter_max_simple_1d โœ…
scatter_max_simple_1d_f64 โœ…
scatter_max_window_2d_operand_1d_indices โœ…
scatter_max_window_2d_operand_1d_indices_f64 โœ…
scatter_max_batch_updates_1d_operand โœ…
scatter_max_batch_updates_1d_operand_f64 โœ…
scatter_max_fp64_dtype_path_check โœ…
scatter_max_depth2_helper_regression_fp64 โœ…
v0.7.5
lax.scatter_min ScatterND scatter_min_simple_1d โœ…
scatter_min_simple_1d_f64 โœ…
scatter_min_window_2d_operand_1d_indices โœ…
scatter_min_window_2d_operand_1d_indices_f64 โœ…
scatter_min_batch_updates_1d_operand โœ…
scatter_min_batch_updates_1d_operand_f64 โœ…
scatter_min_fp64_dtype_path_check โœ…
scatter_min_depth2_helper_regression_fp64 โœ…
v0.7.5
lax.scatter_mul ScatterND scatter_mul_simple_1d โœ…
scatter_mul_simple_1d_f64 โœ…
scatter_mul_window_2d_operand_1d_indices โœ…
scatter_mul_window_2d_operand_1d_indices_f64 โœ…
scatter_mul_batch_updates_1d_operand โœ…
scatter_mul_batch_updates_1d_operand_f64 โœ…
scatter_mul_mismatched_window_dims_from_user_report โœ…
scatter_mul_mismatched_window_dims_from_user_report2 โœ…
scatter_mul_mismatched_window_dims_from_user_report3 โœ…
scatter_mul_fluids_pattern_updates_5_4_1_1 โœ…
scatter_mul_in_cond_float64 โœ…
v0.6.4
lax.select Where select_simple โœ…
select_simple_f64 โœ…
select_mask_scores_tensor_else_dynamic โœ…
select_mask_scores_tensor_else_dynamic_f64 โœ…
select_mask_scores_tensor_else โœ…
select_mask_scores_tensor_else_f64 โœ…
v0.7.1
lax.select_n Where select_n_bool_predicate_two_cases_float โœ…
select_n_bool_predicate_two_cases_float_f64 โœ…
select_n_bool_predicate_two_cases_int โœ…
select_n_bool_predicate_two_cases_int_f64 โœ…
select_n_bool_predicate_scalar_broadcast โœ…
select_n_bool_predicate_scalar_broadcast_f64 โœ…
select_n_int_indices_three_cases โœ…
select_n_int_indices_three_cases_f64 โœ…
select_n_int_indices_four_cases โœ…
select_n_int_indices_four_cases_f64 โœ…
v0.2.0
lax.sign Sign sign โœ…
sign_f64 โœ…
v0.5.0
lax.sin Sin sin โœ…
sin_f64 โœ…
v0.4.4
lax.sinh Sinh sinh โœ…
sinh_f64 โœ…
v0.4.4
lax.slice Slice slice_test1 โœ…
slice_test1_f64 โœ…
slice_3d_none_strides โœ…
slice_3d_none_strides_f64 โœ…
slice_scan_axis_drop โœ…
slice_scan_axis_drop_f64 โœ…
v0.1.0
lax.sort TopK sort_1d โœ…
sort_1d_f64 โœ…
sort_2d โœ…
sort_2d_f64 โœ…
v0.2.0
lax.split Split lax_split_equal_parts โœ…
lax_split_equal_parts_f64 โœ…
lax_split_unequal_parts โœ…
lax_split_unequal_parts_f64 โœ…
v0.7.2
lax.sqrt Sqrt sqrt โœ…
sqrt_f64 โœ…
v0.2.0
lax.square Mul square โœ…
square_f64 โœ…
v0.2.0
lax.squeeze Squeeze lax_squeeze_specific_axis_0 โœ…
lax_squeeze_specific_axis_0_f64 โœ…
lax_squeeze_multiple_axes โœ…
lax_squeeze_multiple_axes_f64 โœ…
lax_squeeze_no_op_empty_dims โœ…
lax_squeeze_no_op_empty_dims_f64 โœ…
lax_squeeze_problem_case_input_squeeze_only_axis_0 โœ…
lax_squeeze_problem_case_input_squeeze_only_axis_0_f64 โœ…
lax_squeeze_problem_case_input_squeeze_axes_0_2 โœ…
lax_squeeze_problem_case_input_squeeze_axes_0_2_f64 โœ…
lax_squeeze_problem_case_input_squeeze_all_dims_explicitly โœ…
lax_squeeze_problem_case_input_squeeze_all_dims_explicitly_f64 โœ…
v0.2.0
lax.stop_gradient Identity stop_gradient โœ…
stop_gradient_f64 โœ…
v0.2.0
lax.sub Sub sub_test1 โœ…
sub_test1_f64 โœ…
sub_test2 โœ…
sub_test2_f64 โœ…
v0.1.0
lax.tanh Tanh tanh โœ…
tanh_f64 โœ…
v0.2.0
lax.transpose Transpose transpose_basic โœ…
transpose_basic_f64 โœ…
v0.2.0
lax.while_loop Loop while_loop_counter โœ…
while_loop_counter_f64 โœ…
while_loop_vector โœ…
while_loop_vector_f64 โœ…
while_loop_f64 โœ…
while_loop_multi_state_f32 โœ…
while_loop_multi_state_f64 โœ…
while_loop_with_closure โœ…
while_loop_with_closure_f64 โœ…
while_loop_basic โœ…
while_loop_two_state โœ…
while_loop_captured_tracer โœ…
while_loop_with_scalar_state โœ…
while_loop_renamed_passthrough โœ…
while_loop_closure_topo โœ…
while_loop_mixed_rank โœ…
while_loop_tracer_passthrough โœ…
while_loop_no_loop_output_reused_as_input โœ…
while_loop_4d_and_scalar_state โœ…
while_loop_4d_and_scalar_state_f64 โœ…
while_loop_cnn_scalar_state_bug โœ…
while_loop_cnn_scalar_state_bug_f64 โœ…
while_loop_nnx_repro โœ…
while_loop_nnx_repro_f64 โœ…
v0.5.1
nn.celu Celu jaxnn_celu โœ…
jaxnn_celu_f64 โœ…
jaxnn_celu_1 โœ…
jaxnn_celu_1_f64 โœ…
v0.7.1
nn.dot_product_attention Add
Cast
MatMul
Mul
Not
Softmax
Transpose
Where
dpa_basic โœ…
dpa_basic_f64 โœ…
dpa_positional_bias_mask โœ…
dpa_positional_bias_mask_f64 โœ…
dpa_diff_heads_embed โœ…
dpa_diff_heads_embed_f64 โœ…
dpa_batch4_seq16 โœ…
dpa_batch4_seq16_f64 โœ…
dpa_float64 โœ…
dpa_float64_f64 โœ…
dpa_heads1_embed4 โœ…
dpa_heads1_embed4_f64 โœ…
dpa_heads8_embed8 โœ…
dpa_heads8_embed8_f64 โœ…
dpa_batch1_seq2 โœ…
dpa_batch1_seq2_f64 โœ…
dpa_batch8_seq4 โœ…
dpa_batch8_seq4_f64 โœ…
dpa_axis1 โœ…
dpa_axis1_f64 โœ…
dpa_with_tensor_mask โœ…
dpa_with_tensor_mask_f64 โœ…
dpa_tiny_mask_all_valid โœ…
dpa_tiny_mask_all_valid_f64 โœ…
dpa_tiny_mask_mixed โœ…
dpa_tiny_mask_mixed_f64 โœ…
dpa_one_false โœ…
dpa_one_false_f64 โœ…
dpa_mostly_false โœ…
dpa_mostly_false_f64 โœ…
dpa_with_causal_mask โœ…
dpa_with_causal_mask_f64 โœ…
dpa_with_padding_mask โœ…
dpa_with_padding_mask_f64 โœ…
dpa_with_local_window_mask โœ…
dpa_with_local_window_mask_f64 โœ…
dpa_mask_none โœ…
v0.1.0
nn.elu Elu jaxnn_elu โœ…
jaxnn_elu_f64 โœ…
jaxnn_elu_1 โœ…
jaxnn_elu_1_f64 โœ…
v0.7.1
nn.gelu Gelu jaxnn_gelu โœ…
jaxnn_gelu_f64 โœ…
jaxnn_gelu_1 โœ…
jaxnn_gelu_1_f64 โœ…
jaxnn_gelu_approx โœ…
jaxnn_gelu_approx_f64 โœ…
v0.7.1
nn.identity Identity jaxnn_identity โœ…
jaxnn_identity_f64 โœ…
jaxnn_identity_1 โœ…
jaxnn_identity_1_f64 โœ…
v0.7.1
nn.leaky_relu LeakyRelu jaxnn_leaky_relu โœ…
jaxnn_leaky_relu_f64 โœ…
jaxnn_leaky_relu_1 โœ…
jaxnn_leaky_relu_1_f64 โœ…
v0.7.1
nn.mish Mish jaxnn_mish โœ…
jaxnn_mish_f64 โœ…
jaxnn_mish_1 โœ…
jaxnn_mish_1_f64 โœ…
v0.7.1
nn.relu Relu jaxnn_relu โœ…
jaxnn_relu_f64 โœ…
jaxnn_relu_1 โœ…
jaxnn_relu_1_f64 โœ…
v0.7.1
nn.selu Selu jaxnn_selu โœ…
jaxnn_selu_f64 โœ…
jaxnn_selu_1 โœ…
jaxnn_selu_1_f64 โœ…
v0.7.1
nn.sigmoid Sigmoid jaxnn_sigmoid โœ…
jaxnn_sigmoid_f64 โœ…
jaxnn_sigmoid_1 โœ…
jaxnn_sigmoid_1_f64 โœ…
v0.7.1
nn.soft_sign Softsign jaxnn_soft_sign โœ…
jaxnn_soft_sign_f64 โœ…
jaxnn_soft_sign_1 โœ…
jaxnn_soft_sign_1_f64 โœ…
v0.7.1
nn.softmax Softmax softmax โœ…
softmax_f64 โœ…
softmax_2d โœ…
softmax_2d_f64 โœ…
softmax_3d โœ…
softmax_3d_f64 โœ…
v0.1.0
nn.softplus Softplus jaxnn_softplus โœ…
jaxnn_softplus_f64 โœ…
jaxnn_softplus_1 โœ…
jaxnn_softplus_1_f64 โœ…
v0.7.1
nn.truncated_normal โž– initializer โœ…
random_truncated_normal_positional โœ…
flax_dense_like_init โœ…
v0.7.1
nnx.avg_pool AveragePool
Transpose
avg_pool_dynamic โœ…
avg_pool โœ…
avg_pool_same_padding_dynamic โœ…
avg_pool_same_padding โœ…
avg_pool_default_padding_dynamic โœ…
avg_pool_default_padding โœ…
avg_pool_stride1_dynamic โœ…
avg_pool_stride1 โœ…
avg_pool_win3x3_stride2_dynamic โœ…
avg_pool_win3x3_stride2 โœ…
avg_pool_stride_none_dynamic โœ…
avg_pool_stride_none โœ…
avg_pool_count_include_pad_false_dynamic โœ…
avg_pool_count_include_pad_false โœ…
v0.1.0
nnx.batch_norm BatchNormalization batch_norm_no_bias_no_scale_dynamic โœ…
batch_norm_no_bias_no_scale โœ…
batch_norm_bias_no_scale_dynamic โœ…
batch_norm_bias_no_scale โœ…
batch_norm_no_bias_scale_dynamic โœ…
batch_norm_no_bias_scale โœ…
batch_norm_bias_scale_dynamic โœ…
batch_norm_bias_scale โœ…
batch_norm_3d_dynamic โœ…
batch_norm_3d โœ…
batch_norm_4d_dynamic โœ…
batch_norm_4d โœ…
batch_norm_4d_no_bias_no_scale_dynamic โœ…
batch_norm_4d_no_bias_no_scale โœ…
v0.1.0
nnx.conv Conv
Transpose
conv_basic_bias_dynamic โœ…
conv_basic_bias โœ…
conv_basic_bias_2 โœ…
conv_basic_bias_3 โœ…
conv_stride2_bias โœ…
conv_no_bias_dynamic โœ…
conv_no_bias โœ…
conv_valid_padding โœ…
conv_stride1 โœ…
conv_stride2 โœ…
conv_different_kernel โœ…
conv_float64 โœ…
conv_single_batch โœ…
conv_large_batch โœ…
conv_1d โœ…
conv_1d_more_1d_inputs โœ…
conv_1d_more_2d_inputs โœ…
conv_1d_large_kernel โœ…
conv_1d_dilation โœ…
conv_1d_stride_dilation โœ…
conv_2d_asymmetric_kernel โœ…
conv_2d_asymmetric_stride โœ…
conv_2d_asymmetric_dilation โœ…
conv_2d_large_dilation โœ…
conv_2d_large_stride โœ…
conv_2d_mixed_params โœ…
conv_3d_basic โœ…
conv_3d_stride โœ…
conv_3d_asymmetric โœ…
conv_3d_dilation โœ…
conv_2d_small_input โœ…
conv_2d_many_channels โœ…
conv_1d_wide_input โœ…
conv_2d_kernel_1x1 โœ…
conv_1d_kernel_1 โœ…
conv_2d_group_conv โœ…
conv_1d_group_conv_more_dims โœ…
conv_2d_depthwise โœ…
conv_1d_complex_on_4d โœ…
conv_2d_complex_on_5d โœ…
conv_1d_high_dilation_on_3d โœ…
conv_2d_group_stride_dilation โœ…
conv_1d_group_on_higher_dim โœ…
conv_1d_same_padding_on_3d โœ…
conv_2d_same_padding_mixed_dilation โœ…
conv_1d_large_kernel_on_4d โœ…
conv_2d_asymmetric_on_5d โœ…
conv_3d_group_complex โœ…
conv_1d_unit_group_on_multi_dim โœ…
v0.1.0
nnx.dot_product_attention Cast
Div
Einsum
Gather
Shape
Softmax
Sqrt
dpa_basic โœ…
dpa_basic_f64 โœ…
dpa_with_tensor_mask โœ…
dpa_with_bias โœ…
dpa_with_causal_mask โœ…
dpa_with_causal_mask_f64 โœ…
dpa_with_mask_and_bias โœ…
v0.1.0
nnx.dropout Dropout dropout_init_params_dynamic โœ…
dropout_init_params_dynamic_f64 โœ…
dropout_init_params โœ…
dropout_init_params_f64 โœ…
dropout_call_params_dynamic โœ…
dropout_call_params_dynamic_f64 โœ…
dropout_call_params โœ…
dropout_call_params_f64 โœ…
v0.1.0
nnx.einsum Add
Einsum
einsum_module_with_bias โœ…
einsum_module_with_bias_f64 โœ…
einsum_module_no_bias โœ…
einsum_module_no_bias_f64 โœ…
v0.4.2
nnx.elu Elu elu โœ… v0.1.0
nnx.embed Gather token_embedding_dynamic โœ…
token_embedding_dynamic_f64 โœ…
token_embedding โœ…
token_embedding_f64 โœ…
positional_embedding_dynamic โœ…
positional_embedding_dynamic_f64 โœ…
positional_embedding โœ…
positional_embedding_f64 โœ…
v0.7.0
nnx.gelu Gelu gelu โœ…
gelu_1 โœ…
gelu_2 โœ…
gelu_2_f64 โœ…
gelu_3_dynamic โœ…
gelu_3_dynamic_f64 โœ…
gelu_3 โœ…
gelu_3_f64 โœ…
v0.1.0
nnx.group_norm GroupNormalization group_norm โœ…
group_norm_no_bias_no_scale_dynamic โœ…
group_norm_no_bias_no_scale โœ…
group_norm_bias_no_scale_dynamic โœ…
group_norm_bias_no_scale โœ…
group_norm_no_bias_scale_dynamic โœ…
group_norm_no_bias_scale โœ…
group_norm_bias_scale_dynamic โœ…
group_norm_bias_scale โœ…
v0.3.0
nnx.layer_norm LayerNormalization layer_norm_dynamic โœ…
layer_norm โœ…
layer_norm_no_bias_no_scale_dynamic โœ…
layer_norm_no_bias_no_scale โœ…
layer_norm_bias_no_scale_dynamic โœ…
layer_norm_bias_no_scale โœ…
layer_norm_no_bias_scale_dynamic โœ…
layer_norm_no_bias_scale โœ…
layer_norm_bias_scale_dynamic โœ…
layer_norm_bias_scale โœ…
layer_norm_multiaxis_dynamic โœ…
layer_norm_multiaxis โœ…
layer_norm_symbolic_batch_dynamic โœ…
layer_norm_symbolic_batch โœ…
layer_norm_negative_axis_no_div_dynamic โœ…
layer_norm_negative_axis_no_div โœ…
v0.1.0
nnx.leaky_relu LeakyRelu leaky_relu โœ… v0.1.0
nnx.linear Gemm
Reshape
linear_symbolic_batch_dynamic โœ…
linear_symbolic_batch_dynamic_f64 โœ…
linear_symbolic_batch โœ…
linear_symbolic_batch_f64 โœ…
linear_high_rank_dynamic โœ…
linear_high_rank_dynamic_f64 โœ…
linear_high_rank โœ…
linear_high_rank_f64 โœ…
linear_no_bias_dynamic โœ…
linear_no_bias_dynamic_f64 โœ…
linear_no_bias โœ…
linear_no_bias_f64 โœ…
linear_high_rank_no_bias_dynamic โœ…
linear_high_rank_no_bias_dynamic_f64 โœ…
linear_high_rank_no_bias โœ…
linear_high_rank_no_bias_f64 โœ…
linear_merge_symbolic_dim_dynamic โœ…
v0.1.0
nnx.linear_general Gemm
Reshape
linear_general_merge_symbolic_dim_dynamic โœ…
linear_general_dynamic โœ…
linear_general โœ…
linear_general_2 โœ…
linear_general_3 โœ…
linear_general_4 โœ…
linear_general_abstract_eval_axes โœ…
linear_general_abstract_eval_axes_pair โœ…
dynamic_batch_and_feature_dims_dynamic โœ…
v0.1.0
nnx.log_softmax LogSoftmax log_softmax โœ…
log_softmax_f64 โœ…
v0.1.0
nnx.max_pool MaxPool
Transpose
max_pool โœ…
max_pool_same_padding โœ…
v0.1.0
nnx.relu Relu relu_1d โœ…
relu_1d_f64 โœ…
relu_4d_dynamic โœ…
relu_4d_dynamic_f64 โœ…
relu_4d โœ…
relu_4d_f64 โœ…
v0.1.0
nnx.rms_norm RMSNormalization rms_norm_basic โœ…
rms_norm_use_scale_false โœ…
rms_norm_4d_dynamic_dynamic โœ…
rms_norm_4d_dynamic โœ…
rms_norm_4d_dynamic_no_scale_dynamic โœ…
rms_norm_4d_dynamic_no_scale โœ…
v0.3.0
nnx.sigmoid Sigmoid sigmoid โœ…
sigmoid_f64 โœ…
v0.1.0
nnx.softmax Softmax softmax_dynamic โœ…
softmax_dynamic_f64 โœ…
softmax โœ…
softmax_f64 โœ…
v0.1.0
nnx.softplus Softplus softplus โœ… v0.1.0
nnx.tanh Tanh tanh โœ…
tanh_f64 โœ…
v0.1.0

Legend:
โœ… = Passed
โŒ = Failed
โž– = No testcase yet


๐ŸŽฏ Examples

Component Description Testcases Since
MlpExample A simple MLP example using Equinox. mlp_training_mode โœ…
mlp_training_mode_f64 โœ…
mlp_inference_mode โœ…
mlp_inference_mode_f64 โœ…
mlp_batched_training_mode_dynamic โœ…
mlp_batched_training_mode_dynamic_f64 โœ…
mlp_batched_training_mode โœ…
mlp_batched_training_mode_f64 โœ…
v0.8.0
SimpleLinearExample A simple linear layer example using Equinox. simple_linear_dynamic โœ…
simple_linear_dynamic_f64 โœ…
simple_linear โœ…
simple_linear_f64 โœ…
nn_linear_dynamic โœ…
nn_linear_dynamic_f64 โœ…
nn_linear โœ…
nn_linear_f64 โœ…
v0.7.1
GPT A simple GPT model that reuses nnx.MultiHeadAttention. gpt_dynamic โœ…
gpt โœ…
v0.7.0
GPT_Attention A multi-head attention layer. gpt_attention โœ… v0.7.1
GPT_CausalSelfAttention A causal self-attention module. causal_self_attention_dynamic โœ…
causal_self_attention โœ…
v0.7.0
GPT_Embeddings Combines token and position embeddings with dropout. gpt_embeddings_dynamic โœ…
gpt_embeddings โœ…
v0.7.0
GPT_Head The head of the GPT model. gpt_head_dynamic โœ…
gpt_head โœ…
v0.7.0
GPT_MLP An MLP block with GELU activation from nanoGPT. gpt_mlp_dynamic โœ…
gpt_mlp โœ…
v0.7.0
GPT_PositionEmbedding A positional embedding layer using nnx.Embed. position_embedding โœ… v0.7.0
GPT_TokenEmbedding A token embedding layer using nnx.Embed. token_embedding_dynamic โœ…
token_embedding โœ…
v0.7.0
GPT_TransformerBlock A transformer block combining attention and MLP. gpt_block_dynamic โœ…
gpt_block โœ…
v0.7.0
GPT_TransformerStack A stack of transformer blocks. transformer_stack_dynamic โœ…
transformer_stack โœ…
v0.7.0
broadcast_add Simple dynamic broadcast + add broadcast_add_dynamic_dynamic โœ…
broadcast_add_dynamic_dynamic_f64 โœ…
broadcast_add_dynamic โœ…
broadcast_add_dynamic_f64 โœ…
v0.7.0
cfl_timestep Tests the CFL condition timestep calculation. cfl_timestep_f64 โœ… v0.6.5
weno_reconstruction Tests the complex arithmetic pattern found in WENO schemes. weno_reconstruction_f64 โœ… v0.6.5
fori_loop_test fori_loop_test: Demonstrates jax.lax.fori_loop with a simple loop. fori_loop_test โœ…
fori_loop_test_f64 โž–
v0.6.3
issue18_abs Test jnp.abs from issue 18 abs_fn โœ…
abs_fn_f64 โœ…
v0.6.3
issue18_arange Test arange from issue 18 arange_fn โœ…
arange_fn_f64 โœ…
v0.6.3
issue18_fori_loop Test fori_loop from issue 18 fori_loop_fn โœ…
fori_loop_fn_f64 โœ…
v0.6.3
issue18_linspace Test linspace from issue 18 linspace_fn โœ…
linspace_fn_f64 โœ…
v0.6.3
issue18_scan Test scan from issue 18 (no xs) scan_fn โœ…
scan_fn_f64 โœ…
v0.6.3
issue18_sign Test jnp.sign from issue 18 sign_fn โœ…
sign_fn_f64 โœ…
v0.6.3
issue18_where Test where from issue 18 where_fn โœ…
where_fn_f64 โœ…
v0.6.3
issue18_while_loop Test while_loop from issue 18 while_loop_fn โœ… v0.6.3
select_test select_test: Demonstrates jnp.select with a dynamic condition based on an input array. select_test_all_options โž–
select_test_scalar_select_option_0 โž–
select_test_scalar_select_option_1 โž–
select_test_scalar_select_option_2 โž–
select_test_default_case โž–
v0.6.1
sort_test sort_test: Demonstrates jnp.sort on slices of an input array. sort_test_basic โž– v0.6.1
cond_scatter_add_mul Tests scatter_add/mul inside jnp.where branches cond_scatter_add_mul_f64_a โœ…
cond_scatter_add_mul_f64_b โœ…
v0.6.4
cond_scatter_repro Reproduces a bug where lax.cond subgraphs do not inherit parent initializers. cond_scatter_repro_f64 โœ… v0.6.4
remat2 Tests a simple case of jax.checkpoint (also known as jax.remat2). checkpoint_scalar_f32 โœ…
checkpoint_scalar_f32_f64 โœ…
v0.6.5
scatter_window Window-scatter (Hร—W patch) with implicit batch (depth-3 path). Exercises GatherScatterMode.FILL_OR_DROP and double precision. Regression of a prior conversion failure. scatter_window_update_f64_example โœ… v0.7.4
AutoEncoder A simple autoencoder example. simple_autoencoder โœ…
simple_autoencoder_f64 โœ…
v0.2.0
CNN A simple convolutional neural network (CNN). simple_cnn_explicit_dimensions โœ…
simple_cnn_dynamic โœ…
simple_cnn โœ…
v0.1.0
ForiLoop fori_loop example fori_loop_counter โœ…
fori_loop_counter_f64 โœ…
v0.5.1
GRUCell Vanilla gated-recurrent-unit cell from Flax/nnx. There is no 1-to-1 ONNX operator, so the converter decomposes it into MatMul, Add, Sigmoid, Tanh, etc. gru_cell_basic โœ… v0.7.2
MLP A simple Multi-Layer Perceptron (MLP) with BatchNorm, Dropout, and GELU activation. simple_mlp_dynamic โœ…
simple_mlp_dynamic_f64 โœ…
simple_mlp โœ…
simple_mlp_f64 โœ…
simple_mlp_with_call_params_dynamic โœ…
simple_mlp_with_call_params_dynamic_f64 โœ…
simple_mlp_with_call_params โœ…
simple_mlp_with_call_params_f64 โœ…
v0.1.0
MultiHeadAttention This is a multi-head attention module implemented by Flax/nnx that has no ONNX correspondent on the same granularity. multihead_attention_nn_dynamic โœ…
multihead_attention_nn โœ…
multihead_attention_nnx_dynamic โœ…
multihead_attention_nnx โœ…
multihead_attention_2_nnx_dynamic โœ…
multihead_attention_2_nnx โœ…
v0.2.0
SequentialReLU Two ReLU activations chained with nnx.Sequential (no parameters). sequential_double_relu โœ…
sequential_double_relu_f64 โœ…
v0.7.1
SequentialWithResidual Tests nnx.Sequential nested inside a module with a residual connection. sequential_nested_with_residual โœ… v0.7.1
TransformerDecoderWithSequential A single-layer Transformer decoder built with nnx primitives (MHA, LayerNorm, Feed-Forward, Dropout). tiny_decoder_with_sequential_dynamic โœ…
tiny_decoder_with_sequential โœ…
tiny_decoder_with_sequential_and_full_dynamic_shapes_dynamic โœ…
v0.7.1
TransformerDecoderWithoutSequential A single-layer Transformer decoder built with nnx primitives (MHA, LayerNorm, Feed-Forward, Dropout). tiny_decoder_without_sequential_dynamic โœ…
tiny_decoder_without_sequential โœ…
v0.7.1
onnx_functions_000 one function on an outer layer. 000_one_function_on_outer_layer_dynamic โœ…
000_one_function_on_outer_layer โœ…
v0.4.0
onnx_functions_001 one function on an inner layer. 001_one_function_inner_dynamic โœ…
001_one_function_inner โœ…
v0.4.0
onnx_functions_002 two nested functions. 002_two_nested_functions_dynamic โœ…
002_two_nested_functions โœ…
v0.4.0
onnx_functions_003 two nested functions. 003_two_simple_nested_functions_dynamic โœ…
003_two_simple_nested_functions โœ…
v0.4.0
onnx_functions_004 nested function plus component 004_nested_function_plus_component_dynamic โœ…
004_nested_function_plus_component โœ…
v0.4.0
onnx_functions_005 nested function plus more components 005_nested_function_plus_component_dynamic โœ…
005_nested_function_plus_component โœ…
v0.4.0
onnx_functions_006 one function on an outer layer. 006_one_function_outer_dynamic โœ…
006_one_function_outer โœ…
v0.4.0
onnx_functions_007 transformer block with nested mlp block with call parameter 007_transformer_block_dynamic โœ…
007_transformer_block โœ…
v0.4.0
onnx_functions_008 transformer block with nested mlp block no call parameter 008_transformer_block_dynamic โœ…
008_transformer_block โœ…
v0.4.0
onnx_functions_009 transformer block using decorator on class and function 009_transformer_block_dynamic โœ…
009_transformer_block โœ…
v0.4.0
onnx_functions_010 transformer stack 010_transformer_stack_dynamic โœ…
010_transformer_stack โœ…
v0.4.0
onnx_functions_012 Vision Transformer (ViT) 012_vit_conv_embedding_dynamic โœ…
012_vit_conv_embedding โœ…
v0.4.0
onnx_functions_013 Vision Transformer (ViT) 013_vit_conv_embedding_with_call_params_dynamic โœ…
013_vit_conv_embedding_with_call_params โœ…
013_vit_conv_embedding_with_internal_call_params_dynamic โœ…
013_vit_conv_embedding_with_internal_call_params โœ…
v0.4.0
onnx_functions_014 one function on an outer layer. 014_one_function_with_input_param_with_default_value โœ…
014_one_function_without_input_param_with_default_value_dynamic โœ…
014_one_function_without_input_param_with_default_value โœ…
v0.4.0
onnx_functions_015 one function on an outer layer. 015_one_function_with_input_param_without_default_value_dynamic โœ…
015_one_function_with_input_param_without_default_value โœ…
v0.4.0
onnx_functions_016 nested function plus more components 016_internal_function_with_input_param_with_default_value_dynamic โœ…
016_internal_function_with_input_param_with_default_value โœ…
v0.4.0
ClassificationHead Classification head for Vision Transformer classification_head_dynamic โœ…
classification_head โœ…
v0.4.0
ClassificationHeadFlatten Classification head for Vision Transformer classification_head_flat_dynamic โœ…
classification_head_flat โœ…
v0.4.0
ConcatClsToken Concatenate CLS token to the input embedding concat_cls_token_dynamic โœ…
concat_cls_token โœ…
v0.4.0
ConcatClsTokenFlatten Concatenate CLS token to the input embedding concat_cls_token_flat_dynamic โœ…
concat_cls_token_flat โœ…
v0.4.0
ConvEmbedding Convolutional Token Embedding for MNIST with hierarchical downsampling. mnist_conv_embedding_dynamic โœ…
mnist_conv_embedding โœ…
v0.1.0
ConvEmbeddingFlatten Convolutional Token Embedding for MNIST with hierarchical downsampling. mnist_conv_embedding_flat_dynamic โœ…
mnist_conv_embedding_flat โœ…
v0.1.0
FeedForward MLP in Transformer feed_forward_dynamic โœ…
feed_forward โœ…
v0.1.0
FeedForwardFlatten MLP in Transformer feed_forward_flat_dynamic โœ…
feed_forward_flat โœ…
v0.1.0
GetToken Get the CLS token from the input embedding get_token_dynamic โœ…
get_token โœ…
v0.4.0
GetTokenFlatten Get the CLS token from the input embedding get_token_flat_dynamic โœ…
get_token_flat โœ…
v0.4.0
PatchEmbedding Cutting the image into patches and linearly embedding them. patch_embedding_dynamic โœ…
patch_embedding โœ…
v0.1.0
PatchEmbeddingFlatten Cutting the image into patches and linearly embedding them. patch_embedding_flat_dynamic โœ…
patch_embedding_flat โœ…
v0.1.0
PositionalEmbedding Add positional embedding to the input embedding positional_embedding_dynamic โœ…
positional_embedding โœ…
v0.4.0
PositionalEmbeddingFlatten Add positional embedding to the input embedding positional_embedding_flat_dynamic โœ…
positional_embedding_flat โœ…
v0.4.0
TransformerBlock Transformer from 'Attention Is All You Need.' transformer_block_dynamic โœ…
transformer_block โœ…
v0.1.0
TransformerBlockFlatten Transformer from 'Attention Is All You Need.' transformer_block_flat_dynamic โœ…
transformer_block_flat โœ…
v0.1.0
TransformerStack Stack of Transformer blocks transformer_stack_dynamic โœ…
transformer_stack โœ…
v0.1.0
TransformerStackFlatten Stack of Transformer blocks transformer_stack_flat_dynamic โœ…
transformer_stack_flat โœ…
v0.1.0
VisionTransformer A Vision Transformer (ViT) model for MNIST with configurable embedding type. vit_conv_embedding_dynamic โœ…
vit_conv_embedding โœ…
vit_patch_embedding โœ…
v0.2.0
VisionTransformerFlatten A Vision Transformer (ViT) model for MNIST with configurable embedding type. vit_conv_embedding_flat_dynamic โœ…
vit_conv_embedding_flat โœ…
vit_patch_embedding_flat_dynamic โœ…
vit_patch_embedding_flat โœ…
v0.2.0

๐Ÿ“Œ Dependencies

Versions of Major Dependencies:

Library Versions
JAX 0.6.2
Flax 0.11.1
onnx 1.18.0
onnxruntime 1.22.1

Note: For more details, check pyproject.toml.


โš ๏ธ Limitations

  • Currently not all JAX/Flax components are supported (you can easily help expand this coverage!).
  • Function references need dynamic resolution at call-time.
  • ONNX graph composition is done in-memory before saving to disk, potentially causing memory issues with very large models.

๐Ÿค How to Contribute

We warmly welcome contributions!

How you can help:

  • Add a plugin: Extend jax2onnx by writing a simple Python file in jax2onnx/plugins. a custom primitive or an example.
  • Bug fixes & improvements: PRs and issues are always welcome.

๐Ÿ’พ Installation

Install from PyPI:

pip install jax2onnx  

๐Ÿ“œ License

This project is licensed under the Apache License, Version 2.0. See LICENSE for details.


๐ŸŒŸ Special Thanks

Special thanks for example contributions to @burakssen, @Cadynum, @clementpoiret and @PVirie

Special thanks for plugin contributions to @burakssen, @clementpoiret and @Clouder0

Special thanks to tumaer/JAXFLUIDS for contributing valuable insights rooted in physics simulation use cases.

Special thanks to @lutzroeder for making shapes internal to ONNX function visible in his great Netron viewer.

Special thanks to the community members involved in:

A huge thanks especially to @limarta, whose elegant jaxpr-to-ONNX demonstration significantly inspired this project.


Happy converting! ๐ŸŽ‰

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax2onnx-0.8.0.tar.gz (372.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax2onnx-0.8.0-py3-none-any.whl (542.2 kB view details)

Uploaded Python 3

File details

Details for the file jax2onnx-0.8.0.tar.gz.

File metadata

  • Download URL: jax2onnx-0.8.0.tar.gz
  • Upload date:
  • Size: 372.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.1.3 CPython/3.12.5 Linux/6.6.87.2-microsoft-standard-WSL2

File hashes

Hashes for jax2onnx-0.8.0.tar.gz
Algorithm Hash digest
SHA256 cb7959c573f23a5ec3b3357aba5b1bad165f7de5feb69146de52cac7029f2d16
MD5 315267c0a0ff3e16392f6cbe662db1c3
BLAKE2b-256 b50d787556fc8c77f349d5f00d17171c2657bbd00bcab747ac26f46b7a674307

See more details on using hashes here.

File details

Details for the file jax2onnx-0.8.0-py3-none-any.whl.

File metadata

  • Download URL: jax2onnx-0.8.0-py3-none-any.whl
  • Upload date:
  • Size: 542.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.1.3 CPython/3.12.5 Linux/6.6.87.2-microsoft-standard-WSL2

File hashes

Hashes for jax2onnx-0.8.0-py3-none-any.whl
Algorithm Hash digest
SHA256 b9b8a4a11032625e72d96d664b112b779f54ed22a8453ff79db682b2495fc4cd
MD5 459a84a3d05787ec530525724e999399
BLAKE2b-256 47e0558316216e624add02e1f213851d301f2a98d3ec89ffa1342e9c49736179

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page