Skip to main content

Compile-first PyTorch optimizer library - AdamW, Muon, SOAP/Shampoo, PSGD, Schedule-Free, and 30+ more with torch.compile fusion and composable features

Project description

HeavyBall

PyPI version Downloads License

HeavyBall is an optimizer library for PyTorch where every optimizer is assembled from composable, compiled building blocks. It includes API-compatible replacements for torch.optim.AdamW, SGD, and RMSprop, alongside Muon, SOAP ( Shampoo), PSGD (Kronecker), LATHER, ADOPT, Schedule-Free, LaProp, and others.

The building blocks, over 100 functions in utils.py, are each compiled with torch.compile(fullgraph=True) and fuse into Triton kernels. Features like MARS gradient correction, cautious updates, and ECC state compression are implemented as chainable transforms that work as flags on any optimizer. DDP and FSDP are supported, with automatic repartitioning for second-order methods.

Quick Start

pip install heavyball

Requires PyTorch >= 2.2.

from heavyball import AdamW

opt = AdamW(model.parameters(), lr=1e-3)
from heavyball import SOAP  # Shampoo-based preconditioning

opt = SOAP(model.parameters(), lr=3e-3)
from heavyball import LATHER  # Lie-group Adam Through Harmonic Eigenbasis Rotations

opt = LATHER(model.parameters(), lr=1e-3)
from heavyball import Muon

opt = Muon(model.parameters(), lr=0.02, ecc="bf16+8", mars=True, caution=True)
from heavyball import SplitOpt, Muon, AdamW

opt = SplitOpt([
    {'params': matrices, 'optimizer': Muon, 'lr': 0.02},
    {'params': vectors, 'optimizer': AdamW, 'lr': 1e-3},
])

The API matches torch.optim, with the same parameter groups, same step()/zero_grad() interface. See examples/ for training scripts. By default, HeavyBall consumes gradients during step() and clears p.grad once it has used it. Pass consume_grad=False if your training loop needs gradients to remain attached after the optimizer step.

Optimizers

The library covers first-order methods (AdamW, NAdam, RMSprop, ADOPT, LaProp, SGD), orthogonal methods (Muon), Shampoo-based preconditioning (SOAP and variants), PSGD with Kronecker and low-rank factorization, Schedule-Free training, and SAM.

Full list

First-order: AdamW, NAdam, RMSprop, ADOPT, AdEMAMix, LaProp, SignLaProp, SGD, Scion, UnscaledAdamW, AdamC, SUDSAdamW

Schedule-Free: SFAdamW

Schedule-Free optimizers override .eval() and .train() to swap between training and evaluation parameter states. Call opt.eval() before validation and opt.train() before resuming training.

Orthogonal: Muon, MuonAdamW, MuonLaProp, HyperBallAdamW, OrthoLaProp, LaPropOrtho

Shampoo-based (SOAP): SOAP, SOAPNAdam, SOAPAdEMAMix, SOLP

PSGD (Kronecker): PSGDKron, LATHER, PSGDPRO

PSGD (Low-Rank): PSGDLRA

SAM: SAMWrapper, MSAMLaProp

SAMWrapper requires a closure passed to step().

MSAMLaProp overrides .eval() and .train() to swap between training and evaluation parameter states. Call opt.eval() before validation and opt.train() before resuming training.

Meta: SplitOpt

Composable Features

These flags compose freely. For example, LaProp(..., ecc="bf16+8", mars=True, caution=True, palm=True) is valid. They are available on all optimizers except SAMWrapper and SplitOpt, which delegate to inner optimizers.

Flag Effect
mars=True Applies MARS variance reduction via previous gradients.
caution=True Masks update elements that disagree with the gradient direction.
ecc="bf16+8" Compresses optimizer state to bf16 + int8 correction (3 bytes vs fp32's 4). See ECC.
param_ecc="bf16+8" Applies the same compression to parameters.
palm=True Enables PaLM-style beta2 scheduling. Only available on optimizers with beta2
gradient_clipping=... Clips incoming gradients. Accepts "l2_clip_", "rmsnorm_clip_", "trust_region_clip_", "a_law_compress", "mu_law_compress", "softsign_compress", or a custom callable.
update_clipping=... Clips outgoing updates after all transforms. Same options as gradient_clipping.
promote=True Promotes gradients to fp32 before the update.
warmup_steps=N Linear learning rate warmup over N steps.

ECC

ECC stores each optimizer state tensor as a bf16 value plus an int8 correction term (3 bytes total vs fp32's 4 bytes), based on the approach from FlashOptim. HeavyBall integrates ECC as a composable flag: correction tensors are attached as attributes at call time, so any built-in optimizer handles ECC without per-optimizer changes.

opt = AdamW(model.parameters(), lr=1e-3, ecc="bf16+8")
opt = Muon(model.parameters(), lr=0.02, ecc="bf16+8", param_ecc="bf16+8")  # state + params

For first-order optimizers (where all state is momentum and variance), bf16+8 gives roughly 25% state memory savings compared to fp32. For second-order methods, preconditioner matrices are not compressed, so total savings are lower. The encode and decode operations are fully elementwise and fuse into the compiled kernel.

Available modes: bf16+8, bf16+16, fp16+8, fp16+16.

Distributed Training

HeavyBall works with both DDP and FSDP. First-order optimizers are elementwise and operate directly on FSDP shards with no repartitioning. Second-order methods (Muon, SOAP, PSGD) need the full parameter to compute their update, so HeavyBall auto-detects FSDP-sharded parameters on the first step and repartitions them with a metadata-first all_to_all_single exchange: each weight matrix is deterministically assigned to one rank, shard metadata is exchanged up front, the owner reconstructs the full parameter, computes the update once, and returns the updated shards. This saves both compute and memory compared to DDP-style redundant updates, at the cost of communication.

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from heavyball import Muon

model = FSDP(model, use_orig_params=True)  # use_orig_params required for shape detection
opt = Muon(model.parameters(), lr=0.02)

For non-FSDP sharding backends, capture the original parameter shapes before wrapping:

from heavyball import SOAP, capture_param_shapes

shapes = capture_param_shapes(model)
model = your_sharding_wrapper(model)
opt = SOAP(model.parameters(), lr=3e-3, orig_shapes=shapes)

Building Custom Optimizers

Every built-in optimizer is a chain of FunctionTransforms, an API also available for building custom optimizers. Parallel runs parallel transform paths with a merge function, which is useful for grafted optimizers or ensemble updates.

import heavyball.chainable as C


def graft(outputs, eps=1e-8):
    adam_update, sgd_update = outputs
    return [s * (a.norm() / s.norm().add(eps)) for a, s in zip(adam_update, sgd_update)]


class GraftedAdam(C.BaseOpt):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
                 weight_decay=0, warmup_steps=0, multi_tensor=True):
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
                        warmup_steps=warmup_steps, multi_tensor=multi_tensor)
        branch = C.Parallel(branches=[[C.scale_by_adam], [C.identity]], merge_fn=graft)
        super().__init__(params, defaults, fns=(branch,))

Custom optimizers that inherit from BaseOpt get ECC, MARS, caution, clipping, warmup, and stochastic rounding automatically.

Key transforms: scale_by_adam, scale_by_laprop, scale_by_soap, scale_by_psgd, scale_by_adopt, scale_by_ademamix, orthogonalize_update, exp_avg, nesterov_ema, heavyball_momentum, mars, palm_beta2, sign, identity.

How it compiles

Every building block in utils.py is wrapped with torch.compile(fullgraph=True). When one compiled function calls another, the inner function inlines and nested calls fuse into the same compiled graph.

For fused first-order optimizers (AdamW, LaProp, ADOPT, NAdam, AdEMAMix), the entire update runs in a single compiled function and fuses into minimal kernels. Stochastic rounding, ECC encode/decode, weight decay, and cautious masking all fold into the same graph, reducing the memory traffic to a minimum. Adam without add-ons gets reduced from 14 reads + 9 writes in O(N) kernels to 4 reads + 3 writes in one kernel, a 3x speedup.

Second-order methods compile their preconditioning steps separately: Newton-Schulz iterations (Muon) and Kronecker factor updates (PSGD, SOAP) each compile as individual regions, while their elementwise portions still fuse. This avoids suboptimal code paths, at the cost of one graph break.

Custom optimizers built via the chainable API inherit this behavior.

Benchmarks

HeavyBall includes a benchmark suite via LightBench that tests for silent optimizer failures across difficulty levels. Results and methodology are documented in docs/benchmark.md.

benchmarks/bench_release_optimizers.py measures optimizer latency, with AdamW step times dropping from 10.63 ms in HeavyBall 2 to 4.15 ms in HeavyBall 3.

Migrating

From 2.x See the 3.0.0 migration guide for renamed classes, removed kwargs, and checkpoint conversion.

From 1.x See the 2.0.0 migration notes, then follow the 3.0.0 guide.

Contributing

To contribute, fork the repository, install with pip install -e .[dev], and run pytest.

License

BSD-2-Clause, see LICENSE.

The name "HeavyBall" comes from Polyak's heavy-ball method, the momentum technique underlying most modern optimizers.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

heavyball-3.2.0.tar.gz (120.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

heavyball-3.2.0-py3-none-any.whl (75.1 kB view details)

Uploaded Python 3

File details

Details for the file heavyball-3.2.0.tar.gz.

File metadata

  • Download URL: heavyball-3.2.0.tar.gz
  • Upload date:
  • Size: 120.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for heavyball-3.2.0.tar.gz
Algorithm Hash digest
SHA256 7886fcef2da86aaf07ce299de8e64c8f3b777c68b74e5983bba78244e64f9848
MD5 2aa5cd1ee31eb1df0b75a3facaf9a4bc
BLAKE2b-256 d37d723c1e7e3fa890cf821f98f4b5c2e91164d17055ea244f3ab10472f110b6

See more details on using hashes here.

Provenance

The following attestation bundles were made for heavyball-3.2.0.tar.gz:

Publisher: release.yaml on HomebrewML/HeavyBall

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file heavyball-3.2.0-py3-none-any.whl.

File metadata

  • Download URL: heavyball-3.2.0-py3-none-any.whl
  • Upload date:
  • Size: 75.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for heavyball-3.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 2f9ad9d92d0979ee94c75054364f5e6f9e3893735a3d10b1b3fc0a3f1daf5fce
MD5 52048515f80ab61122af3d9cac41cc83
BLAKE2b-256 8dc6aab5b33806912086bd8a553d0398ba1e8aa1807f046e18dde4c0ff8519bb

See more details on using hashes here.

Provenance

The following attestation bundles were made for heavyball-3.2.0-py3-none-any.whl:

Publisher: release.yaml on HomebrewML/HeavyBall

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page