Triton-based backend for Flash Attention 2
Project description
Flash Attention Triton
This repository provides a wrapper for the Triton implementation of the Flash Attention algorithm with a Flash Attention 2 compatible API. It allows for a drop-in replacement of the original Flash Attention 2 package for supported functionality. This package provides support for Turing (eg. 2080 Ti, T4) GPUs not supported by the original FA2 CUDA package.
Installation
You can install the package directly from GitHub:
pip install git+https://github.com/rationalism/flash-attn-triton.git
Or from PyPI:
pip install flash-attn-triton
Requirements
- PyTorch 2.6 or later
- Triton 3.2 or later
- CUDA-compatible GPU (compute capability 7.5+)
Usage
The API is designed to be compatible with Flash Attention 2. You can use it in the same way:
from flash_attn_triton import flash_attn_func, flash_attn_qkvpacked_func, FlashAttention
# Basic usage
out = flash_attn_func(q, k, v, causal=True)
# Packed QKV
out = flash_attn_qkvpacked_func(qkv, causal=True)
# Module interface
flash_attn = FlashAttention()
out = flash_attn(q, k, v, causal=True)
Currently Supported Features
- Basic attention mechanism (forward and backward)
- FP16 and BF16 (BF16 only on Ampere and above)
- Causal masking
- Softmax scaling
- Basic MQA/GQA support (via tensor repetition)
- Head dims 16, 32, 64, 128
- Ampere, Turing cards
Limitations
This implementation does not currently support:
- Non-causal attention for sequence lengths not divisible by 128
- Dropout (in progress)
- Volta, Pascal, and earlier cards (in progress)
- varlen/unpadded support
- Attention bias
- Sliding window attention
- ALiBi
- KV caching with in-place updates
- Softcapping
- Deterministic backward pass
Benchmarks
RTX 3090 (Ampere)
fused-attention-batch4-head32-d64-fwd-causal=True-dropout=0.0:
N_CTX Triton [FP16]
0 1024.0 48.049147
1 2048.0 61.062769
2 4096.0 68.363188
3 8192.0 70.768167
4 16384.0 72.332634
fused-attention-batch4-head32-d64-fwd-causal=False-dropout=0.0:
N_CTX Triton [FP16]
0 1024.0 60.190653
1 2048.0 71.126662
2 4096.0 69.049310
3 8192.0 74.579215
4 16384.0 73.911621
fused-attention-batch4-head32-d64-bwd-causal=True-dropout=0.0:
N_CTX Triton [FP16]
0 1024.0 33.531732
1 2048.0 40.884683
2 4096.0 45.627974
3 8192.0 47.449394
4 16384.0 48.993511
fused-attention-batch4-head32-d64-bwd-causal=False-dropout=0.0:
N_CTX Triton [FP16]
0 1024.0 42.834959
1 2048.0 46.382862
2 4096.0 49.984253
3 8192.0 51.358497
4 16384.0 49.913040
RTX 2080 Ti (Turing)
fused-attention-batch4-head32-d64-fwd-causal=True-dropout=0.0:
N_CTX Triton [FP16]
0 1024.0 29.258471
1 2048.0 41.382117
2 4096.0 46.972266
3 8192.0 49.315714
4 16384.0 50.443531
fused-attention-batch4-head32-d64-fwd-causal=False-dropout=0.0:
N_CTX Triton [FP16]
0 1024.0 38.110175
1 2048.0 47.640577
2 4096.0 50.301599
3 8192.0 51.136501
4 16384.0 51.826783
fused-attention-batch4-head32-d64-bwd-causal=True-dropout=0.0:
N_CTX Triton [FP16]
0 1024.0 22.085938
1 2048.0 26.173398
2 4096.0 28.565586
3 8192.0 30.030201
4 16384.0 31.082861
fused-attention-batch4-head32-d64-bwd-causal=False-dropout=0.0:
N_CTX Triton [FP16]
0 1024.0 27.756566
1 2048.0 30.274265
2 4096.0 31.471025
3 8192.0 32.253811
4 16384.0 32.614130
Acknowledgements
This implementation is based on the Triton attention implementation from the original Flash Attention 2 repository by TriDao and the Triton tutorial on fused attention.
License
This project is licensed under the BSD 3-Clause License - see the LICENSE file for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file flash_attn_triton-0.1.0.tar.gz.
File metadata
- Download URL: flash_attn_triton-0.1.0.tar.gz
- Upload date:
- Size: 26.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.16
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7a418ff88dbd1dabd9eb1f2734e8b90ea5bd08f40304821203f40aae1621e2a3
|
|
| MD5 |
190f462718469dd33e9768252dbf7132
|
|
| BLAKE2b-256 |
aab6de613da06cbaaff34900a1c3f7297fc7aab55714e2f7634aec86cacaabbf
|
File details
Details for the file flash_attn_triton-0.1.0-py3-none-any.whl.
File metadata
- Download URL: flash_attn_triton-0.1.0-py3-none-any.whl
- Upload date:
- Size: 27.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.16
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
606f9cc02d60c4a6cbb84a0529bf601b09f21c56f4c1d2ecd00fa3e6091776b8
|
|
| MD5 |
65b110a2951b02ad5c3e3ade9ae9de7f
|
|
| BLAKE2b-256 |
e428c1f6249d80e259dc4e1f12e7aa1b2b04c41b0496cbfc4b4b7afccbe83a55
|