Skip to main content

PyTorch IR extraction framework for compiler backends

Project description

한국어

IR Extraction Framework

PyPI Python License: MIT Docs Publish

A framework for extracting compiler-backend IR (Intermediate Representation) from PyTorch models.

Quick Start

Installation

# Using uv (recommended)
uv sync

# Or using pip
pip install -e .

Basic Usage

import torch
import torch.nn as nn
from torch_ir import extract_ir, ir_to_mermaid

class SimpleMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(4, 8)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(8, 2)

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

# 1. Create model on meta device (no actual weights loaded)
with torch.device('meta'):
    model = SimpleMLP()
model.eval()

# 2. Extract IR
example_inputs = (torch.randn(1, 4, device='meta'),)
ir = extract_ir(model, example_inputs)

# 3. Save IR
ir.save("model_ir.json")

# 4. Visualize IR
print(ir_to_mermaid(ir))

Extracted IR

The IR above produces the following JSON. Each node records its ATen op type, input/output tensor metadata, and producer-consumer relationships — weight values are not included.

{
  "model_name": "SimpleMLP",
  "graph_inputs":  [{"name": "x", "shape": [1, 4], "dtype": "float32"}],
  "graph_outputs": [{"name": "linear_1", "shape": [1, 2], "dtype": "float32"}],
  "weights": [
    {"name": "fc1.weight", "shape": [8, 4], "dtype": "float32"},
    {"name": "fc1.bias",   "shape": [8],    "dtype": "float32"},
    {"name": "fc2.weight", "shape": [2, 8], "dtype": "float32"},
    {"name": "fc2.bias",   "shape": [2],    "dtype": "float32"}
  ],
  "nodes": [
    {
      "name": "linear", "op_type": "aten.linear.default",
      "inputs":  [{"name": "x", "shape": [1, 4]}, {"name": "p_fc1_weight", "shape": [8, 4]}, {"name": "p_fc1_bias", "shape": [8]}],
      "outputs": [{"name": "linear", "shape": [1, 8]}]
    },
    {
      "name": "relu", "op_type": "aten.relu.default",
      "inputs":  [{"name": "linear", "shape": [1, 8]}],
      "outputs": [{"name": "relu", "shape": [1, 8]}]
    },
    {
      "name": "linear_1", "op_type": "aten.linear.default",
      "inputs":  [{"name": "relu", "shape": [1, 8]}, {"name": "p_fc2_weight", "shape": [2, 8]}, {"name": "p_fc2_bias", "shape": [2]}],
      "outputs": [{"name": "linear_1", "shape": [1, 2]}]
    }
  ]
}

IR Visualization

ir_to_mermaid() renders the IR as a Mermaid flowchart. Weight inputs are shown as dashed edges:

flowchart TD
    input_x[/"Input: x<br/>1x4"/]
    op_linear["linear<br/>1x8"]
    input_x -->|"1x4"| op_linear
    w_p_fc1_weight[/"p_fc1_weight<br/>8x4"/]
    w_p_fc1_weight -.->|"8x4"| op_linear
    w_p_fc1_bias[/"p_fc1_bias<br/>8"/]
    w_p_fc1_bias -.->|"8"| op_linear
    op_relu["relu<br/>1x8"]
    op_linear -->|"1x8"| op_relu
    op_linear_1["linear<br/>1x2"]
    op_relu -->|"1x8"| op_linear_1
    w_p_fc2_weight[/"p_fc2_weight<br/>2x8"/]
    w_p_fc2_weight -.->|"2x8"| op_linear_1
    w_p_fc2_bias[/"p_fc2_bias<br/>2"/]
    w_p_fc2_bias -.->|"2"| op_linear_1
    output_0[\"Output<br/>1x2"/]
    op_linear_1 --> output_0

Verification

# Compare original model output with IR execution result
original_model = SimpleMLP()
original_model.load_state_dict(torch.load('weights.pt'))
original_model.eval()

test_input = torch.randn(1, 4)
is_valid, report = verify_ir_with_state_dict(
    ir=ir,
    state_dict=original_model.state_dict(),
    original_model=original_model,
    test_inputs=(test_input,),
)

print(f"Verification: {'PASSED' if is_valid else 'FAILED'}")

Documentation

Dependencies

  • Python >= 3.10
  • PyTorch >= 2.1

Running Tests

# Basic tests
uv run pytest tests/ -v

# Comprehensive tests (all test models)
uv run pytest tests/test_comprehensive.py -v

# Generate reports
uv run pytest tests/test_comprehensive.py --generate-reports --output reports/

# Filter by category
uv run pytest tests/test_comprehensive.py -k "attention" -v

# Run via CLI
uv run python -m tests --output reports/
uv run python -m tests --list-models
uv run python -m tests --category attention

Features

  • Weight-free extraction: Uses meta tensors to extract only graph structure without loading actual weights into memory
  • torch.export based: Uses TorchDynamo-based tracing, the officially recommended PyTorch approach
  • Complete metadata: Automatically extracts shape and dtype information for all tensors
  • IR execution & verification: Execute the extracted IR and verify results match the original model
  • Extensible design: Provides a custom operator registration mechanism

License

MIT License

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch_ir-0.2.0.tar.gz (161.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pytorch_ir-0.2.0-py3-none-any.whl (30.6 kB view details)

Uploaded Python 3

File details

Details for the file pytorch_ir-0.2.0.tar.gz.

File metadata

  • Download URL: pytorch_ir-0.2.0.tar.gz
  • Upload date:
  • Size: 161.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for pytorch_ir-0.2.0.tar.gz
Algorithm Hash digest
SHA256 ed04c4c598da0c205bb0e7f76f1336b73942b1822f734d912b188194f0d9c102
MD5 017392c08cee50a92502a29e0908f8b4
BLAKE2b-256 55f9338876de97a14b9f86410681b4e1eb50bfce4c6751332efa94c8e53726ce

See more details on using hashes here.

Provenance

The following attestation bundles were made for pytorch_ir-0.2.0.tar.gz:

Publisher: publish.yml on sweetcocoa/pytorch-ir

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file pytorch_ir-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: pytorch_ir-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 30.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for pytorch_ir-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 28d76feacd02296cbf6b8eea9052a80c0856c1a436e654a7ebeae1d707178c20
MD5 7b3d6f80dbd763c9646b14386deb12e6
BLAKE2b-256 f2cee6ccd2e05b47b9ac300b582f64366301c260301d8abd78650033c37e71a7

See more details on using hashes here.

Provenance

The following attestation bundles were made for pytorch_ir-0.2.0-py3-none-any.whl:

Publisher: publish.yml on sweetcocoa/pytorch-ir

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page