Skip to main content

A tile level programming language to generate high performance code.

Project description

Tile Language

PyPI version Ask DeepWiki Discord Puzzles

Tile Language (tile-lang) is a concise domain-specific language designed to streamline the development of high-performance GPU/CPU kernels (e.g., GEMM, Dequant GEMM, FlashAttention, LinearAttention). By employing a Pythonic syntax with an underlying compiler infrastructure on top of TVM, tile-lang allows developers to focus on productivity without sacrificing the low-level optimizations necessary for state-of-the-art performance.

Latest News

  • 02/02/2026 🧩: Check out TileLang Puzzles, a fun and interactive way to learn TileLang programming with 10 progressively harder puzzles!
  • 12/18/2025 🚀: Added CuTeDSL backend support, enabling compilation to NVIDIA CUTLASS CuTe DSL! Join us in building and optimizing this exciting new backend: Issue #1454.
  • 12/17/2025 🔬: Integrated Z3 theorem prover into TVM Arith Analyzer, bringing SMT-based symbolic reasoning for enhanced optimizations and automatic correctness verification!
  • 10/31/2025 🔧: Migrated to apache-tvm-ffi, significantly reducing CPU overhead!
  • 10/30/2025 📦: We have released v0.1.6.post2, which is the last version compatible with Python 3.8.
  • 10/07/2025 🍎: Added Apple Metal Device support, check out Pull Request #799 for details.
  • 09/29/2025 🎉: Thrilled to announce that ​​AscendC​​ and ​Ascend​NPU IR​​ backends targeting Huawei Ascend chips are now supported! Check out the preview here: 🔗 link. This includes implementations across two branches: ascendc_pto and npuir. Feel free to explore and share your feedback!
  • 07/04/2025 🚀: Introduced T.gemm_sp for 2:4 sparse tensor core support, check out Pull Request #526 for details.
  • 06/05/2025 ✨: Added NVRTC Backend to significantly reduce compilation time for cute templates!
  • 04/14/2025 🚀: Added high-performance FlashMLA implementation for AMD MI300X, achieving performance parity with hand-optimized assembly kernels of Aiter! See example_mla_amd for details.
  • 03/03/2025 🚀: Added high-performance MLA Decoding support using only 80 lines of Python code, achieving performance on par with FlashMLA on H100 (see example_mla_decode.py)! We also provide documentation explaining how TileLang achieves this.
  • 02/15/2025 ✨: Added WebGPU Codegen support, see Pull Request #86!
  • 02/12/2025 ✨: Excited to announce the release of v0.1.0!
  • 02/10/2025 🚀: Added debug tools for TileLang—T.print for printing variables/buffers (docs) and a memory layout plotter (examples/plot_layout).
  • 01/20/2025 ✨: We are excited to announce that tile-lang, a dsl for high performance AI workloads, is now open source and available to the public!

Tested Devices

Although tile-lang aims to be portable across a range of Devices, it has been specifically tested and validated on the following devices: for NVIDIA GPUs, this includes the H100 (with Auto TMA/WGMMA support), A100, V100, RTX 4090, RTX 3090, and RTX A6000; for AMD GPUs, it includes the MI250 (with Auto MatrixCore support) and the MI300X (with Async Copy support).

OP Implementation Examples

tile-lang provides the building blocks to implement a wide variety of operators. Some examples include:

Within the examples directory, you will also find additional complex kernels—such as convolutions, forward/backward passes for FlashAttention, more operators will continuously be added.

Benchmark Summary

TileLang achieves exceptional performance across a variety of computational patterns. Comprehensive benchmark scripts and settings are available at tilelang-benchmark. Below are selected results showcasing its capabilities:

  • MLA Decoding Performance on H100

    mla decode performance bs64 on H100
    mla decode performance bs128 on H100
  • Flash Attention Performance on H100

    operator performance on H100
  • Matmul Performance on GPUs (RTX 4090, A100, H100, MI300X)

    gemm fp16 performance on Gpus
  • Dequantize Matmul Performance on A100

    dequantize gemv performance on A100

Installation

Method 1: Install with Pip

The quickest way to get started is to install the latest release from PyPI:

pip install tilelang

Alternatively, you can install directly from the GitHub repository:

pip install git+https://github.com/tile-ai/tilelang

Or install locally:

# install required system dependencies
sudo apt-get update
sudo apt-get install -y python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev

pip install -e . -v # remove -e option if you don't want to install in editable mode, -v for verbose output

Method 2: Build from Source

We currently provide three ways to install tile-lang from source:

Method 3: Install with Nightly Version

For users who want access to the latest features and improvements before official releases, we provide nightly builds of tile-lang.

pip install tilelang -f https://tile-ai.github.io/whl/nightly
# or pip install tilelang --find-links https://tile-ai.github.io/whl/nightly

Note: Nightly builds contain the most recent code changes but may be less stable than official releases. They're ideal for testing new features or if you need a specific bugfix that hasn't been released yet.

Quick Start

In this section, you'll learn how to write and execute a straightforward GEMM (matrix multiplication) kernel using tile-lang, followed by techniques for layout optimizations, pipelining, and L2-cache–friendly swizzling.

GEMM Example with Annotations (Layout, L2 Cache Swizzling, and Pipelining, etc.)

Below is an example that demonstrates more advanced features: layout annotation, parallelized copy, and swizzle for improved L2 cache locality. This snippet shows how to adapt your kernel to maximize performance on complex hardware.

# @tilelang.jit(target="cuda")
# target currently can be "cuda" or "hip" or "cpu".
# if not specified, it will be inferred from the input tensors during compile time
@tilelang.jit
def matmul_relu(
    A, B,
    block_M: int = 64,
    block_N: int = 64,
    block_K: int = 64,
    dtype: T.dtype = T.float16,
    accum_dtype: T.dtype = T.float32,
):
    # declare compilation shape constant
    M, N, K = T.const('M, N, K')

    # annotate input tensor shape
    A: T.Tensor[[M, K], dtype]
    B: T.Tensor[[K, N], dtype]

    # allocate output tensor
    C = T.empty([M, N], dtype)

    with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
        A_shared = T.alloc_shared((block_M, block_K), dtype)
        B_shared = T.alloc_shared((block_K, block_N), dtype)
        C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

        # Enable rasterization for better L2 cache locality (Optional)
        # T.use_swizzle(panel_size=10, enable=True)

        # Clear local accumulation
        T.clear(C_local)

        for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
            # Copy tile of A
            # This is a sugar syntax for parallelized copy
            T.copy(A[by * block_M, ko * block_K], A_shared)

            # Copy tile of B
            T.copy(B[ko * block_K, bx * block_N], B_shared)

            # Perform a tile-level GEMM on the shared buffers
            # Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
            T.gemm(A_shared, B_shared, C_local)

        # relu
        for i, j in T.Parallel(block_M, block_N):
            C_local[i, j] = T.max(C_local[i, j], 0)

        # Copy result back to global memory
        T.copy(C_local, C[by * block_M, bx * block_N])

    # You can write multiple cuda kernel in one function, they execute sequentially
    # with T.Kernel(...) as ...

    # Return the tensor, you can also return multiple tensors
    return C


M, N, K = 1024, 1024, 1024

a = torch.randn(M, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, N, device="cuda", dtype=torch.float16)
c_ref = torch.relu(a @ b)

# Call the kernel
c = matmul_relu(a, b)
torch.testing.assert_close(c, c_ref, rtol=1e-2, atol=1e-2)

# Call the kernel with overwritten compilation constants
c = matmul_relu(a, b, block_M=128, block_N=128, block_K=64)
torch.testing.assert_close(c, c_ref, rtol=1e-2, atol=1e-2)

# Retrieve the compiled kernel
kernel = matmul_relu.compile(a, b) # use torch.Tensor
kernel = matmul_relu.compile(      # use T.Tensor as placeholder
  T.Tensor((M, K), T.float16),
  T.Tensor((K, N), T.float16)
)
kernel = matmul_relu.compile(      # directly specify the shape constants
  M=M, N=N, K=K,
  block_M=128, block_N=128, block_K=64
)
print(kernel.get_kernel_source())
c = kernel(a, b)

# 5.Profile latency with kernel
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
latency = profiler.do_bench()
print(f"Latency: {latency} ms")

Dive Deep into TileLang Beyond GEMM

In addition to GEMM, we provide a variety of examples to showcase the versatility and power of TileLang, including:

  • Dequantize GEMM: Achieve high-performance dequantization by fine-grained control over per-thread operations, with many features now adopted as default behaviors in BitBLAS, which utilizing magic layout transformation and intrins to accelerate dequantize gemm.
  • FlashAttention: Enable cross-operator fusion with simple and intuitive syntax, and we also provide an example of auto tuning.
  • LinearAttention: Examples include RetNet and Mamba implementations.
  • Convolution: Implementations of Convolution with IM2Col.

Upcoming Features

Check our tilelang v0.2.0 release plan for upcoming features.


TileLang has now been used in project BitBLAS and AttentionEngine.

Join the Discussion

Welcome to join our Discord community for discussions, support, and collaboration!

Join our Discord

Acknowledgments

We would like to express our gratitude to the TVM community for their invaluable contributions. The initial version of this project was mainly developed by LeiWang1999, chengyupku and nox-410 with supervision from Prof. Zhi Yang at Peking University. Part of this work was carried out during an internship at Microsoft Research, where Dr. Lingxiao Ma, Dr. Yuqing Xia, Dr. Jilong Xue, and Dr. Fan Yang offered valuable advice and support. We deeply appreciate their mentorship and contributions.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tilelang-0.1.8.tar.gz (93.2 MB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

tilelang-0.1.8-cp38-abi3-manylinux_2_34_aarch64.whl (40.4 MB view details)

Uploaded CPython 3.8+manylinux: glibc 2.34+ ARM64

tilelang-0.1.8-cp38-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (43.5 MB view details)

Uploaded CPython 3.8+manylinux: glibc 2.27+ x86-64manylinux: glibc 2.28+ x86-64

tilelang-0.1.8-cp38-abi3-macosx_11_0_arm64.whl (36.0 MB view details)

Uploaded CPython 3.8+macOS 11.0+ ARM64

File details

Details for the file tilelang-0.1.8.tar.gz.

File metadata

  • Download URL: tilelang-0.1.8.tar.gz
  • Upload date:
  • Size: 93.2 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.6

File hashes

Hashes for tilelang-0.1.8.tar.gz
Algorithm Hash digest
SHA256 da967821698eb7a79a76d27fbe25e314a3273f2b12ba4833e981658139d0e6d9
MD5 266d34bc48211e08676352ffb5b853bc
BLAKE2b-256 e6276e363f48f878389078e2899756b8fecc326388b585122fd7f8a86590dfab

See more details on using hashes here.

File details

Details for the file tilelang-0.1.8-cp38-abi3-manylinux_2_34_aarch64.whl.

File metadata

File hashes

Hashes for tilelang-0.1.8-cp38-abi3-manylinux_2_34_aarch64.whl
Algorithm Hash digest
SHA256 bcc2e28202cde516bdd59e1c25b7f6a139d1c52207d92576f4e711c6217e16ba
MD5 35285fd78b81d02c3c1638c19a4d0dfa
BLAKE2b-256 e6dbd130c8db9140bb21a2ef81a455614a4aeec3388088bf9af5df01ad0ba45d

See more details on using hashes here.

File details

Details for the file tilelang-0.1.8-cp38-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for tilelang-0.1.8-cp38-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 5a4018e581f55c852d98a42d3b4acf2dbcfb8b7d8b9156ba7c6b0ab61600a10c
MD5 a2d23dfe5f2a6d6f2c8fc257d6c8233e
BLAKE2b-256 5d0b96ba853aa9e4795020d183e0ca832e9e37d82d4f7f48896241323d1b5ece

See more details on using hashes here.

File details

Details for the file tilelang-0.1.8-cp38-abi3-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for tilelang-0.1.8-cp38-abi3-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 83654bff38448b6b26e143f150c928360619e91bf431289cf6cd74a4b31c7eba
MD5 007036ce121349bba28ebf0957480e53
BLAKE2b-256 de1710ab5c8ccc58783edcc5392ba653f4732702e44a72065224b3d7a4971852

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page