Skip to main content

Dynamic Sparse Attention with Landmark Tokens - High-performance Triton implementation

Project description

DSALT: Dynamic Sparse Attention with Landmark Tokens

PyPI License

DSALT is a PyTorch library implementing Dynamic Sparse Attention with Landmark Tokens, a memory-efficient attention mechanism for transformers. Each query attends to an adaptive local causal window plus a small set of global landmark tokens, instead of the full O(N²) set. On CUDA it runs custom Triton kernels; everywhere else it falls back to a masked SDPA path, so the package stays importable and runnable on any platform (including CPU and Windows).

Install: pip install dsalt
Source: https://github.com/LeonardoCofone/dsalt-library
Paper: https://zenodo.org/records/19312826

✨ Key Features

  • Sparse attention, adaptive local causal window top-k landmark tokens per head (A(i) = W(i) ∪ L(i)).
  • Fully differentiable selectors, the hard window/landmark selection stays non-differentiable, but the gradient still reaches both predictors through soft weights: a soft window edge trains the per-token window size, and a soft landmark re-weight trains the per-head balance α.
  • GPU-portable, Triton kernels on CUDA, transparent SDPA fallback otherwise; AMP dtype is auto-selected from the GPU's compute capability (bf16 on sm_80+, fp16 on T4-class cards, none on CPU).
  • One-shot autotune, Triton block sizes are benchmarked once per (head_dim, GPU) at the first launch, then reused for the whole run; portable heuristics if benchmarking is impossible.
  • Packed-sequence training, concatenated sequences + cu_seqlens, fused FlashAttention-2-style forward/backward with online softmax and a key-parallel, atomic-free dk/dv backward.
  • Flexible loss, memory-frugal chunked cross-entropy (default), optional Liger fused linear cross-entropy, or "auto" to pick the fastest per GPU.
  • DDP + torch.compile, single- and multi-GPU via DistributedDataParallel, co-existing with torch.compile; gradient accumulation, cosine schedule with warm-up, checkpointing, and rich representation-health diagnostics.

📋 Table of Contents


🛠️ Installation

Requirements

  • Python 3.10+ (the codebase uses X | None / tuple[...] syntax)
  • PyTorch 2.0+ (the only required dependency)
  • CUDA 11.0+ for the GPU path (CPU fallback always available)
  • Triton 2.0+ (optional; enables the GPU kernels, Linux/CUDA)

From PyPI

pip install dsalt                 # core (torch only)
pip install "dsalt[triton]"       # + Triton GPU kernels
pip install "dsalt[dev]"          # + lint/type/test tooling
pip install "dsalt[all]"          # everything

From source

git clone https://github.com/LeonardoCofone/dsalt-library.git
cd dsalt-library
pip install -e .

🚀 Quick Start

Inference

import torch
from dsalt.model import DSALTLMHeadModel

model = DSALTLMHeadModel(
    vocab_size=32000,
    d_model=1024,
    n_layers=24,
    n_heads=16,
    n_min=32,
    n_max=512,
    k_lmk=64,
    max_seq_len=2048,   # required: sizes the RoPE cache
)

input_ids = torch.randint(0, 32000, (1, 1024))   # [batch, seq_len]
out = model(input_ids)                           # dict
logits = out["logits"]                           # [1, 1024, 32000]
print(logits.shape)

Computing the loss

When labels are given the forward computes the loss internally (fused), and returns a dict {"loss", "logits", "aux_loss"} (here logits is None because the loss is fused to save memory):

labels = torch.randint(0, 32000, (1, 1024))
out  = model(input_ids, labels=labels)
loss = out["loss"]
loss.backward()

Building from a config

from dsalt.model import DSALTConfig, DSALTLMHeadModel

cfg = DSALTConfig(
    vocab_size=50257, d_model=512, n_layers=6, n_heads=8,
    n_min=64, n_max=256, k_lmk=16, max_seq_len=1024,
)
model = DSALTLMHeadModel.from_config(cfg)
cfg.save("config.json")               # reload with DSALTConfig.load("config.json")

🏗️ Architecture Overview

Each query's attention set is the union of two sparse sets:

┌─ Local window W(i) (adaptive) ─┬─ Global landmarks L(i) ─┐
│  Recent causal tokens up to    │  Top-k informative      │
│  a per-token learned size      │  tokens per head        │
└────────────────────────────────┴─────────────────────────┘
                 ↓                            ↓
              Sparse attention output  over  A(i) = W(i) ∪ L(i)
  • Adaptive window (§4.2). A small learned projection win_gate predicts a per-token continuous window w̃(i) = n_min + σ(f(x_i))·(n_max − n_min) from the block input. The window core is a hard mask (so the cost stays sub-quadratic), but a thin differentiable band at the boundary lets gradients train win_gate.
  • Landmark tokens (§4.3). A per-head hybrid-energy score s = α·z(‖x·W_V‖₂) + (1−α)·z(‖x‖₂) ranks tokens; the top-k are admitted as landmarks. The selection is hard (detached), while a soft re-weight on the admitted tokens' logits trains the per-head balance α = σ(α̃).
  • Blocks. Pre-norm DSALTTransformerBlock = RMSNorm → DSALTAttention (with RoPE/YaRN positions) → residual, then RMSNorm → SwiGLUFFN → residual.
  • Model. DSALTLMHeadModel = token embeddings → block stack → RMSNorm → (optionally tied) LM head, with a fused/chunked cross-entropy loss.
  • Kernels. On CUDA, fused Triton forward/backward with online softmax and one-shot autotuned block sizes; a masked-SDPA path mirrors the exact same math on CPU / no-Triton environments and serves as the correctness reference.

For the engineering rationale (differentiable approximations, DDP + torch.compile graph integrity, the key-parallel backward, and profiling evidence) see DESIGN_NOTES.md.


🎯 Training

DSALTTrainer drives single- and multi-GPU (DDP) training. It expects packed batches (input_ids, labels, cu_seqlens, max_seqlen), where cu_seqlens is an int32 offset tensor of shape [num_seqs + 1] and -100 labels are ignored.

from dsalt.model import DSALTLMHeadModel
from dsalt.training import DSALTTrainer

model = DSALTLMHeadModel(
    vocab_size=32000, d_model=768, n_layers=12, n_heads=12,
    n_min=32, n_max=256, k_lmk=32, max_seq_len=1024,
)

trainer = DSALTTrainer(
    model=model,
    train_loader=train_loader,   # yields (ids, labels, cu_seqlens, max_seqlen)
    val_loader=val_loader,
    lr=3e-4,
    total_steps=10_000,
    warmup_steps=1_000,
    mixed_precision="auto",      # bf16 on sm_80+, fp16 on T4-class, none on CPU
    save_dir="./checkpoints_dsalt",
    log_every=100,
)
trainer.train()

Multi-GPU (DDP). Launch one process per GPU; the trainer wraps the model in DistributedDataParallel when world_size > 1, and (optionally) applies torch.compile after the DDP wrap:

torchrun --nproc_per_node=2 your_train_script.py

Only DDP is supported in this release (no FSDP). The trainer handles gradient accumulation, gradient clipping, cosine LR decay with warm-up, AMP autodetect, checkpointing (checkpoint_best.pt / checkpoint_step_<n>.pt / checkpoint_final.pt), and per-layer representation-health metrics. Resume with trainer.load_checkpoint(path).

Every constructor argument, default, and the full metric list are documented in FEATURE.md.


📚 API Reference

# Top-level exports
from dsalt import (
    DSALTConfig, DSALTLMHeadModel,
    DSALTAttention, DSALTTransformerBlock, SwiGLUFFN,
    DSALTTrainer,
    dsalt_triton_attention,            # None when Triton is unavailable
    hybrid_scores_per_head,            # single source of the landmark score (§4.3)
    compute_hybrid_scores, select_landmarks, soft_landmark_weights,
    HybridEnergyLandmarkSelector,
    sparse_attention_forward, sparse_attention_forward_packed,
    RMSENorm, compute_window_sizes, apply_rotary_emb, build_rope_cache,
    build_local_window_mask, build_local_window_mask_packed,
    LigerFusedLinearCrossEntropyFunction,
)

q, k, v for the low-level kernel are [total_len, n_heads, head_dim]; cu_seqlens is the int32 sequence-offset tensor. The complete, source-verified signature and semantics of every component live in FEATURE.md.


📖 Documentation


📄 License

Apache 2.0, see LICENSE.


🤝 Contributing

Contributions are welcome, see CONTRIBUTING.md. Especially valuable: Triton kernel optimisation, new architectures (encoder / encoder-decoder), additional training strategies, documentation, and bug fixes.


📝 Citation

If you use DSALT in your research, please cite the paper:

@software{dsalt,
  author  = {Cofone, Leonardo},
  title   = {DSALT: Dynamic Sparse Attention with Landmark Tokens},
  url     = {https://github.com/LeonardoCofone/dsalt-library},
  note    = {https://zenodo.org/records/19312826},
}

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dsalt-0.4.25.tar.gz (82.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

dsalt-0.4.25-py3-none-any.whl (86.5 kB view details)

Uploaded Python 3

File details

Details for the file dsalt-0.4.25.tar.gz.

File metadata

  • Download URL: dsalt-0.4.25.tar.gz
  • Upload date:
  • Size: 82.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.0

File hashes

Hashes for dsalt-0.4.25.tar.gz
Algorithm Hash digest
SHA256 aef7e3e5618bd3837a39de8588538b7fe0b6d12044a0c1041703f334a7bfa71a
MD5 dea9ab96d45a2fab75fcc063aaa5f754
BLAKE2b-256 1edd1184162375961adf163445a909b48916137fb6ed66316acf90d46debf3da

See more details on using hashes here.

File details

Details for the file dsalt-0.4.25-py3-none-any.whl.

File metadata

  • Download URL: dsalt-0.4.25-py3-none-any.whl
  • Upload date:
  • Size: 86.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.0

File hashes

Hashes for dsalt-0.4.25-py3-none-any.whl
Algorithm Hash digest
SHA256 68713562307310436febcd4745c3e0123371b048a4e9b63e3b76546d516b7e9c
MD5 8706015db0ebbc8a6ba1543bb0659e15
BLAKE2b-256 361c02c1653d2346d78291e7a980439ee7d1fd30b27ed4341d011217f374b85f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page