Skip to main content

Dynamic Sparse Attention with Landmark Tokens — High-performance Triton implementation

Project description

DSALT: Dynamic Sparse Attention with Landmark Tokens

PyPI License

A high-performance PyTorch library implementing DSALT (Dynamic Sparse Attention with Landmark Tokens)—a memory-efficient sparse attention mechanism for transformers, built with Triton kernels and optimized for distributed training.

Install from PyPI: pip install dsalt
GitHub: dsalt-pytorch
Paper: Zenodo Preprint
Feature Roadmap: See FEATURE.md

🚀 Key Features

  • Memory-Efficient Sparse Attention: Triton-accelerated kernels (4–8× memory savings vs. dense attention)
  • Adaptive Local Windows: Token-by-token dynamic window sizing that grows with sequence position
  • Global Landmark Tokens: Top-k informative tokens selected per head via hybrid energy scoring
  • Production-Ready Training: Complete trainer with mixed precision, gradient checkpointing, and validation
  • Distributed Training: Full support for DDP and FSDP (model sharding across 2+ GPUs)
  • Numerically Verified: CPU/GPU equivalence tests ensure correctness; gradient stability validated

Recent Optimizations (2024)

Eliminated Silent Memory Replication

  • Landmark tensor shape: [B, H, K] (was [B, H, N, K]) — saves O(N) allocation
  • Hidden state input: [B, N, D] (was [B, H, N, D]) — eliminates copy per head
  • Combined effect: 4–8× memory reduction in landmark computation

Fixed Correctness Issues

  • Gradient checkpointing now properly checkpoints full attention block (not lambda-wrapped)
  • Backward kernel signatures cleaned: removed dead code and unused parameters
  • Distributed training fixed: _is_main no longer silently defined twice

Enhanced Distributed Training

  • FSDP support for 2+ GPU model sharding: torchrun --nproc_per_node=2 train.py --fsdp
  • Gradient accumulation optimized: no_sync() eliminates intermediate all-reduce cost
  • DataParallel removed: unsuitable overhead for sparse patterns

📋 Table of Contents

  1. Installation
  2. Quick Start
  3. Architecture
  4. Training Examples
  5. API Reference
  6. Testing
  7. Citation
  8. Contributing
  9. License

🛠️ Installation

Requirements

  • Python: 3.8+
  • PyTorch: 2.0+
  • CUDA: 11.0+ (for GPU acceleration; CPU fallback available)
  • Triton: 2.0+ (optional; enables GPU kernels; CPU fallback via PyTorch)

From PyPI

pip install dsalt

From PyPI with GPU Acceleration

# Includes Triton for GPU kernels
pip install dsalt

From Source

git clone https://github.com/LeonardoCofone/dsalt-pytorch.git
cd dsalt-pytorch
pip install -e .

Development Setup

pip install -r requirements-dev.txt

🚀 Quick Start

1. Minimal Example: Language Model Inference

import torch
from dsalt.model import DSALTLMHeadModel

# Create a DSALT language model
model = DSALTLMHeadModel(
    vocab_size=32000,      # Size of vocabulary
    d_model=1024,          # Hidden dimension
    n_layers=24,           # Depth: 24 transformer blocks
    n_heads=16,            # Multi-head attention heads
    n_min=32,              # Minimum local window
    n_max=512,             # Maximum local window
    k_lmk=64,              # Number of landmark tokens per head
)

# Forward pass (inference)
input_ids = torch.randint(0, 32000, (1, 1024))  # [batch=1, seq_len=1024]
logits = model(input_ids)                        # [1, 1024, 32000]
print(f"Output shape: {logits.shape}")

# With labels: direct loss computation
input_ids = torch.randint(0, 32000, (4, 512))  # [batch=4, seq_len=512]
labels = torch.randint(0, 32000, (4, 512))
outputs = model(input_ids, labels=labels)
loss = outputs.loss
loss.backward()

2. Training: Single GPU

import torch
from torch.utils.data import DataLoader, TensorDataset
from dsalt.model import DSALTLMHeadModel
from dsalt.training import DSALTTrainer

# Prepare dataset
vocab_size = 32000
seq_len = 512
train_dataset = TensorDataset(
    torch.randint(0, vocab_size, (1000, seq_len)),  # 1000 sequences
)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

# Create model
model = DSALTLMHeadModel(
    vocab_size=vocab_size,
    d_model=768,
    n_layers=12,
    n_heads=12,
    n_min=32,
    n_max=256,
    k_lmk=32,
)

# Train
trainer = DSALTTrainer(
    model=model,
    train_loader=train_loader,
    lr=3e-4,
    total_steps=10000,
    save_dir="checkpoints",
    dtype=torch.bfloat16,  # Mixed precision: BF16
    log_every=50,
)
trainer.train()

3. Training: Multi-GPU with FSDP (Fully Sharded Data Parallel)

# Command: torchrun with FSDP enabled
# torchrun --nproc_per_node=2 train.py

import torch
from dsalt.training import DSALTTrainer

trainer = DSALTTrainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    lr=3e-4,
    total_steps=100000,
    fsdp=True,              # ← Enable FSDP for 2+ GPU sharding
    dtype=torch.bfloat16,
    save_dir="checkpoints",
)
trainer.train()

🏗️ Architecture

Overview

DSALT combines local causal windows (adaptive, growing with position) with global landmark tokens (top-k per head):

┌─ Local Attention ──────┬─ Global Landmarks ────┐
│ Recent N tokens        │ Top-K informative     │
│ (adaptive window)      │ (hybrid energy score) │
└───────────────────────┴──────────────────────┘
         ↓                         ↓
         └─────────────┬───────────┘
                       ↓
            Sparse Attention Output

Key Components

  1. DSALTAttention: Multi-head sparse attention module

    • Adaptive window size prediction per token
    • Landmark token selection (no gradient)
    • Sparse kernel computation (Triton or CPU fallback)
  2. WindowSizePredictor: Learned dynamic window module

    • Predicts continuous window sizes
    • Enables attention scope to adapt to token importance
    • Regularization: entropy loss on window decisions
  3. HybridEnergyScorer (in kernels): Landmark selection

    • Computes energy scores per token (norm-based)
    • Z-score normalization
    • Top-k selection per head
    • Excludes tokens in local window (redundancy-aware)
  4. DSALTTransformer: Decoder-only stack

    • Pre-norm RMSNorm for stability
    • SwiGLU feed-forward networks
    • Residual connections and dropout
  5. Sparse Attention Kernel (Triton)

    • Fused forward pass: avoids materializing full attention matrix
    • Backward pass: gradient stability for Q, K, V
    • CPU fallback: functional equivalence on devices without GPU

Memory Profile

Dense Attention (standard transformer):

  • Attention matrix: O(N²) memory, O(N²) compute

DSALT Sparse Attention:

  • Local window: O(w·N) memory (w = adaptive window)
  • Landmarks: O(K·N) memory (K = constant landmark count, independent of N)
  • Total: O((w+K)·N) ≪ O(N²) for long sequences

🎯 Training

Configuration

trainer = DSALTTrainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    
    # Optimization
    lr=3e-4,
    weight_decay=0.1,
    max_grad_norm=1.0,
    grad_accum=2,                  # Gradient accumulation: effective batch 2x
    
    # Schedule
    warmup_steps=500,
    total_steps=100_000,
    
    # Checkpointing & Logging
    save_every=1000,
    log_every=50,
    val_every=500,
    save_dir="checkpoints",
    
    # Precision & Device
    dtype=torch.bfloat16,          # Mixed precision training
    device=torch.device("cuda"),
    
    # Parallelism (choose ONE)
    ddp=False,                     # Standard DDP
    fsdp=True,                     # FSDP: model sharding
    fsdp_cpu_offload=False,        # CPU offload (slow, for very large models)
    
    # Memory Optimization
    gradient_checkpointing=True,   # Gradient checkpointing: saves activation memory
    
    # Regularization
    window_reg_coef=0.01,          # Window entropy regularization
)

trainer.train()

Training with Gradient Accumulation + Distributed

# 2 GPUs, batch=4 per GPU, accumulate 2 steps = effective batch 16
torchrun --nproc_per_node=2 train.py \
    --batch_size 4 \
    --grad_accum 2 \
    --fsdp true \
    --dtype bfloat16

📚 API Reference

Core Classes

DSALTLMHeadModel

Language model wrapper for autoregressive training/inference.

model = DSALTLMHeadModel(
    vocab_size=32000,
    d_model=1024,
    n_layers=24,
    n_heads=16,
    n_min=32,           # Min window size
    n_max=512,          # Max window size
    k_lmk=64,           # Landmarks per head
    norm_eps=1e-6,
    dropout=0.1,
    bias=False,
)

# Forward: returns logits or (loss, logits) if labels provided
outputs = model(input_ids, labels=None)
logits = outputs.logits

DSALTTransformer

Core transformer architecture (without LM head).

transformer = DSALTTransformer(
    d_model=1024,
    n_heads=16,
    n_layers=24,
    n_min=32,
    n_max=512,
    k_lmk=64,
)

# Forward: returns [batch, seq_len, d_model]
x = transformer(input_embeddings)

DSALTAttention

Single attention layer.

attn = DSALTAttention(
    d_model=1024,
    n_heads=16,
    n_min=32,
    n_max=512,
    k_lmk=64,
    dropout=0.1,
    gradient_checkpointing=False,
)

# Returns (output, window_sizes) if return_window=True
out, _ = attn(x, x_prev=None, return_window=True)

DSALTTrainer

Training loop with mixed precision, DDP/FSDP, checkpointing.

trainer = DSALTTrainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    lr=3e-4,
    total_steps=100_000,
    dtype=torch.bfloat16,
    ddp=False,
    fsdp=True,
    gradient_checkpointing=True,
)

trainer.train()  # Blocking: runs until total_steps

Kernel Functions

dsalt_attention(Q, K, V, window_sizes, landmark_idx)

Low-level sparse attention computation.

from dsalt.kernels import dsalt_attention

# Q, K, V: [batch, n_heads, seq_len, d_head]
# window_sizes: [batch, n_heads, seq_len] int32
# landmark_idx: [batch, n_heads, k_landmarks] int32
# Returns: [batch, n_heads, seq_len, d_head]

out = dsalt_attention(Q, K, V, window_sizes, landmark_idx)

compute_hybrid_energy_scores(X, WV, window_sizes, k, alpha)

Compute landmark scores and select top-k.

from dsalt.kernels import compute_hybrid_energy_scores

# X: [batch, seq_len, d_model]
# WV: [n_heads, d_model, d_head]
# Returns: [batch, n_heads, k]

landmark_idx = compute_hybrid_energy_scores(
    X=hidden_states,
    WV=value_projections,
    window_sizes=window_sizes,
    k=64,
    alpha=torch.tensor([0.6] * n_heads),
)

🧪 Testing

Run Full Test Suite

# All tests with coverage
make test-cov

# Or directly with pytest
pytest tests/ -v --cov=dsalt --cov-report=html

# View coverage report
open htmlcov/index.html

Test Modules

  • tests/test_sparse_attn.py: Attention kernel CPU/GPU equivalence, backward pass
  • tests/test_hybrid_energy.py: Landmark scoring and selection
  • tests/test_dsalt_lm.py: Language model wrapper, loss computation
  • tests/test_main.py: End-to-end training smoke test

CI/CD

Tests automatically run on:

  • Push to any branch
  • Pull requests
  • Scheduled nightly builds

📊 Performance & Benchmarks

Memory Usage (Approximate)

Model: d_model=1024, n_heads=16, n_layers=12, seq_len=1024, batch=4

Attention Type Memory (GB) Relative
Dense (Q×K^T) ~3.5 1.0×
FlashAttention 2 ~1.8 0.51×
DSALT ~0.6 0.17×

Compute Efficiency

  • Forward: ~95% of time spent in Triton kernels (minimal Python overhead)
  • Backward: Full gradient support with automatic differentiation
  • Mixed Precision (BF16): 1.5–2× speedup vs. FP32 on modern GPUs

📖 Citation

If DSALT is useful in your research, please cite:

@article{dsalt2024,
  title={Noise Accumulation and Rank Collapse in Dense Self-Attention: DSALT},
  author={Leonardo Cofone},
  journal={Zenodo Preprint},
  year={2026},
  url={https://zenodo.org/records/19312827},
  note={Dynamic Sparse Attention with Landmark Tokens}
}

🤝 Contributing

Contributions are welcome! Please see CONTRIBUTING.md for guidelines.

Areas for contribution:

  • Performance tuning (Triton kernel optimization)
  • Additional model architectures (encoder, encoder-decoder)
  • New training strategies and samplers
  • Documentation and tutorials
  • Bug reports and fixes

📄 License

Licensed under the Apache License 2.0. See LICENSE for details.


🙏 Acknowledgments

  • Triton: GPU kernel framework by OpenAI
  • FlashAttention: Inspiration for fused kernel design (Dao et al.)
  • PyTorch: Deep learning framework and distributed training infrastructure

📞 Support & Questions


Last Updated: May 2026
Version: 1.2.0
Status: ✅ Production-Ready

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dsalt-0.1.12.tar.gz (41.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

dsalt-0.1.12-py3-none-any.whl (32.6 kB view details)

Uploaded Python 3

File details

Details for the file dsalt-0.1.12.tar.gz.

File metadata

  • Download URL: dsalt-0.1.12.tar.gz
  • Upload date:
  • Size: 41.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.0

File hashes

Hashes for dsalt-0.1.12.tar.gz
Algorithm Hash digest
SHA256 def68f4d13dfaf4f8c81f5309f6afe85c44bc9d9682fe6410b965122948ead55
MD5 f562f566a071d13481dac0bf7ebff39e
BLAKE2b-256 42082b74a788f341456f29d9f5749d35f4dc0f3bc1c7c5cc5023f26e230c83dc

See more details on using hashes here.

File details

Details for the file dsalt-0.1.12-py3-none-any.whl.

File metadata

  • Download URL: dsalt-0.1.12-py3-none-any.whl
  • Upload date:
  • Size: 32.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.0

File hashes

Hashes for dsalt-0.1.12-py3-none-any.whl
Algorithm Hash digest
SHA256 422638b45a9214d2e78c71ab415b0f8e15393e6b30e9625825f291c3cbe60fbe
MD5 b16ffbd5b6b583634845e4dd66c8a717
BLAKE2b-256 2c82c8aa687df9cb4169ab29052a10275b54d76f22e24b61655aeeb69b72ee4c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page