Skip to main content

Dynamic Sparse Attention with Landmark Tokens - High-performance Triton implementation

Project description

DSALT: Dynamic Sparse Attention with Landmark Tokens

PyPI License

DSALT is a PyTorch library implementing Dynamic Sparse Attention with Landmark Tokens, a memory-efficient attention mechanism for transformers. Each query attends to an adaptive local causal window plus a small set of global landmark tokens, instead of the full O(Nยฒ) set. On CUDA it runs custom Triton kernels; everywhere else it falls back to a masked SDPA path, so the package stays importable and runnable on any platform (including CPU and Windows).

Install: pip install dsalt
Source: https://github.com/LeonardoCofone/dsalt-library
Paper: https://zenodo.org/records/19312826

โœจ Key Features

  • Sparse attention, local adaptive window โˆช top-k landmark tokens per head.
  • GPU-portable, Triton kernels on CUDA, transparent SDPA fallback otherwise; correct AMP autodetect across GPU generations (bf16 only where natively supported, fp16 on T4-class cards).
  • One-shot autotune, Triton block sizes are benchmarked once per (head_dim, GPU) at the first launch, then reused for the whole run; heuristic fallback if benchmarking is impossible.
  • Packed-sequence training, concatenated sequences + cu_seqlens, fused forward/backward with online softmax.
  • Fused cross-entropy, optional Liger fused linear cross-entropy, or a memory-frugal chunked pure-PyTorch loss.
  • DDP training, single- and multi-GPU via DistributedDataParallel, gradient accumulation, cosine schedule with warm-up, checkpointing, and rich representation-health diagnostics.

๐Ÿ“‹ Table of Contents


๐Ÿ› ๏ธ Installation

Requirements

  • Python 3.10+ (the codebase uses X | None / tuple[...] syntax)
  • PyTorch 2.0+
  • CUDA 11.0+ for the GPU path (CPU fallback always available)
  • Triton 2.0+ (optional; enables the GPU kernels, Linux/CUDA)

From PyPI

pip install dsalt                 # core
pip install "dsalt[triton]"       # + Triton GPU kernels
pip install "dsalt[dev]"          # + lint/type/test tooling

From source

git clone https://github.com/LeonardoCofone/dsalt-library.git
cd dsalt-library
pip install -e .

๐Ÿš€ Quick Start

Inference

import torch
from dsalt.model import DSALTLMHeadModel

model = DSALTLMHeadModel(
    vocab_size=32000,
    d_model=1024,
    n_layers=24,
    n_heads=16,
    n_min=32,
    n_max=512,
    k_lmk=64,
    max_seq_len=2048,   # required
)

input_ids = torch.randint(0, 32000, (1, 1024))   # [batch, seq_len]
out = model(input_ids)                           # dict
logits = out["logits"]                           # [1, 1024, 32000]
print(logits.shape)

Computing the loss

The forward computes the loss internally (fused) when labels are given, and returns a dict {"loss", "logits", "aux_loss"}:

labels = torch.randint(0, 32000, (1, 1024))
out  = model(input_ids, labels=labels)
loss = out["loss"]          # logits is None here (loss is fused)
loss.backward()

Building from a config

from dsalt.model import DSALTConfig, DSALTLMHeadModel

cfg = DSALTConfig(
    vocab_size=50257, d_model=512, n_layers=6, n_heads=8,
    n_min=64, n_max=256, k_lmk=16, max_seq_len=1024,
)
model = DSALTLMHeadModel.from_config(cfg)
cfg.save("config.json")               # reload with DSALTConfig.load(...)

๐Ÿ—๏ธ Architecture Overview

DSALT combines a per-token local causal window with global landmark tokens selected per head:

โ”Œโ”€ Local window (adaptive) โ”€โ”€โ”ฌโ”€ Global landmarks โ”€โ”€โ”
โ”‚  Recent tokens up to       โ”‚  Top-k informative  โ”‚
โ”‚  window size               โ”‚  tokens per head    โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
                 โ†“                      โ†“
              Sparse attention output  (W(i) โˆช L(i))

Components:

  1. DSALTAttention, multi-head sparse attention over W(i) โˆช L(i) with RoPE/YaRN positions.
  2. hybrid_scores_per_head, the single source of the hybrid-energy landmark score (ยง4.3), shared by both the SDPA path and the Triton kernel.
  3. DSALTTransformerBlock / SwiGLUFFN, pre-norm block with a gated SwiGLU FFN.
  4. DSALTLMHeadModel, embeddings + block stack + RMSNorm + (tied) LM head.
  5. Triton kernels, fused forward (dsalt_triton_attention) and backward with online softmax and one-shot autotuned block sizes.

Note. The local window is frozen to (n_min + n_max) // 2 in this release (no learnable window predictor). The learned adaptivity is the per-head alpha in the landmark-energy score. See COSE_CAMBIATE_DALLA_TEORIA.md.


๐ŸŽฏ Training

DSALTTrainer drives single- and multi-GPU (DDP) training. It expects packed batches: (input_ids, labels, cu_seqlens, max_seqlen), where cu_seqlens is an int32 offset tensor of shape [num_seqs + 1] and -100 labels are ignored.

from dsalt.model import DSALTLMHeadModel
from dsalt.training import DSALTTrainer

model = DSALTLMHeadModel(
    vocab_size=32000, d_model=768, n_layers=12, n_heads=12,
    n_min=32, n_max=256, k_lmk=32, max_seq_len=1024,
)

trainer = DSALTTrainer(
    model=model,
    train_loader=train_loader,   # yields (ids, labels, cu_seqlens, max_seqlen)
    val_loader=val_loader,
    lr=3e-4,
    total_steps=10_000,
    warmup_steps=1_000,
    mixed_precision="auto",      # bf16 on sm_80+, fp16 on T4-class, none on CPU
    save_dir="./checkpoints_dsalt",
    log_every=100,
)
trainer.train()

Mixed precision

mixed_precision="auto" selects the dtype from the GPU's compute capability: bf16 on sm_80+ (A100/H100/L4/โ€ฆ), fp16 (with a GradScaler) below that (e.g. T4 sm_75), and no autocast on CPU. You can force "bf16", "fp16", or "none" explicitly.

Multi-GPU (DDP)

Launch one process per GPU and pass the distributed identity through; the trainer wraps the model in DistributedDataParallel when world_size > 1:

torchrun --nproc_per_node=2 your_train_script.py
trainer = DSALTTrainer(
    model=model, train_loader=train_loader, val_loader=val_loader,
    rank=rank, local_rank=local_rank, world_size=world_size,
    ddp_backend="nccl", total_steps=100_000,
)
trainer.train()

Only DDP is supported in this release (no FSDP). The trainer also handles gradient accumulation, gradient clipping, cosine LR decay with warm-up, checkpointing (checkpoint_best/step_N/final.pt), and per-layer representation-health metrics. Resume with trainer.load_checkpoint(path).


๐Ÿ“š API Reference

# Top-level exports
from dsalt import (
    DSALTConfig, DSALTLMHeadModel,
    DSALTAttention, DSALTTransformerBlock, SwiGLUFFN,
    DSALTTrainer,
    dsalt_triton_attention,            # None when Triton is unavailable
    hybrid_scores_per_head,            # single source of the landmark score
    sparse_attention_forward, sparse_attention_forward_packed,
    RMSENorm, compute_window_sizes, apply_rotary_emb, build_rope_cache,
)

# Low-level Triton kernel (packed sequences, CUDA + Triton only)
from dsalt.kernels import dsalt_triton_attention
out = dsalt_triton_attention(q, k, v, lmk_indices, lmk_bias, w_sizes, cu_seqlens)

q, k, v are [total_len, n_heads, head_dim]; cu_seqlens is the int32 sequence-offset tensor. See FEATURE.md for the full signature and semantics of every component.


๐Ÿ“– Hyperparameter Guide

Full, source-verified defaults for DSALTLMHeadModel, DSALTConfig, DSALTAttention, and DSALTTrainer live in FEATURE.md. Highlights:

Component Required Notable defaults
DSALTLMHeadModel vocab_size, d_model, n_layers, n_heads, n_min, n_max, k_lmk, max_seq_len d_ff=None (โ†’ 8/3ยทd_model), loss_fn="chunked", tie_weights=True, yarn_scale=1.0
DSALTTrainer model, train_loader, val_loader lr=3e-4, max_grad_norm=0.5, warmup_steps=1000, mixed_precision="auto"

alpha is a learnable per-head parameter (init sigmoid โ‰ˆ 0.6), not a constructor flag. The auxiliary loss term is inert in this release (frozen window) and kept only for signature compatibility.


๐Ÿ“„ License

Apache 2.0, see https://github.com/LeonardoCofone/dsalt-library/blob/main/LICENSE.


๐Ÿค Contributing

Contributions are welcome, see CONTRIBUTING.md. Especially valuable: Triton kernel optimisation, new architectures (encoder / encoder-decoder), additional training strategies, documentation, and bug fixes.


๐Ÿ“ Citation

If you use DSALT in your research, please cite the paper:

@software{dsalt,
  author  = {Cofone, Leonardo},
  title   = {DSALT: Dynamic Sparse Attention with Landmark Tokens},
  url      = {https://github.com/LeonardoCofone/dsalt-library},
  note    = {https://zenodo.org/records/19312826},
}

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dsalt-0.3.90.tar.gz (53.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

dsalt-0.3.90-py3-none-any.whl (54.9 kB view details)

Uploaded Python 3

File details

Details for the file dsalt-0.3.90.tar.gz.

File metadata

  • Download URL: dsalt-0.3.90.tar.gz
  • Upload date:
  • Size: 53.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.0

File hashes

Hashes for dsalt-0.3.90.tar.gz
Algorithm Hash digest
SHA256 e030b683cdaeafb6d65fe9299bb44f9b5785a2d87ad554c915a061aa5570d0fb
MD5 cd50ecbaebf5db5f4d9eae69d66f4feb
BLAKE2b-256 bda393045f8a01d9c8d44b4cb38a77aef9fb48860af6d1e19cf15847d41d4dd6

See more details on using hashes here.

File details

Details for the file dsalt-0.3.90-py3-none-any.whl.

File metadata

  • Download URL: dsalt-0.3.90-py3-none-any.whl
  • Upload date:
  • Size: 54.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.0

File hashes

Hashes for dsalt-0.3.90-py3-none-any.whl
Algorithm Hash digest
SHA256 64ee30914ace5b2b7bb6fb827362b38447c4476deedd161b24b447a5b116bd36
MD5 f0b8cb8abbf5b3078bdb8ebd4d249a5a
BLAKE2b-256 fba1b2d3a887b60994a190ebbc5de7f58ce9824cae9f1793ed26970388f142aa

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page