Skip to main content

Dynamic Sparse Attention with Landmark Tokens — High-performance Triton implementation

Project description

DSALT: Dynamic Sparse Attention with Landmark Tokens

PyPI License

DSALT is a high‑performance PyTorch library that implements Dynamic Sparse Attention with Landmark Tokens – a memory‑efficient attention mechanism for transformers. It relies on Triton kernels and supports distributed training.

Install: pip install dsalt
Source: https://github.com/LeonardoCofone/dsalt-pytorch
Paper: https://zenodo.org/records/19312826
Feature guide: See FEATURE.md here: https://github.com/LeonardoCofone/dsalt-pytorch/blob/main/FEATURE.md

🚀 Key Features

  • Memory‑efficient sparse attention – Triton‑accelerated kernels provide 4–8× memory savings compared to dense attention.
  • Adaptive local windows – Token‑wise window sizes that grow with sequence position.
  • Global landmark tokens – Top‑k informative tokens per head selected via a hybrid energy scoring function.
  • Production‑ready training – Mixed‑precision, gradient checkpointing, and validation support.
  • Distributed training – Full DDP and FSDP support for multi‑GPU setups.
  • Numerical verification – CPU/GPU equivalence tests and gradient stability checks.

📋 Table of Contents

  1. Installation
  2. Quick Start
  3. Architecture Overview
  4. Training & Generation
  5. API Reference
  6. Hyperparameter Guide
  7. Testing
  8. Performance & Benchmarks
  9. Citation
  10. Contributing
  11. License

🛠️ Installation

Requirements

  • Python 3.8+
  • PyTorch 2.0+
  • CUDA 11.0+ (GPU) – CPU fallback is available
  • Triton 2.0+ (optional, enables GPU kernels)

From PyPI

pip install dsalt

From source

git clone https://github.com/LeonardoCofone/dsalt-pytorch.git
cd dsalt-pytorch
pip install -e .

Development setup

pip install -r requirements-dev.txt

🚀 Quick Start

1. Language‑model inference

import torch
from dsalt.model import DSALTLMHeadModel

model = DSALTLMHeadModel(
    vocab_size=32000,
    d_model=1024,
    n_layers=24,
    n_heads=16,
    n_min=32,
    n_max=512,
    k_lmk=64,
)

input_ids = torch.randint(0, 32000, (1, 1024))  # [batch, seq_len]
logits = model(input_ids)                     # [1, 1024, 32000]
print(logits.shape)

# With labels – loss is computed internally
labels = torch.randint(0, 32000, (1, 1024))
outputs = model(input_ids, labels=labels)
loss = outputs.loss
loss.backward()

2. Single‑GPU training

import torch
from torch.utils.data import DataLoader, TensorDataset
from dsalt.model import DSALTLMHeadModel
from dsalt.training import DSALTTrainer

vocab_size = 32000
seq_len = 512
train_dataset = TensorDataset(
    torch.randint(0, vocab_size, (1000, seq_len))
)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

model = DSALTLMHeadModel(
    vocab_size=vocab_size,
    d_model=768,
    n_layers=12,
    n_heads=12,
    n_min=32,
    n_max=256,
    k_lmk=32,
)

trainer = DSALTTrainer(
    model=model,
    train_loader=train_loader,
    lr=3e-4,
    total_steps=10_000,
    save_dir="checkpoints",
    dtype=torch.bfloat16,
    log_every=50,
)
trainer.train()

3. Multi‑GPU with DataParallel

import torch
import torch.nn as nn
from dsalt.model import DSALTLMHeadModel
from dsalt.training import DSALTTrainer

model = DSALTLMHeadModel(...).to("cuda")
model = nn.DataParallel(model)  # uses all available GPUs

trainer = DSALTTrainer(
    model=model,
    train_loader=train_loader,
    lr=3e-4,
    total_steps=100_000,
    dtype=torch.bfloat16,
    save_dir="checkpoints",
)
trainer.train()

4. Multi‑GPU with FSDP (model sharding)

torchrun --nproc_per_node=2 train.py

Then configure the trainer with fsdp=True.


🏗️ Architecture Overview

DSALT combines local causal windows (adaptive per token) with global landmark tokens (top‑k per head):

┌─ Local window (adaptive) ──┬─ Global landmarks ──┐
│ Recent N tokens            │ Top‑K informative │
│ (window size grows)        │ tokens per head   │
└────────────────────────────┴────────────────────┘
                ↓                     ↓
            Sparse attention output

Key components:

  1. DSALTAttention – multi‑head sparse attention with adaptive windows and landmark selection.
  2. WindowSizePredictor – learns per‑token window sizes.
  3. HybridEnergyScorer (kernel) – computes landmark scores.
  4. DSALTTransformer – stack of attention + feed‑forward layers.
  5. Triton kernels – fused forward and backward passes for speed and memory efficiency.

🎯 Training & Generation

See the code snippets above for full training loops. The DSALTTrainer handles:

  • Mixed‑precision (BF16 default)
  • Gradient checkpointing
  • Learning‑rate warm‑up and cosine decay
  • Optional window‑entropy regularisation (window_reg_coef)
  • Checkpointing and logging utilities

📚 API Reference (excerpt)

from dsalt.model import DSALTLMHeadModel
model = DSALTLMHeadModel(vocab_size=32000, d_model=1024, n_layers=24,
                         n_heads=16, n_min=32, n_max=512, k_lmk=64)
logits, windows = model(input_ids, return_window=True)

Low‑level kernel call:

from dsalt.kernels import dsalt_attention
out = dsalt_attention(Q, K, V, window_sizes, landmark_idx)

📊 Performance & Benchmarks (May 2026)

Attention type Approx. memory (GB) Relative speed
Dense (O(N²)) ~3.5 1.0×
FlashAttention 2 ~1.8 0.5×
DSALT ~0.6 0.17×

📖 Hyperparameter Guide

All hyperparameters are documented in FEATURE.md. Typical configurations are provided for:

  • Mobile / Edge – tiny models, low memory.
  • Consumer GPU – e.g., RTX 4090, 24 GB.
  • Enterprise – H100 80 GB, optional FSDP.
  • Research – multi‑node, large models.

🧪 Testing

make test-cov          # Full test suite with coverage report
pytest tests/ -v       # Run tests directly

Key test modules:

  • tests/test_sparse_attn.py – kernel equivalence and backward.
  • tests/test_hybrid_energy.py – landmark scoring.
  • tests/test_dsalt_lm.py – language‑model wrapper.
  • tests/test_main.py – end‑to‑end smoke test.

📄 License

See here: https://github.com/LeonardoCofone/dsalt-pytorch/blob/main/LICENSE


🤝 Contributing

Contributions are welcome! Please read CONTRIBUTING.md for guidelines. Areas where help is especially valuable:

  • Triton kernel optimisation
  • New model architectures (encoder, encoder‑decoder)
  • Additional training strategies and samplers
  • Documentation and tutorials
  • Bug reports and fixes

📞 Support & Questions


Last Updated: May 2026

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dsalt-0.2.8.tar.gz (38.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

dsalt-0.2.8-py3-none-any.whl (32.2 kB view details)

Uploaded Python 3

File details

Details for the file dsalt-0.2.8.tar.gz.

File metadata

  • Download URL: dsalt-0.2.8.tar.gz
  • Upload date:
  • Size: 38.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.0

File hashes

Hashes for dsalt-0.2.8.tar.gz
Algorithm Hash digest
SHA256 267d97ea3f971f331ab5b2c94fbf232a13aa735d2a43bbc979c595b7f0213e6a
MD5 607a0101bbd8805b6122a22720270cc3
BLAKE2b-256 1a71e48017b3c4ac4b80d9fac9deaf99a4d883bb54869a62fc55617a76b093ee

See more details on using hashes here.

File details

Details for the file dsalt-0.2.8-py3-none-any.whl.

File metadata

  • Download URL: dsalt-0.2.8-py3-none-any.whl
  • Upload date:
  • Size: 32.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.0

File hashes

Hashes for dsalt-0.2.8-py3-none-any.whl
Algorithm Hash digest
SHA256 dc055f71e9d917b680b2fd67b7800f81558e0e2a9fbaa1b5603d48e00913f409
MD5 9fea652c51f12a7845bb6153242b60ce
BLAKE2b-256 0978b9d30d9d10218ecaf0c925e0436667d8234e639f6e9f805e1af20ea530b6

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page