Skip to main content

PyTorch IR extraction framework for compiler backends

Project description

한국어

IR Extraction Framework

PyPI Python License: MIT Docs Publish

A framework for extracting compiler-backend IR (Intermediate Representation) from PyTorch models.

Quick Start

Installation

# Using uv (recommended)
uv sync

# Or using pip
pip install -e .

Basic Usage

import torch
import torch.nn as nn
from torch_ir import extract_ir, ir_to_mermaid

class SimpleMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(4, 8)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(8, 2)

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

# 1. Create model on meta device (no actual weights loaded)
with torch.device('meta'):
    model = SimpleMLP()
model.eval()

# 2. Extract IR
example_inputs = (torch.randn(1, 4, device='meta'),)
ir = extract_ir(model, example_inputs)

# 3. Save IR
ir.save("model_ir.json")

# 4. Visualize IR
print(ir_to_mermaid(ir))

Extracted IR

The IR above produces the following JSON. Each node records its ATen op type, input/output tensor metadata, and producer-consumer relationships — weight values are not included.

{
  "model_name": "SimpleMLP",
  "graph_inputs":  [{"name": "x", "shape": [1, 4], "dtype": "float32"}],
  "graph_outputs": [{"name": "linear_1", "shape": [1, 2], "dtype": "float32"}],
  "weights": [
    {"name": "fc1.weight", "shape": [8, 4], "dtype": "float32"},
    {"name": "fc1.bias",   "shape": [8],    "dtype": "float32"},
    {"name": "fc2.weight", "shape": [2, 8], "dtype": "float32"},
    {"name": "fc2.bias",   "shape": [2],    "dtype": "float32"}
  ],
  "nodes": [
    {
      "name": "linear", "op_type": "aten.linear.default",
      "inputs":  [{"name": "x", "shape": [1, 4]}, {"name": "p_fc1_weight", "shape": [8, 4]}, {"name": "p_fc1_bias", "shape": [8]}],
      "outputs": [{"name": "linear", "shape": [1, 8]}]
    },
    {
      "name": "relu", "op_type": "aten.relu.default",
      "inputs":  [{"name": "linear", "shape": [1, 8]}],
      "outputs": [{"name": "relu", "shape": [1, 8]}]
    },
    {
      "name": "linear_1", "op_type": "aten.linear.default",
      "inputs":  [{"name": "relu", "shape": [1, 8]}, {"name": "p_fc2_weight", "shape": [2, 8]}, {"name": "p_fc2_bias", "shape": [2]}],
      "outputs": [{"name": "linear_1", "shape": [1, 2]}]
    }
  ]
}

IR Visualization

ir_to_mermaid() renders the IR as a Mermaid flowchart. Weight inputs are shown as dashed edges:

flowchart TD
    input_x[/"Input: x<br/>1x4"/]
    op_linear["linear<br/>1x8"]
    input_x -->|"1x4"| op_linear
    w_p_fc1_weight[/"p_fc1_weight<br/>8x4"/]
    w_p_fc1_weight -.->|"8x4"| op_linear
    w_p_fc1_bias[/"p_fc1_bias<br/>8"/]
    w_p_fc1_bias -.->|"8"| op_linear
    op_relu["relu<br/>1x8"]
    op_linear -->|"1x8"| op_relu
    op_linear_1["linear<br/>1x2"]
    op_relu -->|"1x8"| op_linear_1
    w_p_fc2_weight[/"p_fc2_weight<br/>2x8"/]
    w_p_fc2_weight -.->|"2x8"| op_linear_1
    w_p_fc2_bias[/"p_fc2_bias<br/>2"/]
    w_p_fc2_bias -.->|"2"| op_linear_1
    output_0[\"Output<br/>1x2"/]
    op_linear_1 --> output_0

Verification

# Compare original model output with IR execution result
original_model = SimpleMLP()
original_model.load_state_dict(torch.load('weights.pt'))
original_model.eval()

test_input = torch.randn(1, 4)
is_valid, report = verify_ir_with_state_dict(
    ir=ir,
    state_dict=original_model.state_dict(),
    original_model=original_model,
    test_inputs=(test_input,),
)

print(f"Verification: {'PASSED' if is_valid else 'FAILED'}")

Documentation

Dependencies

  • Python >= 3.10
  • PyTorch >= 2.1

Running Tests

# Basic tests
uv run pytest tests/ -v

# Comprehensive tests (all test models)
uv run pytest tests/test_comprehensive.py -v

# Generate reports
uv run pytest tests/test_comprehensive.py --generate-reports --output reports/

# Filter by category
uv run pytest tests/test_comprehensive.py -k "attention" -v

# Run via CLI
uv run python -m tests --output reports/
uv run python -m tests --list-models
uv run python -m tests --category attention

Features

  • Weight-free extraction: Uses meta tensors to extract only graph structure without loading actual weights into memory
  • torch.export based: Uses TorchDynamo-based tracing, the officially recommended PyTorch approach
  • Complete metadata: Automatically extracts shape and dtype information for all tensors
  • IR execution & verification: Execute the extracted IR and verify results match the original model
  • Extensible design: Provides a custom operator registration mechanism

License

MIT License

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch_ir-0.2.2.tar.gz (170.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pytorch_ir-0.2.2-py3-none-any.whl (31.0 kB view details)

Uploaded Python 3

File details

Details for the file pytorch_ir-0.2.2.tar.gz.

File metadata

  • Download URL: pytorch_ir-0.2.2.tar.gz
  • Upload date:
  • Size: 170.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for pytorch_ir-0.2.2.tar.gz
Algorithm Hash digest
SHA256 46e6faaf61e7e988d13b6ff79b3328e752d575abfd07373772cd6cecd2c4b7f0
MD5 3a136c556e0b9ef7aa5512eb4f6b10ac
BLAKE2b-256 22dc5ac8393dab10bb153207f0bce818e74cb836218b4c0c3359e7ec7cb2e8b4

See more details on using hashes here.

Provenance

The following attestation bundles were made for pytorch_ir-0.2.2.tar.gz:

Publisher: publish.yml on sweetcocoa/pytorch-ir

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file pytorch_ir-0.2.2-py3-none-any.whl.

File metadata

  • Download URL: pytorch_ir-0.2.2-py3-none-any.whl
  • Upload date:
  • Size: 31.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for pytorch_ir-0.2.2-py3-none-any.whl
Algorithm Hash digest
SHA256 c2e772a0d488a30a34369b4f086594dee9304aedfe569bc077d15fb3040e7e65
MD5 1a73b5458b449b9d9d213f29fd757f0a
BLAKE2b-256 d614ad7517882c5365e098d3ae68e88833623d47b899546549e39f222cf6619e

See more details on using hashes here.

Provenance

The following attestation bundles were made for pytorch_ir-0.2.2-py3-none-any.whl:

Publisher: publish.yml on sweetcocoa/pytorch-ir

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page