Butterfly, low-displacement-rank, and Monarch structured-matrix primitives in PyTorch
Project description
torch-structured
Consolidated PyTorch library of structured-matrix primitives:
torch_structured(core) — butterfly matrices for exact fast linear transforms (FFT, iFFT, DCT, DST, Hadamard, circulant, Toeplitz) as learnablenn.Moduledrop-in replacements fornn.Linear.torch_structured.structured— low-displacement-rank layers ported from structured-nets: Toeplitz-like, Hankel, Vandermonde, Fastfood, Circulant, LDR subdiagonal / tridiagonal, Krylov utilities.torch_structured.monarch— Monarch / block-diagonal-butterfly primitives ported from m2: block-diagonal and block-diagonal-butterfly multiplies, structured linear layers, butterfly-factor helper, and Hyena implicit long filter.
See the NOTICE file for upstream attributions and citations.
Requirements
- Python >= 3.10
- PyTorch >= 2.0
- NumPy, SciPy, einops, opt_einsum
- A C++ compiler supporting C++14 (for building extensions)
- CUDA toolkit (optional, for GPU acceleration)
Installation
uv pip install . # or: pip install .
uv pip install -e ".[dev]" # development install
CUDA support
CUDA extensions are compiled automatically when a CUDA toolkit is detected. Override with env vars:
FORCE_CUDA=1 uv pip install . # force CUDA compilation
FORCE_CPU=1 uv pip install . # force CPU-only build
TORCH_CUDA_ARCH_LIST targets specific GPU architectures (default: "7.0 8.0 9.0+PTX").
Built extensions (CUDA builds):
torch_structured._butterfly,torch_structured._version— core butterfly ops (torch.ops-style).torch_structured._hadamard_cuda— fast Walsh-Hadamard transform (pybind module).torch_structured._diag_mult_cuda— subdiagonal cycle-multiply helper (pybind module).
Quickstart
Core butterfly
import torch
from torch_structured import Butterfly
from torch_structured.special import fft, hadamard
layer = Butterfly(in_size=1024, out_size=1024)
fft_layer = fft(1024)
hadamard_layer = hadamard(1024)
Structured (LDR) layers
from torch_structured.structured.layers import ToeplitzLike, LDRSubdiagonal
from torch_structured.structured.hadamard import hadamard_transform_torch
toeplitz = ToeplitzLike(layer_size=256, r=2)
ldr_sd = LDRSubdiagonal(layer_size=256, r=2)
y = hadamard_transform_torch(torch.randn(4, 128))
Monarch primitives
import torch
from torch_structured.monarch.blockdiag_linear import BlockdiagLinear
from torch_structured.monarch.blockdiag_butterfly_multiply import (
blockdiag_butterfly_multiply,
)
linear = BlockdiagLinear(in_features=512, out_features=512, nblocks=4)
# low-level multiply:
x = torch.randn(8, 64)
w1 = torch.randn(8, 8, 8)
w2 = torch.randn(8, 8, 8)
out = blockdiag_butterfly_multiply(x, w1, w2)
Triton backend (v1.2+)
Starting with v1.2, torch_structured ships a Triton-based GPU backend that replaces
the CUDA C++ extensions for the main kernels (butterfly_multiply, diag_mult,
hadamard_transform). The Triton path is the default when both a CUDA device and
PyTorch >= 2.6 are available.
Hardware requirements
The Triton backend requires NVIDIA CUDA Compute Capability CC 8.0 or later (Ampere generation: RTX 30xx/40xx, A100, H100, etc.). Older GPUs are NOT supported on the Triton path:
- Volta (sm_70 — V100, Titan V): pin to v1.1 (
pip install torch-structured==1.1.*) OR use the CUDA backend with a self-built.so. - Turing (sm_75 — T4, RTX 20xx): same recommendation as Volta.
Switch to the CUDA backend (when the legacy .so is built) via:
export TORCH_STRUCTURED_BACKEND=cuda
Deterministic mode
By default, the Triton backward kernel uses atomic-add reductions for
d_twiddle accumulation, which can produce slightly different results across
runs (within documented tolerance, but not bit-identical).
For reproducible gradients, opt into deterministic mode:
import torch_structured
torch_structured.set_deterministic(True)
# ... training step ...
torch_structured.set_deterministic(False)
Under deterministic mode, the backward routes through the pure-PyTorch oracle
(butterfly_multiply_torch) — slower, but deterministic by construction.
Deterministic mode also activates automatically when
torch.use_deterministic_algorithms(True) is set globally (additive OR
composition with PyTorch's flag).
Switching backends
Use TORCH_STRUCTURED_BACKEND at import time OR
torch_structured.set_backend() at runtime:
TORCH_STRUCTURED_BACKEND=triton # default on Ampere+
TORCH_STRUCTURED_BACKEND=cuda # legacy CUDA C++ path (requires built .so)
TORCH_STRUCTURED_BACKEND=torch # pure-PyTorch fallback (CPU OK)
TORCH_STRUCTURED_BACKEND=auto # try triton -> cuda -> torch
Runtime selector
On some shapes the Triton kernel may be slower than the legacy CUDA path. To
avoid forcing users to choose between backends, the library ships a static
routing table (torch_structured/_routing.json) baked from
triton.testing.do_bench-style measurements at packaging time. When you call
a routed shape with the Triton backend AND the legacy .so is available, the
call transparently routes to CUDA. The selector is dormant when no cell is
marked route_to_cuda — Triton handles every shape.
To regenerate the routing table on your hardware:
python tests/_baseline_butterfly.py # regenerate forward perf grid
python tests/_baseline_butterfly_backward.py # regenerate backward perf grid
python scripts/regenerate_routing_table.py # rebake _routing.json
Measured performance
The numbers below were measured on an NVIDIA RTX 2000 Ada Generation Laptop
GPU (sm_89) with PyTorch's CUDA 13.0 build, batch_size=64, nstacks=1,
nblocks=1. Each cell is the p50 over a triton.testing.do_bench sweep
(warmup=25ms, rep=100ms), in milliseconds — lower is better. "torch" is the
pure-PyTorch oracle (butterfly_multiply_torch); "CUDA" is the legacy C++
backend; "Triton" is the v1.2 default.
Forward (butterfly_multiply):
| size (n) | dtype | Triton (ms) | CUDA (ms) | torch (ms) | Triton vs torch | routed → CUDA |
|---|---|---|---|---|---|---|
| 256 | fp32 | 0.033 | 0.060 | 0.384 | 11.7× | |
| 256 | complex64 | 0.043 | 0.054 | 0.380 | 8.8× | |
| 512 | fp32 | 0.044 | 0.049 | 0.427 | 9.8× | |
| 512 | complex64 | 0.074 | 0.087 | 0.571 | 7.8× | |
| 1024 | fp32 | 0.072 | 0.076 | 0.464 | 6.5× | |
| 1024 | complex64 | 0.125 | 0.071 | 0.473 | 3.8× | ✓ |
| 2048 | fp32 | 0.135 | 0.083 | 0.570 | 4.2× | |
| 2048 | complex64 | 0.255 | 0.080 | 0.510 | 2.0× | ✓ |
Backward (gradient, full autograd callback incl. trail recompute):
| size (n) | dtype | Triton (ms) | CUDA (ms) | torch (ms) | Triton vs torch | routed → CUDA |
|---|---|---|---|---|---|---|
| 256 | fp32 | 0.303 | 0.483 | 2.421 | 8.0× | |
| 256 | complex64 | 0.290 | 0.324 | 2.734 | 9.4× | |
| 512 | fp32 | 0.265 | 0.330 | 2.601 | 9.8× | |
| 512 | complex64 | 1.713 | 0.439 | 3.326 | 1.9× | ✓ |
| 1024 | fp32 | 2.723 | 0.829 | 5.842 | 2.1× | ✓ |
| 1024 | complex64 | 1.661 | 1.049 | 10.265 | 6.2× | |
| 2048 | fp32 | 2.171 | 0.586 | 5.517 | 2.5× | ✓ |
| 2048 | complex64 | 2.132 | 0.462 | 5.684 | 2.7× | ✓ |
Takeaways on this machine: Triton beats the pure-PyTorch oracle everywhere
(~2–12×) and is competitive with — often faster than — the legacy CUDA kernel on
forward and on the smaller backward shapes. On the larger/complex backward
shapes the legacy CUDA kernel is still faster, so the shipped _routing.json
transparently routes those cells to CUDA when the legacy .so is built (the
✓ rows above, matching the baked routing table).
Other GPUs will perform differently. These figures are specific to this laptop Ada GPU, this driver/toolkit, and these problem sizes — absolute times and the Triton-vs-CUDA crossover points shift with compute capability, memory bandwidth, PyTorch/Triton versions, and shape. The committed
_routing.jsonreflects this dev host; regenerate it on your own hardware (see above) for routing decisions tuned to your GPU. Treat the table as illustrative, not as a performance guarantee.
Deprecation timeline
torch_structured ships Triton as the default backend in v1.2. The legacy CUDA
C++ backend (csrc/) is being retired over a two-release deprecation cadence:
- v1.2 (current): Triton is the default.
TORCH_STRUCTURED_BACKEND=cudastill works for users who built_butterfly.so/_diag_mult.so/_hadamard.solocally, but emits a one-timeDeprecationWarningat import time pointing here. The Monarch Mixer MathDx kernel (previously vendored undercsrc/) is removed entirely in v1.2; see the CHANGELOG for the full file list. - v1.3 (next minor release, ~6 months out): CUDA build is default-disabled.
csrc/extensions stay in the source tree and can still be compiled viaFORCE_CUDA=1, but the PyPI wheel does NOT include them. TheDeprecationWarningstill fires when a locally-built CUDA path is used. - v1.4+ (post-milestone):
csrc/tree,setup.pyCUDA extension code, and_cuda_legacy/are deleted. The standard 2-release deprecation cadence gives users two minor releases to migrate.
Migration: most users should set nothing and let the Triton default take over. If you have a workload that needs the CUDA backend (e.g., Volta sm_70 / Turing sm_75 hardware that Triton doesn't fully support), see the "Triton backend" section above for hardware requirements; otherwise pin to v1.1.
Tests
pytest tests/
CUDA-only tests are automatically skipped when the corresponding extension is not built.
Citation
See NOTICE for full upstream attributions and BibTeX entries for:
- Dao, Gu, Eichhorn, Rudra, Ré, Learning Fast Algorithms for Linear Transforms Using Butterfly Factorizations, ICML 2019
- Dao et al., Kaleidoscope, ICLR 2020
- Thomas, Gu, Dao, Rudra, Ré, Learning Compressed Transforms with Low Displacement Rank, NeurIPS 2018
- Dao et al., Monarch: Expressive Structured Matrices for Efficient and Accurate Training, ICML 2022
- Fu, Arora, Grogan et al., Monarch Mixer: A Simple Sub-Quadratic GEMM-Based Architecture, NeurIPS 2023
License
Apache-2.0 (see LICENSE).
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torch_structured-1.2.3.tar.gz.
File metadata
- Download URL: torch_structured-1.2.3.tar.gz
- Upload date:
- Size: 160.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3a7a03adef1619180b0c01ddf393255413c56b513312f1368b811e30d21b9e8a
|
|
| MD5 |
59ad1a42b7f2663f041e48f520db9fae
|
|
| BLAKE2b-256 |
e43518521a5fa2d39dc40a503fe60f2d127655e41478f1280aca7d9c916e7c1c
|
File details
Details for the file torch_structured-1.2.3-py3-none-any.whl.
File metadata
- Download URL: torch_structured-1.2.3-py3-none-any.whl
- Upload date:
- Size: 116.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d9f578d357a485169c4bf4c693b6b4e6cc4276700cdbc7a9311ecadff4141298
|
|
| MD5 |
11404528284c92293d6543491d3a094d
|
|
| BLAKE2b-256 |
109cc80f786e0337a9307be3c70da37c3471a911866266eecc5e2d0e65c74e71
|