Skip to main content

Butterfly, low-displacement-rank, and Monarch structured-matrix primitives in PyTorch

Project description

torch-structured

Consolidated PyTorch library of structured-matrix primitives:

  • torch_structured (core) — butterfly matrices for exact fast linear transforms (FFT, iFFT, DCT, DST, Hadamard, circulant, Toeplitz) as learnable nn.Module drop-in replacements for nn.Linear.
  • torch_structured.structured — low-displacement-rank layers ported from structured-nets: Toeplitz-like, Hankel, Vandermonde, Fastfood, Circulant, LDR subdiagonal / tridiagonal, Krylov utilities.
  • torch_structured.monarch — Monarch / block-diagonal-butterfly primitives ported from m2: block-diagonal and block-diagonal-butterfly multiplies, structured linear layers, butterfly-factor helper, and Hyena implicit long filter.

See the NOTICE file for upstream attributions and citations.

Requirements

  • Python >= 3.10
  • PyTorch >= 2.0
  • NumPy, SciPy, einops, opt_einsum
  • A C++ compiler supporting C++14 (for building extensions)
  • CUDA toolkit (optional, for GPU acceleration)

Installation

uv pip install .            # or: pip install .
uv pip install -e ".[dev]"  # development install

CUDA support

CUDA extensions are compiled automatically when a CUDA toolkit is detected. Override with env vars:

FORCE_CUDA=1 uv pip install .   # force CUDA compilation
FORCE_CPU=1 uv pip install .    # force CPU-only build

TORCH_CUDA_ARCH_LIST targets specific GPU architectures (default: "7.0 8.0 9.0+PTX").

Built extensions (CUDA builds):

  • torch_structured._butterfly, torch_structured._version — core butterfly ops (torch.ops-style).
  • torch_structured._hadamard_cuda — fast Walsh-Hadamard transform (pybind module).
  • torch_structured._diag_mult_cuda — subdiagonal cycle-multiply helper (pybind module).

Quickstart

Core butterfly

import torch
from torch_structured import Butterfly
from torch_structured.special import fft, hadamard

layer = Butterfly(in_size=1024, out_size=1024)
fft_layer = fft(1024)
hadamard_layer = hadamard(1024)

Structured (LDR) layers

from torch_structured.structured.layers import ToeplitzLike, LDRSubdiagonal
from torch_structured.structured.hadamard import hadamard_transform_torch

toeplitz = ToeplitzLike(layer_size=256, r=2)
ldr_sd = LDRSubdiagonal(layer_size=256, r=2)
y = hadamard_transform_torch(torch.randn(4, 128))

Monarch primitives

import torch
from torch_structured.monarch.blockdiag_linear import BlockdiagLinear
from torch_structured.monarch.blockdiag_butterfly_multiply import (
    blockdiag_butterfly_multiply,
)

linear = BlockdiagLinear(in_features=512, out_features=512, nblocks=4)
# low-level multiply:
x = torch.randn(8, 64)
w1 = torch.randn(8, 8, 8)
w2 = torch.randn(8, 8, 8)
out = blockdiag_butterfly_multiply(x, w1, w2)

Triton backend (v1.2+)

Starting with v1.2, torch_structured ships a Triton-based GPU backend that replaces the CUDA C++ extensions for the main kernels (butterfly_multiply, diag_mult, hadamard_transform). The Triton path is the default when both a CUDA device and PyTorch >= 2.6 are available.

Hardware requirements

The Triton backend requires NVIDIA CUDA Compute Capability CC 8.0 or later (Ampere generation: RTX 30xx/40xx, A100, H100, etc.). Older GPUs are NOT supported on the Triton path:

  • Volta (sm_70 — V100, Titan V): pin to v1.1 (pip install torch-structured==1.1.*) OR use the CUDA backend with a self-built .so.
  • Turing (sm_75 — T4, RTX 20xx): same recommendation as Volta.

Switch to the CUDA backend (when the legacy .so is built) via:

export TORCH_STRUCTURED_BACKEND=cuda

Deterministic mode

By default, the Triton backward kernel uses atomic-add reductions for d_twiddle accumulation, which can produce slightly different results across runs (within documented tolerance, but not bit-identical).

For reproducible gradients, opt into deterministic mode:

import torch_structured

torch_structured.set_deterministic(True)
# ... training step ...
torch_structured.set_deterministic(False)

Under deterministic mode, the backward routes through the pure-PyTorch oracle (butterfly_multiply_torch) — slower, but deterministic by construction. Deterministic mode also activates automatically when torch.use_deterministic_algorithms(True) is set globally (additive OR composition with PyTorch's flag).

Switching backends

Use TORCH_STRUCTURED_BACKEND at import time OR torch_structured.set_backend() at runtime:

TORCH_STRUCTURED_BACKEND=triton  # default on Ampere+
TORCH_STRUCTURED_BACKEND=cuda    # legacy CUDA C++ path (requires built .so)
TORCH_STRUCTURED_BACKEND=torch   # pure-PyTorch fallback (CPU OK)
TORCH_STRUCTURED_BACKEND=auto    # try triton -> cuda -> torch

Runtime selector

On some shapes the Triton kernel may be slower than the legacy CUDA path. To avoid forcing users to choose between backends, the library ships a static routing table (torch_structured/_routing.json) baked from triton.testing.do_bench-style measurements at packaging time. When you call a routed shape with the Triton backend AND the legacy .so is available, the call transparently routes to CUDA. The selector is dormant when no cell is marked route_to_cuda — Triton handles every shape.

To regenerate the routing table on your hardware:

python tests/_baseline_butterfly.py            # regenerate forward perf grid
python tests/_baseline_butterfly_backward.py   # regenerate backward perf grid
python scripts/regenerate_routing_table.py     # rebake _routing.json

Measured performance

The numbers below were measured on an NVIDIA RTX 2000 Ada Generation Laptop GPU (sm_89) with PyTorch's CUDA 13.0 build, batch_size=64, nstacks=1, nblocks=1. Each cell is the p50 over a triton.testing.do_bench sweep (warmup=25ms, rep=100ms), in milliseconds — lower is better. "torch" is the pure-PyTorch oracle (butterfly_multiply_torch); "CUDA" is the legacy C++ backend; "Triton" is the v1.2 default.

Forward (butterfly_multiply):

size (n) dtype Triton (ms) CUDA (ms) torch (ms) Triton vs torch routed → CUDA
256 fp32 0.033 0.060 0.384 11.7×
256 complex64 0.043 0.054 0.380 8.8×
512 fp32 0.044 0.049 0.427 9.8×
512 complex64 0.074 0.087 0.571 7.8×
1024 fp32 0.072 0.076 0.464 6.5×
1024 complex64 0.125 0.071 0.473 3.8×
2048 fp32 0.135 0.083 0.570 4.2×
2048 complex64 0.255 0.080 0.510 2.0×

Backward (gradient, full autograd callback incl. trail recompute):

size (n) dtype Triton (ms) CUDA (ms) torch (ms) Triton vs torch routed → CUDA
256 fp32 0.303 0.483 2.421 8.0×
256 complex64 0.290 0.324 2.734 9.4×
512 fp32 0.265 0.330 2.601 9.8×
512 complex64 1.713 0.439 3.326 1.9×
1024 fp32 2.723 0.829 5.842 2.1×
1024 complex64 1.661 1.049 10.265 6.2×
2048 fp32 2.171 0.586 5.517 2.5×
2048 complex64 2.132 0.462 5.684 2.7×

Takeaways on this machine: Triton beats the pure-PyTorch oracle everywhere (~2–12×) and is competitive with — often faster than — the legacy CUDA kernel on forward and on the smaller backward shapes. On the larger/complex backward shapes the legacy CUDA kernel is still faster, so the shipped _routing.json transparently routes those cells to CUDA when the legacy .so is built (the ✓ rows above, matching the baked routing table).

Other GPUs will perform differently. These figures are specific to this laptop Ada GPU, this driver/toolkit, and these problem sizes — absolute times and the Triton-vs-CUDA crossover points shift with compute capability, memory bandwidth, PyTorch/Triton versions, and shape. The committed _routing.json reflects this dev host; regenerate it on your own hardware (see above) for routing decisions tuned to your GPU. Treat the table as illustrative, not as a performance guarantee.

Deprecation timeline

torch_structured ships Triton as the default backend in v1.2. The legacy CUDA C++ backend (csrc/) is being retired over a two-release deprecation cadence:

  • v1.2 (current): Triton is the default. TORCH_STRUCTURED_BACKEND=cuda still works for users who built _butterfly.so / _diag_mult.so / _hadamard.so locally, but emits a one-time DeprecationWarning at import time pointing here. The Monarch Mixer MathDx kernel (previously vendored under csrc/) is removed entirely in v1.2; see the CHANGELOG for the full file list.
  • v1.3 (next minor release, ~6 months out): CUDA build is default-disabled. csrc/ extensions stay in the source tree and can still be compiled via FORCE_CUDA=1, but the PyPI wheel does NOT include them. The DeprecationWarning still fires when a locally-built CUDA path is used.
  • v1.4+ (post-milestone): csrc/ tree, setup.py CUDA extension code, and _cuda_legacy/ are deleted. The standard 2-release deprecation cadence gives users two minor releases to migrate.

Migration: most users should set nothing and let the Triton default take over. If you have a workload that needs the CUDA backend (e.g., Volta sm_70 / Turing sm_75 hardware that Triton doesn't fully support), see the "Triton backend" section above for hardware requirements; otherwise pin to v1.1.

Tests

pytest tests/

CUDA-only tests are automatically skipped when the corresponding extension is not built.

Citation

See NOTICE for full upstream attributions and BibTeX entries for:

  • Dao, Gu, Eichhorn, Rudra, Ré, Learning Fast Algorithms for Linear Transforms Using Butterfly Factorizations, ICML 2019
  • Dao et al., Kaleidoscope, ICLR 2020
  • Thomas, Gu, Dao, Rudra, Ré, Learning Compressed Transforms with Low Displacement Rank, NeurIPS 2018
  • Dao et al., Monarch: Expressive Structured Matrices for Efficient and Accurate Training, ICML 2022
  • Fu, Arora, Grogan et al., Monarch Mixer: A Simple Sub-Quadratic GEMM-Based Architecture, NeurIPS 2023

License

Apache-2.0 (see LICENSE).

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_structured-1.2.2.tar.gz (159.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_structured-1.2.2-py3-none-any.whl (115.5 kB view details)

Uploaded Python 3

File details

Details for the file torch_structured-1.2.2.tar.gz.

File metadata

  • Download URL: torch_structured-1.2.2.tar.gz
  • Upload date:
  • Size: 159.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.11

File hashes

Hashes for torch_structured-1.2.2.tar.gz
Algorithm Hash digest
SHA256 1d71ab04c57e3dd5f05ed279169ace7958ac88b9a604bc66f3970bef3eb03adf
MD5 61ae894057660433a29b367f917ba009
BLAKE2b-256 205098c8efc03a83b6d141f365a9512a9abe67064709daee47e673950ad336b8

See more details on using hashes here.

File details

Details for the file torch_structured-1.2.2-py3-none-any.whl.

File metadata

File hashes

Hashes for torch_structured-1.2.2-py3-none-any.whl
Algorithm Hash digest
SHA256 f9f1a5475db7d6e35bb773d38f974d51b2e0f6ec223ef2cc4927a85fca0968a9
MD5 f5d6c6c268429bff8d3aa6e031a08514
BLAKE2b-256 8fae937192f9c96cd9a1be8a879ec4467e3d00abc4d4096102e1935c6dd4a9da

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page