Dynamic Sparse Attention with Landmark Tokens - High-performance Triton implementation
Project description
DSALT: Dynamic Sparse Attention with Landmark Tokens
DSALT is a high‑performance PyTorch library that implements Dynamic Sparse Attention with Landmark Tokens – a memory‑efficient attention mechanism for transformers. It relies on Triton kernels and supports distributed training.
Install:
pip install dsalt
Source: https://github.com/LeonardoCofone/dsalt-library
Paper: https://zenodo.org/records/19312826
Feature guide: SeeFEATURE.mdhere: https://github.com/LeonardoCofone/dsalt-library/blob/main/FEATURE.md
🚀 Key Features
- Memory‑efficient sparse attention – Triton‑accelerated kernels provide 4–8× memory savings compared to dense attention.
- Adaptive local windows – Token‑wise window sizes that grow with sequence position.
- Global landmark tokens – Top‑k informative tokens per head selected via a hybrid energy scoring function.
- Production‑ready training – Mixed‑precision, gradient checkpointing, and validation support.
- Distributed training – Full DDP and FSDP support for multi‑GPU setups.
- Numerical verification – CPU/GPU equivalence tests and gradient stability checks.
📋 Table of Contents
- Installation
- Quick Start
- Architecture Overview
- Training & Generation
- API Reference
- Hyperparameter Guide
- Testing
- Performance & Benchmarks
- Citation
- Contributing
- License
🛠️ Installation
Requirements
- Python 3.8+
- PyTorch 2.0+
- CUDA 11.0+ (GPU) – CPU fallback is available
- Triton 2.0+ (optional, enables GPU kernels)
From PyPI
pip install dsalt
From source
git clone https://github.com/LeonardoCofone/dsalt-library.git
cd dsalt-pytorch
pip install -e .
Development setup
pip install -r requirements-dev.txt
🚀 Quick Start
1. Language‑model inference
import torch
from dsalt.model import DSALTLMHeadModel
model = DSALTLMHeadModel(
vocab_size=32000,
d_model=1024,
n_layers=24,
n_heads=16,
n_min=32,
n_max=512,
k_lmk=64,
)
input_ids = torch.randint(0, 32000, (1, 1024)) # [batch, seq_len]
logits = model(input_ids) # [1, 1024, 32000]
print(logits.shape)
# With labels – loss is computed internally
labels = torch.randint(0, 32000, (1, 1024))
outputs = model(input_ids, labels=labels)
loss = outputs.loss
loss.backward()
2. Single‑GPU training
import torch
from torch.utils.data import DataLoader, TensorDataset
from dsalt.model import DSALTLMHeadModel
from dsalt.training import DSALTTrainer
vocab_size = 32000
seq_len = 512
train_dataset = TensorDataset(
torch.randint(0, vocab_size, (1000, seq_len))
)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
model = DSALTLMHeadModel(
vocab_size=vocab_size,
d_model=768,
n_layers=12,
n_heads=12,
n_min=32,
n_max=256,
k_lmk=32,
)
trainer = DSALTTrainer(
model=model,
train_loader=train_loader,
lr=3e-4,
total_steps=10_000,
save_dir="checkpoints",
dtype=torch.bfloat16,
log_every=50,
)
trainer.train()
3. Multi‑GPU with DataParallel
import torch
import torch.nn as nn
from dsalt.model import DSALTLMHeadModel
from dsalt.training import DSALTTrainer
model = DSALTLMHeadModel(...).to("cuda")
model = nn.DataParallel(model) # uses all available GPUs
trainer = DSALTTrainer(
model=model,
train_loader=train_loader,
lr=3e-4,
total_steps=100_000,
dtype=torch.bfloat16,
save_dir="checkpoints",
)
trainer.train()
4. Multi‑GPU with FSDP (model sharding)
torchrun --nproc_per_node=2 train.py
Then configure the trainer with fsdp=True.
🏗️ Architecture Overview
DSALT combines local causal windows (adaptive per token) with global landmark tokens (top‑k per head):
┌─ Local window (adaptive) ──┬─ Global landmarks ──┐
│ Recent N tokens │ Top‑K informative │
│ (window size grows) │ tokens per head │
└────────────────────────────┴────────────────────┘
↓ ↓
Sparse attention output
Key components:
DSALTAttention– multi‑head sparse attention with adaptive windows and landmark selection.WindowSizePredictor– learns per‑token window sizes.HybridEnergyScorer(kernel) – computes landmark scores.DSALTTransformer– stack of attention + feed‑forward layers.- Triton kernels – fused forward and backward passes for speed and memory efficiency.
🎯 Training & Generation
See the code snippets above for full training loops. The DSALTTrainer handles:
- Mixed‑precision (BF16 default)
- Gradient checkpointing
- Learning‑rate warm‑up and cosine decay
- Optional window‑entropy regularisation (
window_reg_coef) - Checkpointing and logging utilities
📚 API Reference (excerpt)
from dsalt.model import DSALTLMHeadModel
model = DSALTLMHeadModel(vocab_size=32000, d_model=1024, n_layers=24,
n_heads=16, n_min=32, n_max=512, k_lmk=64)
logits, windows = model(input_ids, return_window=True)
Low‑level kernel call:
from dsalt.kernels import dsalt_attention
out = dsalt_attention(Q, K, V, window_sizes, landmark_idx)
📊 Performance & Benchmarks (May 2026)
| Attention type | Approx. memory (GB) | Relative speed |
|---|---|---|
| Dense (O(N²)) | ~3.5 | 1.0× |
| FlashAttention 2 | ~1.8 | 0.5× |
| DSALT | ~0.6 | 0.17× |
📖 Hyperparameter Guide
All hyperparameters are documented in FEATURE.md. Typical configurations are provided for:
- Mobile / Edge – tiny models, low memory.
- Consumer GPU – e.g., RTX 4090, 24 GB.
- Enterprise – H100 80 GB, optional FSDP.
- Research – multi‑node, large models.
🧪 Testing
make test-cov # Full test suite with coverage report
pytest tests/ -v # Run tests directly
Key test modules:
tests/test_sparse_attn.py– kernel equivalence and backward.tests/test_hybrid_energy.py– landmark scoring.tests/test_dsalt_lm.py– language‑model wrapper.tests/test_main.py– end‑to‑end smoke test.
📄 License
See here: https://github.com/LeonardoCofone/dsalt-library/blob/main/LICENSE
🤝 Contributing
Contributions are welcome! Please read CONTRIBUTING.md for guidelines. Areas where help is especially valuable:
- Triton kernel optimisation
- New model architectures (encoder, encoder‑decoder)
- Additional training strategies and samplers
- Documentation and tutorials
- Bug reports and fixes
📞 Support & Questions
- Issues: https://github.com/LeonardoCofone/dsalt-library/issues
- Discussions: https://github.com/LeonardoCofone/dsalt-library/discussions
- Paper: https://zenodo.org/records/19312826
Last Updated: May 2026
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file dsalt-0.2.32.tar.gz.
File metadata
- Download URL: dsalt-0.2.32.tar.gz
- Upload date:
- Size: 37.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ac7284230b9e8a1b769018ec04b04af74b69ca368a1e1b045741a2a5549a6b16
|
|
| MD5 |
393a343c3c0d0e5f06f816cdc568d66e
|
|
| BLAKE2b-256 |
6c6986c142ee68f86a3d260832cab9d23253f7a83142a6a5cd666860332cf965
|
File details
Details for the file dsalt-0.2.32-py3-none-any.whl.
File metadata
- Download URL: dsalt-0.2.32-py3-none-any.whl
- Upload date:
- Size: 30.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a68582ba35ba69db69c01fb540973e1e166fc74b82a60384a1cadc17681a547d
|
|
| MD5 |
ab3c4f96c023d48bc3e5ab698be12521
|
|
| BLAKE2b-256 |
a89db6a31529a47905a5436c152571d77b281257dca19b6f950fc6f1345656b9
|