Skip to main content

Multi-Level Triton Runner supporting Python, IR, PTX, and cubin.

Project description

Multi-Level Triton Runner

Runner Docs | Dump Docs | Benchmark Docs | 中文文档 | triton-runner.org

English | 中文

Triton Runner is a lightweight execution and debugging layer for Triton. It lets you launch kernels from multiple compilation stages, inspect intermediate IR, and reuse compiled artifacts directly during performance tuning.

Compatibility summary:

  • Supported Triton versions: v3.0.0 through v3.6.0
  • Primary target: v3.5.x
  • Supported runner inputs: Python Triton, Gluon, TTIR, TTGIR, LLIR, PTX, cubin, AMDGCN, and hsaco
  • Dump support: Python, TTIR, and TTGIR
  • Optional CUDA bridge: TVM-FFI on Triton v3.3+
  • MLIR split output: set MLIR_ENABLE_DUMP=1 to expand all.mlir into per-pass files in the cache directory

✨ Features

📦 Installation

Install from PyPI

pip install triton-runner

Install from source

git clone https://github.com/toyaix/triton-runner
cd triton-runner
pip install -e .

Optional: TVM-FFI

Triton Runner also provides a CUDA/cubin-only bridge to TVM-FFI for Triton v3.3+.

pip install triton-runner[tvm-ffi]

pip install -e .[tvm-ffi]
export TRITON_RUNNER_ENABLE_TVM_FFI=1

🚀 Quick Start

I. Multi-Level Runner

Triton Runner can launch kernels from multiple points in the Triton compilation pipeline.

---
title: Triton Compilation Pipeline
---
flowchart LR

    subgraph Triton
        A["Python<br>Triton"]:::supported --> B["TTIR<br>Triton IR"]:::supported
        B --> C["TTGIR<br>Triton GPU IR"]:::supported
        C --> D["LLIR<br>LLVM IR"]:::supported

        Gluon["Python<br>Gluon"]:::supported --> C
        TLX["Python<br>TLX"]:::supported --> B
    end

    subgraph Backend
        D --> E["PTX"]:::supported
        D --> G["GCN"]:::supported
        E --> F["cubin<br>CUDA Binary"]:::supported
        G --> H["hsaco<br>HIP Binary"]:::supported
    end

    classDef supported fill:#AED6F1,stroke:#2E86C1,stroke-width:2px,color:#000000;
    classDef unsupported fill:#F5B7B1,stroke:#C0392B,stroke-width:2px,color:#000000;

TLX support is available for commit 9a7a23d in examples/runner/tlx/README.md.

1. Python Runner

Triton Runner supports two integration styles for Python kernels, and both are valid:

  1. Replace @triton.jit with @triton_runner.jit
  2. Monkey-patch Triton's decorators and keep using @triton.jit

If the module also uses @triton.autotune, call triton_runner.configure_autotune_backend() when using the monkey-patch style.

import triton_runner

@triton_runner.jit
def kernel(...):
    ...
import triton
import triton.language as tl
import triton_runner

triton_runner.configure_jit_backend()
# Optional when using @triton.autotune
# triton_runner.configure_autotune_backend()

@triton.jit
def kernel(...):
    ...

Examples:

python examples/runner/v3.5.x/python/matmul.py

On success, Triton Runner prints the kernel launch banner. When the kernel cache is reused, it also prints the cache location.

2. TTIR Runner

Provide the .ttir file and point the runner at its directory, typically with ttir_dir=triton_runner.get_file_dir(__file__). See examples/runner/v3.5.x/ttir/matmul/matmul.py.

You can also reuse the Triton cache generated by the Python runner.

python examples/runner/v3.5.x/ttir/matmul/matmul.py

3. TTGIR Runner

TTGIR is architecture-aware. Provide the matching .ttgir file and, when needed, the corresponding metadata JSON. See examples/runner/v3.5.x/ttgir/sm90/matmul-with-tma-v4.py.

If you hit torch.AcceleratorError: CUDA error: an illegal instruction was encountered, the selected TTGIR artifact likely does not match the target GPU, or the metadata JSON is missing.

4. LLIR / PTX / cubin Runner

For LLIR, PTX, and cubin launches, provide the input file plus the matching metadata JSON.

5. Gluon Runner

Gluon uses the same compiler stack as Triton but exposes a lower-level programming model. The repository currently includes two Gluon examples:

python examples/runner/v3.5.x/gluon/01-intro.py
python examples/runner/v3.5.x/gluon/02-layouts.py

6. Architecture Examples

Architecture-specific examples are collected in examples/runner/README.md.

Representative commands:

python examples/runner/v3.5.x/python/matmul-with-tma-v4.py
python examples/runner/v3.5.x/ttgir/sm90/matmul-with-tma-v4.py
python examples/runner/v3.5.x/cubin/sm90/matmul-with-tma-v4.py
python examples/runner/amd/v3.6.0/hsaco/matmul.py

If your GPU does not match one of the bundled examples, set TRITON_CACHE_DIR=$PWD/.cache, compile once on the target machine, and then reuse the generated kernel cache.

7. Versioned Examples

Use the example set that matches your Triton version:

II. Multi-Level Dump

Triton Runner supports dump workflows at the Python, TTIR, and TTGIR levels.

---
title: Triton Dump Coverage
---
flowchart LR

    subgraph Triton
        A["Python<br>Triton"]:::supported --> B["TTIR<br>Triton IR"]:::supported
        B --> C["TTGIR<br>Triton GPU IR"]:::supported
        C --> D["LLIR<br>LLVM IR"]:::unsupported

        Gluon["Python<br>Gluon"]:::unsupported --> C
    end

    subgraph Backend
        D --> E["PTX"]:::unsupported
        E --> F["cubin<br>CUDA Binary"]:::unsupported
    end

    classDef supported fill:#AED6F1,stroke:#2E86C1,stroke-width:2px,color:#000000;
    classDef unsupported fill:#F5B7B1,stroke:#C0392B,stroke-width:2px,color:#000000;

The full dump guide lives in examples/dump/README.md.

1. Python Dump

Inside a Triton kernel, use triton_runner.language.dump() to inspect a block. You can also use triton_runner.language.dump_boundary() for boundary blocks and triton_runner.language.dump_grids() for grid inspection.

Representative examples:

python examples/dump/python/01-vec_add/dump_output.py
python examples/dump/python/03-matrix_multiplication/dump_acc.py
python examples/dump/python/04-softmax/dump_max_in_loop.py
python examples/dump/python/06-attention/dump_out.py

2. TTIR Dump

TTIR dump examples cover common ops such as tt.load, arith.addf, and tt.trans.

python examples/dump/ttir/01-vector_add/dump_addf.py
python examples/dump/ttir/03-matrix_multiplication/dump_acc.py
python examples/dump/ttir/04-softmax/dump_maxnumf.py
python examples/dump/ttir/06-attention/dump_out.py

3. TTGIR Dump

TTGIR dump examples cover the same class of operations at the GPU IR level.

python examples/dump/ttgir/01-vec_add/dump_addf.py
python examples/dump/ttgir/03-matrix_multiplication/dump_acc.py
python examples/dump/ttgir/04-softmax/dump_maxnumf.py
python examples/dump/ttgir/06-attention/dump_out.py

III. Benchmarks

Benchmark examples are under benchmark/README.md. The repository currently includes:

  • launch_latency: kernel launch overhead
  • matmul: matrix multiplication performance
  • flash_attention: attention benchmark cases

Example commands:

python benchmark/launch_latency/bench.py
python benchmark/matmul/mma/bench.py
python benchmark/attn/flash_attention/bench.py

benchmark/launch_latency/bench.py requires Triton v3.3.0+.

IV. Solving Triton Issues

The case studies in doc/solving_triton_issues/README.md show how to reproduce and work around Triton regressions with Triton Runner, especially by reusing cubin artifacts.

Current documented cases include:

⚙️ Environment Variables

Variable Default Description
TRITON_RUNNER_ENABLE_TVM_FFI 0 Enable TVM-FFI CUDA bridge (requires triton-runner[tvm-ffi] and Triton v3.3+)
TRITON_RUNNER_QUIET 0 Suppress verbose kernel cache path output

📄 License

This project is licensed under the MIT License. See LICENSE for details.

This project includes code from:

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

triton_runner-0.3.6-py3-none-any.whl (79.7 kB view details)

Uploaded Python 3

File details

Details for the file triton_runner-0.3.6-py3-none-any.whl.

File metadata

  • Download URL: triton_runner-0.3.6-py3-none-any.whl
  • Upload date:
  • Size: 79.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.12

File hashes

Hashes for triton_runner-0.3.6-py3-none-any.whl
Algorithm Hash digest
SHA256 255715a1e59865b75955cd502ec5b203040e6de396044ad07ba37ca7e7f31464
MD5 09d3ea7585e05b6923c0a8fdab3c0eeb
BLAKE2b-256 f65140e7c5eb97188d52e8b69040161a5fb1bac0220cb6dfebd6d3dda1a3894e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page