Skip to main content

Symbolic shape inference for ONNX models

Project description

onnx-shape-inference

PyPI - Version PyPI - Python Version codecov Ruff PyPI Downloads

Experimental symbolic shape inference for ONNX models. Built on top of ONNX IR, this library performs shape inference directly on the IR without serialization overhead, using SymPy for symbolic dimension arithmetic.

Features

  • Symbolic shape inference — propagates shapes through the graph using SymPy expressions for symbolic dimensions
  • Shape data propagation — tracks known element values of shape tensors (e.g. through Shape → Slice → Concat → Reshape chains) to resolve concrete output shapes that standard shape inference cannot
  • Extensible registry — register custom shape inference functions for custom operators
  • Merge policies — choose between strict and permissive shape merging strategies

Installation

pip install onnx-shape-inference

Or install from source (main branch):

pip install git+https://github.com/justinchuby/onnx-shape-inference.git

Command line

Run shape inference on a model and see how many new shapes were inferred:

onnx-shape-inference model.onnx

Save the inferred model to a file:

onnx-shape-inference model.onnx -o model_inferred.onnx

Overwrite the input model in place:

onnx-shape-inference model.onnx --in-place

Select a different merge policy:

onnx-shape-inference model.onnx --policy strict

Usage

import onnx_ir as ir
from onnx_shape_inference import infer_symbolic_shapes

# Load a model
model = ir.load("model.onnx")

# Run shape inference
model = infer_symbolic_shapes(model)

# Or with a strict merge policy
model = infer_symbolic_shapes(model, policy="strict")

Use with onnxscript optimizer

You can run symbolic shape inference on the model to help the optimizer discover more optimization opportunities.

import onnx_shape_inference
import onnx_ir as ir
import onnxscript.optimizer

model = ir.load("model.onnx")

# Provide more shape information with infer_symbolic_shapes
model = onnx_shape_inference.infer_symbolic_shapes(model)

# onnxscript optimizer can leverage this information to better optimize the model
onnxscript.optimizer.optimize(model)

ir.save(model, "model_optimized.onnx")

Per-node inference

You can run shape inference on individual nodes by using the ShapeInferenceContext and registry directly. This is useful for debugging, testing, or integrating into custom graph passes.

import onnx_ir as ir
from onnx_shape_inference import ShapeInferenceContext, registry

# Populate the registry with all built-in ops
registry.collect()

# Create a context with the model's opset imports
ctx = ShapeInferenceContext(opset_imports={"": 21})

# Look up the inference function for the op
infer_func = registry.get("", "Relu", version=21)

# Build a node (or get one from an existing graph)
x = ir.Value(name="x", shape=ir.Shape([2, 3]), type=ir.TensorType(ir.DataType.FLOAT))
y = ir.Value(name="y")
node = ir.Node("", "Relu", inputs=[x], outputs=[y])

# Run inference
infer_func(ctx, node)

print(y.shape)  # [2,3]
print(y.dtype)  # FLOAT

Registering custom operators

from onnx_shape_inference import registry

@registry.register("com.custom", "MyOp", since_version=1)
def infer_my_op(ctx, node):
    input_shape = node.inputs[0].shape
    output_shape = ir.Shape([...])
    ctx.set_shape(node.outputs[0], output_shape)

Shape data propagation (pkg.onnx_shape_inference.sym_data)

Shape inference alone cannot resolve output shapes when ops like Reshape consume non-constant shape tensors that were computed at runtime (e.g. Shape → Slice → Concat → Reshape). The sym_data feature bridges this gap by tracking the known element values of 1-D integer tensors as they flow through the graph.

After inference, each value that carries propagated data has a pkg.onnx_shape_inference.sym_data entry in its metadata_props. You can read it directly or use the SYM_DATA_KEY constant:

import json
import numpy as np
import onnx_ir as ir
from onnx_shape_inference import SYM_DATA_KEY, infer_symbolic_shapes

model = infer_symbolic_shapes(model)

for node in model.graph:
    for value in node.inputs:
        if SYM_DATA_KEY in value.metadata_props:
            text = value.metadata_props[SYM_DATA_KEY]  # e.g. '["N",3,768]'
            elements = json.loads(text)                # ["N", 3, 768]

            # You can create an ir.Shape from it
            shape = ir.Shape(elements)

            # Then you can replace this input with a constant value

When all elements are concrete integers the value is also stored as a constant tensor, so downstream consumers that read constants directly can access it without parsing metadata_props.

Development

pip install pytest parameterized
pip install -e .
pytest

License

onnx-shape-inference is distributed under the terms of the Apache-2.0 license.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

onnx_shape_inference-0.1.9.tar.gz (227.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

onnx_shape_inference-0.1.9-py3-none-any.whl (188.0 kB view details)

Uploaded Python 3

File details

Details for the file onnx_shape_inference-0.1.9.tar.gz.

File metadata

  • Download URL: onnx_shape_inference-0.1.9.tar.gz
  • Upload date:
  • Size: 227.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for onnx_shape_inference-0.1.9.tar.gz
Algorithm Hash digest
SHA256 4bb7387e9c6fcc4e988e9196972712231874d547c0fd5828ccff4939105a604c
MD5 b765e562cdc960cae301b6fe1230213c
BLAKE2b-256 50800d0c8eb165272521a207661d0682768d7fb2ca94501c0134da6c1c0ea0c4

See more details on using hashes here.

Provenance

The following attestation bundles were made for onnx_shape_inference-0.1.9.tar.gz:

Publisher: main.yml on justinchuby/onnx-shape-inference

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file onnx_shape_inference-0.1.9-py3-none-any.whl.

File metadata

File hashes

Hashes for onnx_shape_inference-0.1.9-py3-none-any.whl
Algorithm Hash digest
SHA256 fc77ea64e98b9ea31a0b896f800e0f0ce30315efd097d5c396f056b018385814
MD5 f0087ea0046176a5f6c3e53131edd204
BLAKE2b-256 e9b1c917bea2a3331d45891813e01b82a69bbc5a12cef5dfdeabfaa75f283e71

See more details on using hashes here.

Provenance

The following attestation bundles were made for onnx_shape_inference-0.1.9-py3-none-any.whl:

Publisher: main.yml on justinchuby/onnx-shape-inference

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page