Skip to main content

Yet another ONNX to PyTorch FX converter

Project description

onnx2fx

Yet another ONNX to PyTorch FX converter.

⚠️ Note: This project is under active development. The public API may change at any time.

onnx2fx converts ONNX models into PyTorch FX GraphModules, enabling seamless integration with PyTorch's ecosystem for optimization, analysis, and deployment.

Features

  • Simple API: Convert ONNX models with a single function call
  • Extensive Operator Support: Wide ONNX operator coverage including standard and Microsoft domain operators
  • Multi-Opset Version Support: Automatic selection of version-specific operator handlers based on model opset
  • Custom Operator Registration: Easily extend support for unsupported or custom ONNX operators
  • PyTorch FX Output: Get a torch.fx.GraphModule for easy inspection, optimization, and compilation
  • Dynamic Shape Support: Handle models with dynamic input dimensions
  • Quantization Support: Support for quantized operators (QLinear*, DequantizeLinear, etc.)
  • Training Support: Convert models to trainable modules with make_trainable() utility

Tested Models

The following models have been tested and verified to work with onnx2fx:

  • PaddleOCRv5: Text detection and recognition models (mobile and server variants)
    • PP-OCRv5_mobile_det, PP-OCRv5_mobile_rec
    • PP-OCRv5_server_det, PP-OCRv5_server_rec
  • TorchVision Models: ResNet, VGG, MobileNet, etc. (via ONNX export)
  • LFM2: Liquid Foundation Model (LFM2-350M-ENJP-MT)
  • LFM2.5: Liquid Foundation Model 2.5
  • TinyLlama: TinyLlama-1.1B-Chat

Installation

Requirements

  • Python >= 3.11
  • PyTorch >= 2.9.0
  • ONNX >= 1.19.1

From Source

git clone https://github.com/mshr-h/onnx2fx.git
cd onnx2fx
uv sync

Development Installation

git clone https://github.com/mshr-h/onnx2fx.git
cd onnx2fx
uv sync --dev

Quick Start

Basic Conversion

import torch
import onnx
from onnx2fx import convert

# Load from file path
fx_module = convert("model.onnx")

# Or from onnx.ModelProto
onnx_model = onnx.load("model.onnx")
fx_module = convert(onnx_model)

# For models with external data, you can pass base_dir.
# memmap_external_data avoids loading external data into memory.
fx_module = convert("model.onnx", base_dir="/path/to/model_dir", memmap_external_data=True)

# Run inference
input_tensor = torch.randn(1, 3, 224, 224)
output = fx_module(input_tensor)

Inspecting the Converted Graph

from onnx2fx import convert

fx_module = convert("model.onnx")

# Print the FX graph
print(fx_module.graph)

# Get the graph code
print(fx_module.code)

Registering Custom Operators

For unsupported or custom ONNX operators, you can register your own handlers:

import torch
from onnx2fx import convert, register_op

# Using decorator
@register_op("MyCustomOp")
def my_custom_op(builder, node):
    x = builder.get_value(node.input[0])
    return builder.call_function(torch.sigmoid, args=(x,))

# Or register directly
def my_handler(builder, node):
    x = builder.get_value(node.input[0])
    return builder.call_function(torch.tanh, args=(x,))

register_op("TanhCustom", my_handler)

# For custom domains (e.g., Microsoft operators)
@register_op("BiasGelu", domain="com.microsoft")
def bias_gelu(builder, node):
    x = builder.get_value(node.input[0])
    bias = builder.get_value(node.input[1])
    return builder.call_function(
        lambda t, b: torch.nn.functional.gelu(t + b),
        args=(x, bias)
    )

Note: ai.onnx.ml is treated as a distinct domain. If you register or query operators in that domain, pass domain="ai.onnx.ml" explicitly.

Multi-Opset Version Support

The library automatically selects the appropriate operator handler based on the model's opset version. For operators with version-specific behavior (e.g., Softmax changed default axis in opset 13), the correct implementation is used automatically:

from onnx2fx import convert

# Models with different opset versions are handled automatically
fx_module_v11 = convert("model_opset11.onnx")  # Uses opset 11 semantics
fx_module_v17 = convert("model_opset17.onnx")  # Uses opset 17 semantics

Training Converted Models

By default, ONNX weights are loaded as non-trainable buffers. Use make_trainable() to enable training:

import torch
from onnx2fx import convert, make_trainable

# Convert and make trainable
fx_module = convert("model.onnx")
fx_module = make_trainable(fx_module)  # Convert buffers to trainable parameters

# Now you can train the model
optimizer = torch.optim.Adam(fx_module.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()

for inputs, targets in dataloader:
    optimizer.zero_grad()
    outputs = fx_module(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()

Querying Supported Operators

from onnx2fx import (
    get_supported_ops,
    get_all_supported_ops,
    get_registered_domains,
    is_supported,
)

# Check if an operator is supported
print(is_supported("Conv"))  # True
print(is_supported("BiasGelu", domain="com.microsoft"))  # True

# Get all operators for a domain
standard_ops = get_supported_ops()  # Default ONNX domain
microsoft_ops = get_supported_ops("com.microsoft")

# Get all operators across all domains
all_ops = get_all_supported_ops()

# Get registered domains
domains = get_registered_domains()  # ['', 'com.microsoft']

Analyzing Model Compatibility

Before converting, you can analyze a model to check operator support:

from onnx2fx import analyze_model

# Analyze an ONNX model
result = analyze_model("model.onnx")

# Check results
print(f"Supported operators: {result.supported_ops}")
print(f"Unsupported operators: {result.unsupported_ops}")
print(f"Is fully supported: {result.is_fully_supported()}")

# Get detailed summary
print(result.summary())

Exception Handling

Handle conversion errors gracefully:

from onnx2fx import (
    convert,
    Onnx2FxError,
    UnsupportedOpError,
    ConversionError,
)

try:
    fx_module = convert("model.onnx")
except UnsupportedOpError as e:
    print(f"Unsupported operator: {e}")
except ConversionError as e:
    print(f"Conversion failed: {e}")
except Onnx2FxError as e:
    print(f"onnx2fx error: {e}")

Supported Operators

Standard ONNX Domain

This is a short list of representative operators. For the full list, call get_supported_ops() or get_all_supported_ops().

  • Core tensor & shape: Reshape, Transpose, Concat, Split, Slice, Gather, Pad, Resize, Shape, Cast
  • Math & activations: Add, Mul, MatMul, Gemm, Relu, Gelu, SiLU, Softmax, LogSoftmax
  • Normalization & pooling: BatchNormalization, LayerNormalization, InstanceNormalization, GroupNormalization, MaxPool, AveragePool, GlobalAveragePool
  • Reductions & indexing: ReduceSum, ReduceMean, ArgMax, ArgMin, TopK
  • Control flow & sequence: If, Loop, SequenceConstruct, SplitToSequence, ConcatFromSequence
  • Quantization: QuantizeLinear, DequantizeLinear, QLinearConv, QLinearMatMul
  • Other: Einsum, NonMaxSuppression, StringNormalizer

Attention & Normalization Extensions

  • Attention (opset 24+)
  • RotaryEmbedding (opset 23+)
  • GroupQueryAttention
  • EmbedLayerNormalization
  • SkipLayerNormalization
  • SimplifiedLayerNormalization
  • SkipSimplifiedLayerNormalization

Microsoft Domain (com.microsoft)

Note: Some operators are available in both the standard and Microsoft domains (e.g., Attention, RotaryEmbedding, SimplifiedLayerNormalization, SkipSimplifiedLayerNormalization, GroupQueryAttention, SkipLayerNormalization, EmbedLayerNormalization).

  • Attention
  • RotaryEmbedding
  • SimplifiedLayerNormalization, SkipSimplifiedLayerNormalization
  • SkipLayerNormalization, EmbedLayerNormalization
  • GroupQueryAttention

API Reference

convert(model)

Converts an ONNX model to a PyTorch FX GraphModule.

Parameters:

  • model (Union[onnx.ModelProto, str]): Either an in-memory onnx.ModelProto or a file path to an ONNX model.

Returns:

  • torch.fx.GraphModule: A PyTorch FX Graph module.

register_op(op_type, handler=None, domain="", since_version=1)

Register a custom ONNX operator handler.

Parameters:

  • op_type (str): The ONNX operator type name.
  • handler (OpHandler, optional): The handler function. If not provided, returns a decorator.
  • domain (str, optional): The ONNX domain. Default is "" (standard ONNX domain).
  • since_version (int, optional): The minimum opset version for this handler. Default is 1.

unregister_op(op_type, domain="", since_version=None)

Unregister an operator handler.

Parameters:

  • op_type (str): The ONNX operator type name.
  • domain (str, optional): The ONNX domain.
  • since_version (int, optional): The specific opset handler to remove. If None, removes all versions.

Returns:

  • bool: True if the operator was unregistered.

is_supported(op_type, domain="")

Check if an operator is supported.

get_supported_ops(domain="")

Get list of supported ONNX operators for a domain.

get_all_supported_ops()

Get all supported operators across all domains.

get_registered_domains()

Get list of registered domains.

analyze_model(model)

Analyze an ONNX model for operator support.

Parameters:

  • model (Union[onnx.ModelProto, str]): Either an in-memory onnx.ModelProto or a file path.

Returns:

  • AnalysisResult: Analysis results with supported/unsupported operators.

AnalysisResult

Dataclass containing model analysis results.

Attributes:

  • total_nodes (int): Total number of nodes in the model graph.
  • unique_ops (Set[Tuple[str, str]]): Set of unique (op_type, domain) tuples.
  • supported_ops (List[Tuple[str, str]]): List of supported (op_type, domain) tuples.
  • unsupported_ops (List[Tuple[str, str, int]]): List of unsupported (op_type, domain, opset_version) tuples.
  • opset_versions (Dict[str, int]): Mapping of domain to opset version.
  • op_counts (Dict[Tuple[str, str], int]): Count of each (op_type, domain) in the model.

Methods:

  • is_fully_supported(): Returns True if all operators are supported.
  • summary(): Returns a human-readable summary string.

Exceptions

  • Onnx2FxError: Base exception for all onnx2fx errors.
  • UnsupportedOpError: Raised when an operator is not supported.
  • ConversionError: Raised when conversion fails.
  • ValueNotFoundError: Raised when a value is not found in the environment.

Development

Running Tests

# Run all tests
pytest

# Run all tests in parallel for faster execution
pytest -n auto

# Run specific test file
pytest tests/test_activation.py

# Skip slow tests
pytest -m "not slow"

Code Formatting

# Format code with ruff
ruff format .

# Check linting
ruff check .

License

This project is licensed under the Apache License 2.0 - see the LICENSE file for details.

Author

Masahiro Hiramori (contact@mshr-h.com)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

onnx2fx-0.0.0.tar.gz (86.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

onnx2fx-0.0.0-py3-none-any.whl (105.9 kB view details)

Uploaded Python 3

File details

Details for the file onnx2fx-0.0.0.tar.gz.

File metadata

  • Download URL: onnx2fx-0.0.0.tar.gz
  • Upload date:
  • Size: 86.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.9.29 {"installer":{"name":"uv","version":"0.9.29","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for onnx2fx-0.0.0.tar.gz
Algorithm Hash digest
SHA256 e86d4fadb18b03dde5e14533accaf5a6c80d59f65da67cfd6f68f4dbfdfdeee0
MD5 67e454ee6522ba5db15ce5414148550a
BLAKE2b-256 dfb37a78b24d857d3fe1297cd05420226d3d232869e5798b5c6e1a584b6fe66f

See more details on using hashes here.

File details

Details for the file onnx2fx-0.0.0-py3-none-any.whl.

File metadata

  • Download URL: onnx2fx-0.0.0-py3-none-any.whl
  • Upload date:
  • Size: 105.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.9.29 {"installer":{"name":"uv","version":"0.9.29","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}

File hashes

Hashes for onnx2fx-0.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 52e759285a4b0c6b0871c16bf050a5d29b3a79031135c77878d0605347eea195
MD5 202d3291a390b091ffe069644e7a2b30
BLAKE2b-256 32ff1a264e29b9350ecb251f2b19a611f2fffb4a5cd73fee1ff166498b6dd406

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page