Skip to main content

α-Diffmax: sparse power-law normalizing operator for attention (0 < α < 1)

Project description

diffmax

α-Diffmax is a sparse, power-law normalizing operator for attention (0 < α < 1). It generalises softmax: lower α produces sparser attention weights while preserving a valid probability distribution. The threshold τ is found by bisection, making it numerically stable and differentiable end-to-end.

The figure below shows a comparison of DiffMax, EntMax, and SoftMax:

DiffMax、EntMax 与 SoftMax 对比图

Figure 1: Comparison of DiffMax, EntMax, and SoftMax under different temperature parameters

Installation

pip install diffmax

Colab / CPU-only environments work out of the box. A CUDA GPU with Triton unlocks the fused kernel automatically.

from diffmax import diffmax_bisect

Optional extras:

pip install "diffmax[monitoring]"   # swanlab metric logging
pip install "diffmax[dev]"          # pytest, scipy (for development)
pip install "diffmax[examples]"     # matplotlib, numpy, pandas

Quick start

import torch
from diffmax import diffmax_bisect, DiffmaxBisectModule

# Functional API — drop-in for torch.softmax
scores = torch.randn(2, 8, 64, 64)          # (B, H, L, L)
weights = diffmax_bisect(scores, alpha=0.85, dim=-1)
# weights.sum(dim=-1) == 1.0, many entries exactly zero

# Module API — use inside nn.Sequential or nn.Module
layer = DiffmaxBisectModule(alpha=0.85, dim=-1)
weights = layer(scores)

API

Symbol Description
diffmax_bisect(X, alpha, dim, ...) Functional forward pass
DiffmaxBisectModule(alpha, dim, ...) nn.Module wrapper, optionally learnable α
HillMap(c, beta0) Hill β→α map (recommended)
TanhMap(c, beta0) Tanh β→α map
ExpMap(c, beta0) Exponential β→α map
diffmax_bisect_monitored(...) Monitored variant (requires swanlab)
__version__ Package version string

diffmax_bisect signature

diffmax_bisect(
    X: Tensor,
    alpha: float | Tensor | Callable = 0.9,
    dim: int = -1,
    n_iter: int = 50,
    ensure_sum_one: bool = True,
) -> Tensor
  • alpha: must satisfy 0 < α < 1. Scalars, broadcast tensors, and zero-arg callables are all accepted.
  • n_iter=50 is sufficient for fp32 convergence; raise to 200 for fp64.

Backends

Device Implementation Notes
CPU Pure PyTorch bisection Default; always available
CUDA (NVIDIA) Triton fused kernel Auto-selected when Triton is installed
ROCm (AMD) Placeholder Plugs into CUDA dispatch key
Ascend NPU Placeholder Activated when torch_npu is installed

Rows with N > 4096 or dtype float64 fall back to the CPU backend even on CUDA.

Development

# Install with uv (recommended)
uv venv && uv sync --extra dev

# Run tests
uv run pytest tests/ -v

# Install with conda / pip
conda create -n diffmax python=3.9 -y && conda activate diffmax
pip install -r requirements.txt && pip install -e ".[dev]"
pytest tests/ -v

CUDA benchmarks (require a GPU):

python -m benchmarks.bench_forward
python -m benchmarks.bench_backward
python -m benchmarks.bench_e2e_attention

Monitoring

Install swanlab to log per-call metrics:

pip install "diffmax[monitoring]"
import swanlab
from diffmax import diffmax_bisect_monitored

swanlab.init(project="diffmax-experiments")

Y = diffmax_bisect_monitored(
    scores, alpha=0.85, dim=-1,
    name_prefix="diffmax/encoder_0",
    step=global_step,
)

Logged keys: {prefix}/forward/*, {prefix}/alpha/*, {prefix}/convergence/*, {prefix}/backward/*.

Dispatch flow

diffmax_bisect(X, alpha, ...)
    → _DiffmaxAutocast.apply(...)       # AMP-aware autograd.Function
        → torch.ops.diffmax.bisect(...) # PyTorch custom dispatcher op
            → CPU kernel (default)
            → CUDA / Triton kernel
            → NPU kernel

loss.backward()
    → torch.library.register_autograd  # _autograd.py
        → torch.ops.diffmax.bisect_backward(...)
            → CPU / CUDA / NPU backward kernel

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

diffmax-0.2.0.tar.gz (29.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

diffmax-0.2.0-py3-none-any.whl (24.6 kB view details)

Uploaded Python 3

File details

Details for the file diffmax-0.2.0.tar.gz.

File metadata

  • Download URL: diffmax-0.2.0.tar.gz
  • Upload date:
  • Size: 29.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.10

File hashes

Hashes for diffmax-0.2.0.tar.gz
Algorithm Hash digest
SHA256 b04ed3d8f862a66e29dd5a057e1ddf1b422a1c75a4a79093c8fc0f028e254e3f
MD5 c8bd3cc04514508d8129ae325ab574a9
BLAKE2b-256 dcd13cdc073cea007d6ed71f64aa8f09b7f69975f52af0187d7aa278749f6621

See more details on using hashes here.

File details

Details for the file diffmax-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: diffmax-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 24.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.10

File hashes

Hashes for diffmax-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 e23e4f0c86e8118b7c23782d308d365455da81adc1101ff6caa04556f3d9d13d
MD5 ff07302f289aee15e9da744e95d17862
BLAKE2b-256 4347b6adebb61e98243c7e17578a9d2b029602dd492b860668f1b07ab082473f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page