AdaSplash: Efficient Adaptive Sparse Attention in Triton
Project description
AdaSplash: Adaptive Sparse Flash Attention
AdaSplash, aka flash entmax attention, is an efficient adaptive sparse attention mechanism implemented in Triton. Check out our paper: https://arxiv.org/abs/2502.12082.
Installation
You can install AdaSplash via pip:
pip install adasplash
Alternatively, install the latest development version directly from GitHub:
pip install git+https://github.com/deep-spin/adasplash.git
Usage
AdaSplash provides three main functions, all available via from adasplash import ...:
Triton Entmax (Optimized Entmax Activation)
from adasplash import triton_entmax
import torch
x = torch.randn(128, 256).cuda()
y = triton_entmax(x, alpha=1.5, n_iter=10, fast_math=True)
- Uses Halley's method + bisection instead of pure bisection.
- Faster and more efficient than traditional Entmax implementations.
AdaSplash with Block Masking
from adasplash import adasplash
q = torch.randn(1, 8, 128, 64, device="cuda")
k = torch.randn(1, 8, 128, 64, device="cuda")
v = torch.randn(1, 8, 128, 64, device="cuda")
output = adasplash(q, k, v, alpha=1.5, niter=10, is_causal=True, varlen=None)
- Leverages adaptive sparsity for efficiency in both forward and backward passes.
- Requires O(Tr × Tc) bits of extra memory for storing a binary mask per block.
AdaSplash without Block Masking
from adasplash import adasplash_no_block_mask
output = adasplash_no_block_mask(q, k, v, alpha=1.5, niter=10, is_causal=True, varlen=None)
- Does not use block masking but still benefits from tiling and fused ops for efficiency.
- Requires less memory than the block-masked version.
Key Features
Variable Length Sequences:
varlen = torch.tensor([34, 128], device='cuda') # Actual sequence lengths
output = adasplash(q, k, v, varlen=varlen)
Adaptive Sparsity Control:
# Control sparsity via alpha parameter
output = adasplash(q, k, v, alpha=1.333) # More dense
output = adasplash(q, k, v, alpha=2.0) # More sparse
Causal and Non-causal Masking:
output = adasplash(q, k, v, is_causal=True) # Causal masking
output = adasplash(q, k, v, is_causal=False) # Non-causal masking
Benchmark
Testing
To ensure the library works as expected, install the development dependencies and run tests:
pip install -r requirements-dev.txt
pytest
Citation
If you use AdaSplash in your research, please cite:
@article{goncalves2025adasplash,
title={AdaSplash: Adaptive Sparse Flash Attention},
author={Nuno Gonçalves and Marcos Treviso and André F. T. Martins},
journal={arXiv preprint arXiv:2502.12082},
url={https://arxiv.org/abs/2502.12082},
year={2025}
}
License
AdaSplash is licensed under the MIT License. See the LICENSE file for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file adasplash-0.1.1b0.tar.gz.
File metadata
- Download URL: adasplash-0.1.1b0.tar.gz
- Upload date:
- Size: 18.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
812f0aea6fabd89e1d0c3a1581365dc0731c9d2bfa1ca1d5752505b7630a2568
|
|
| MD5 |
8319066247bb9874fdb173dedbf4e49e
|
|
| BLAKE2b-256 |
5b8b5fbf00a165b29689b9d3dcd9db0a339c0fe4484aa7f609fefda8751336ca
|
File details
Details for the file adasplash-0.1.1b0-py3-none-any.whl.
File metadata
- Download URL: adasplash-0.1.1b0-py3-none-any.whl
- Upload date:
- Size: 17.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
cdca0394a53b283e284bebf756497436a48be6b36bd27bbb3d33892ecf23dabb
|
|
| MD5 |
76ab939567b81704b96408f709976c29
|
|
| BLAKE2b-256 |
281dbe105a18e7cd3445334f3a9a69851eb2dcfd47a3c389d8c6b00b579b95d7
|