Dynamic Sparse Attention with Landmark Tokens — High-performance Triton implementation
Project description
DSALT, Dynamic Sparse Attention with Landmark Tokens
See the repo here: REPO GITHUB, there you can see all the .md mentioned in this file See the full feature catalog in FEATURE.md.
A high-performance PyTorch library implementing DSALT (Dynamic Sparse Attention with Landmark Tokens), a sparse attention transformer library built for efficient training with Triton and PyTorch.
Published on PyPI:
pip install dsalt
🚀 Key Features
- Efficient Sparse Attention: Triton-accelerated kernels for GPU-optimized sparse causal self-attention
- Dynamic Window Sizing: Adaptive local attention windows that grow with sequence position
- Landmark Token Selection: Global landmark tokens selected via hybrid energy scoring
- Mixed Precision Training: Full support for BF16/FP16 training with gradient scaling
- Distributed Training: DDP (DistributedDataParallel) support for multi-GPU training
- Production Ready: Complete training harness with checkpointing, logging, and validation
📋 Table of Contents
🛠️ Installation
Requirements
- Python 3.8+
- PyTorch 2.0+
- CUDA 11.0+ (for GPU acceleration)
- Triton 2.0+ (optional, for GPU kernels)
Install from PyPI
pip install dsalt
Install with Triton support
pip install dsalt[triton]
Install with Flash Attention fallback
pip install dsalt[flash-attn]
Install from source
git clone https://github.com/LeonardoCofone/dsalt-pytorch.git
cd dsalt-pytorch
pip install -e .
Developer setup
pip install -r requirements-dev.txt
🚀 Quick Start
import torch
from dsalt.model import DSALTLMHeadModel
# Create a DSALT language model
model = DSALTLMHeadModel(
vocab_size=32000,
d_model=1024,
n_layers=24,
n_heads=16,
n_min=32, # Minimum window size
n_max=512, # Maximum window size
k_lmk=64, # Number of landmark tokens
)
# Forward pass
input_ids = torch.randint(0, 32000, (1, 1024))
logits = model(input_ids)
print(f"Output shape: {logits.shape}") # [1, 1024, 32000]
🏗️ Architecture
DSALT combines local causal windows with global landmark tokens:
- Local Attention: Each token attends to a dynamic window of recent tokens
- Landmark Selection: Top-k informative tokens selected globally via energy scoring
- Sparse Computation: Only compute attention for relevant token pairs
Key Components
DSALTTransformer: Main transformer architectureDSALTAttention: Multi-head sparse attention layerWindowSizePredictor: Learned adaptive window sizingHybridEnergyScorer: Landmark token selectionSparseAttentionKernel: Triton-accelerated attention computation
🎯 Training
Single GPU Training
from dsalt.training import DSALTTrainer
from torch.utils.data import DataLoader
trainer = DSALTTrainer(
model=model,
train_loader=train_dataloader,
val_loader=val_dataloader,
lr=3e-4,
total_steps=100000,
save_dir="checkpoints",
dtype=torch.bfloat16,
)
trainer.train()
Multi-GPU Distributed Training
import torch.distributed as dist
# Initialize process group
dist.init_process_group(backend='nccl')
trainer = DSALTTrainer(
model=model,
train_loader=train_dataloader,
val_loader=val_dataloader,
ddp=True, # Enable DDP
# ... other args
)
📚 API Reference
Core Classes
DSALTLMHeadModel: Language model wrapper with LM headDSALTTransformer: Base transformer architectureDSALTAttention: Sparse attention moduleDSALTTrainer: Training harness
Kernel Functions
dsalt_attention(): Main sparse attention functioncompute_hybrid_energy_scores(): Landmark scoringselect_landmarks(): Landmark selection
🧪 Testing
Run the full test suite:
python tests/test.py
Run specific tests:
python tests/test_sparse_attn.py # Attention kernels
python tests/test_dsalt_lm.py # LM wrapper
📖 Citation
If you use DSALT in your research, please cite our paper:
@article{dsalt2024,
title={Noise Accumulation and Rank Collapse in Dense Self-Attention: DSALT},
author={Leonardo et al.},
journal={Zenodo preprint},
year={2026}
}
Paper: https://zenodo.org/records/19312827
🤝 Contributing
We welcome contributions! Please see our contributing guidelines.
📄 License
This project is licensed under the Apache 2.0 License - see the LICENSE file for details.
🙏 Acknowledgments
- Built on top of Triton for GPU kernels
- Inspired by Flash Attention
- Thanks to the PyTorch team for the excellent deep learning framework
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file dsalt-0.1.4.tar.gz.
File metadata
- Download URL: dsalt-0.1.4.tar.gz
- Upload date:
- Size: 37.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0a9eb3b74f7c0551feb4a7f268f74ef880ef48299ad5868e549faf2713e86ece
|
|
| MD5 |
2b16cdecb530e130098424aed72be775
|
|
| BLAKE2b-256 |
7df6992d18b3d7ca895b85b1c8e5fda14ad3e2727b5742973c42139717a500f7
|
File details
Details for the file dsalt-0.1.4-py3-none-any.whl.
File metadata
- Download URL: dsalt-0.1.4-py3-none-any.whl
- Upload date:
- Size: 31.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7628be8a238e34f1d10afcb693793a8004ebeeedcb17e7fd007e104c50ae2c3e
|
|
| MD5 |
055811fd98ce93228c8237e4a5248858
|
|
| BLAKE2b-256 |
a4fb8efaf435f9c07909333a2b49c286f2ca80f1755ac4004b5c63d5018a55a6
|