Skip to main content

PyTorch IR extraction framework for compiler backends

Project description

한국어

IR Extraction Framework

PyPI Python License: MIT Docs Publish

A framework for extracting compiler-backend IR (Intermediate Representation) from PyTorch models.

Quick Start

Installation

# Using uv (recommended)
uv sync

# Or using pip
pip install -e .

Basic Usage

import torch
import torch.nn as nn
from torch_ir import extract_ir, ir_to_mermaid

class SimpleMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(4, 8)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(8, 2)

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

# 1. Create model on meta device (no actual weights loaded)
with torch.device('meta'):
    model = SimpleMLP()
model.eval()

# 2. Extract IR
example_inputs = (torch.randn(1, 4, device='meta'),)
ir = extract_ir(model, example_inputs)

# 3. Save IR
ir.save("model_ir.json")

# 4. Visualize IR
print(ir_to_mermaid(ir))

Extracted IR

The IR above produces the following JSON. Each node records its ATen op type, input/output tensor metadata, and producer-consumer relationships — weight values are not included.

{
  "model_name": "SimpleMLP",
  "graph_inputs":  [{"name": "x", "shape": [1, 4], "dtype": "float32"}],
  "graph_outputs": [{"name": "linear_1", "shape": [1, 2], "dtype": "float32"}],
  "weights": [
    {"name": "fc1.weight", "shape": [8, 4], "dtype": "float32"},
    {"name": "fc1.bias",   "shape": [8],    "dtype": "float32"},
    {"name": "fc2.weight", "shape": [2, 8], "dtype": "float32"},
    {"name": "fc2.bias",   "shape": [2],    "dtype": "float32"}
  ],
  "nodes": [
    {
      "name": "linear", "op_type": "aten.linear.default",
      "inputs":  [{"name": "x", "shape": [1, 4]}, {"name": "p_fc1_weight", "shape": [8, 4]}, {"name": "p_fc1_bias", "shape": [8]}],
      "outputs": [{"name": "linear", "shape": [1, 8]}]
    },
    {
      "name": "relu", "op_type": "aten.relu.default",
      "inputs":  [{"name": "linear", "shape": [1, 8]}],
      "outputs": [{"name": "relu", "shape": [1, 8]}]
    },
    {
      "name": "linear_1", "op_type": "aten.linear.default",
      "inputs":  [{"name": "relu", "shape": [1, 8]}, {"name": "p_fc2_weight", "shape": [2, 8]}, {"name": "p_fc2_bias", "shape": [2]}],
      "outputs": [{"name": "linear_1", "shape": [1, 2]}]
    }
  ]
}

IR Visualization

ir_to_mermaid() renders the IR as a Mermaid flowchart. Weight inputs are shown as dashed edges:

flowchart TD
    input_x[/"Input: x<br/>1x4"/]
    op_linear["linear<br/>1x8"]
    input_x -->|"1x4"| op_linear
    w_p_fc1_weight[/"p_fc1_weight<br/>8x4"/]
    w_p_fc1_weight -.->|"8x4"| op_linear
    w_p_fc1_bias[/"p_fc1_bias<br/>8"/]
    w_p_fc1_bias -.->|"8"| op_linear
    op_relu["relu<br/>1x8"]
    op_linear -->|"1x8"| op_relu
    op_linear_1["linear<br/>1x2"]
    op_relu -->|"1x8"| op_linear_1
    w_p_fc2_weight[/"p_fc2_weight<br/>2x8"/]
    w_p_fc2_weight -.->|"2x8"| op_linear_1
    w_p_fc2_bias[/"p_fc2_bias<br/>2"/]
    w_p_fc2_bias -.->|"2"| op_linear_1
    output_0[\"Output<br/>1x2"/]
    op_linear_1 --> output_0

Verification

# Compare original model output with IR execution result
original_model = SimpleMLP()
original_model.load_state_dict(torch.load('weights.pt'))
original_model.eval()

test_input = torch.randn(1, 4)
is_valid, report = verify_ir_with_state_dict(
    ir=ir,
    state_dict=original_model.state_dict(),
    original_model=original_model,
    test_inputs=(test_input,),
)

print(f"Verification: {'PASSED' if is_valid else 'FAILED'}")

Documentation

Dependencies

  • Python >= 3.10
  • PyTorch >= 2.1

Running Tests

# Basic tests
uv run pytest tests/ -v

# Comprehensive tests (all test models)
uv run pytest tests/test_comprehensive.py -v

# Generate reports
uv run pytest tests/test_comprehensive.py --generate-reports --output reports/

# Filter by category
uv run pytest tests/test_comprehensive.py -k "attention" -v

# Run via CLI
uv run python -m tests --output reports/
uv run python -m tests --list-models
uv run python -m tests --category attention

Features

  • Weight-free extraction: Uses meta tensors to extract only graph structure without loading actual weights into memory
  • torch.export based: Uses TorchDynamo-based tracing, the officially recommended PyTorch approach
  • Complete metadata: Automatically extracts shape and dtype information for all tensors
  • IR execution & verification: Execute the extracted IR and verify results match the original model
  • Extensible design: Provides a custom operator registration mechanism

License

MIT License

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch_ir-0.2.3.tar.gz (170.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pytorch_ir-0.2.3-py3-none-any.whl (31.0 kB view details)

Uploaded Python 3

File details

Details for the file pytorch_ir-0.2.3.tar.gz.

File metadata

  • Download URL: pytorch_ir-0.2.3.tar.gz
  • Upload date:
  • Size: 170.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for pytorch_ir-0.2.3.tar.gz
Algorithm Hash digest
SHA256 be6409d19c5d9cc3b6e564b29106f1c282b01497555307c89f8a2ce42590f435
MD5 c98e421a7637ee8a57121d9d37aa02d3
BLAKE2b-256 55548624f819c785216d3040e104d381a6cc9401ffda95df1a5d41a66fa16d3b

See more details on using hashes here.

Provenance

The following attestation bundles were made for pytorch_ir-0.2.3.tar.gz:

Publisher: publish.yml on sweetcocoa/pytorch-ir

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file pytorch_ir-0.2.3-py3-none-any.whl.

File metadata

  • Download URL: pytorch_ir-0.2.3-py3-none-any.whl
  • Upload date:
  • Size: 31.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for pytorch_ir-0.2.3-py3-none-any.whl
Algorithm Hash digest
SHA256 a136b8f84afe1b26c7a02021d6e1b9796a6ce054136443004c675c3dc769dbbe
MD5 44184d1cc2bd1c071c819cd5ef922bc5
BLAKE2b-256 5830eb93ef5d2d2a2b1b9dba91c20425453626e07ea7c59cdd0207a7550f0871

See more details on using hashes here.

Provenance

The following attestation bundles were made for pytorch_ir-0.2.3-py3-none-any.whl:

Publisher: publish.yml on sweetcocoa/pytorch-ir

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page