Skip to main content

AdaSplash: Efficient Adaptive Sparse Attention in Triton

Project description

AdaSplash: Adaptive Sparse Flash Attention

PyPI version

AdaSplash, aka flash entmax attention, is an efficient adaptive sparse attention mechanism implemented in Triton. Check out our paper: https://arxiv.org/abs/2502.12082.

Installation

You can install AdaSplash via pip:

pip install adasplash

Alternatively, install the latest development version directly from GitHub:

pip install git+https://github.com/deep-spin/adasplash.git

Usage

AdaSplash provides three main functions, all available via from adasplash import ...:

Triton Entmax (Optimized Entmax Activation)

from adasplash import triton_entmax
import torch

x = torch.randn(128, 256).cuda()
y = triton_entmax(x, alpha=1.5, n_iter=10, fast_math=True)
  • Uses Halley's method + bisection instead of pure bisection.
  • Faster and more efficient than traditional Entmax implementations.

AdaSplash with Block Masking

from adasplash import adasplash

q = torch.randn(1, 8, 128, 64, device="cuda")
k = torch.randn(1, 8, 128, 64, device="cuda")
v = torch.randn(1, 8, 128, 64, device="cuda")

output = adasplash(q, k, v, alpha=1.5, niter=10, is_causal=True, varlen=None)
  • Leverages adaptive sparsity for efficiency in both forward and backward passes.
  • Requires O(Tr × Tc) bits of extra memory for storing a binary mask per block.

AdaSplash without Block Masking

from adasplash import adasplash_no_block_mask

output = adasplash_no_block_mask(q, k, v, alpha=1.5, niter=10, is_causal=True, varlen=None)
  • Does not use block masking but still benefits from tiling and fused ops for efficiency.
  • Requires less memory than the block-masked version.

Key Features

Variable Length Sequences:

varlen = torch.tensor([34, 128], device='cuda')  # Actual sequence lengths
output = adasplash(q, k, v, varlen=varlen)

Adaptive Sparsity Control:

# Control sparsity via alpha parameter
output = adasplash(q, k, v, alpha=1.333)  # More dense
output = adasplash(q, k, v, alpha=2.0)  # More sparse

Causal and Non-causal Masking:

output = adasplash(q, k, v, is_causal=True)  # Causal masking
output = adasplash(q, k, v, is_causal=False)  # Non-causal masking

Benchmark

Benchmark

Testing

To ensure the library works as expected, install the development dependencies and run tests:

pip install -r requirements-dev.txt
pytest

Citation

If you use AdaSplash in your research, please cite:

@article{goncalves2025adasplash,
  title={AdaSplash: Adaptive Sparse Flash Attention},
  author={Nuno Gonçalves and Marcos Treviso and André F. T. Martins},
  journal={arXiv preprint arXiv:2502.12082},
  url={https://arxiv.org/abs/2502.12082},
  year={2025}
}

License

AdaSplash is licensed under the MIT License. See the LICENSE file for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

adasplash-0.1.1b0.tar.gz (18.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

adasplash-0.1.1b0-py3-none-any.whl (17.9 kB view details)

Uploaded Python 3

File details

Details for the file adasplash-0.1.1b0.tar.gz.

File metadata

  • Download URL: adasplash-0.1.1b0.tar.gz
  • Upload date:
  • Size: 18.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.7

File hashes

Hashes for adasplash-0.1.1b0.tar.gz
Algorithm Hash digest
SHA256 812f0aea6fabd89e1d0c3a1581365dc0731c9d2bfa1ca1d5752505b7630a2568
MD5 8319066247bb9874fdb173dedbf4e49e
BLAKE2b-256 5b8b5fbf00a165b29689b9d3dcd9db0a339c0fe4484aa7f609fefda8751336ca

See more details on using hashes here.

File details

Details for the file adasplash-0.1.1b0-py3-none-any.whl.

File metadata

  • Download URL: adasplash-0.1.1b0-py3-none-any.whl
  • Upload date:
  • Size: 17.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.7

File hashes

Hashes for adasplash-0.1.1b0-py3-none-any.whl
Algorithm Hash digest
SHA256 cdca0394a53b283e284bebf756497436a48be6b36bd27bbb3d33892ecf23dabb
MD5 76ab939567b81704b96408f709976c29
BLAKE2b-256 281dbe105a18e7cd3445334f3a9a69851eb2dcfd47a3c389d8c6b00b579b95d7

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page