Symbolic shape inference for ONNX models
Project description
onnx-shape-inference
Experimental symbolic shape inference for ONNX models. Built on top of ONNX IR, this library performs shape inference directly on the IR without serialization overhead, using SymPy for symbolic dimension arithmetic.
Features
- Symbolic shape inference — propagates shapes through the graph using SymPy expressions for symbolic dimensions
- Shape data propagation — tracks known element values of shape tensors (e.g. through
Shape → Slice → Concat → Reshapechains) to resolve concrete output shapes that standard shape inference cannot - Extensible registry — register custom shape inference functions for custom operators
- Merge policies — choose between strict and permissive shape merging strategies
Installation
pip install onnx-shape-inference
Or install from source (main branch):
pip install git+https://github.com/justinchuby/onnx-shape-inference.git
Command line
Run shape inference on a model and see how many new shapes were inferred:
onnx-shape-inference model.onnx
Save the inferred model to a file:
onnx-shape-inference model.onnx -o model_inferred.onnx
Overwrite the input model in place:
onnx-shape-inference model.onnx --in-place
Select a different merge policy:
onnx-shape-inference model.onnx --policy strict
Usage
import onnx_ir as ir
from onnx_shape_inference import infer_symbolic_shapes
# Load a model
model = ir.load("model.onnx")
# Run shape inference
model = infer_symbolic_shapes(model)
# Or with a strict merge policy
model = infer_symbolic_shapes(model, policy="strict")
Use with onnxscript optimizer
You can run symbolic shape inference on the model to help the optimizer discover more optimization opportunities.
import onnx_shape_inference
import onnx_ir as ir
import onnxscript.optimizer
model = ir.load("model.onnx")
# Provide more shape information with infer_symbolic_shapes
model = onnx_shape_inference.infer_symbolic_shapes(model)
# onnxscript optimizer can leverage this information to better optimize the model
onnxscript.optimizer.optimize(model)
ir.save(model, "model_optimized.onnx")
Per-node inference
You can run shape inference on individual nodes by using the
ShapeInferenceContext and registry directly. This is useful for
debugging, testing, or integrating into custom graph passes.
import onnx_ir as ir
from onnx_shape_inference import ShapeInferenceContext, registry
# Populate the registry with all built-in ops
registry.collect()
# Create a context with the model's opset imports
ctx = ShapeInferenceContext(opset_imports={"": 21})
# Look up the inference function for the op
infer_func = registry.get("", "Relu", version=21)
# Build a node (or get one from an existing graph)
x = ir.Value(name="x", shape=ir.Shape([2, 3]), type=ir.TensorType(ir.DataType.FLOAT))
y = ir.Value(name="y")
node = ir.Node("", "Relu", inputs=[x], outputs=[y])
# Run inference
infer_func(ctx, node)
print(y.shape) # [2,3]
print(y.dtype) # FLOAT
Registering custom operators
from onnx_shape_inference import registry
@registry.register("com.custom", "MyOp", since_version=1)
def infer_my_op(ctx, node):
input_shape = node.inputs[0].shape
output_shape = ir.Shape([...])
ctx.set_shape(node.outputs[0], output_shape)
Shape data propagation (pkg.onnx_shape_inference.sym_data)
Shape inference alone cannot resolve output shapes when ops like Reshape consume
non-constant shape tensors that were computed at runtime (e.g. Shape → Slice → Concat → Reshape).
The sym_data feature bridges this gap by tracking the known element values of
1-D integer tensors as they flow through the graph.
After inference, each value that carries propagated data has a
pkg.onnx_shape_inference.sym_data entry in its metadata_props. You can read
it directly or use the SYM_DATA_KEY constant:
import ast
import numpy as np
import onnx_ir as ir
from onnx_shape_inference import SYM_DATA_KEY
model = infer_symbolic_shapes(model)
for node in model.graph:
for value in node.inputs:
if SYM_DATA_KEY in value.metadata_props:
text = value.metadata_props[SYM_DATA_KEY] # e.g. '["N",3,768]'
elements = ast.literal_eval(text) # ["N", 3, 768]
# You can create an ir.Shape from it
shape = ir.Shape(elements)
# Then you can replace this input with a constant value
When all elements are concrete integers the value is also stored as a constant
tensor, so downstream consumers that read constants directly can access it
without parsing metadata_props.
Development
pip install pytest parameterized
pip install -e .
pytest
License
onnx-shape-inference is distributed under the terms of the Apache-2.0 license.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file onnx_shape_inference-0.1.6.tar.gz.
File metadata
- Download URL: onnx_shape_inference-0.1.6.tar.gz
- Upload date:
- Size: 132.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3d564ec2b8f63193ce45dacf8ac3d2736c47715ad0c50c211d45bd88e0f25373
|
|
| MD5 |
47492e1f394370a5684621e1f7a01cd6
|
|
| BLAKE2b-256 |
fccc1ba78ad3cb03f86fceda8a21cd1caeee468acab26ec2973d6b220feb1579
|
Provenance
The following attestation bundles were made for onnx_shape_inference-0.1.6.tar.gz:
Publisher:
main.yml on justinchuby/onnx-shape-inference
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
onnx_shape_inference-0.1.6.tar.gz -
Subject digest:
3d564ec2b8f63193ce45dacf8ac3d2736c47715ad0c50c211d45bd88e0f25373 - Sigstore transparency entry: 952320939
- Sigstore integration time:
-
Permalink:
justinchuby/onnx-shape-inference@278318e62a477527bfc85b6d2f3c097eae6e1ac0 -
Branch / Tag:
refs/tags/v0.1.6 - Owner: https://github.com/justinchuby
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
main.yml@278318e62a477527bfc85b6d2f3c097eae6e1ac0 -
Trigger Event:
push
-
Statement type:
File details
Details for the file onnx_shape_inference-0.1.6-py3-none-any.whl.
File metadata
- Download URL: onnx_shape_inference-0.1.6-py3-none-any.whl
- Upload date:
- Size: 178.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
363f53823304f5e870e366d64391828e269a4ef44e1d3437013dbbe1dc82f134
|
|
| MD5 |
954dea9eee542b76e465f1a2692e0e80
|
|
| BLAKE2b-256 |
b56544798975468d06a2e5bb799ab7c0716e1ca032f52685fea4cca6277514ee
|
Provenance
The following attestation bundles were made for onnx_shape_inference-0.1.6-py3-none-any.whl:
Publisher:
main.yml on justinchuby/onnx-shape-inference
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
onnx_shape_inference-0.1.6-py3-none-any.whl -
Subject digest:
363f53823304f5e870e366d64391828e269a4ef44e1d3437013dbbe1dc82f134 - Sigstore transparency entry: 952320944
- Sigstore integration time:
-
Permalink:
justinchuby/onnx-shape-inference@278318e62a477527bfc85b6d2f3c097eae6e1ac0 -
Branch / Tag:
refs/tags/v0.1.6 - Owner: https://github.com/justinchuby
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
main.yml@278318e62a477527bfc85b6d2f3c097eae6e1ac0 -
Trigger Event:
push
-
Statement type: