Accurate and efficient 8-bit plug-and-play attention.
Project description
SageAttention
This repository provides the official implementation of SageAttention.
SageAttention: Accurate 8-Bit Attention for Plug-and-play Inference Acceleration
Paper: https://arxiv.org/abs/2410.02367
Jintao Zhang, Jia Wei, Haofeng Huang, Pengle Zhang, Jun Zhu, Jianfei Chen
SageAttention2 Technical Report: Accurate 4-Bit Attention for Plug-and-play Inference Acceleration
Paper: https://arxiv.org/abs/2411.10958
Jintao Zhang, Haofeng Huang, Pengle Zhang, Jia Wei, Jun Zhu, Jianfei Chen
SageAttention
SageAttention2
Project Updates
- News 2024-11-19: SageAttention2 will be released soon.
- News 2024-11-12: Support for
sageattn_varlenis available now. - News 2024-11-11: Support for different sequence length between
qandk,v,(batch_size, head_num, seq_len, head_dim)or(batch_size, seq_len, head_num, head_dim)input shapes, andgroup-query attentionis available now.
Base environment
python>=3.9
torch>=2.3.0
triton>=2.3.0
We recommend to install: (the kernel will be faster a little)
python>=3.11
torch>=2.4.0
triton-nightly
Installation
Install using pip:
pip install sageattention
Or compiling from source:
cd sageattention
pip install .
Note: SageAttention is currently optimized for RTX4090 and RTX3090 GPUs. Performance improvements may not be significant on other GPU architectures. We will progressively extend support to other GPUs.
How to use
from sageattention import sageattn
attn_output = sageattn(q, k, v, tensor_layout="HND", is_causal=False, smooth_k=True)
q, k, v are FP16/BF16/FP32 type with the shape (batch_size, head_num, seq_len, head_dim) using default tensor_layout="HND". For shape (batch_size, seq_len, head_num, head_dim), set tensor_layout="NHD". is_causal determines the use of a causal mask. smooth_k is a technique we proposed to ensure the accuracy. Disabling smooth_k might slightly increase speed, but could compromise accuracy if the distribution of q, k, v is irregular. In rare cases, setting smooth_k to False may result in better accuracy.
Note:
sageattnis an accurate implementation that integrating smoothing K, INT8 per-block quantization forq, k, and a FP16 accumulator for Matmul of $PV$. Support forhead_dimvalues of64,96, and128is currently available. Extended support for values 48, 72, and 256 will be available soon. Support for different sequence length betweenqandk,vandgroup-query attentionis available. Support of different sequences length in the same batch is available throughsageattn_varlen.
Plug-and-play Example
We can replace scaled_dot_product_attention easily.
We will take Cogvideo as an example:
Just add the following codes and run!
from sageattention import sageattn
import torch.nn.functional as F
F.scaled_dot_product_attention = sageattn
Specifically,
cd example
python sageattn_cogvideo.py
You can get a lossless video in ./example faster than by using python original_cogvideo.py
Note: Not all models use
F.scaled_dot_product_attention, so maybe you should replace the original Attention by modifying theAttention Classof the target model (as shown in another example in./example).
Performance
Speed of Kernels
Note: The TOPS results refer only to the Attention Kernel, excluding the quantization and smoothing K.
End-to-end performance
Citation
If you use this code or find our work valuable, please cite:
@misc{zhang2024sageattention,
title={SageAttention: Accurate 8-Bit Attention for Plug-and-play Inference Acceleration},
author={Jintao Zhang and Jia wei and Haofeng Huang and Pengle Zhang and Jun Zhu and Jianfei Chen},
year={2024},
eprint={2410.02367},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2410.02367},
}
@misc{zhang2024sageattention2,
title={SageAttention2 Technical Report: Accurate 4 Bit Attention for Plug-and-play Inference Acceleration},
author={Jintao Zhang and Haofeng Huang and Pengle Zhang and Jia Wei and Jun Zhu and Jianfei Chen},
year={2024},
eprint={2411.10958},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2411.10958},
}
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file sageattention-1.0.6.tar.gz.
File metadata
- Download URL: sageattention-1.0.6.tar.gz
- Upload date:
- Size: 12.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.11.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b0398a5877222ee1abaeccd27d1029309c8b57994794a11c1278e0bfd176512c
|
|
| MD5 |
53c4688ce05a3aa44f311052f2c23815
|
|
| BLAKE2b-256 |
a5c90445157fd09c8e5722503588db2a25afe9ab362d927b344dbbe7ee3ce84b
|
File details
Details for the file sageattention-1.0.6-py3-none-any.whl.
File metadata
- Download URL: sageattention-1.0.6-py3-none-any.whl
- Upload date:
- Size: 20.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.11.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
fafc66569bed62a16839e820c2612141b5a20accf55b876d941bab9c0ac5d888
|
|
| MD5 |
0f3980efe3f6bead7652a58f7d8838ef
|
|
| BLAKE2b-256 |
5306f7b47adb766bcb38b3f88763374a3e8dffea05ee9b556bc24dbcbd60fd29
|