Dynamic Sparse Attention with Landmark Tokens - High-performance Triton implementation
Project description
DSALT: Dynamic Sparse Attention with Landmark Tokens
DSALT is a high‑performance PyTorch library that implements Dynamic Sparse Attention with Landmark Tokens – a memory‑efficient attention mechanism for transformers. It relies on Triton kernels and supports distributed training.
Install:
pip install dsalt
Source: https://github.com/LeonardoCofone/dsalt-library
Paper: https://zenodo.org/records/19312826
🚀 Key Features
- Memory‑efficient sparse attention – Triton‑accelerated kernels provide 4–8× memory savings compared to dense attention.
- Adaptive local windows – Token‑wise window sizes that grow with sequence position.
- Global landmark tokens – Top‑k informative tokens per head selected via a hybrid energy scoring function.
- Production‑ready training – Mixed‑precision, gradient checkpointing, and validation support.
- Distributed training – Full DDP and FSDP support for multi‑GPU setups.
- Numerical verification – CPU/GPU equivalence tests and gradient stability checks.
📋 Table of Contents
- Installation
- Quick Start
- Architecture Overview
- Training & Generation
- API Reference
- Hyperparameter Guide
- Testing
- Performance & Benchmarks
- Citation
- Contributing
- License
🛠️ Installation
Requirements
- Python 3.8+
- PyTorch 2.0+
- CUDA 11.0+ (GPU) – CPU fallback is available
- Triton 2.0+ (optional, enables GPU kernels)
From PyPI
pip install dsalt
From source
git clone https://github.com/LeonardoCofone/dsalt-library.git
cd dsalt-pytorch
pip install -e .
Development setup
pip install -r requirements-dev.txt
🚀 Quick Start
1. Language‑model inference
import torch
from dsalt.model import DSALTLMHeadModel
model = DSALTLMHeadModel(
vocab_size=32000,
d_model=1024,
n_layers=24,
n_heads=16,
n_min=32,
n_max=512,
k_lmk=64,
)
input_ids = torch.randint(0, 32000, (1, 1024)) # [batch, seq_len]
logits = model(input_ids) # [1, 1024, 32000]
print(logits.shape)
# With labels – loss is computed internally
labels = torch.randint(0, 32000, (1, 1024))
outputs = model(input_ids, labels=labels)
loss = outputs.loss
loss.backward()
2. Single‑GPU training
import torch
from torch.utils.data import DataLoader, TensorDataset
from dsalt.model import DSALTLMHeadModel
from dsalt.training import DSALTTrainer
vocab_size = 32000
seq_len = 512
train_dataset = TensorDataset(
torch.randint(0, vocab_size, (1000, seq_len))
)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
model = DSALTLMHeadModel(
vocab_size=vocab_size,
d_model=768,
n_layers=12,
n_heads=12,
n_min=32,
n_max=256,
k_lmk=32,
)
trainer = DSALTTrainer(
model=model,
train_loader=train_loader,
lr=3e-4,
total_steps=10_000,
save_dir="checkpoints",
dtype=torch.bfloat16,
log_every=50,
)
trainer.train()
3. Multi‑GPU with DataParallel
import torch
import torch.nn as nn
from dsalt.model import DSALTLMHeadModel
from dsalt.training import DSALTTrainer
model = DSALTLMHeadModel(...).to("cuda")
model = nn.DataParallel(model) # uses all available GPUs
trainer = DSALTTrainer(
model=model,
train_loader=train_loader,
lr=3e-4,
total_steps=100_000,
dtype=torch.bfloat16,
save_dir="checkpoints",
)
trainer.train()
4. Multi‑GPU with FSDP (model sharding)
torchrun --nproc_per_node=2 train.py
Then configure the trainer with fsdp=True.
🏗️ Architecture Overview
DSALT combines local causal windows (adaptive per token) with global landmark tokens (top‑k per head):
┌─ Local window (adaptive) ──┬─ Global landmarks ─┐
│ Recent N tokens │ Top‑K informative │
│ (window size grows) │ tokens per head │
└────────────────────────────┴────────────────────┘
↓ ↓
Sparse attention output
Key components:
DSALTAttention– multi‑head sparse attention with adaptive windows and landmark selection.WindowSizePredictor– learns per‑token window sizes.HybridEnergyScorer(kernel) – computes landmark scores.DSALTTransformer– stack of attention + feed‑forward layers.- Triton kernels – fused forward and backward passes for speed and memory efficiency.
🎯 Training & Generation
See the code snippets above for full training loops. The DSALTTrainer handles:
- Mixed‑precision (BF16 default)
- Gradient checkpointing to save activation memory at the cost of recomputing some layers.
- Learning‑rate warm‑up and cosine decay for stable and effective optimisation.
- Optional window‑entropy regularisation (
window_reg_coef) to encourage diverse window size predictions. - Comprehensive checkpointing and logging utilities.
- Support for various distributed training strategies (DDP, FSDP).
For detailed configuration, refer to the Hyperparameter Guide below, especially the DSALTTrainer section.
📖 Hyperparameter Guide
This section provides a comprehensive reference for all key hyperparameters across DSALT components, their defaults, and recommended usage.
DSALTLMHeadModel (Language Model)
The main language‑model wrapper that combines embeddings, transformer blocks, and an output head.
Required Parameters
vocab_size: int # Vocabulary size (e.g., 32000 for GPT‑2)
d_model: int # Hidden dimension, must be divisible by `n_heads`
n_layers: int # Number of transformer blocks
n_heads: int # Number of attention heads (d_model // n_heads must be a power of two and ≥ 16)
Architecture Hyperparameters
d_ff: int | None = None # Feed‑forward hidden dim (default = 4 × d_model)
max_seq_len: int = 2048 # Maximum sequence length for positional embeddings
dropout: float = 0.0 # Dropout rate applied after attention and FFN
use_fa2: bool = True # Enable FlashAttention 2 when Triton is available
tie_weights: bool = True # Share embedding and output‑projection weights
Sparse‑Attention Hyperparameters
n_min: int = 32 # Minimum local window size (causal sliding window)
n_max: int = 256 # Maximum local window size (grows with token position)
k_lmk: int = 16 # Number of global landmark tokens per head
Note: alpha is a learnable per‑head weight automatically initialised; it is not exposed as a configuration flag.
DSALTAttention (Attention Module)
A multi‑head sparse‑attention layer with adaptive windows and landmark selection.
Required Parameters
d_model: int
n_heads: int
Sparse‑Attention Hyperparameters (inherited from the model)
n_min: int
n_max: int
k_lmk: int
alpha: float = 0.6 # Initial value for the learnable weight per head
Regularisation & Optimisation
dropout: float = 0.0
use_fa2: bool = True # FlashAttention 2 fallback when the whole sequence fits the local window
gradient_checkpointing: bool = False # Enable gradient checkpointing for memory savings
compile_attention: bool = False # Enable `torch.compile` for the attention block (requires PyTorch 2.0+)
WindowSizePredictor (Dynamic Window Module)
Learns a per‑token window size that adapts between n_min and n_max.
Embedded Parameters (no constructor arguments)
d_model,n_heads,n_min,n_maxare automatically inferred from the parentDSALTAttention.
Output
output: [batch, n_heads, seq_len] # Predicted window size per token per head
The module also returns a continuous regularisation term used by the trainer when window_reg_coef > 0.
DSALTTransformer (Core Stack)
Stack of DSALTAttention + feed‑forward blocks.
All architectural hyperparameters are inherited from DSALTLMHeadModel.
DSALTTrainer (Training Configuration)
High‑level training loop with mixed‑precision, distributed training, and checkpointing.
Optimisation Hyperparameters
lr: float = 3e-4 # Initial learning rate
weight_decay: float = 0.1 # L2 regularisation strength
max_grad_norm: float = 1.0 # Maximum gradient norm for clipping
grad_accum: int = 1 # Number of gradient accumulation steps
Learning‑Rate Schedule
warmup_steps: int = 500 # Number of steps for linear learning rate warm-up
total_steps: int = 10_000 # Total number of training steps for cosine decay
Logging & Checkpointing
log_every: int = 50 # Log training metrics every N steps
val_every: int = 500 # Run validation every N steps
save_every: int = 1000 # Save model checkpoint every N steps
save_dir: str = "checkpoints" # Directory to save checkpoints and logs
Precision & Device
dtype: torch.dtype = torch.bfloat16 # Data type for training (BF16 default for speed/stability)
device: torch.device = "cuda:0" # Device to run training on (e.g., "cuda:0", "cpu")
Distributed Training (choose one)
ddp: bool = False # Enable standard DistributedDataParallel
fsdp: bool = False # Enable Fully‑sharded Data Parallel (model sharding)
fsdp_cpu_offload: bool = False # Optional CPU off‑load for very large models with FSDP
Memory Optimisation
gradient_checkpointing: bool = False # Save ~30 % activation memory at the cost of extra compute
Regularisation
window_reg_coef: float = 0.0 # Entropy penalty on the predicted window distribution (0.0 disables)
📚 API Reference (excerpt)
# Model creation
from dsalt.model import DSALTLMHeadModel
model = DSALTLMHeadModel(
vocab_size=32000,
d_model=1024,
n_layers=24,
n_heads=16,
n_min=32,
n_max=512,
k_lmk=64,
)
# Forward pass
input_ids = torch.randint(0, 32000, (1, 1024))
logits, windows = model(input_ids, return_window=True)
# Low‑level kernel call (for advanced users)
from dsalt.kernels import dsalt_attention
Q, K, V = torch.randn(1, 16, 1024, 64), torch.randn(1, 16, 1024, 64), torch.randn(1, 16, 1024, 64)
window_sizes = torch.full((1, 16, 1024), 64, dtype=torch.int32)
landmark_idx = torch.randint(0, 1024, (1, 16, 1024, 16), dtype=torch.int32)
out = dsalt_attention(Q, K, V, window_sizes, landmark_idx)
🧪 Testing
Key test modules include:
tests/test_sparse_attn.py– CPU/GPU equivalence and backward pass.tests/test_hybrid_energy.py– Landmark scoring and selection.tests/test_dsalt_lm.py– Language‑model wrapper and loss.tests/test_main.py– End‑to‑end smoke test.
📄 License
See here: https://github.com/LeonardoCofone/dsalt-library/blob/main/LICENSE
🤝 Contributing
Contributions are welcome! Please read CONTRIBUTING.md for guidelines. Areas where help is especially valuable:
- Triton kernel optimisation
- New model architectures (encoder, encoder‑decoder)
- Additional training strategies and samplers
- Documentation and tutorials
- Bug reports and fixes
📞 Support & Questions
- Issues: https://github.com/LeonardoCofone/dsalt-library/issues
- Discussions: https://github.com/LeonardoCofone/dsalt-library/discussions
- Paper: https://zenodo.org/records/19312826
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file dsalt-0.2.45.tar.gz.
File metadata
- Download URL: dsalt-0.2.45.tar.gz
- Upload date:
- Size: 33.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
988027367dff4f7b1ddd02adf0018f78d13a26bf61c57818940b10e06c8b95ca
|
|
| MD5 |
f57e24beacc72e75ac9d39d6f47fdb53
|
|
| BLAKE2b-256 |
539300be0b1ade445852e23498a3eda5df6a08c64f9f5acc95968698c1c36143
|
File details
Details for the file dsalt-0.2.45-py3-none-any.whl.
File metadata
- Download URL: dsalt-0.2.45-py3-none-any.whl
- Upload date:
- Size: 24.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
38ca7baf419d4722304f731a835db95b949ed8ba926cb9bc7995f2ac42bfa91d
|
|
| MD5 |
f9295a46cbe37c21bbe414401bf1961f
|
|
| BLAKE2b-256 |
ce50086e3beb7bb9ee986a0b93d8b80ef6adbb77174609ef9661a9b561357a88
|