Skip to main content

α-Diffmax: sparse power-law normalizing operator for attention (0 < α < 1)

Project description

diffmax

α-Diffmax is a sparse, power-law normalizing operator for attention (0 < α < 1). It generalises softmax: lower α produces sparser attention weights while preserving a valid probability distribution. The threshold τ is found by bisection, making it numerically stable and differentiable end-to-end.

Installation

pip install diffmax

Colab / CPU-only environments work out of the box. A CUDA GPU with Triton unlocks the fused kernel automatically.

from diffmax import diffmax_bisect

Optional extras:

pip install "diffmax[monitoring]"   # swanlab metric logging
pip install "diffmax[dev]"          # pytest, scipy (for development)
pip install "diffmax[examples]"     # matplotlib, numpy, pandas

Quick start

import torch
from diffmax import diffmax_bisect, DiffmaxBisectModule

# Functional API — drop-in for torch.softmax
scores = torch.randn(2, 8, 64, 64)          # (B, H, L, L)
weights = diffmax_bisect(scores, alpha=0.85, dim=-1)
# weights.sum(dim=-1) == 1.0, many entries exactly zero

# Module API — use inside nn.Sequential or nn.Module
layer = DiffmaxBisectModule(alpha=0.85, dim=-1)
weights = layer(scores)

Learnable α via a β→α map:

from diffmax import DiffmaxBisectModule, HillMap

layer = DiffmaxBisectModule(alpha=0.85, dim=-1, alpha_map=HillMap(c=0.3))
# Exposes `log_beta` as a trainable parameter; optimiser tunes α.
optimizer = torch.optim.Adam(layer.parameters(), lr=1e-3)

API

Symbol Description
diffmax_bisect(X, alpha, dim, ...) Functional forward pass
DiffmaxBisectModule(alpha, dim, ...) nn.Module wrapper, optionally learnable α
HillMap(c, beta0) Hill β→α map (recommended)
TanhMap(c, beta0) Tanh β→α map
ExpMap(c, beta0) Exponential β→α map
diffmax_bisect_monitored(...) Monitored variant (requires swanlab)
__version__ Package version string

diffmax_bisect signature

diffmax_bisect(
    X: Tensor,
    alpha: float | Tensor | Callable = 0.9,
    dim: int = -1,
    n_iter: int = 50,
    ensure_sum_one: bool = True,
) -> Tensor
  • alpha: must satisfy 0 < α < 1. Scalars, broadcast tensors, and zero-arg callables are all accepted.
  • n_iter=50 is sufficient for fp32 convergence; raise to 200 for fp64.

Backends

Device Implementation Notes
CPU Pure PyTorch bisection Default; always available
CUDA (NVIDIA) Triton fused kernel Auto-selected when Triton is installed
ROCm (AMD) Placeholder Plugs into CUDA dispatch key
Ascend NPU Placeholder Activated when torch_npu is installed

Rows with N > 4096 or dtype float64 fall back to the CPU backend even on CUDA.

Development

# Install with uv (recommended)
uv venv && uv sync --extra dev

# Run tests
uv run pytest tests/ -v

# Install with conda / pip
conda create -n diffmax python=3.9 -y && conda activate diffmax
pip install -r requirements.txt && pip install -e ".[dev]"
pytest tests/ -v

CUDA benchmarks (require a GPU):

python -m benchmarks.bench_forward
python -m benchmarks.bench_backward
python -m benchmarks.bench_e2e_attention

Monitoring

Install swanlab to log per-call metrics:

pip install "diffmax[monitoring]"
import swanlab
from diffmax import diffmax_bisect_monitored

swanlab.init(project="diffmax-experiments")

Y = diffmax_bisect_monitored(
    scores, alpha=0.85, dim=-1,
    name_prefix="diffmax/encoder_0",
    step=global_step,
)

Logged keys: {prefix}/forward/*, {prefix}/alpha/*, {prefix}/convergence/*, {prefix}/backward/*.

Dispatch flow

diffmax_bisect(X, alpha, ...)
    → _DiffmaxAutocast.apply(...)       # AMP-aware autograd.Function
        → torch.ops.diffmax.bisect(...) # PyTorch custom dispatcher op
            → CPU kernel (default)
            → CUDA / Triton kernel
            → NPU kernel

loss.backward()
    → torch.library.register_autograd  # _autograd.py
        → torch.ops.diffmax.bisect_backward(...)
            → CPU / CUDA / NPU backward kernel

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

diffmax-0.1.0.tar.gz (28.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

diffmax-0.1.0-py3-none-any.whl (24.2 kB view details)

Uploaded Python 3

File details

Details for the file diffmax-0.1.0.tar.gz.

File metadata

  • Download URL: diffmax-0.1.0.tar.gz
  • Upload date:
  • Size: 28.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.10

File hashes

Hashes for diffmax-0.1.0.tar.gz
Algorithm Hash digest
SHA256 a59625f09745043b61edef575b779dfb8fcb455ce23e08da8171ccd3d115ab64
MD5 683b8dd18b2015439bfe99a7347b77e2
BLAKE2b-256 f11ceaeca8aa61000e840ced48e59bc39df0de8c12c0a9205df0565bd2ac23ca

See more details on using hashes here.

File details

Details for the file diffmax-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: diffmax-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 24.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.10

File hashes

Hashes for diffmax-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 80f41c7f62d7ddaf84a5e54956688c9ec100e5b1759aa5ef153477424940b9b6
MD5 7c2504c8a23ed9f55f201e9bf863ee75
BLAKE2b-256 691722323d15cf0425be1fa293db64bdbe2a3078031773292535de9a6aa07458

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page