Skip to main content

Symbolic shape inference for ONNX models

Project description

onnx-shape-inference

PyPI - Version PyPI - Python Version codecov Ruff PyPI Downloads

Experimental symbolic shape inference for ONNX models. Built on top of ONNX IR, this library performs shape inference directly on the IR without serialization overhead, using SymPy for symbolic dimension arithmetic.

Features

  • Symbolic shape inference — propagates shapes through the graph using SymPy expressions for symbolic dimensions
  • Shape data propagation — tracks known element values of shape tensors (e.g. through Shape → Slice → Concat → Reshape chains) to resolve concrete output shapes that standard shape inference cannot
  • Extensible registry — register custom shape inference functions for custom operators
  • Merge policies — choose between strict and permissive shape merging strategies

Installation

pip install onnx-shape-inference

Or install from source (main branch):

pip install git+https://github.com/justinchuby/onnx-shape-inference.git

Command line

Run shape inference on a model and see how many new shapes were inferred:

onnx-shape-inference model.onnx

Save the inferred model to a file:

onnx-shape-inference model.onnx -o model_inferred.onnx

Overwrite the input model in place:

onnx-shape-inference model.onnx --in-place

Select a different merge policy:

onnx-shape-inference model.onnx --policy strict

Usage

import onnx_ir as ir
from onnx_shape_inference import infer_symbolic_shapes

# Load a model
model = ir.load("model.onnx")

# Run shape inference
model = infer_symbolic_shapes(model)

# Or with a strict merge policy
model = infer_symbolic_shapes(model, policy="strict")

Use with onnxscript optimizer

You can run symbolic shape inference on the model to help the optimizer discover more optimization opportunities.

import onnx_shape_inference
import onnx_ir as ir
import onnxscript.optimizer

model = ir.load("model.onnx")

# Provide more shape information with infer_symbolic_shapes
model = onnx_shape_inference.infer_symbolic_shapes(model)

# onnxscript optimizer can leverage this information to better optimize the model
onnxscript.optimizer.optimize(model)

ir.save(model, "model_optimized.onnx")

Per-node inference

You can run shape inference on individual nodes by using the ShapeInferenceContext and registry directly. This is useful for debugging, testing, or integrating into custom graph passes.

import onnx_ir as ir
from onnx_shape_inference import ShapeInferenceContext, registry

# Populate the registry with all built-in ops
registry.collect()

# Create a context with the model's opset imports
ctx = ShapeInferenceContext(opset_imports={"": 21})

# Look up the inference function for the op
infer_func = registry.get("", "Relu", version=21)

# Build a node (or get one from an existing graph)
x = ir.Value(name="x", shape=ir.Shape([2, 3]), type=ir.TensorType(ir.DataType.FLOAT))
y = ir.Value(name="y")
node = ir.Node("", "Relu", inputs=[x], outputs=[y])

# Run inference
infer_func(ctx, node)

print(y.shape)  # [2,3]
print(y.dtype)  # FLOAT

Registering custom operators

from onnx_shape_inference import registry

@registry.register("com.custom", "MyOp", since_version=1)
def infer_my_op(ctx, node):
    input_shape = node.inputs[0].shape
    output_shape = ir.Shape([...])
    ctx.set_shape(node.outputs[0], output_shape)

Shape data propagation (pkg.onnx_shape_inference.sym_data)

Shape inference alone cannot resolve output shapes when ops like Reshape consume non-constant shape tensors that were computed at runtime (e.g. Shape → Slice → Concat → Reshape). The sym_data feature bridges this gap by tracking the known element values of 1-D integer tensors as they flow through the graph.

After inference, each value that carries propagated data has a pkg.onnx_shape_inference.sym_data entry in its metadata_props. You can read it directly or use the SYM_DATA_KEY constant:

import json
import numpy as np
import onnx_ir as ir
from onnx_shape_inference import SYM_DATA_KEY, infer_symbolic_shapes

model = infer_symbolic_shapes(model)

for node in model.graph:
    for value in node.inputs:
        if SYM_DATA_KEY in value.metadata_props:
            text = value.metadata_props[SYM_DATA_KEY]  # e.g. '["N",3,768]'
            elements = json.loads(text)                # ["N", 3, 768]

            # You can create an ir.Shape from it
            shape = ir.Shape(elements)

            # Then you can replace this input with a constant value

When all elements are concrete integers the value is also stored as a constant tensor, so downstream consumers that read constants directly can access it without parsing metadata_props.

Development

pip install pytest parameterized
pip install -e .
pytest

License

onnx-shape-inference is distributed under the terms of the Apache-2.0 license.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

onnx_shape_inference-0.1.7.tar.gz (226.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

onnx_shape_inference-0.1.7-py3-none-any.whl (187.2 kB view details)

Uploaded Python 3

File details

Details for the file onnx_shape_inference-0.1.7.tar.gz.

File metadata

  • Download URL: onnx_shape_inference-0.1.7.tar.gz
  • Upload date:
  • Size: 226.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for onnx_shape_inference-0.1.7.tar.gz
Algorithm Hash digest
SHA256 1280975e493a49682be8b5956efda6c17f55ab4d0690eb2beccc427aea1bdb36
MD5 18f6bb235ddc5f8a77e25422e4d2ecf4
BLAKE2b-256 ab51fa1cf06cdb5be0ef59baa33badeb8752c46dc9f3233e52cfbf75a417ecc7

See more details on using hashes here.

Provenance

The following attestation bundles were made for onnx_shape_inference-0.1.7.tar.gz:

Publisher: main.yml on justinchuby/onnx-shape-inference

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file onnx_shape_inference-0.1.7-py3-none-any.whl.

File metadata

File hashes

Hashes for onnx_shape_inference-0.1.7-py3-none-any.whl
Algorithm Hash digest
SHA256 21c502486703844ffba099eeb9cb4a3f1d3206c77c6ea2033e35fa55dd6e1178
MD5 5a6c8ac5de2b4a052655f0d4bbf0194a
BLAKE2b-256 9deecf656679c980d1624eae7fdf09615429484d0bcd60ef1bff464ef7cc80a3

See more details on using hashes here.

Provenance

The following attestation bundles were made for onnx_shape_inference-0.1.7-py3-none-any.whl:

Publisher: main.yml on justinchuby/onnx-shape-inference

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page