Skip to main content

Dynamic Sparse Attention with Landmark Tokens - High-performance Triton implementation

Project description

DSALT: Dynamic Sparse Attention with Landmark Tokens

PyPI License

DSALT is a PyTorch library implementing Dynamic Sparse Attention with Landmark Tokens, a memory-efficient attention mechanism for transformers. Each query attends to an adaptive local causal window plus a small set of global landmark tokens, instead of the full O(Nยฒ) set. On CUDA it runs custom Triton kernels; everywhere else it falls back to a masked SDPA path, so the package stays importable and runnable on any platform (including CPU and Windows).

Install: pip install dsalt
Source: https://github.com/LeonardoCofone/dsalt-library
Paper: https://zenodo.org/records/19312826

โœจ Key Features

  • Sparse attention, local adaptive window โˆช top-k landmark tokens per head.
  • GPU-portable, Triton kernels on CUDA, transparent SDPA fallback otherwise; correct AMP autodetect across GPU generations (bf16 only where natively supported, fp16 on T4-class cards).
  • One-shot autotune, Triton block sizes are benchmarked once per (head_dim, GPU) at the first launch, then reused for the whole run; heuristic fallback if benchmarking is impossible.
  • Packed-sequence training, concatenated sequences + cu_seqlens, fused forward/backward with online softmax.
  • Fused cross-entropy, optional Liger fused linear cross-entropy, or a memory-frugal chunked pure-PyTorch loss.
  • DDP training, single- and multi-GPU via DistributedDataParallel, gradient accumulation, cosine schedule with warm-up, checkpointing, and rich representation-health diagnostics.

๐Ÿ“‹ Table of Contents


๐Ÿ› ๏ธ Installation

Requirements

  • Python 3.10+ (the codebase uses X | None / tuple[...] syntax)
  • PyTorch 2.0+
  • CUDA 11.0+ for the GPU path (CPU fallback always available)
  • Triton 2.0+ (optional; enables the GPU kernels, Linux/CUDA)

From PyPI

pip install dsalt                 # core
pip install "dsalt[triton]"       # + Triton GPU kernels
pip install "dsalt[dev]"          # + lint/type/test tooling

From source

git clone https://github.com/LeonardoCofone/dsalt-library.git
cd dsalt-library
pip install -e .

๐Ÿš€ Quick Start

Inference

import torch
from dsalt.model import DSALTLMHeadModel

model = DSALTLMHeadModel(
    vocab_size=32000,
    d_model=1024,
    n_layers=24,
    n_heads=16,
    n_min=32,
    n_max=512,
    k_lmk=64,
    max_seq_len=2048,   # required
)

input_ids = torch.randint(0, 32000, (1, 1024))   # [batch, seq_len]
out = model(input_ids)                           # dict
logits = out["logits"]                           # [1, 1024, 32000]
print(logits.shape)

Computing the loss

The forward computes the loss internally (fused) when labels are given, and returns a dict {"loss", "logits", "aux_loss"}:

labels = torch.randint(0, 32000, (1, 1024))
out  = model(input_ids, labels=labels)
loss = out["loss"]          # logits is None here (loss is fused)
loss.backward()

Building from a config

from dsalt.model import DSALTConfig, DSALTLMHeadModel

cfg = DSALTConfig(
    vocab_size=50257, d_model=512, n_layers=6, n_heads=8,
    n_min=64, n_max=256, k_lmk=16, max_seq_len=1024,
)
model = DSALTLMHeadModel.from_config(cfg)
cfg.save("config.json")               # reload with DSALTConfig.load(...)

๐Ÿ—๏ธ Architecture Overview

DSALT combines a per-token local causal window with global landmark tokens selected per head:

โ”Œโ”€ Local window (adaptive) โ”€โ”€โ”ฌโ”€ Global landmarks โ”€โ”€โ”
โ”‚  Recent tokens up to       โ”‚  Top-k informative  โ”‚
โ”‚  window size               โ”‚  tokens per head    โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
                 โ†“                      โ†“
              Sparse attention output  (W(i) โˆช L(i))

Components:

  1. DSALTAttention, multi-head sparse attention over W(i) โˆช L(i) with RoPE/YaRN positions.
  2. hybrid_scores_per_head, the single source of the hybrid-energy landmark score (ยง4.3), shared by both the SDPA path and the Triton kernel.
  3. DSALTTransformerBlock / SwiGLUFFN, pre-norm block with a gated SwiGLU FFN.
  4. DSALTLMHeadModel, embeddings + block stack + RMSNorm + (tied) LM head.
  5. Triton kernels, fused forward (dsalt_triton_attention) and backward with online softmax and one-shot autotuned block sizes.

Note. The local window is frozen to (n_min + n_max) // 2 in this release (no learnable window predictor). The learned adaptivity is the per-head


๐ŸŽฏ Training

DSALTTrainer drives single- and multi-GPU (DDP) training. It expects packed batches: (input_ids, labels, cu_seqlens, max_seqlen), where cu_seqlens is an int32 offset tensor of shape [num_seqs + 1] and -100 labels are ignored.

from dsalt.model import DSALTLMHeadModel
from dsalt.training import DSALTTrainer

model = DSALTLMHeadModel(
    vocab_size=32000, d_model=768, n_layers=12, n_heads=12,
    n_min=32, n_max=256, k_lmk=32, max_seq_len=1024,
)

trainer = DSALTTrainer(
    model=model,
    train_loader=train_loader,   # yields (ids, labels, cu_seqlens, max_seqlen)
    val_loader=val_loader,
    lr=3e-4,
    total_steps=10_000,
    warmup_steps=1_000,
    mixed_precision="auto",      # bf16 on sm_80+, fp16 on T4-class, none on CPU
    save_dir="./checkpoints_dsalt",
    log_every=100,
)
trainer.train()

Mixed precision

mixed_precision="auto" selects the dtype from the GPU's compute capability: bf16 on sm_80+ (A100/H100/L4/โ€ฆ), fp16 (with a GradScaler) below that (e.g. T4 sm_75), and no autocast on CPU. You can force "bf16", "fp16", or "none" explicitly.

Multi-GPU (DDP)

Launch one process per GPU and pass the distributed identity through; the trainer wraps the model in DistributedDataParallel when world_size > 1:

torchrun --nproc_per_node=2 your_train_script.py
trainer = DSALTTrainer(
    model=model, train_loader=train_loader, val_loader=val_loader,
    rank=rank, local_rank=local_rank, world_size=world_size,
    ddp_backend="nccl", total_steps=100_000,
)
trainer.train()

Only DDP is supported in this release (no FSDP). The trainer also handles gradient accumulation, gradient clipping, cosine LR decay with warm-up, checkpointing (checkpoint_best/step_N/final.pt), and per-layer representation-health metrics. Resume with trainer.load_checkpoint(path).


๐Ÿ“š API Reference

# Top-level exports
from dsalt import (
    DSALTConfig, DSALTLMHeadModel,
    DSALTAttention, DSALTTransformerBlock, SwiGLUFFN,
    DSALTTrainer,
    dsalt_triton_attention,            # None when Triton is unavailable
    hybrid_scores_per_head,            # single source of the landmark score
    sparse_attention_forward, sparse_attention_forward_packed,
    RMSENorm, compute_window_sizes, apply_rotary_emb, build_rope_cache,
)

# Low-level Triton kernel (packed sequences, CUDA + Triton only)
from dsalt.kernels import dsalt_triton_attention
out = dsalt_triton_attention(q, k, v, lmk_indices, lmk_bias, w_sizes, cu_seqlens)

q, k, v are [total_len, n_heads, head_dim]; cu_seqlens is the int32 sequence-offset tensor. See FEATURE.md for the full signature and semantics of every component.


๐Ÿ“– Hyperparameter Guide

Full, source-verified defaults for DSALTLMHeadModel, DSALTConfig, DSALTAttention, and DSALTTrainer live in FEATURE.md. Highlights:

Component Required Notable defaults
DSALTLMHeadModel vocab_size, d_model, n_layers, n_heads, n_min, n_max, k_lmk, max_seq_len d_ff=None (โ†’ 8/3ยทd_model), loss_fn="chunked", tie_weights=True, yarn_scale=1.0
DSALTTrainer model, train_loader, val_loader lr=3e-4, max_grad_norm=0.5, warmup_steps=1000, mixed_precision="auto"

alpha is a learnable per-head parameter (init sigmoid โ‰ˆ 0.6), not a constructor flag. The auxiliary loss term is inert in this release (frozen window) and kept only for signature compatibility.


๐Ÿ“„ License

Apache 2.0, see https://github.com/LeonardoCofone/dsalt-library/blob/main/LICENSE.


๐Ÿค Contributing

Contributions are welcome, see CONTRIBUTING.md. Especially valuable: Triton kernel optimisation, new architectures (encoder / encoder-decoder), additional training strategies, documentation, and bug fixes.


๐Ÿ“ Citation

If you use DSALT in your research, please cite the paper:

@software{dsalt,
  author  = {Cofone, Leonardo},
  title   = {DSALT: Dynamic Sparse Attention with Landmark Tokens},
  url      = {https://github.com/LeonardoCofone/dsalt-library},
  note    = {https://zenodo.org/records/19312826},
}

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dsalt-0.4.11.tar.gz (71.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

dsalt-0.4.11-py3-none-any.whl (75.9 kB view details)

Uploaded Python 3

File details

Details for the file dsalt-0.4.11.tar.gz.

File metadata

  • Download URL: dsalt-0.4.11.tar.gz
  • Upload date:
  • Size: 71.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.0

File hashes

Hashes for dsalt-0.4.11.tar.gz
Algorithm Hash digest
SHA256 cae6e5366fa6c3db79c7a8911cfff212fd788f8ce74a99b7d42b1a57389bdc4f
MD5 1b16dc86b98012ab410c204fa9a92f38
BLAKE2b-256 5b9329c7b91153761551f68c2ac334f7f2aab26e2ba3b61c8d7c8b84cb8e71b2

See more details on using hashes here.

File details

Details for the file dsalt-0.4.11-py3-none-any.whl.

File metadata

  • Download URL: dsalt-0.4.11-py3-none-any.whl
  • Upload date:
  • Size: 75.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.0

File hashes

Hashes for dsalt-0.4.11-py3-none-any.whl
Algorithm Hash digest
SHA256 f72f37ab71324b5b70bbed3fe85334d0d560dab2af92a1490eff1d76b864ed11
MD5 eff8cac237840b2a49e9dac999bbff0e
BLAKE2b-256 5430b1ad5e17d0a545186d2c0496a615e32a462552746bac7907192dadfe6024

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page