α-Diffmax: sparse power-law normalizing operator for attention (0 < α < 1)
Project description
diffmax
α-Diffmax is a sparse, power-law normalizing operator for attention (0 < α < 1). It generalises softmax: lower α produces sparser attention weights while preserving a valid probability distribution. The threshold τ is found by bisection, making it numerically stable and differentiable end-to-end.
The figure below shows a comparison of DiffMax, EntMax, and SoftMax:
Figure 1: Comparison of DiffMax, EntMax, and SoftMax under different temperature parameters
Installation
pip install diffmax
Colab / CPU-only environments work out of the box. A CUDA GPU with Triton unlocks the fused kernel automatically.
from diffmax import diffmax_bisect
Optional extras:
pip install "diffmax[monitoring]" # swanlab metric logging
pip install "diffmax[dev]" # pytest, scipy (for development)
pip install "diffmax[examples]" # matplotlib, numpy, pandas
Quick start
import torch
from diffmax import diffmax_bisect, DiffmaxBisectModule
# Functional API — drop-in for torch.softmax
scores = torch.randn(2, 8, 64, 64) # (B, H, L, L)
weights = diffmax_bisect(scores, alpha=0.85, dim=-1)
# weights.sum(dim=-1) == 1.0, many entries exactly zero
# Module API — use inside nn.Sequential or nn.Module
layer = DiffmaxBisectModule(alpha=0.85, dim=-1)
weights = layer(scores)
API
| Symbol | Description |
|---|---|
diffmax_bisect(X, alpha, dim, ...) |
Functional forward pass |
DiffmaxBisectModule(alpha, dim, ...) |
nn.Module wrapper, optionally learnable α |
HillMap(c, beta0) |
Hill β→α map (recommended) |
TanhMap(c, beta0) |
Tanh β→α map |
ExpMap(c, beta0) |
Exponential β→α map |
diffmax_bisect_monitored(...) |
Monitored variant (requires swanlab) |
__version__ |
Package version string |
diffmax_bisect signature
diffmax_bisect(
X: Tensor,
alpha: float | Tensor | Callable = 0.9,
dim: int = -1,
n_iter: int = 50,
ensure_sum_one: bool = True,
) -> Tensor
alpha: must satisfy 0 < α < 1. Scalars, broadcast tensors, and zero-arg callables are all accepted.n_iter=50is sufficient for fp32 convergence; raise to 200 for fp64.
Backends
| Device | Implementation | Notes |
|---|---|---|
| CPU | Pure PyTorch bisection | Default; always available |
| CUDA (NVIDIA) | Triton fused kernel | Auto-selected when Triton is installed |
| ROCm (AMD) | Placeholder | Plugs into CUDA dispatch key |
| Ascend NPU | Placeholder | Activated when torch_npu is installed |
Rows with N > 4096 or dtype float64 fall back to the CPU backend even on CUDA.
Development
# Install with uv (recommended)
uv venv && uv sync --extra dev
# Run tests
uv run pytest tests/ -v
# Install with conda / pip
conda create -n diffmax python=3.9 -y && conda activate diffmax
pip install -r requirements.txt && pip install -e ".[dev]"
pytest tests/ -v
CUDA benchmarks (require a GPU):
python -m benchmarks.bench_forward
python -m benchmarks.bench_backward
python -m benchmarks.bench_e2e_attention
Monitoring
Install swanlab to log per-call metrics:
pip install "diffmax[monitoring]"
import swanlab
from diffmax import diffmax_bisect_monitored
swanlab.init(project="diffmax-experiments")
Y = diffmax_bisect_monitored(
scores, alpha=0.85, dim=-1,
name_prefix="diffmax/encoder_0",
step=global_step,
)
Logged keys: {prefix}/forward/*, {prefix}/alpha/*, {prefix}/convergence/*, {prefix}/backward/*.
Dispatch flow
diffmax_bisect(X, alpha, ...)
→ _DiffmaxAutocast.apply(...) # AMP-aware autograd.Function
→ torch.ops.diffmax.bisect(...) # PyTorch custom dispatcher op
→ CPU kernel (default)
→ CUDA / Triton kernel
→ NPU kernel
loss.backward()
→ torch.library.register_autograd # _autograd.py
→ torch.ops.diffmax.bisect_backward(...)
→ CPU / CUDA / NPU backward kernel
License
MIT
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file diffmax-0.2.0.tar.gz.
File metadata
- Download URL: diffmax-0.2.0.tar.gz
- Upload date:
- Size: 29.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.10
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b04ed3d8f862a66e29dd5a057e1ddf1b422a1c75a4a79093c8fc0f028e254e3f
|
|
| MD5 |
c8bd3cc04514508d8129ae325ab574a9
|
|
| BLAKE2b-256 |
dcd13cdc073cea007d6ed71f64aa8f09b7f69975f52af0187d7aa278749f6621
|
File details
Details for the file diffmax-0.2.0-py3-none-any.whl.
File metadata
- Download URL: diffmax-0.2.0-py3-none-any.whl
- Upload date:
- Size: 24.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.10
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e23e4f0c86e8118b7c23782d308d365455da81adc1101ff6caa04556f3d9d13d
|
|
| MD5 |
ff07302f289aee15e9da744e95d17862
|
|
| BLAKE2b-256 |
4347b6adebb61e98243c7e17578a9d2b029602dd492b860668f1b07ab082473f
|