Drop-in throughput and memory optimisations for FAIR Hiera (4D-SDPA, gather/scatter, Triton kernels).
Project description
hiera-optim
Drop-in throughput and memory optimisations for FAIR's Hiera and its MAE variant. Two lines:
from hiera_optim import optimize
optimize(model)
restore the model's silent math-fallback attention to FlashAttention / cuDNN-attn, replace boolean mask indexing with torch.gather / scatter_, and unblock torch.compile. Numerically equivalent within bf16 noise.
Results
H100 (GH200), bf16, full forward + backward.
Production config: Hiera-Base, 224x224, 8 in-chans, B=128
| ms / step | samples / s | peak mem | |
|---|---|---|---|
FAIR baseline + torch.compile |
131.7 | 972 | 14.0 GB |
hiera-optim + torch.compile |
70.3 | 1820 | 9.4 GB |
| speedup / saving | 1.88x | 1.87x | 33% |
Across the variant matrix (444 GH200 cells)
| median | mean | best | worst | |
|---|---|---|---|---|
| speedup | 1.35x | 1.42x | 2.10x | 1.10x |
| memory ratio | 74% | 73% | 29% | 99% |
RTX 4090, Hiera-Base, 8 in-chans, B=32: 1.81x eager, 2.86x with torch.compile.
Full matrix and per-cell numbers: MATRIX_RESULTS.md.
Install
pip install hiera-optim
From source:
git clone https://github.com/avocardio/hiera-optim.git
cd hiera-optim
pip install -e .
Requires PyTorch >= 2.5 and Triton >= 2.3. Recognises FAIR Hiera in-tree (models.hiera) or via PyPI (hiera-transformer).
Usage
import torch
from hiera_optim import optimize
from hiera import mae_hiera_base_224
model = mae_hiera_base_224(pretrained=False, in_chans=3, input_size=(224, 224))
optimize(model)
model = torch.compile(model, mode="default", dynamic=False)
x = torch.randn(128, 3, 224, 224, device="cuda", dtype=torch.bfloat16)
loss, *_ = model(x, mask_ratio=0.6)
loss.backward()
optimize(model) does two things, in place, weights preserved:
- Swap every
MaskUnitAttentionfor a 4D-reshape variant so PyTorch SDPA dispatches to FlashAttention / cuDNN-attn / mem-efficient instead of math. FAIR's original feeds SDPA a 5-D tensor which the fused kernels reject, costing ~13x per call on Ada, ~6x on Hopper. - Swap
x[mask.tile(...)]andx_dec[mask] = ...for explicittorch.gather/scatter_. Removes a slowindexing_backward_kerneland theaten::nonzerograph break that stopstorch.compile.
Optional
from hiera_optim import optimize, enable_stage_checkpointing
optimize(model, sdpa_backend="auto") # per-block SDPA hint
enable_stage_checkpointing(model, stages=(2,)) # OOM lever
GPU support
| Architecture | SM | Status |
|---|---|---|
| Ada (RTX 4090, L40) | SM89 | Tested |
| Hopper (H100, GH200) | SM90 | Tested |
| Ampere (A100) | SM80 | Should work |
| Blackwell (B200) | SM100 | Should work |
Tests
pip install -e .[test]
pytest
112 tests cover all 5 Hiera variants x q_pool {1, 2, 3} x mask ratios x bf16/fp16/fp32 x 1D/2D/3D inputs x classification + MAE.
License
MIT.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file hiera_optim-0.1.0.tar.gz.
File metadata
- Download URL: hiera_optim-0.1.0.tar.gz
- Upload date:
- Size: 30.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d49cd4c56e4e8214a6516375a741963bd17f1a8eb26186d3e3d5db6c4e1f8565
|
|
| MD5 |
170846394bc6b3f06e5091001ac754d0
|
|
| BLAKE2b-256 |
6dde1e1f3ae43fb02f4c25cd9c02a6d44dd8ee172becd6ae9a5dc7f1a7df5a52
|
File details
Details for the file hiera_optim-0.1.0-py3-none-any.whl.
File metadata
- Download URL: hiera_optim-0.1.0-py3-none-any.whl
- Upload date:
- Size: 24.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3245a7386c7afaff0e1918d43a4bbbe8f4883c83ed09dac5f508b15edc032602
|
|
| MD5 |
8847f8834d4997cfda2bf4cc67f06930
|
|
| BLAKE2b-256 |
b5ae9034d08285e15def614e81905ecaea00227854570ba9117ecb9863072f99
|