Skip to main content

Drop-in throughput and memory optimisations for FAIR Hiera (4D-SDPA, gather/scatter, Triton kernels).

Project description

hiera-optim

Drop-in throughput and memory optimisations for FAIR's Hiera and its MAE variant. Two lines:

from hiera_optim import optimize
optimize(model)

restore the model's silent math-fallback attention to FlashAttention / cuDNN-attn, replace boolean mask indexing with torch.gather / scatter_, and unblock torch.compile. Numerically equivalent within bf16 noise.

Results

H100 (GH200), bf16, full forward + backward.

Production config: Hiera-Base, 224x224, 8 in-chans, B=128

ms / step samples / s peak mem
FAIR baseline + torch.compile 131.7 972 14.0 GB
hiera-optim + torch.compile 70.3 1820 9.4 GB
speedup / saving 1.88x 1.87x 33%

Across the variant matrix (444 GH200 cells)

median mean best worst
speedup 1.35x 1.42x 2.10x 1.10x
memory ratio 74% 73% 29% 99%

RTX 4090, Hiera-Base, 8 in-chans, B=32: 1.81x eager, 2.86x with torch.compile.

Full matrix and per-cell numbers: MATRIX_RESULTS.md.

Install

pip install hiera-optim

From source:

git clone https://github.com/avocardio/hiera-optim.git
cd hiera-optim
pip install -e .

Requires PyTorch >= 2.5 and Triton >= 2.3. Recognises FAIR Hiera in-tree (models.hiera) or via PyPI (hiera-transformer).

Usage

import torch
from hiera_optim import optimize
from hiera import mae_hiera_base_224

model = mae_hiera_base_224(pretrained=False, in_chans=3, input_size=(224, 224))
optimize(model)
model = torch.compile(model, mode="default", dynamic=False)

x = torch.randn(128, 3, 224, 224, device="cuda", dtype=torch.bfloat16)
loss, *_ = model(x, mask_ratio=0.6)
loss.backward()

optimize(model) does two things, in place, weights preserved:

  1. Swap every MaskUnitAttention for a 4D-reshape variant so PyTorch SDPA dispatches to FlashAttention / cuDNN-attn / mem-efficient instead of math. FAIR's original feeds SDPA a 5-D tensor which the fused kernels reject, costing ~13x per call on Ada, ~6x on Hopper.
  2. Swap x[mask.tile(...)] and x_dec[mask] = ... for explicit torch.gather / scatter_. Removes a slow indexing_backward_kernel and the aten::nonzero graph break that stops torch.compile.

Optional

from hiera_optim import optimize, enable_stage_checkpointing

optimize(model, sdpa_backend="auto")           # per-block SDPA hint
enable_stage_checkpointing(model, stages=(2,)) # OOM lever

GPU support

Architecture SM Status
Ada (RTX 4090, L40) SM89 Tested
Hopper (H100, GH200) SM90 Tested
Ampere (A100) SM80 Should work
Blackwell (B200) SM100 Should work

Tests

pip install -e .[test]
pytest

112 tests cover all 5 Hiera variants x q_pool {1, 2, 3} x mask ratios x bf16/fp16/fp32 x 1D/2D/3D inputs x classification + MAE.

License

MIT.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

hiera_optim-0.1.0.tar.gz (30.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

hiera_optim-0.1.0-py3-none-any.whl (24.2 kB view details)

Uploaded Python 3

File details

Details for the file hiera_optim-0.1.0.tar.gz.

File metadata

  • Download URL: hiera_optim-0.1.0.tar.gz
  • Upload date:
  • Size: 30.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.13

File hashes

Hashes for hiera_optim-0.1.0.tar.gz
Algorithm Hash digest
SHA256 d49cd4c56e4e8214a6516375a741963bd17f1a8eb26186d3e3d5db6c4e1f8565
MD5 170846394bc6b3f06e5091001ac754d0
BLAKE2b-256 6dde1e1f3ae43fb02f4c25cd9c02a6d44dd8ee172becd6ae9a5dc7f1a7df5a52

See more details on using hashes here.

File details

Details for the file hiera_optim-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: hiera_optim-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 24.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.13

File hashes

Hashes for hiera_optim-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 3245a7386c7afaff0e1918d43a4bbbe8f4883c83ed09dac5f508b15edc032602
MD5 8847f8834d4997cfda2bf4cc67f06930
BLAKE2b-256 b5ae9034d08285e15def614e81905ecaea00227854570ba9117ecb9863072f99

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page