Skip to main content

A unified ONNX parsing and pattern-matching library.

Project description

onnx_toolkit

A unified ONNX parsing and pattern-matching library. onnx_toolkit provides two tightly integrated components:

  • ONNXParser / ONNXQuery — load a model and query its graph with a fluent, chainable API
  • Pattern / Pattern.detect — describe structural subgraph patterns and match them against the graph using DFS with commutativity support

Installation

pip install onnx_toolkit

Quick Start

from onnx_toolkit import ONNXParser, Pattern

# Load a model and query it
parser = ONNXParser("model.onnx")

# Find all Conv nodes that have weights, then look at their immediate children
convs = parser.find().find_by_op_type("Conv").has_params()
print(convs)

# Check if a node matches a Swish/SiLU activation pattern
x = Pattern.any()
swish = Pattern.silu(x)
detector = Pattern.detect(parser.model, start_node="MatMul_0")
print(detector.match(swish))  # True or False

Logging

onnx_toolkit uses Python's standard logging module. To enable debug output:

import logging
logging.basicConfig(level=logging.DEBUG)

To enable output only for this library:

logging.getLogger("onnx_toolkit").setLevel(logging.DEBUG)
logging.getLogger("onnx_toolkit").addHandler(logging.StreamHandler())

Sub-loggers

Logger name Covers
onnx_toolkit Top-level / parser messages
onnx_toolkit.query ONNXQuery filter and traversal steps
onnx_toolkit.pattern Pattern construction messages
onnx_toolkit.detect Pattern.detect DFS walk and decisions

ONNXParser

Loads an ONNX model from disk and provides an entry point for graph queries.

parser = ONNXParser("model.onnx")

Methods

find() → ONNXQuery

Returns an ONNXQuery over all nodes in the graph. This is the starting point for every query chain.

query = parser.find()

pattern_detect(pattern, start_node=None, end_node=None) → bool

Convenience shortcut: creates a Pattern.detect bound to this model and immediately calls .match(pattern).

x = Pattern.any()
result = parser.pattern_detect(Pattern.relu(x), start_node="Relu_0")
Parameter Type Description
pattern Pattern The pattern to match
start_node str or NodeProto Seed node (name string or proto object)
end_node str or NodeProto Stop traversal at this node (optional)

summary() → str

Returns a human-readable text summary of the loaded model, including node count, tensor count, graph I/O counts, and a ranked frequency table of op types.

print(parser.summary())
# ONNX model summary
#   Nodes       : 152
#   Tensors     : 104
#   ...

ONNXQuery

A lazy, chainable view over a subset of ONNX graph nodes. Every filter and traversal method returns a new ONNXQuery, so calls can be freely chained without mutating the receiver.

parser.find()
      .find_by_op_type("Conv")
      .has_params()
      .children()
      .find_by_op_type("Relu")

Filter Methods

find_by_op_type(op_type: str) → ONNXQuery

Filters nodes whose op_type matches exactly (case-sensitive).

convs = parser.find().find_by_op_type("Conv")

find_by_name(name: str, *, exact: bool = False) → ONNXQuery

Filters nodes by name. By default performs a case-insensitive substring match. Pass exact=True for strict equality.

# Substring match (default)
nodes = parser.find().find_by_name("block2")

# Exact match
node = parser.find().find_by_name("Conv_14", exact=True)

find_by_tensor(tensor_name: str) → ONNXQuery

Filters nodes that either consume or produce a tensor with the given name.

nodes = parser.find().find_by_tensor("input.1")

find_by_param_name(name: str, *, exact: bool = False) → ONNXQuery

Filters nodes whose weight (initializer) inputs match the given name. Supports substring (default) or exact matching.

# All nodes with a weight whose name contains "bias"
nodes = parser.find().find_by_param_name("bias")

find_by_attribute(attr_name: str, value: Any = None) → ONNXQuery

Filters nodes that have the given ONNX attribute, optionally requiring it to equal a specific value.

# Any node with a "group" attribute
parser.find().find_by_attribute("group")

# Nodes with group=2 (depthwise conv)
parser.find().find_by_attribute("group", value=2)

has_params() → ONNXQuery

Filters to nodes that have at least one weight tensor (initializer input).

weighted = parser.find().find_by_op_type("Conv").has_params()

Traversal Methods

children() → ONNXQuery

Returns all nodes that consume any output produced by the current selection (one hop forward).

next_layer = query.children()

parents() → ONNXQuery

Returns all nodes that produce any input consumed by the current selection (one hop backward).

prev_layer = query.parents()

ancestors(max_depth: int = 100) → ONNXQuery

Returns all transitive parent nodes up to max_depth hops away. The current nodes are not included.

all_predecessors = query.ancestors()
close_predecessors = query.ancestors(max_depth=3)

descendants(max_depth: int = 100) → ONNXQuery

Returns all transitive child nodes up to max_depth hops away. The current nodes are not included.

all_successors = query.descendants()

Entry / Exit Shortcuts

entry_nodes → ONNXQuery (property)

Returns nodes whose inputs include at least one graph-level input.

first_ops = parser.find().entry_nodes

output_nodes → ONNXQuery (property)

Returns nodes whose outputs include at least one graph-level output.

last_ops = parser.find().output_nodes

Tensor / Parameter Access

tensor(name: str = None) → Any

Returns weight tensor(s) for the selected nodes as NumPy arrays.

  • If name is provided: returns that specific tensor (or None if not found).
  • If exactly one node is selected: returns {param_name: array, ...}.
  • If multiple nodes are selected: returns {node_name: {param_name: array}, ...}.
# Get a tensor by name directly
w = query.tensor("conv1.weight")

# Get all weights for a single selected node
params = parser.find().find_by_name("Conv_0", exact=True).tensor()

single_tensor → np.ndarray (property)

Returns the single weight tensor for a single selected node. Raises ValueError if the selection contains more or fewer than one node, or if the node has multiple or no tensors.

weight = parser.find().find_by_name("Conv_0", exact=True).single_tensor

Pattern Matching Integration

matches(pattern: Pattern) → ONNXQuery

Returns the subset of nodes from the current selection whose subgraphs match pattern. This is the preferred way to do bulk pattern matching across many nodes.

x = Pattern.any()
silu_nodes = parser.find().find_by_op_type("Mul").matches(Pattern.silu(x))

Set Operations

union(other: ONNXQuery) → ONNXQuery

Returns all nodes present in self or other, deduplicated.

intersection(other: ONNXQuery) → ONNXQuery

Returns nodes present in both self and other.

difference(other: ONNXQuery) → ONNXQuery

Returns nodes present in self but not in other.

a = parser.find().find_by_op_type("Relu")
b = parser.find().find_by_op_type("Conv").children().find_by_op_type("Relu")

only_in_a = a.difference(b)
all_relus = a.union(b)
shared = a.intersection(b)

Accessors and Helpers

Method / Property Returns Description
count() int Number of nodes in the current selection
is_empty() bool True if the selection is empty
op_types() list[str] Deduplicated list of op types in the selection
first() NodeProto or None First node, or None if empty
last() NodeProto or None Last node, or None if empty
single_node NodeProto The single selected node; raises if count ≠ 1
__iter__ iterator Iterate over selected NodeProto objects
__len__ int Same as count()
__getitem__(i) ONNXQuery Slice or index the selection

Pattern

A lightweight DSL for describing ONNX subgraph structures. Patterns are composable and can be built using arithmetic operators, class methods, and helper constructors.

Named Constructors

Pattern.any() → Pattern

Wildcard that matches any node regardless of op type.

x = Pattern.any()

Pattern.const(value) → Pattern

Matches a constant or initializer tensor whose value is approximately equal to value (tolerance 1e-3).

half = Pattern.const(0.5)

Pattern.op(op_type, *input_patterns) → Pattern

Matches a node with the given op_type, optionally constraining its parent nodes via input_patterns.

erf_of_x = Pattern.op("Erf", Pattern.any())

Patterns created with Pattern.op can also be used as unary functions by calling them:

relu = Pattern.op("Relu")
applied = relu(Pattern.any())   # same as Pattern.op("Relu", Pattern.any())

Note: A pattern can only be called as a unary function if it has no existing inputs. Use Pattern.op(op_type, *inputs) directly for multi-input patterns.

Operator Overloads (DSL Sugar)

You can compose patterns using standard Python operators:

Expression Resulting Pattern
a + b Add(a, b)
a * b Mul(a, b)
a ** b Pow(a, b)
a - b Sub(a, b)
a / b Div(a, b)
-a Neg(a)

Raw Python scalars are automatically coerced to Pattern.const(value):

x = Pattern.any()
cube = x ** 3.0    # equivalent to x ** Pattern.const(3.0)

Built-in Activation Patterns

Commonly used activation functions are available as class methods. All accept a Pattern argument representing the input.

Method Activation
Pattern.relu(x) ReLU
Pattern.relu6(x) ReLU6 (Clip 0–6)
Pattern.sigmoid(x) Sigmoid
Pattern.tanh(x) Tanh
Pattern.leaky_relu(x) LeakyReLU
Pattern.elu(x) ELU
Pattern.selu(x) SELU
Pattern.softplus(x) Softplus
Pattern.softsign(x) Softsign
Pattern.hardsigmoid(x) HardSigmoid
Pattern.hardswish(x) HardSwish (x * σ_hard(x))
Pattern.silu(x) SiLU (x * sigmoid(x))
Pattern.swish(x) Swish (alias of SiLU)
Pattern.gelu(x) GELU (erf approximation)
Pattern.gelu_tanh(x) GELU (tanh approximation)
Pattern.mish(x) Mish (x * tanh(softplus(x)))
Pattern.softmax(x) Softmax
Pattern.log_softmax(x) LogSoftmax
Pattern.prelu(x, slope) PReLU
Pattern.thresholded_relu(x) ThresholdedReLU

Example: Building a Custom Pattern

x = Pattern.any()

# Manual GELU (erf form)
gelu = x * (Pattern.op("Erf")(x / Pattern.const(1.41421356237)) + Pattern.const(1.0)) * Pattern.const(0.5)

# Or use the built-in helper
gelu = Pattern.gelu(x)

Pattern.detect

Matches a Pattern against a specific subgraph rooted at a given node, using DFS with commutativity support for Add and Mul nodes.

detector = Pattern.detect(model, start_node="MatMul_0")
result = detector.match(pattern)

Constructor

Pattern.detect(model, start_node=None, end_node=None)
Parameter Type Description
model ModelProto or _GraphShim The ONNX model to search in
start_node str or NodeProto Root of the subgraph to check (name string or proto object)
end_node str or NodeProto Optional stop condition — branch succeeds when this node is reached

match(pattern: Pattern) → bool

Returns True if the subgraph rooted at start_node matches pattern.

x = Pattern.any()
detector = Pattern.detect(model, start_node="Mul_5")
print(detector.match(Pattern.silu(x)))  # True or False

find_all(pattern: Pattern) → list[NodeProto]

Scans nodes for all subgraph matches against pattern. If start_node was provided, only descendants of that node are scanned; otherwise the entire graph is searched.

detector = Pattern.detect(model)
all_gelus = detector.find_all(Pattern.gelu(Pattern.any()))

Recipes

Find all depthwise convolutions

parser = ONNXParser("model.onnx")
dw_convs = parser.find().find_by_op_type("Conv").find_by_attribute("group", value=-1)
# Note: group=in_channels for depthwise; use find_by_attribute("group") to list all grouped convs

Retrieve weights of a specific layer

conv_node = parser.find().find_by_name("features.0.0", exact=True)
params = conv_node.tensor()
# {"features.0.0.weight": array(...), "features.0.0.bias": array(...)}

Find all SiLU / Swish activations in the graph

x = Pattern.any()
swish_muls = (
    parser.find()
    .find_by_op_type("Mul")
    .matches(Pattern.silu(x))
)
print(f"Found {swish_muls.count()} SiLU/Swish activation(s)")

Walk from graph inputs to the first Conv

entry = parser.find().entry_nodes
first_conv = entry.descendants().find_by_op_type("Conv").first()

Check whether any GELU variants exist

x = Pattern.any()
detector = Pattern.detect(parser.model)

has_gelu     = bool(detector.find_all(Pattern.gelu(x)))
has_gelu_approx = bool(detector.find_all(Pattern.gelu_tanh(x)))

API Reference Summary

ONNXParser

Method Returns
ONNXParser(onnx_path) instance
find() ONNXQuery
pattern_detect(pattern, start_node, end_node) bool
summary() str

ONNXQuery

Method / Property Returns
find_by_op_type(op_type) ONNXQuery
find_by_name(name, exact=False) ONNXQuery
find_by_tensor(tensor_name) ONNXQuery
find_by_param_name(name, exact=False) ONNXQuery
find_by_attribute(attr_name, value=None) ONNXQuery
has_params() ONNXQuery
children() ONNXQuery
parents() ONNXQuery
ancestors(max_depth=100) ONNXQuery
descendants(max_depth=100) ONNXQuery
entry_nodes ONNXQuery
output_nodes ONNXQuery
tensor(name=None) dict / array
single_tensor np.ndarray
matches(pattern) ONNXQuery
union(other) ONNXQuery
intersection(other) ONNXQuery
difference(other) ONNXQuery
count() int
is_empty() bool
op_types() list[str]
first() NodeProto
last() NodeProto
single_node NodeProto

Pattern

Constructor / Method Description
Pattern.any() Wildcard
Pattern.const(value) Constant tensor matcher
Pattern.op(op_type, *inputs) Op-type matcher with optional input constraints
Arithmetic operators (+, *, **, etc.) Compose Add / Mul / Pow / Sub / Div / Neg nodes
Activation helpers (see table above) Pre-built common activation patterns

Pattern.detect

Method Returns
match(pattern) bool
find_all(pattern) list[NodeProto]

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

onnx_toolkit-0.1.0.tar.gz (16.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

onnx_toolkit-0.1.0-py3-none-any.whl (17.5 kB view details)

Uploaded Python 3

File details

Details for the file onnx_toolkit-0.1.0.tar.gz.

File metadata

  • Download URL: onnx_toolkit-0.1.0.tar.gz
  • Upload date:
  • Size: 16.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.10.11 {"installer":{"name":"uv","version":"0.10.11","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for onnx_toolkit-0.1.0.tar.gz
Algorithm Hash digest
SHA256 70ab0c57a2e5e6508cd42d61a72b728bcf8e2cf281c545981a6fffa0e60ab28c
MD5 33fc64b8f5d8cc8f43968c2846e4a8b4
BLAKE2b-256 577747b1305d933f5e7c80502368513ac48adb1b4b51edc63ab1f2371dca9902

See more details on using hashes here.

File details

Details for the file onnx_toolkit-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: onnx_toolkit-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 17.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.10.11 {"installer":{"name":"uv","version":"0.10.11","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for onnx_toolkit-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 5877559063d15b662daf8c97ad4cfde7e04bf03d486b80101fe2aff28c80f162
MD5 da318fc11b3f7af3b5276c86ddae7608
BLAKE2b-256 4b3a8424b495d5e04a01dffc6741d02033c7d4b1a9e587a729728b356de665bb

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page