Skip to main content

PyTorch IR extraction framework for compiler backends

Project description

한국어

IR Extraction Framework

PyPI Python License: MIT Docs Publish

A framework for extracting compiler-backend IR (Intermediate Representation) from PyTorch models.

Quick Start

Installation

# Using uv (recommended)
uv sync

# Or using pip
pip install -e .

Basic Usage

import torch
import torch.nn as nn
from torch_ir import extract_ir, ir_to_mermaid

class SimpleMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(4, 8)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(8, 2)

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

# 1. Create model on meta device (no actual weights loaded)
with torch.device('meta'):
    model = SimpleMLP()
model.eval()

# 2. Extract IR
example_inputs = (torch.randn(1, 4, device='meta'),)
ir = extract_ir(model, example_inputs)

# 3. Save IR
ir.save("model_ir.json")

# 4. Visualize IR
print(ir_to_mermaid(ir))

Extracted IR

The IR above produces the following JSON. Each node records its ATen op type, input/output tensor metadata, and producer-consumer relationships — weight values are not included.

{
  "model_name": "SimpleMLP",
  "graph_inputs":  [{"name": "x", "shape": [1, 4], "dtype": "float32"}],
  "graph_outputs": [{"name": "linear_1", "shape": [1, 2], "dtype": "float32"}],
  "weights": [
    {"name": "fc1.weight", "shape": [8, 4], "dtype": "float32"},
    {"name": "fc1.bias",   "shape": [8],    "dtype": "float32"},
    {"name": "fc2.weight", "shape": [2, 8], "dtype": "float32"},
    {"name": "fc2.bias",   "shape": [2],    "dtype": "float32"}
  ],
  "nodes": [
    {
      "name": "linear", "op_type": "aten.linear.default",
      "inputs":  [{"name": "x", "shape": [1, 4]}, {"name": "p_fc1_weight", "shape": [8, 4]}, {"name": "p_fc1_bias", "shape": [8]}],
      "outputs": [{"name": "linear", "shape": [1, 8]}]
    },
    {
      "name": "relu", "op_type": "aten.relu.default",
      "inputs":  [{"name": "linear", "shape": [1, 8]}],
      "outputs": [{"name": "relu", "shape": [1, 8]}]
    },
    {
      "name": "linear_1", "op_type": "aten.linear.default",
      "inputs":  [{"name": "relu", "shape": [1, 8]}, {"name": "p_fc2_weight", "shape": [2, 8]}, {"name": "p_fc2_bias", "shape": [2]}],
      "outputs": [{"name": "linear_1", "shape": [1, 2]}]
    }
  ]
}

IR Visualization

ir_to_mermaid() renders the IR as a Mermaid flowchart. Weight inputs are shown as dashed edges:

flowchart TD
    input_x[/"Input: x<br/>1x4"/]
    op_linear["linear<br/>1x8"]
    input_x -->|"1x4"| op_linear
    w_p_fc1_weight[/"p_fc1_weight<br/>8x4"/]
    w_p_fc1_weight -.->|"8x4"| op_linear
    w_p_fc1_bias[/"p_fc1_bias<br/>8"/]
    w_p_fc1_bias -.->|"8"| op_linear
    op_relu["relu<br/>1x8"]
    op_linear -->|"1x8"| op_relu
    op_linear_1["linear<br/>1x2"]
    op_relu -->|"1x8"| op_linear_1
    w_p_fc2_weight[/"p_fc2_weight<br/>2x8"/]
    w_p_fc2_weight -.->|"2x8"| op_linear_1
    w_p_fc2_bias[/"p_fc2_bias<br/>2"/]
    w_p_fc2_bias -.->|"2"| op_linear_1
    output_0[\"Output<br/>1x2"/]
    op_linear_1 --> output_0

Verification

# Compare original model output with IR execution result
original_model = SimpleMLP()
original_model.load_state_dict(torch.load('weights.pt'))
original_model.eval()

test_input = torch.randn(1, 4)
is_valid, report = verify_ir_with_state_dict(
    ir=ir,
    state_dict=original_model.state_dict(),
    original_model=original_model,
    test_inputs=(test_input,),
)

print(f"Verification: {'PASSED' if is_valid else 'FAILED'}")

Documentation

Dependencies

  • Python >= 3.10
  • PyTorch >= 2.1

Running Tests

# Basic tests
uv run pytest tests/ -v

# Comprehensive tests (all test models)
uv run pytest tests/test_comprehensive.py -v

# Generate reports
uv run pytest tests/test_comprehensive.py --generate-reports --output reports/

# Filter by category
uv run pytest tests/test_comprehensive.py -k "attention" -v

# Run via CLI
uv run python -m tests --output reports/
uv run python -m tests --list-models
uv run python -m tests --category attention

Features

  • Weight-free extraction: Uses meta tensors to extract only graph structure without loading actual weights into memory
  • torch.export based: Uses TorchDynamo-based tracing, the officially recommended PyTorch approach
  • Complete metadata: Automatically extracts shape and dtype information for all tensors
  • IR execution & verification: Execute the extracted IR and verify results match the original model
  • Extensible design: Provides a custom operator registration mechanism

License

MIT License

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch_ir-0.1.0.tar.gz (166.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pytorch_ir-0.1.0-py3-none-any.whl (33.2 kB view details)

Uploaded Python 3

File details

Details for the file pytorch_ir-0.1.0.tar.gz.

File metadata

  • Download URL: pytorch_ir-0.1.0.tar.gz
  • Upload date:
  • Size: 166.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for pytorch_ir-0.1.0.tar.gz
Algorithm Hash digest
SHA256 7b2c5da7ccd93dcb3be1f68ef7bde0d5fd81ae733153c58ba9ec4c6ff0e247ad
MD5 6f725f56a4e8a1ec7707997c04317e58
BLAKE2b-256 afe050187180ec3e3a12ae724ab5a5a60236de221600df7d19216bea04dd86c8

See more details on using hashes here.

Provenance

The following attestation bundles were made for pytorch_ir-0.1.0.tar.gz:

Publisher: publish.yml on sweetcocoa/pytorch-ir

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file pytorch_ir-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: pytorch_ir-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 33.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for pytorch_ir-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 b5a1ba69a2cd5704b085f182c0972326f30f64653cfd4605fe3c416c7a4899fc
MD5 f622657049b0ca9b2885961980ffb4ff
BLAKE2b-256 07b568a89c3f9c4c90147fc261e34e891845a9b75d596a7529212e1b1f086e58

See more details on using hashes here.

Provenance

The following attestation bundles were made for pytorch_ir-0.1.0-py3-none-any.whl:

Publisher: publish.yml on sweetcocoa/pytorch-ir

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page