Yet another ONNX to PyTorch FX converter
Project description
onnx2fx
Yet another ONNX to PyTorch FX converter.
⚠️ Note: This project is under active development. The public API may change at any time.
onnx2fx converts ONNX models into PyTorch FX GraphModules, enabling seamless integration with PyTorch's ecosystem for optimization, analysis, and deployment.
Features
- Simple API: Convert ONNX models with a single function call
- Extensive Operator Support: Wide ONNX operator coverage including standard and Microsoft domain operators
- Multi-Opset Version Support: Automatic selection of version-specific operator handlers based on model opset
- Custom Operator Registration: Easily extend support for unsupported or custom ONNX operators
- PyTorch FX Output: Get a
torch.fx.GraphModulefor easy inspection, optimization, and compilation - Dynamic Shape Support: Handle models with dynamic input dimensions
- Quantization Support: Support for quantized operators (QLinear*, DequantizeLinear, etc.)
- Training Support: Convert models to trainable modules with
make_trainable()utility
Tested Models
The following models have been tested and verified to work with onnx2fx:
- PaddleOCRv5: Text detection and recognition models (mobile and server variants)
- PP-OCRv5_mobile_det, PP-OCRv5_mobile_rec
- PP-OCRv5_server_det, PP-OCRv5_server_rec
- TorchVision Models: ResNet, VGG, MobileNet, etc. (via ONNX export)
- LFM2: Liquid Foundation Model (LFM2-350M-ENJP-MT)
- LFM2.5: Liquid Foundation Model 2.5
- TinyLlama: TinyLlama-1.1B-Chat
Installation
Requirements
- Python >= 3.11
- PyTorch >= 2.9.0
- ONNX >= 1.19.1
From PyPI
pip install onnx2fx
# or
uv pip install onnx2fx
From Source
git clone https://github.com/mshr-h/onnx2fx.git
cd onnx2fx
uv sync
Development Installation
git clone https://github.com/mshr-h/onnx2fx.git
cd onnx2fx
uv sync --dev
Quick Start
Basic Conversion
import torch
import onnx
from onnx2fx import convert
# Load from file path
fx_module = convert("model.onnx")
# Or from onnx.ModelProto
onnx_model = onnx.load("model.onnx")
fx_module = convert(onnx_model)
# For models with external data, you can pass base_dir.
# memmap_external_data avoids loading external data into memory.
fx_module = convert("model.onnx", base_dir="/path/to/model_dir", memmap_external_data=True)
# Run inference
input_tensor = torch.randn(1, 3, 224, 224)
output = fx_module(input_tensor)
Inspecting the Converted Graph
from onnx2fx import convert
fx_module = convert("model.onnx")
# Print the FX graph
print(fx_module.graph)
# Get the graph code
print(fx_module.code)
Registering Custom Operators
For unsupported or custom ONNX operators, you can register your own handlers:
import torch
from onnx2fx import convert, register_op
# Using decorator
@register_op("MyCustomOp")
def my_custom_op(builder, node):
x = builder.get_value(node.input[0])
return builder.call_function(torch.sigmoid, args=(x,))
# Or register directly
def my_handler(builder, node):
x = builder.get_value(node.input[0])
return builder.call_function(torch.tanh, args=(x,))
register_op("TanhCustom", my_handler)
# For custom domains (e.g., Microsoft operators)
@register_op("BiasGelu", domain="com.microsoft")
def bias_gelu(builder, node):
x = builder.get_value(node.input[0])
bias = builder.get_value(node.input[1])
return builder.call_function(
lambda t, b: torch.nn.functional.gelu(t + b),
args=(x, bias)
)
Note:
ai.onnx.mlis treated as a distinct domain. If you register or query operators in that domain, passdomain="ai.onnx.ml"explicitly.
Multi-Opset Version Support
The library automatically selects the appropriate operator handler based on the model's opset version. For operators with version-specific behavior (e.g., Softmax changed default axis in opset 13), the correct implementation is used automatically:
from onnx2fx import convert
# Models with different opset versions are handled automatically
fx_module_v11 = convert("model_opset11.onnx") # Uses opset 11 semantics
fx_module_v17 = convert("model_opset17.onnx") # Uses opset 17 semantics
Training Converted Models
By default, ONNX weights are loaded as non-trainable buffers. Use make_trainable() to enable training:
import torch
from onnx2fx import convert, make_trainable
# Convert and make trainable
fx_module = convert("model.onnx")
fx_module = make_trainable(fx_module) # Convert buffers to trainable parameters
# Now you can train the model
optimizer = torch.optim.Adam(fx_module.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()
for inputs, targets in dataloader:
optimizer.zero_grad()
outputs = fx_module(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
Querying Supported Operators
from onnx2fx import (
get_supported_ops,
get_all_supported_ops,
get_registered_domains,
is_supported,
)
# Check if an operator is supported
print(is_supported("Conv")) # True
print(is_supported("BiasGelu", domain="com.microsoft")) # True
# Get all operators for a domain
standard_ops = get_supported_ops() # Default ONNX domain
microsoft_ops = get_supported_ops("com.microsoft")
# Get all operators across all domains
all_ops = get_all_supported_ops()
# Get registered domains
domains = get_registered_domains() # ['', 'com.microsoft']
Analyzing Model Compatibility
Before converting, you can analyze a model to check operator support:
from onnx2fx import analyze_model
# Analyze an ONNX model
result = analyze_model("model.onnx")
# Check results
print(f"Supported operators: {result.supported_ops}")
print(f"Unsupported operators: {result.unsupported_ops}")
print(f"Is fully supported: {result.is_fully_supported()}")
# Get detailed summary
print(result.summary())
Exception Handling
Handle conversion errors gracefully:
from onnx2fx import (
convert,
Onnx2FxError,
UnsupportedOpError,
ConversionError,
)
try:
fx_module = convert("model.onnx")
except UnsupportedOpError as e:
print(f"Unsupported operator: {e}")
except ConversionError as e:
print(f"Conversion failed: {e}")
except Onnx2FxError as e:
print(f"onnx2fx error: {e}")
Supported Operators
Standard ONNX Domain
This is a short list of representative operators. For the full list, call
get_supported_ops() or get_all_supported_ops().
- Core tensor & shape: Reshape, Transpose, Concat, Split, Slice, Gather, Pad, Resize, Shape, Cast
- Math & activations: Add, Mul, MatMul, Gemm, Relu, Gelu, SiLU, Softmax, LogSoftmax
- Normalization & pooling: BatchNormalization, LayerNormalization, InstanceNormalization, GroupNormalization, MaxPool, AveragePool, GlobalAveragePool
- Reductions & indexing: ReduceSum, ReduceMean, ArgMax, ArgMin, TopK
- Control flow & sequence: If, Loop, SequenceConstruct, SplitToSequence, ConcatFromSequence
- Quantization: QuantizeLinear, DequantizeLinear, QLinearConv, QLinearMatMul
- Other: Einsum, NonMaxSuppression, StringNormalizer
Attention & Normalization Extensions
- Attention (opset 24+)
- RotaryEmbedding (opset 23+)
- GroupQueryAttention
- EmbedLayerNormalization
- SkipLayerNormalization
- SimplifiedLayerNormalization
- SkipSimplifiedLayerNormalization
Microsoft Domain (com.microsoft)
Note: Some operators are available in both the standard and Microsoft domains (e.g., Attention, RotaryEmbedding, SimplifiedLayerNormalization, SkipSimplifiedLayerNormalization, GroupQueryAttention, SkipLayerNormalization, EmbedLayerNormalization).
- Attention
- RotaryEmbedding
- SimplifiedLayerNormalization, SkipSimplifiedLayerNormalization
- SkipLayerNormalization, EmbedLayerNormalization
- GroupQueryAttention
API Reference
convert(model)
Converts an ONNX model to a PyTorch FX GraphModule.
Parameters:
model(Union[onnx.ModelProto, str]): Either an in-memoryonnx.ModelProtoor a file path to an ONNX model.
Returns:
torch.fx.GraphModule: A PyTorch FX Graph module.
register_op(op_type, handler=None, domain="", since_version=1)
Register a custom ONNX operator handler.
Parameters:
op_type(str): The ONNX operator type name.handler(OpHandler, optional): The handler function. If not provided, returns a decorator.domain(str, optional): The ONNX domain. Default is "" (standard ONNX domain).since_version(int, optional): The minimum opset version for this handler. Default is 1.
unregister_op(op_type, domain="", since_version=None)
Unregister an operator handler.
Parameters:
op_type(str): The ONNX operator type name.domain(str, optional): The ONNX domain.since_version(int, optional): The specific opset handler to remove. If None, removes all versions.
Returns:
bool: True if the operator was unregistered.
is_supported(op_type, domain="")
Check if an operator is supported.
get_supported_ops(domain="")
Get list of supported ONNX operators for a domain.
get_all_supported_ops()
Get all supported operators across all domains.
get_registered_domains()
Get list of registered domains.
analyze_model(model)
Analyze an ONNX model for operator support.
Parameters:
model(Union[onnx.ModelProto, str]): Either an in-memoryonnx.ModelProtoor a file path.
Returns:
AnalysisResult: Analysis results with supported/unsupported operators.
AnalysisResult
Dataclass containing model analysis results.
Attributes:
total_nodes(int): Total number of nodes in the model graph.unique_ops(Set[Tuple[str, str]]): Set of unique (op_type, domain) tuples.supported_ops(List[Tuple[str, str]]): List of supported (op_type, domain) tuples.unsupported_ops(List[Tuple[str, str, int]]): List of unsupported (op_type, domain, opset_version) tuples.opset_versions(Dict[str, int]): Mapping of domain to opset version.op_counts(Dict[Tuple[str, str], int]): Count of each (op_type, domain) in the model.
Methods:
is_fully_supported(): ReturnsTrueif all operators are supported.summary(): Returns a human-readable summary string.
Exceptions
Onnx2FxError: Base exception for all onnx2fx errors.UnsupportedOpError: Raised when an operator is not supported.ConversionError: Raised when conversion fails.ValueNotFoundError: Raised when a value is not found in the environment.
Development
Running Tests
# Run all tests
pytest
# Run all tests in parallel for faster execution
pytest -n auto
# Run specific test file
pytest tests/test_activation.py
# Skip slow tests
pytest -m "not slow"
Code Formatting
# Format code with ruff
ruff format .
# Check linting
ruff check .
License
This project is licensed under the Apache License 2.0 - see the LICENSE file for details.
Author
Masahiro Hiramori (contact@mshr-h.com)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file onnx2fx-0.0.1.tar.gz.
File metadata
- Download URL: onnx2fx-0.0.1.tar.gz
- Upload date:
- Size: 220.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.9.29 {"installer":{"name":"uv","version":"0.9.29","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b6822beaaf6315bf1882a5aab22e8eb5914d673fe0462cfcbf5fab4112df60e5
|
|
| MD5 |
fcfc58a77e58aed1dc8e9b94d3081d9c
|
|
| BLAKE2b-256 |
1efdd77161eef4b1c05a9f2eb21159cb8faa7d958eb4f3e352e90502137ec0f6
|
File details
Details for the file onnx2fx-0.0.1-py3-none-any.whl.
File metadata
- Download URL: onnx2fx-0.0.1-py3-none-any.whl
- Upload date:
- Size: 109.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.9.29 {"installer":{"name":"uv","version":"0.9.29","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":true}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c8c668848e5c7964bc8cad8aa3c6548c4bf0b3a5fec4024f318235613f97e709
|
|
| MD5 |
54f739060b222b79645ba999262bfe53
|
|
| BLAKE2b-256 |
3f8672b276e20cbbb000c317a20b835d02aec580473fb753cce81285c3d0b476
|