AdaSplash: Efficient Adaptive Sparse Attention in Triton
Project description
AdaSplash: Adaptive Sparse Flash Attention
AdaSplash, aka flash entmax attention, is an efficient adaptive sparse attention mechanism implemented in Triton. Check out the AdaSplash paper: https://arxiv.org/abs/2502.12082. AdaSplash-2 is based on the follow-up paper: https://arxiv.org/abs/2604.15180.
Installation
You can install AdaSplash via pip:
pip install adasplash
Alternatively, install the latest development version directly from GitHub:
pip install git+https://github.com/deep-spin/adasplash.git
Usage
AdaSplash exposes v2 as the default direction while preserving explicit v1 entry points.
All functions are available via from adasplash import ....
Triton Entmax (Optimized Entmax Activation)
from adasplash import triton_entmax, triton_entmax_v1
import torch
x = torch.randn(128, 256).cuda()
y = triton_entmax(x, alpha=1.5, n_iter=2, use_histogram=True)
y_v1 = triton_entmax_v1(x, alpha=1.5, n_iter=10, fast_math=True)
triton_entmaxuses the v2 histogram initialization and hybrid solver.triton_entmax_v1preserves the original Halley/bisection implementation.- Faster and more efficient than traditional Entmax implementations.
AdaSplash
from adasplash import adasplash, adasplash_v1, adasplash_v2
q = torch.randn(1, 8, 128, 64, device="cuda")
k = torch.randn(1, 8, 128, 64, device="cuda")
v = torch.randn(1, 8, 128, 64, device="cuda")
output = adasplash(q, k, v, is_causal=True, varlen=None)
output_v1 = adasplash_v1(q, k, v, alpha=1.5, niter=10, is_causal=True, varlen=None)
output_v2 = adasplash_v2(q, k, v, niter=1, varlen=None)
adasplashdispatches supported causal alpha=1.5 calls to AdaSplash-2.- Calls requesting v1-only behavior, such as
alpha != 1.5oris_causal=False, fall back to v1. adasplash_v1andadasplash_v2provide explicit behavior.
AdaSplash without Block Masking
from adasplash import adasplash_no_block_mask
output = adasplash_no_block_mask(q, k, v, alpha=1.5, niter=10, is_causal=True, varlen=None)
- Does not use block masking but still benefits from tiling and fused ops for efficiency.
- Requires less memory than the block-masked version.
Dense Entmax Attention Utility
from adasplash import entmax_attention
output = entmax_attention(q, k, v, is_causal=True, padding="right", varlen=None)
- Uses the public v2
triton_entmaxactivation over dense attention scores. - Supports causal masking, variable lengths, left/right padding, and ALiBi slopes.
Key Features
Variable Length Sequences:
varlen = torch.tensor([34, 128], device='cuda') # Actual sequence lengths
output = adasplash(q, k, v, varlen=varlen)
Adaptive Sparsity Control:
# Control sparsity via alpha parameter
output = adasplash(q, k, v, alpha=1.333) # More dense
output = adasplash(q, k, v, alpha=2.0) # More sparse
These calls use the v1 compatibility path. For strict v2 behavior, call adasplash_v2.
Causal and Non-causal Masking:
output = adasplash(q, k, v, is_causal=True) # Causal masking
output = adasplash(q, k, v, is_causal=False) # Non-causal masking
Benchmarks
Efficiency
Single-vector retrieval
Check the Sparse ModernBERT repo.
Testing
To ensure the library works as expected, install the development dependencies and run tests:
pip install -r requirements-dev.txt
TRITON_INTERPRET=1 pytest
pytest -m "not slow and not stress" # on a CUDA machine
Citation
If you use AdaSplash in your research, please cite:
@inproceedings{goncalves2025adasplash,
title={AdaSplash: Adaptive Sparse Flash Attention},
author={Nuno Gon{\c{c}}alves and Marcos V Treviso and Andre Martins},
booktitle={Forty-second International Conference on Machine Learning},
year={2025},
url={https://openreview.net/forum?id=OWIPDWhUcO}
}
Acknowledgements
We thank Vlad Niculae for his insightful and constructive comments throughout this work. We also thank the SARDINE Lab members for reviewing this paper and providing helpful feedback. This work was supported by the Portuguese Recovery and Resilience Plan through project C645008882-00000055 (Center for ResponsibleAI), by the EU’s Horizon Europe Research and Innovation Actions (UTTER, contract 101070631), by the project DECOLLAGE (ERC-2022-CoG 101088763), and by FCT/MECI through national funds and when applicable co-funded EU funds under UID/50008: Instituto de Telecomunicações.
License
AdaSplash is licensed under the MIT License. See the LICENSE file for details.
Acknowledgements
We would like to the SARDINE lab team for the helpful discussions. This work was supported by the project DECOLLAGE (ERC-2022-CoG 101088763), by the Portuguese Recovery and Resilience Plan through project C64500888200000055 (Center for Responsible AI), and by FCT/MECI through national funds and when applicable co-funded EU funds under UID/50008: Instituto de Telecomunicações. Vlad Niculae is supported by the Dutch Research Council (NWO) via VI.Veni.212.228. Edoardo M. Ponti is supported by the ERC Starting Grant AToM-FM (101222956).
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file adasplash-0.2.0.tar.gz.
File metadata
- Download URL: adasplash-0.2.0.tar.gz
- Upload date:
- Size: 36.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4351e9ab7f6ca60779cbdead35c4421b78655e08080687f2115118c09c7bd1fd
|
|
| MD5 |
fa6bb68895ca24a11703e581097b4797
|
|
| BLAKE2b-256 |
98e81cbf98306d6f499d1a2a1b645824c13b2bb0c901a23092c4e7a07c8d0070
|
File details
Details for the file adasplash-0.2.0-py3-none-any.whl.
File metadata
- Download URL: adasplash-0.2.0-py3-none-any.whl
- Upload date:
- Size: 34.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0be8424bb47b7a3edd3a0ea052a96fbf77f1970e05a98e5d7cabd30f332855b5
|
|
| MD5 |
ddd0c89f4c83730b4f4c041020d4fcca
|
|
| BLAKE2b-256 |
bb8ca3c8eb748b371d7229f64c394773c354a62548afd0925f8da5471074bedf
|