Skip to main content

Dynamic Sparse Attention with Landmark Tokens - High-performance Triton implementation

Project description

DSALT: Dynamic Sparse Attention with Landmark Tokens

PyPI License

DSALT is a PyTorch library implementing Dynamic Sparse Attention with Landmark Tokens, a memory-efficient attention mechanism for transformers. Each query attends to an adaptive local causal window plus a small set of global landmark tokens, instead of the full O(Nยฒ) set. On CUDA it runs custom Triton kernels; everywhere else it falls back to a masked SDPA path, so the package stays importable and runnable on any platform (including CPU and Windows).

Install: pip install dsalt
Source: https://github.com/LeonardoCofone/dsalt-library
Paper: https://zenodo.org/records/19312826

โœจ Key Features

  • Sparse attention, local adaptive window โˆช top-k landmark tokens per head.
  • GPU-portable, Triton kernels on CUDA, transparent SDPA fallback otherwise; correct AMP autodetect across GPU generations (bf16 only where natively supported, fp16 on T4-class cards).
  • One-shot autotune, Triton block sizes are benchmarked once per (head_dim, GPU) at the first launch, then reused for the whole run; heuristic fallback if benchmarking is impossible.
  • Packed-sequence training, concatenated sequences + cu_seqlens, fused forward/backward with online softmax.
  • Fused cross-entropy, optional Liger fused linear cross-entropy, or a memory-frugal chunked pure-PyTorch loss.
  • DDP training, single- and multi-GPU via DistributedDataParallel, gradient accumulation, cosine schedule with warm-up, checkpointing, and rich representation-health diagnostics.

๐Ÿ“‹ Table of Contents


๐Ÿ› ๏ธ Installation

Requirements

  • Python 3.10+ (the codebase uses X | None / tuple[...] syntax)
  • PyTorch 2.0+
  • CUDA 11.0+ for the GPU path (CPU fallback always available)
  • Triton 2.0+ (optional; enables the GPU kernels, Linux/CUDA)

From PyPI

pip install dsalt                 # core
pip install "dsalt[triton]"       # + Triton GPU kernels
pip install "dsalt[dev]"          # + lint/type/test tooling

From source

git clone https://github.com/LeonardoCofone/dsalt-library.git
cd dsalt-library
pip install -e .

๐Ÿš€ Quick Start

Inference

import torch
from dsalt.model import DSALTLMHeadModel

model = DSALTLMHeadModel(
    vocab_size=32000,
    d_model=1024,
    n_layers=24,
    n_heads=16,
    n_min=32,
    n_max=512,
    k_lmk=64,
    max_seq_len=2048,   # required
)

input_ids = torch.randint(0, 32000, (1, 1024))   # [batch, seq_len]
out = model(input_ids)                           # dict
logits = out["logits"]                           # [1, 1024, 32000]
print(logits.shape)

Computing the loss

The forward computes the loss internally (fused) when labels are given, and returns a dict {"loss", "logits", "aux_loss"}:

labels = torch.randint(0, 32000, (1, 1024))
out  = model(input_ids, labels=labels)
loss = out["loss"]          # logits is None here (loss is fused)
loss.backward()

Building from a config

from dsalt.model import DSALTConfig, DSALTLMHeadModel

cfg = DSALTConfig(
    vocab_size=50257, d_model=512, n_layers=6, n_heads=8,
    n_min=64, n_max=256, k_lmk=16, max_seq_len=1024,
)
model = DSALTLMHeadModel.from_config(cfg)
cfg.save("config.json")               # reload with DSALTConfig.load(...)

๐Ÿ—๏ธ Architecture Overview

DSALT combines a per-token local causal window with global landmark tokens selected per head:

โ”Œโ”€ Local window (adaptive) โ”€โ”€โ”ฌโ”€ Global landmarks โ”€โ”€โ”
โ”‚  Recent tokens up to       โ”‚  Top-k informative  โ”‚
โ”‚  window size               โ”‚  tokens per head    โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
                 โ†“                      โ†“
              Sparse attention output  (W(i) โˆช L(i))

Components:

  1. DSALTAttention, multi-head sparse attention over W(i) โˆช L(i) with RoPE/YaRN positions.
  2. hybrid_scores_per_head, the single source of the hybrid-energy landmark score (ยง4.3), shared by both the SDPA path and the Triton kernel.
  3. DSALTTransformerBlock / SwiGLUFFN, pre-norm block with a gated SwiGLU FFN.
  4. DSALTLMHeadModel, embeddings + block stack + RMSNorm + (tied) LM head.
  5. Triton kernels, fused forward (dsalt_triton_attention) and backward with online softmax and one-shot autotuned block sizes.

Note. The local window is frozen to (n_min + n_max) // 2 in this release (no learnable window predictor). The learned adaptivity is the per-head


๐ŸŽฏ Training

DSALTTrainer drives single- and multi-GPU (DDP) training. It expects packed batches: (input_ids, labels, cu_seqlens, max_seqlen), where cu_seqlens is an int32 offset tensor of shape [num_seqs + 1] and -100 labels are ignored.

from dsalt.model import DSALTLMHeadModel
from dsalt.training import DSALTTrainer

model = DSALTLMHeadModel(
    vocab_size=32000, d_model=768, n_layers=12, n_heads=12,
    n_min=32, n_max=256, k_lmk=32, max_seq_len=1024,
)

trainer = DSALTTrainer(
    model=model,
    train_loader=train_loader,   # yields (ids, labels, cu_seqlens, max_seqlen)
    val_loader=val_loader,
    lr=3e-4,
    total_steps=10_000,
    warmup_steps=1_000,
    mixed_precision="auto",      # bf16 on sm_80+, fp16 on T4-class, none on CPU
    save_dir="./checkpoints_dsalt",
    log_every=100,
)
trainer.train()

Mixed precision

mixed_precision="auto" selects the dtype from the GPU's compute capability: bf16 on sm_80+ (A100/H100/L4/โ€ฆ), fp16 (with a GradScaler) below that (e.g. T4 sm_75), and no autocast on CPU. You can force "bf16", "fp16", or "none" explicitly.

Multi-GPU (DDP)

Launch one process per GPU and pass the distributed identity through; the trainer wraps the model in DistributedDataParallel when world_size > 1:

torchrun --nproc_per_node=2 your_train_script.py
trainer = DSALTTrainer(
    model=model, train_loader=train_loader, val_loader=val_loader,
    rank=rank, local_rank=local_rank, world_size=world_size,
    ddp_backend="nccl", total_steps=100_000,
)
trainer.train()

Only DDP is supported in this release (no FSDP). The trainer also handles gradient accumulation, gradient clipping, cosine LR decay with warm-up, checkpointing (checkpoint_best/step_N/final.pt), and per-layer representation-health metrics. Resume with trainer.load_checkpoint(path).


๐Ÿ“š API Reference

# Top-level exports
from dsalt import (
    DSALTConfig, DSALTLMHeadModel,
    DSALTAttention, DSALTTransformerBlock, SwiGLUFFN,
    DSALTTrainer,
    dsalt_triton_attention,            # None when Triton is unavailable
    hybrid_scores_per_head,            # single source of the landmark score
    sparse_attention_forward, sparse_attention_forward_packed,
    RMSENorm, compute_window_sizes, apply_rotary_emb, build_rope_cache,
)

# Low-level Triton kernel (packed sequences, CUDA + Triton only)
from dsalt.kernels import dsalt_triton_attention
out = dsalt_triton_attention(q, k, v, lmk_indices, lmk_bias, w_sizes, cu_seqlens)

q, k, v are [total_len, n_heads, head_dim]; cu_seqlens is the int32 sequence-offset tensor. See FEATURE.md for the full signature and semantics of every component.


๐Ÿ“– Hyperparameter Guide

Full, source-verified defaults for DSALTLMHeadModel, DSALTConfig, DSALTAttention, and DSALTTrainer live in FEATURE.md. Highlights:

Component Required Notable defaults
DSALTLMHeadModel vocab_size, d_model, n_layers, n_heads, n_min, n_max, k_lmk, max_seq_len d_ff=None (โ†’ 8/3ยทd_model), loss_fn="chunked", tie_weights=True, yarn_scale=1.0
DSALTTrainer model, train_loader, val_loader lr=3e-4, max_grad_norm=0.5, warmup_steps=1000, mixed_precision="auto"

alpha is a learnable per-head parameter (init sigmoid โ‰ˆ 0.6), not a constructor flag. The auxiliary loss term is inert in this release (frozen window) and kept only for signature compatibility.


๐Ÿ“„ License

Apache 2.0, see https://github.com/LeonardoCofone/dsalt-library/blob/main/LICENSE.


๐Ÿค Contributing

Contributions are welcome, see CONTRIBUTING.md. Especially valuable: Triton kernel optimisation, new architectures (encoder / encoder-decoder), additional training strategies, documentation, and bug fixes.


๐Ÿ“ Citation

If you use DSALT in your research, please cite the paper:

@software{dsalt,
  author  = {Cofone, Leonardo},
  title   = {DSALT: Dynamic Sparse Attention with Landmark Tokens},
  url      = {https://github.com/LeonardoCofone/dsalt-library},
  note    = {https://zenodo.org/records/19312826},
}

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dsalt-0.3.91.tar.gz (53.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

dsalt-0.3.91-py3-none-any.whl (54.8 kB view details)

Uploaded Python 3

File details

Details for the file dsalt-0.3.91.tar.gz.

File metadata

  • Download URL: dsalt-0.3.91.tar.gz
  • Upload date:
  • Size: 53.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.0

File hashes

Hashes for dsalt-0.3.91.tar.gz
Algorithm Hash digest
SHA256 eeeeeb4d474fc909a62fde46ed37ea50a0ebc22cb98aade537790938d3e4602c
MD5 039d5466c16c10fe6e0bc1730d80ce67
BLAKE2b-256 94e0df5192827b6bd6f7314a8c233c53b7e6e90f9ed3701af8ef4d1b4e020116

See more details on using hashes here.

File details

Details for the file dsalt-0.3.91-py3-none-any.whl.

File metadata

  • Download URL: dsalt-0.3.91-py3-none-any.whl
  • Upload date:
  • Size: 54.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.0

File hashes

Hashes for dsalt-0.3.91-py3-none-any.whl
Algorithm Hash digest
SHA256 ee1a0fc344f1b65c2169ce88e230d9e2aa84a3bcd426119c2931e4d77bd88f24
MD5 187c419fd6da836b0938d7a7e988f0be
BLAKE2b-256 4ed8f16f1c24bbb7c2f29b24f0a27d322e46bcfaddc47b0e0a73001864b61175

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page