Skip to main content

Symbolic shape inference for ONNX models

Project description

onnx-shape-inference

PyPI - Version PyPI - Python Version codecov Ruff PyPI Downloads

Experimental symbolic shape inference for ONNX models. Built on top of ONNX IR, this library performs shape inference directly on the IR without serialization overhead, using SymPy for symbolic dimension arithmetic.

Features

  • Symbolic shape inference — propagates shapes through the graph using SymPy expressions for symbolic dimensions
  • Shape data propagation — tracks known element values of shape tensors (e.g. through Shape → Slice → Concat → Reshape chains) to resolve concrete output shapes that standard shape inference cannot
  • Extensible registry — register custom shape inference functions for custom operators
  • Merge policies — choose between strict and permissive shape merging strategies

Installation

pip install onnx-shape-inference

Or install from source (main branch):

pip install git+https://github.com/justinchuby/onnx-shape-inference.git

Command line

Run shape inference on a model and see how many new shapes were inferred:

onnx-shape-inference model.onnx

Save the inferred model to a file:

onnx-shape-inference model.onnx -o model_inferred.onnx

Overwrite the input model in place:

onnx-shape-inference model.onnx --in-place

Select a different merge policy:

onnx-shape-inference model.onnx --policy strict

Usage

import onnx_ir as ir
from onnx_shape_inference import infer_symbolic_shapes

# Load a model
model = ir.load("model.onnx")

# Run shape inference
model = infer_symbolic_shapes(model)

# Or with a strict merge policy
model = infer_symbolic_shapes(model, policy="strict")

Use with onnxscript optimizer

You can run symbolic shape inference on the model to help the optimizer discover more optimization opportunities.

import onnx_shape_inference
import onnx_ir as ir
import onnxscript.optimizer

model = ir.load("model.onnx")

# Provide more shape information with infer_symbolic_shapes
model = onnx_shape_inference.infer_symbolic_shapes(model)

# onnxscript optimizer can leverage this information to better optimize the model
onnxscript.optimizer.optimize(model)

ir.save(model, "model_optimized.onnx")

Per-node inference

You can run shape inference on individual nodes by using the ShapeInferenceContext and registry directly. This is useful for debugging, testing, or integrating into custom graph passes.

import onnx_ir as ir
from onnx_shape_inference import ShapeInferenceContext, registry

# Populate the registry with all built-in ops
registry.collect()

# Create a context with the model's opset imports
ctx = ShapeInferenceContext(opset_imports={"": 21})

# Look up the inference function for the op
infer_func = registry.get("", "Relu", version=21)

# Build a node (or get one from an existing graph)
x = ir.Value(name="x", shape=ir.Shape([2, 3]), type=ir.TensorType(ir.DataType.FLOAT))
y = ir.Value(name="y")
node = ir.Node("", "Relu", inputs=[x], outputs=[y])

# Run inference
infer_func(ctx, node)

print(y.shape)  # [2,3]
print(y.dtype)  # FLOAT

Registering custom operators

from onnx_shape_inference import registry

@registry.register("com.custom", "MyOp", since_version=1)
def infer_my_op(ctx, node):
    input_shape = node.inputs[0].shape
    output_shape = ir.Shape([...])
    ctx.set_shape(node.outputs[0], output_shape)

Shape data propagation (pkg.onnx_shape_inference.sym_data)

Shape inference alone cannot resolve output shapes when ops like Reshape consume non-constant shape tensors that were computed at runtime (e.g. Shape → Slice → Concat → Reshape). The sym_data feature bridges this gap by tracking the known element values of 1-D integer tensors as they flow through the graph.

After inference, each value that carries propagated data has a pkg.onnx_shape_inference.sym_data entry in its metadata_props. You can read it directly or use the SYM_DATA_KEY constant:

import ast
import numpy as np
import onnx_ir as ir
from onnx_shape_inference import SYM_DATA_KEY

model = infer_symbolic_shapes(model)

for node in model.graph:
    for value in node.inputs:
        if SYM_DATA_KEY in value.metadata_props:
            text = value.metadata_props[SYM_DATA_KEY]  # e.g. '["N",3,768]'
            elements = ast.literal_eval(text)          # ["N", 3, 768]

            # You can create an ir.Shape from it
            shape = ir.Shape(elements)

            # Then you can replace this input with a constant value

When all elements are concrete integers the value is also stored as a constant tensor, so downstream consumers that read constants directly can access it without parsing metadata_props.

Development

pip install pytest parameterized
pip install -e .
pytest

License

onnx-shape-inference is distributed under the terms of the Apache-2.0 license.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

onnx_shape_inference-0.1.6.tar.gz (132.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

onnx_shape_inference-0.1.6-py3-none-any.whl (178.1 kB view details)

Uploaded Python 3

File details

Details for the file onnx_shape_inference-0.1.6.tar.gz.

File metadata

  • Download URL: onnx_shape_inference-0.1.6.tar.gz
  • Upload date:
  • Size: 132.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for onnx_shape_inference-0.1.6.tar.gz
Algorithm Hash digest
SHA256 3d564ec2b8f63193ce45dacf8ac3d2736c47715ad0c50c211d45bd88e0f25373
MD5 47492e1f394370a5684621e1f7a01cd6
BLAKE2b-256 fccc1ba78ad3cb03f86fceda8a21cd1caeee468acab26ec2973d6b220feb1579

See more details on using hashes here.

Provenance

The following attestation bundles were made for onnx_shape_inference-0.1.6.tar.gz:

Publisher: main.yml on justinchuby/onnx-shape-inference

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file onnx_shape_inference-0.1.6-py3-none-any.whl.

File metadata

File hashes

Hashes for onnx_shape_inference-0.1.6-py3-none-any.whl
Algorithm Hash digest
SHA256 363f53823304f5e870e366d64391828e269a4ef44e1d3437013dbbe1dc82f134
MD5 954dea9eee542b76e465f1a2692e0e80
BLAKE2b-256 b56544798975468d06a2e5bb799ab7c0716e1ca032f52685fea4cca6277514ee

See more details on using hashes here.

Provenance

The following attestation bundles were made for onnx_shape_inference-0.1.6-py3-none-any.whl:

Publisher: main.yml on justinchuby/onnx-shape-inference

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page