Skip to main content

Dynamic Sparse Attention with Landmark Tokens — High-performance Triton implementation

Project description

DSALT, Dynamic Sparse Attention with Landmark Tokens

PyPI License

See the repo here: REPO GITHUB, there you can see all the .md mentioned in this file See the full feature catalog in FEATURE.md.

A high-performance PyTorch library implementing DSALT (Dynamic Sparse Attention with Landmark Tokens), a sparse attention transformer library built for efficient training with Triton and PyTorch.

Published on PyPI: pip install dsalt

🚀 Key Features

  • Efficient Sparse Attention: Triton-accelerated kernels for GPU-optimized sparse causal self-attention
  • Dynamic Window Sizing: Adaptive local attention windows that grow with sequence position
  • Landmark Token Selection: Global landmark tokens selected via hybrid energy scoring
  • Mixed Precision Training: Full support for BF16/FP16 training with gradient scaling
  • Distributed Training: DDP (DistributedDataParallel) support for multi-GPU training
  • Production Ready: Complete training harness with checkpointing, logging, and validation

📋 Table of Contents

🛠️ Installation

Requirements

  • Python 3.8+
  • PyTorch 2.0+
  • CUDA 11.0+ (for GPU acceleration)
  • Triton 2.0+ (optional, for GPU kernels)

Install from PyPI

pip install dsalt

Install with Triton support

pip install dsalt[triton]

Install with Flash Attention fallback

pip install dsalt[flash-attn]

Install from source

git clone https://github.com/LeonardoCofone/dsalt-pytorch.git
cd dsalt-pytorch
pip install -e .

Developer setup

pip install -r requirements-dev.txt

🚀 Quick Start

import torch
from dsalt.model import DSALTLMHeadModel

# Create a DSALT language model
model = DSALTLMHeadModel(
    vocab_size=32000,
    d_model=1024,
    n_layers=24,
    n_heads=16,
    n_min=32,      # Minimum window size
    n_max=512,     # Maximum window size
    k_lmk=64,      # Number of landmark tokens
)

# Forward pass
input_ids = torch.randint(0, 32000, (1, 1024))
logits = model(input_ids)
print(f"Output shape: {logits.shape}")  # [1, 1024, 32000]

🏗️ Architecture

DSALT combines local causal windows with global landmark tokens:

  • Local Attention: Each token attends to a dynamic window of recent tokens
  • Landmark Selection: Top-k informative tokens selected globally via energy scoring
  • Sparse Computation: Only compute attention for relevant token pairs

Key Components

  • DSALTTransformer: Main transformer architecture
  • DSALTAttention: Multi-head sparse attention layer
  • WindowSizePredictor: Learned adaptive window sizing
  • HybridEnergyScorer: Landmark token selection
  • SparseAttentionKernel: Triton-accelerated attention computation

🎯 Training

Single GPU Training

from dsalt.training import DSALTTrainer
from torch.utils.data import DataLoader

trainer = DSALTTrainer(
    model=model,
    train_loader=train_dataloader,
    val_loader=val_dataloader,
    lr=3e-4,
    total_steps=100000,
    save_dir="checkpoints",
    dtype=torch.bfloat16,
)

trainer.train()

Multi-GPU Distributed Training

import torch.distributed as dist

# Initialize process group
dist.init_process_group(backend='nccl')

trainer = DSALTTrainer(
    model=model,
    train_loader=train_dataloader,
    val_loader=val_dataloader,
    ddp=True,  # Enable DDP
    # ... other args
)

📚 API Reference

Core Classes

  • DSALTLMHeadModel: Language model wrapper with LM head
  • DSALTTransformer: Base transformer architecture
  • DSALTAttention: Sparse attention module
  • DSALTTrainer: Training harness

Kernel Functions

  • dsalt_attention(): Main sparse attention function
  • compute_hybrid_energy_scores(): Landmark scoring
  • select_landmarks(): Landmark selection

🧪 Testing

Run the full test suite:

python tests/test.py

Run specific tests:

python tests/test_sparse_attn.py  # Attention kernels
python tests/test_dsalt_lm.py     # LM wrapper

📖 Citation

If you use DSALT in your research, please cite our paper:

@article{dsalt2024,
  title={Noise Accumulation and Rank Collapse in Dense Self-Attention: DSALT},
  author={Leonardo et al.},
  journal={Zenodo preprint},
  year={2026}
}

Paper: https://zenodo.org/records/19312827

🤝 Contributing

We welcome contributions! Please see our contributing guidelines.

📄 License

This project is licensed under the Apache 2.0 License - see the LICENSE file for details.

🙏 Acknowledgments

  • Built on top of Triton for GPU kernels
  • Inspired by Flash Attention
  • Thanks to the PyTorch team for the excellent deep learning framework

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dsalt-0.1.8.tar.gz (32.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

dsalt-0.1.8-py3-none-any.whl (26.0 kB view details)

Uploaded Python 3

File details

Details for the file dsalt-0.1.8.tar.gz.

File metadata

  • Download URL: dsalt-0.1.8.tar.gz
  • Upload date:
  • Size: 32.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.0

File hashes

Hashes for dsalt-0.1.8.tar.gz
Algorithm Hash digest
SHA256 63c0e23327de7df5dbd6b72d7817d59e5770121d3033e312130d6f93ab1d162a
MD5 1c81a436b86540d0b3440da055c5ec23
BLAKE2b-256 47ff95bc42cc486ca2648c04e6a93e30a0b3ad1babed6a01bf4bf51473bd25c6

See more details on using hashes here.

File details

Details for the file dsalt-0.1.8-py3-none-any.whl.

File metadata

  • Download URL: dsalt-0.1.8-py3-none-any.whl
  • Upload date:
  • Size: 26.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.0

File hashes

Hashes for dsalt-0.1.8-py3-none-any.whl
Algorithm Hash digest
SHA256 6b2b378691af9d9c174b7967123227321211ff450928eacbdb64133a039c15e5
MD5 df7bc598143cf18060f602cec78b04fc
BLAKE2b-256 f845af31c566f4ed87eae668e2c24a1fc0b87c8f0aced8eb1a33df83c46ad3c0

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page