Skip to main content

Dynamic Sparse Attention with Landmark Tokens - High-performance Triton implementation

Project description

DSALT: Dynamic Sparse Attention with Landmark Tokens

PyPI License

DSALT is a PyTorch library implementing Dynamic Sparse Attention with Landmark Tokens, a memory-efficient attention mechanism for transformers. Each query attends to an adaptive local causal window plus a small set of global landmark tokens, instead of the full O(N²) set. On CUDA it runs custom Triton kernels; everywhere else it falls back to a masked SDPA path, so the package stays importable and runnable on any platform (including CPU and Windows).

Install: pip install dsalt
Source: https://github.com/LeonardoCofone/dsalt-library
Paper: https://zenodo.org/records/19312826

✨ Key Features

  • Sparse attention, adaptive local causal window top-k landmark tokens per head (A(i) = W(i) ∪ L(i)).
  • Fully differentiable selectors, the hard window/landmark selection stays non-differentiable, but the gradient still reaches both predictors through soft weights: a soft window edge trains the per-token window size, and a soft landmark re-weight trains the per-head balance α.
  • GPU-portable, Triton kernels on CUDA, transparent SDPA fallback otherwise; AMP dtype is auto-selected from the GPU's compute capability (bf16 on sm_80+, fp16 on T4-class cards, none on CPU).
  • One-shot autotune, Triton block sizes are benchmarked once per (head_dim, GPU) at the first launch, then reused for the whole run; portable heuristics if benchmarking is impossible.
  • Packed-sequence training, concatenated sequences + cu_seqlens, fused FlashAttention-2-style forward/backward with online softmax and a key-parallel, atomic-free dk/dv backward.
  • Flexible loss, memory-frugal chunked cross-entropy (default), optional Liger fused linear cross-entropy, or "auto" to pick the fastest per GPU.
  • DDP + torch.compile, single- and multi-GPU via DistributedDataParallel, co-existing with torch.compile; gradient accumulation, cosine schedule with warm-up, checkpointing, and rich representation-health diagnostics.

📋 Table of Contents


🛠️ Installation

Requirements

  • Python 3.10+ (the codebase uses X | None / tuple[...] syntax)
  • PyTorch 2.0+ (the only required dependency)
  • CUDA 11.0+ for the GPU path (CPU fallback always available)
  • Triton 2.0+ (optional; enables the GPU kernels, Linux/CUDA)

From PyPI

pip install dsalt                 # core (torch only)
pip install "dsalt[triton]"       # + Triton GPU kernels
pip install "dsalt[dev]"          # + lint/type/test tooling
pip install "dsalt[all]"          # everything

From source

git clone https://github.com/LeonardoCofone/dsalt-library.git
cd dsalt-library
pip install -e .

🚀 Quick Start

Inference

import torch
from dsalt.model import DSALTLMHeadModel

model = DSALTLMHeadModel(
    vocab_size=32000,
    d_model=1024,
    n_layers=24,
    n_heads=16,
    n_min=32,
    n_max=512,
    k_lmk=64,
    max_seq_len=2048,   # required: sizes the RoPE cache
)

input_ids = torch.randint(0, 32000, (1, 1024))   # [batch, seq_len]
out = model(input_ids)                           # dict
logits = out["logits"]                           # [1, 1024, 32000]
print(logits.shape)

Computing the loss

When labels are given the forward computes the loss internally (fused), and returns a dict {"loss", "logits", "aux_loss"} (here logits is None because the loss is fused to save memory):

labels = torch.randint(0, 32000, (1, 1024))
out  = model(input_ids, labels=labels)
loss = out["loss"]
loss.backward()

Building from a config

from dsalt.model import DSALTConfig, DSALTLMHeadModel

cfg = DSALTConfig(
    vocab_size=50257, d_model=512, n_layers=6, n_heads=8,
    n_min=64, n_max=256, k_lmk=16, max_seq_len=1024,
)
model = DSALTLMHeadModel.from_config(cfg)
cfg.save("config.json")               # reload with DSALTConfig.load("config.json")

🏗️ Architecture Overview

Each query's attention set is the union of two sparse sets:

┌─ Local window W(i) (adaptive) ─┬─ Global landmarks L(i) ─┐
│  Recent causal tokens up to    │  Top-k informative      │
│  a per-token learned size      │  tokens per head        │
└────────────────────────────────┴─────────────────────────┘
                 ↓                            ↓
              Sparse attention output  over  A(i) = W(i) ∪ L(i)
  • Adaptive window (§4.2). A small learned projection win_gate predicts a per-token continuous window w̃(i) = n_min + σ(f(x_i))·(n_max − n_min) from the block input. The window core is a hard mask (so the cost stays sub-quadratic), but a thin differentiable band at the boundary lets gradients train win_gate.
  • Landmark tokens (§4.3). A per-head hybrid-energy score s = α·z(‖x·W_V‖₂) + (1−α)·z(‖x‖₂) ranks tokens; the top-k are admitted as landmarks. The selection is hard (detached), while a soft re-weight on the admitted tokens' logits trains the per-head balance α = σ(α̃).
  • Blocks. Pre-norm DSALTTransformerBlock = RMSNorm → DSALTAttention (with RoPE/YaRN positions) → residual, then RMSNorm → SwiGLUFFN → residual.
  • Model. DSALTLMHeadModel = token embeddings → block stack → RMSNorm → (optionally tied) LM head, with a fused/chunked cross-entropy loss.
  • Kernels. On CUDA, fused Triton forward/backward with online softmax and one-shot autotuned block sizes; a masked-SDPA path mirrors the exact same math on CPU / no-Triton environments and serves as the correctness reference.

For the engineering rationale (differentiable approximations, DDP + torch.compile graph integrity, the key-parallel backward, and profiling evidence) see DESIGN_NOTES.md.


🎯 Training

DSALTTrainer drives single- and multi-GPU (DDP) training. It expects packed batches (input_ids, labels, cu_seqlens, max_seqlen), where cu_seqlens is an int32 offset tensor of shape [num_seqs + 1] and -100 labels are ignored.

from dsalt.model import DSALTLMHeadModel
from dsalt.training import DSALTTrainer

model = DSALTLMHeadModel(
    vocab_size=32000, d_model=768, n_layers=12, n_heads=12,
    n_min=32, n_max=256, k_lmk=32, max_seq_len=1024,
)

trainer = DSALTTrainer(
    model=model,
    train_loader=train_loader,   # yields (ids, labels, cu_seqlens, max_seqlen)
    val_loader=val_loader,
    lr=3e-4,
    total_steps=10_000,
    warmup_steps=1_000,
    mixed_precision="auto",      # bf16 on sm_80+, fp16 on T4-class, none on CPU
    save_dir="./checkpoints_dsalt",
    log_every=100,
)
trainer.train()

Multi-GPU (DDP). Launch one process per GPU; the trainer wraps the model in DistributedDataParallel when world_size > 1, and (optionally) applies torch.compile after the DDP wrap:

torchrun --nproc_per_node=2 your_train_script.py

Only DDP is supported in this release (no FSDP). The trainer handles gradient accumulation, gradient clipping, cosine LR decay with warm-up, AMP autodetect, checkpointing (checkpoint_best.pt / checkpoint_step_<n>.pt / checkpoint_final.pt), and per-layer representation-health metrics. Resume with trainer.load_checkpoint(path).

Every constructor argument, default, and the full metric list are documented in FEATURE.md.


📚 API Reference

# Top-level exports
from dsalt import (
    DSALTConfig, DSALTLMHeadModel,
    DSALTAttention, DSALTTransformerBlock, SwiGLUFFN,
    DSALTTrainer,
    dsalt_triton_attention,            # None when Triton is unavailable
    hybrid_scores_per_head,            # single source of the landmark score (§4.3)
    compute_hybrid_scores, select_landmarks, soft_landmark_weights,
    HybridEnergyLandmarkSelector,
    sparse_attention_forward, sparse_attention_forward_packed,
    RMSENorm, compute_window_sizes, apply_rotary_emb, build_rope_cache,
    build_local_window_mask, build_local_window_mask_packed,
    LigerFusedLinearCrossEntropyFunction,
)

q, k, v for the low-level kernel are [total_len, n_heads, head_dim]; cu_seqlens is the int32 sequence-offset tensor. The complete, source-verified signature and semantics of every component live in FEATURE.md.


📖 Documentation


📄 License

Apache 2.0, see LICENSE.


🤝 Contributing

Contributions are welcome, see CONTRIBUTING.md. Especially valuable: Triton kernel optimisation, new architectures (encoder / encoder-decoder), additional training strategies, documentation, and bug fixes.


📝 Citation

If you use DSALT in your research, please cite the paper:

@software{dsalt,
  author  = {Cofone, Leonardo},
  title   = {DSALT: Dynamic Sparse Attention with Landmark Tokens},
  url     = {https://github.com/LeonardoCofone/dsalt-library},
  note    = {https://zenodo.org/records/19312826},
}

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dsalt-0.4.28.tar.gz (82.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

dsalt-0.4.28-py3-none-any.whl (87.2 kB view details)

Uploaded Python 3

File details

Details for the file dsalt-0.4.28.tar.gz.

File metadata

  • Download URL: dsalt-0.4.28.tar.gz
  • Upload date:
  • Size: 82.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.0

File hashes

Hashes for dsalt-0.4.28.tar.gz
Algorithm Hash digest
SHA256 1fe5fbfe37be63700c08052d60c0d9ac88def21a6ffe9954edd0a2a3fd4d65df
MD5 65c89199daae009bc5ed1aed1dca48a5
BLAKE2b-256 e9560814f47c93a5c986694055c61dea3503d53ab7fdce0d7915995366c80dd0

See more details on using hashes here.

File details

Details for the file dsalt-0.4.28-py3-none-any.whl.

File metadata

  • Download URL: dsalt-0.4.28-py3-none-any.whl
  • Upload date:
  • Size: 87.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.0

File hashes

Hashes for dsalt-0.4.28-py3-none-any.whl
Algorithm Hash digest
SHA256 6b08f7fbfbc8d836571fc1463bc498f85098c9ae207f228130ea48880d275d89
MD5 daa2603d4ffc481031257d161d48da95
BLAKE2b-256 742a26047aa0043ea421338b10699b82b1a690ea985743bc9de1fd7d590a57a0

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page