Skip to main content

Butterfly, low-displacement-rank, and Monarch structured-matrix primitives in PyTorch

Project description

torch-structured

Consolidated PyTorch library of structured-matrix primitives:

  • torch_structured (core) — butterfly matrices for exact fast linear transforms (FFT, iFFT, DCT, DST, Hadamard, circulant, Toeplitz) as learnable nn.Module drop-in replacements for nn.Linear.
  • torch_structured.structured — low-displacement-rank layers ported from structured-nets: Toeplitz-like, Hankel, Vandermonde, Fastfood, Circulant, LDR subdiagonal / tridiagonal, Krylov utilities.
  • torch_structured.monarch — Monarch / block-diagonal-butterfly primitives ported from m2: block-diagonal and block-diagonal-butterfly multiplies, structured linear layers, butterfly-factor helper, and Hyena implicit long filter.

See the NOTICE file for upstream attributions and citations.

Requirements

  • Python >= 3.10
  • PyTorch >= 2.0
  • NumPy, SciPy, einops, opt_einsum
  • A C++ compiler supporting C++14 (for building extensions)
  • CUDA toolkit (optional, for GPU acceleration)

Installation

uv pip install .            # or: pip install .
uv pip install -e ".[dev]"  # development install

CUDA support

CUDA extensions are compiled automatically when a CUDA toolkit is detected. Override with env vars:

FORCE_CUDA=1 uv pip install .   # force CUDA compilation
FORCE_CPU=1 uv pip install .    # force CPU-only build

TORCH_CUDA_ARCH_LIST targets specific GPU architectures (default: "7.0 8.0 9.0+PTX").

Built extensions (CUDA builds):

  • torch_structured._butterfly, torch_structured._version — core butterfly ops (torch.ops-style).
  • torch_structured._hadamard_cuda — fast Walsh-Hadamard transform (pybind module).
  • torch_structured._diag_mult_cuda — subdiagonal cycle-multiply helper (pybind module).

Quickstart

Core butterfly

import torch
from torch_structured import Butterfly
from torch_structured.butterfly.special import fft, hadamard

layer = Butterfly(in_size=1024, out_size=1024)
fft_layer = fft(1024)
hadamard_layer = hadamard(1024)

Structured (LDR) layers

from torch_structured.structured.layers import ToeplitzLike, LDRSubdiagonal
from torch_structured.structured.hadamard import hadamard_transform_torch

toeplitz = ToeplitzLike(layer_size=256, r=2)
ldr_sd = LDRSubdiagonal(layer_size=256, r=2)
y = hadamard_transform_torch(torch.randn(4, 128))

Monarch primitives

import torch
from torch_structured.monarch.blockdiag_linear import BlockdiagLinear
from torch_structured.monarch.blockdiag_butterfly_multiply import (
    blockdiag_butterfly_multiply,
)

linear = BlockdiagLinear(in_features=512, out_features=512, nblocks=4)
# low-level multiply:
x = torch.randn(8, 64)
w1 = torch.randn(8, 8, 8)
w2 = torch.randn(8, 8, 8)
out = blockdiag_butterfly_multiply(x, w1, w2)

Triton backend (v1.2+)

Starting with v1.2, torch_structured ships a Triton-based GPU backend that replaces the CUDA C++ extensions for the main kernels (butterfly_multiply, diag_mult, hadamard_transform). The Triton path is the default when both a CUDA device and PyTorch >= 2.6 are available.

Hardware requirements

The Triton backend requires NVIDIA CUDA Compute Capability CC 8.0 or later (Ampere generation: RTX 30xx/40xx, A100, H100, etc.). Older GPUs are NOT supported on the Triton path:

  • Volta (sm_70 — V100, Titan V): pin to v1.1 (pip install torch-structured==1.1.*) OR use the CUDA backend with a self-built .so.
  • Turing (sm_75 — T4, RTX 20xx): same recommendation as Volta.

Switch to the CUDA backend (when the legacy .so is built) via:

export TORCH_STRUCTURED_BACKEND=cuda

Deterministic mode

By default, the Triton backward kernel uses atomic-add reductions for d_twiddle accumulation, which can produce slightly different results across runs (within documented tolerance, but not bit-identical).

For reproducible gradients, opt into deterministic mode:

import torch_structured

torch_structured.set_deterministic(True)
# ... training step ...
torch_structured.set_deterministic(False)

Under deterministic mode, the backward routes through the pure-PyTorch oracle (butterfly_multiply_torch) — slower, but deterministic by construction. Deterministic mode also activates automatically when torch.use_deterministic_algorithms(True) is set globally (additive OR composition with PyTorch's flag).

Switching backends

Use TORCH_STRUCTURED_BACKEND at import time OR torch_structured.set_backend() at runtime:

TORCH_STRUCTURED_BACKEND=triton  # default on Ampere+
TORCH_STRUCTURED_BACKEND=cuda    # legacy CUDA C++ path (requires built .so)
TORCH_STRUCTURED_BACKEND=torch   # pure-PyTorch fallback (CPU OK)
TORCH_STRUCTURED_BACKEND=auto    # try triton -> cuda -> torch

Runtime selector

On some shapes the Triton kernel may be slower than the legacy CUDA path. To avoid forcing users to choose between backends, the library ships a static routing table (torch_structured/_routing.json) baked from triton.testing.do_bench-style measurements at packaging time. When you call a routed shape with the Triton backend AND the legacy .so is available, the call transparently routes to CUDA. The selector is dormant when no cell is marked route_to_cuda — Triton handles every shape.

To regenerate the routing table on your hardware:

python tests/_baseline_butterfly.py            # regenerate forward perf grid
python tests/_baseline_butterfly_backward.py   # regenerate backward perf grid
python scripts/regenerate_routing_table.py     # rebake _routing.json

Measured performance

The numbers below were measured on an NVIDIA RTX 2000 Ada Generation Laptop GPU (sm_89) with PyTorch's CUDA 13.0 build, batch_size=64, nstacks=1, nblocks=1. Each cell is the p50 over a triton.testing.do_bench sweep (warmup=25ms, rep=100ms), in milliseconds — lower is better. "torch" is the pure-PyTorch oracle (butterfly_multiply_torch); "CUDA" is the legacy C++ backend; "Triton" is the v1.2 default.

Forward (butterfly_multiply):

size (n) dtype Triton (ms) CUDA (ms) torch (ms) Triton vs torch routed → CUDA
256 fp32 0.033 0.060 0.384 11.7×
256 complex64 0.043 0.054 0.380 8.8×
512 fp32 0.044 0.049 0.427 9.8×
512 complex64 0.074 0.087 0.571 7.8×
1024 fp32 0.072 0.076 0.464 6.5×
1024 complex64 0.125 0.071 0.473 3.8×
2048 fp32 0.135 0.083 0.570 4.2×
2048 complex64 0.255 0.080 0.510 2.0×

Backward (gradient, full autograd callback incl. trail recompute):

size (n) dtype Triton (ms) CUDA (ms) torch (ms) Triton vs torch routed → CUDA
256 fp32 0.303 0.483 2.421 8.0×
256 complex64 0.290 0.324 2.734 9.4×
512 fp32 0.265 0.330 2.601 9.8×
512 complex64 1.713 0.439 3.326 1.9×
1024 fp32 2.723 0.829 5.842 2.1×
1024 complex64 1.661 1.049 10.265 6.2×
2048 fp32 2.171 0.586 5.517 2.5×
2048 complex64 2.132 0.462 5.684 2.7×

Takeaways on this machine: Triton beats the pure-PyTorch oracle everywhere (~2–12×) and is competitive with — often faster than — the legacy CUDA kernel on forward and on the smaller backward shapes. On the larger/complex backward shapes the legacy CUDA kernel is still faster, so the shipped _routing.json transparently routes those cells to CUDA when the legacy .so is built (the ✓ rows above, matching the baked routing table).

Other GPUs will perform differently. These figures are specific to this laptop Ada GPU, this driver/toolkit, and these problem sizes — absolute times and the Triton-vs-CUDA crossover points shift with compute capability, memory bandwidth, PyTorch/Triton versions, and shape. The committed _routing.json reflects this dev host; regenerate it on your own hardware (see above) for routing decisions tuned to your GPU. Treat the table as illustrative, not as a performance guarantee.

Deprecation timeline

torch_structured ships Triton as the default backend in v1.2. The legacy CUDA C++ backend (csrc/) is being retired over a two-release deprecation cadence:

  • v1.2 (current): Triton is the default. TORCH_STRUCTURED_BACKEND=cuda still works for users who built _butterfly.so / _diag_mult.so / _hadamard.so locally, but emits a one-time DeprecationWarning at import time pointing here. The Monarch Mixer MathDx kernel (previously vendored under csrc/) is removed entirely in v1.2; see the CHANGELOG for the full file list.
  • v1.3 (next minor release, ~6 months out): CUDA build is default-disabled. csrc/ extensions stay in the source tree and can still be compiled via FORCE_CUDA=1, but the PyPI wheel does NOT include them. The DeprecationWarning still fires when a locally-built CUDA path is used.
  • v1.4+ (post-milestone): csrc/ tree, setup.py CUDA extension code, and _cuda_legacy/ are deleted. The standard 2-release deprecation cadence gives users two minor releases to migrate.

Migration: most users should set nothing and let the Triton default take over. If you have a workload that needs the CUDA backend (e.g., Volta sm_70 / Turing sm_75 hardware that Triton doesn't fully support), see the "Triton backend" section above for hardware requirements; otherwise pin to v1.1.

Tests

pytest tests/

CUDA-only tests are automatically skipped when the corresponding extension is not built.

Citation

See NOTICE for full upstream attributions and BibTeX entries for:

  • Dao, Gu, Eichhorn, Rudra, Ré, Learning Fast Algorithms for Linear Transforms Using Butterfly Factorizations, ICML 2019
  • Dao et al., Kaleidoscope, ICLR 2020
  • Thomas, Gu, Dao, Rudra, Ré, Learning Compressed Transforms with Low Displacement Rank, NeurIPS 2018
  • Dao et al., Monarch: Expressive Structured Matrices for Efficient and Accurate Training, ICML 2022
  • Fu, Arora, Grogan et al., Monarch Mixer: A Simple Sub-Quadratic GEMM-Based Architecture, NeurIPS 2023

License

Apache-2.0 (see LICENSE).

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_structured-1.2.4.tar.gz (161.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_structured-1.2.4-py3-none-any.whl (116.3 kB view details)

Uploaded Python 3

File details

Details for the file torch_structured-1.2.4.tar.gz.

File metadata

  • Download URL: torch_structured-1.2.4.tar.gz
  • Upload date:
  • Size: 161.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.11

File hashes

Hashes for torch_structured-1.2.4.tar.gz
Algorithm Hash digest
SHA256 3cb07ae7981eac98aba9206f4ef0768cc3bb0fa087b68fe97b1337ec4ee402ea
MD5 744f51b214cc307a301caca83c50f3d0
BLAKE2b-256 adac48a8b3785dee966fa08ddedc336758d5a36d83df829625be7edc4228146b

See more details on using hashes here.

File details

Details for the file torch_structured-1.2.4-py3-none-any.whl.

File metadata

File hashes

Hashes for torch_structured-1.2.4-py3-none-any.whl
Algorithm Hash digest
SHA256 6976b2a0324b07bf186291fa581fe6eb29b6eca9587281940a76e1b62de0126d
MD5 9dfd9a4cdb75532a1ed05db1fefbb875
BLAKE2b-256 761923dc0a7e78a69d3bf58edf9898cefebf56620a662c910a7b5c74f7c6296b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page