Skip to main content

Multi-Level Triton Runner supporting Python, IR, PTX, and cubin.

Project description

Multi-Level Triton Runner(Debugging) 🔧

Documentation🔗 triton-runner.org

English | 中文

Triton Runner is a lightweight, multi-level execution engine for OpenAI/Triton, designed to support IR/PTX/cubin launches in complex pass pipelines.

Triton Runner is compatible with Triton v3.5.0, v3.4.0 (primary), v3.3.x, v3.2.0, v3.1.0 or v3.0.0.

Triton Runner supports multi-level debugging across Python/TTIR/TTGIR on Triton v3.4.0.

✨ Features

📦 Installation

Quick Installation

You can install the latest stable release of Triton from pip.

pip install triton-runner

Install from source

You can install from source to access the latest features and developments.

git clone https://github.com/toyaix/triton-runner
cd triton-runner

pip install -e .

🚀 Quick Start

See the provided examples in the triton-runner.org repository for your first run.

I. Multi-Level Runner

All of Triton’s compilation levels are supported by Triton Runner.

---
title: Triton Compilation Pipeline
---
flowchart LR

    subgraph Triton
        A["Python<br>Triton"]:::supported --> B["TTIR<br>Triton IR"]:::supported
        B --> C["TTGIR<br>Triton GPU IR"]:::supported
        C --> D["LLIR<br>LLVM IR"]:::supported

        Gluon["Python<br>Gluon"]:::supported --> C
    end

    subgraph Backend
        D --> E["PTX"]:::supported
        E --> F["cubin<br>CUDA Binary"]:::supported
    end

    classDef supported fill:#AED6F1,stroke:#2E86C1,stroke-width:2px,color:#000000;
    classDef unsupported fill:#F5B7B1,stroke:#C0392B,stroke-width:2px,color:#000000;

TLX (Minimally Invasive Paths to Performance Portability) will be supported in Meta/Triton after it is uploaded to PyPI.

1. Python Runner

You can run your Triton code using @triton_runner.jit instead of @triton.jit. See an example in examples/runner/v3.4.0/python/matmul.py

You can run the example with python examples/runner/v3.4.0/python/matmul.py. After running successfully, you should see output like [Triton Runner] Triton kernel.

If the kernel cache is hit, the following message will be displayed: [Triton Runner] Triton kernel cache hit and saved at. This indicates that the kernel was compiled and cached during a previous run.

2. TTIR Runner

In addition to using @triton_runner.jit instead of @triton.jit, you also need to provide the TTIR file. You can place it in the same directory as the current Python file and use ttir_dir=triton_runner.get_file_dir(__file__). See an example in examples/runner/v3.4.0/ttir/matmul.py. Alternatively, you can use the Triton cache directory generated by the Python runner(previous step).

You can run the example with python examples/runner/v3.4.0/ttir/matmul/matmul.py.

3. TTGIR Runner

TTGIR(Triton GPU IR) is architecture-aware and upwardly compatible. In the .ttgir file, you might see a target annotation like ttg.target = "cuda:90", which specifies the GPU backend.

Similar to the TTIR Runner, you need to provide a .ttgir file and specify its location in the program. See an example in examples/runner/v3.4.0/ttgir/sm90/matmul-with-tma-v4.py.

Because TTGIR is upwardly compatible, you can run the example using the TTGIR Runner with python examples/runner/v3.4.0/ttgir/sm75/matmul.py.

4. LLIR/PTX/cubin Runner

In addition to using @triton_runner.jit instead of @triton.jit, you also need to provide the corresponding file. Like the TTGIR runner, You can place it in the same directory as the current Python file and use ttgir_dir=triton_runner.get_file_dir(__file__). Since all of them are architecture-specific, be sure to use the corresponding metadata JSON file. See an example in examples/runner/v3.4.0/llir/sm90/matmul-with-tma-v4.py.

If your architecture is sm90(Hopper), you can run the example using the TTGIR runner with python examples/runner/v3.4.0/llir/sm90/matmul-with-tma-v4.py.

5. Gluon Runner

Gluon is a GPU programming language based on the same compiler stack as Triton. But unlike Triton, Gluon is a lower-level language that gives the user more control and responsibility when implementing kernels.

Currently, only two cases are supported. Triton v3.5.0 has just been released, please wait for future updates.

python examples/runner/v3.4.0/gluon/01-intro.py
python examples/runner/v3.4.0/gluon/02-layouts.py

6. Hopper Examples

I provide examples for different architectures and Triton versions. Here's example commands for multi-level targeting sm90 (H100, H200, H20, etc.) with Triton v3.4.0.

python examples/runner/v3.4.0/python/matmul-with-tma-v4.py
python examples/runner/v3.4.0/ttir/matmul-with-tma/matmul-with-tma-v4.py
python examples/runner/v3.4.0/ttgir/sm90/matmul-with-tma-v4.py
python examples/runner/v3.4.0/llir/sm90/matmul-with-tma-v4.py
python examples/runner/v3.4.0/ptx/sm90/matmul-with-tma-v4.py
python examples/runner/v3.4.0/cubin/sm90/matmul-with-tma-v4.py
python examples/runner/v3.4.0/gluon/01-intro.py

7. More Architectures Examples

For architecture-specific example commands, please refer to the examples/runner directory:

  • sm90: Hopper (H100, H200, H20, etc.)
  • sm80: Ampere (A100, A30)
  • sm120: Blackwell (RTX PRO 6000, RTX 5090, etc.)
  • sm86: Ampere (A10, RTX 3090, etc.)
  • sm75: Turing (T4, RTX 2080, etc.)

If your GPU does not have one of the above compute capabilities, you can use TRITON_CACHE_DIR=$PWD/.cache to output the Triton cache to the current directory, and use this kernel cache directory to run your program.

8. More Triton Version Examples

Please refer to the appropriate examples directory based on your Triton version:

II. Multi-Level Debugging

Python/TTIR/TTGIR now support debugging on Triton v3.4.0.

---
title: Triton Compilation Pipeline
---
flowchart LR

    subgraph Triton
        A["Python<br>Triton"]:::supported --> B["TTIR<br>Triton IR"]:::supported
        B --> C["TTGIR<br>Triton GPU IR"]:::supported
        C --> D["LLIR<br>LLVM IR"]:::unsupported

        Gluon["Python<br>Gluon"]:::unsupported --> C
    end

    subgraph Backend
        D --> E["PTX"]:::unsupported
        E --> F["cubin<br>CUDA Binary"]:::unsupported
    end

    classDef supported fill:#AED6F1,stroke:#2E86C1,stroke-width:2px,color:#000000;
    classDef unsupported fill:#F5B7B1,stroke:#C0392B,stroke-width:2px,color:#000000;

1. Python Debug

In addition to using @triton_runner.jit instead of @triton.jit, you also need use triton_runner.language.dump() in your Triton kernel. And we allocate a temporary tensor called debug_tensor, and simply pass it to the kernel through the debug_tensor parameter. Here are some example commands for debugging. See more in examples/debugging/README.md.

python examples/debugging/python/01-vec_add/debug_output.py
python examples/debugging/python/03-matrix_multiplication/debug_acc.py
python examples/debugging/python/04-softmax/debug_max_in_loop.py
python examples/debugging/python/05-softmax_lse/debug_log_acc.py
python examples/debugging/python/06-attention/debug_out.py

2. TTIR Debug

Debugging is supported for TTIR ops like tt.load, arith.addf, and tt.trans in Triton v3.4.0. Here are some example commands for debugging. See more in examples/debugging/README.md.

python examples/debugging/ttir/01-vector_add/debug_addf.py
python examples/debugging/ttir/03-matrix_multiplication/debug_acc.py
python examples/debugging/ttir/04-softmax/debug_maxnumf.py
python examples/debugging/ttir/05-softmax_lse/debug_more.py
python examples/debugging/ttir/06-attention/debug_out.py

3. TTGIR Debug

Debugging is supported for TTGIR level like tt.load, arith.addf, and tt.trans in Triton v3.4.0. Here are some example commands for debugging. See more in examples/debugging/README.md.

python examples/debugging/ttgir/01-vec_add/debug_addf.py
python examples/debugging/ttgir/03-matrix_multiplication/debug_acc.py
python examples/debugging/ttgir/04-softmax/debug_maxnumf.py
python examples/debugging/ttgir/05-softmax_lse/debug_more.py
python examples/debugging/ttgir/06-attention/debug_out.py

III. Benchmarks

Benchmarks Referencing TritonBench

  • launch_latency: Measures kernel launch overhead.
  • matmul: Provides a benchmark for matrix multiplication performance.
python benchmark/launch_latency/bench.py

python benchmark/static_shape/matmul.py

IV. Solving Triton Issues

To solve Triton’s performance and shared memory issues as shown in the doc/solving_triton_issues folder, we use the cubin Runner.

📄 License

This project is licensed under the MIT License. See the LICENSE file for more details.

This project includes code from:

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

triton_runner-0.2.7-py3-none-any.whl (34.0 kB view details)

Uploaded Python 3

File details

Details for the file triton_runner-0.2.7-py3-none-any.whl.

File metadata

  • Download URL: triton_runner-0.2.7-py3-none-any.whl
  • Upload date:
  • Size: 34.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.11

File hashes

Hashes for triton_runner-0.2.7-py3-none-any.whl
Algorithm Hash digest
SHA256 01e661fe9de8e8ee69c7ece309552e0b1ea69390b56aba9d1d0c56f6da095258
MD5 15dadbd38722a154f0f6a1a55f28a340
BLAKE2b-256 f853f81b75c5af7b504c30f6ad27950cac9ba43945448f24581a7c108c93ecfb

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page