Skip to main content

Adams optimizer: next-generation optimizer blending element-wise methods with matrix-aware regularization

Project description

Adams Optimizer

Adams is a next-generation optimizer that blends the simplicity of element-wise methods with the stability benefits of matrix-aware regularization. It updates neural network parameters in both 1D (per-element) and 2D (per-matrix) ways, staying fast and easy to parallelize like Adam/AdamW while improving stability and generalization.

  • Stable: no loss spikes observed; no gradient clipping required.
  • Fast & scalable: element-wise updates + one rank-1 spectral decay step per matrix; easily parallelizable.
  • Simple: no epsilon hyperparameter; truly scale-invariant per-parameter update.

Definition 📝

Adams pseudocode

How Adams Works 🌟

1) Bounded, element-wise update (1D)

Small second-moment estimates are a major source of instability and loss spikes in Adam-like methods. Adams replaces the usual preconditioned step with a bounded update using atan2:

$$ \Delta \theta \propto \text{atan2}\big(\hat m_t,\sqrt{\hat n_t}\big), $$

which:

  • naturally bounds the step size,
  • removes the need for the epsilon hyperparameter,
  • yields true scale invariance of the update.

2) Spectral weight decay (2D)

For matrix parameters $W \in \mathbb{R}^{M \times N}$, spectral norm better reflects the scale relevant to activations than the Frobenius norm. Adams therefore applies decoupled spectral weight decay (akin to AdamW’s decoupling), replacing the usual $\tfrac{1}{2}|W|_F^2$ with the spectral norm $\tfrac{1}{2}\sigma_1^2$:

  • We compute a one-step power iteration with persistent state (same idea as PyTorch’s spectral_norm) to approximate the top singular triplet $(u_1, \sigma_1, v_1)$.
  • The decay term is applied as $\sqrt{M} u_1 \sigma_1 v^\top_1$ (the gradient of $\tfrac{1}{2}\sigma_1^2$, scaled by $\sqrt{M}$ to match the RMS of $W$) per update step.
  • This helps control activation scales and mitigates instabilities tied to large spectral norms.

Efficiency: the spectral step adds only two GEMV operations per matrix per update, comparable to a handful of extra element-wise ops. In typical FSDP/ZeRO setups the full weight matrix is available during forward/backward, so this integrates cleanly at scale.

Design Motivation 💡

Recent reports suggest that fully matrix-based optimizers (e.g., Muon) can be hard to implement/parallelize broadly and often show modest end-to-end benefits on large models (~1.1x or less), despite strong stability. Meanwhile, the dominant optimizer Adam is simple and fast but prone to instability and loss spikes.

Adams asks: Can we keep Adam’s speed and simplicity while gaining matrix-level stability?

Installation

pip install adams-torch

Quick Start 📈

You don’t need to manually broadcast parameters or all-reduce gradients—multi-GPU usage matches single-GPU usage. Fully compatible with torch.compile.

FSDP is not supported yet. Contributions welcome.

import os
import torch
import torch.distributed as dist
from adams import Adams_ZeRO  # main optimizer

def init():
    # Initialize distributed training if launched via torchrun/torch.distributed
    if "LOCAL_RANK" in os.environ:
        dist.init_process_group(backend="nccl")
        torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Model().to(device)

    # Spectral decay applies to matrix-shaped params.
    # scalar_vector_weight_decay applies standard decoupled L2 to 0D/1D params.
    optimizer = Adams_ZeRO(
        model.parameters(),
        lr=3e-4,
        weight_decay=0.1,                 # spectral decay for matrices
        scalar_vector_weight_decay=0.1,   # L2 for scalars/vectors
        betas=(0.9, 0.95)
    )

    # Sync any internal buffers across ranks if required by your setup.
    optimizer.broadcast_buffers(model.buffers())

    return model, optimizer

@torch.compile  # Optional: works with torch.compile
def train_step(model, optimizer, batch):
    loss = model(batch)        # forward; compute your loss
    loss.backward()            # backward
    optimizer.step()           # no gradient clipping needed
    optimizer.zero_grad(set_to_none=True)
    return loss

Notes ⚠️

Care should be taken as matrix-based optimizers (e.g. Muon).

  1. Non‑matrix parameters. Disable the matrix‑based part (spectral decay) for parameters that are scalars, vectors, or collections of vectors (e.g. LayerNorm, Embedding, Output Head, etc.) by setting param.use_spectral_decay = False. Adams uses a separate decoupled L2 term, controlled by scalar_vector_weight_decay (default 0.1).
  2. Batched matrices. Parameters that are conceptually multiple matrices concatenated along leading dimensions (e.g., attention QKV projections) should be expressed with shape (B, M, N). Adams treats all dimensions except the last two as batch dimensions. (P.S. In our experiments, we treat each Attention head q,k,v as separate projection matrices. E.g. there are 24 (8 head * 3) matrices in QKV proj for 8 MHA heads)

Practical Tips ✏️

  • Hyperparameters: start with AdamW-like settings; the bounded update removes epsilon. Adams can handle much larger weight decay to improve generalization, e.g. 1.0.
  • Stability: the bounded step and spectral decay together target sources of spikes linked to tiny second moments and large spectral norms.
  • Generalization & adversarial robustness: spectral regularization is widely observed to improve both, and Adams adopts a lightweight decoupled form.

References

  1. Scaling Exponents Across Parameterizations and Optimizers
  2. Adaptive Preconditioners Trigger Loss Spikes in Adam
  3. Muon: An optimizer for the hidden layers of neural networks
  4. Spectral Norm Regularization for Improving the Generalizability of Deep Learning
  5. Thinking from spectral norm gradient to new weight decay

License

Apache-2.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

adams_torch-0.0.1.tar.gz (316.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

adams_torch-0.0.1-py3-none-any.whl (12.3 kB view details)

Uploaded Python 3

File details

Details for the file adams_torch-0.0.1.tar.gz.

File metadata

  • Download URL: adams_torch-0.0.1.tar.gz
  • Upload date:
  • Size: 316.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for adams_torch-0.0.1.tar.gz
Algorithm Hash digest
SHA256 d37a4822c45904f68a04102b0d35fe0474be1e22eccc2c8d27e870f50e1dde11
MD5 1c89361a9796dc350734ad35bd5f5c7b
BLAKE2b-256 9bcae85c7ef7b0af57ecc3c9322721459242c7b44c6412e74feae1e08ea15779

See more details on using hashes here.

File details

Details for the file adams_torch-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: adams_torch-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 12.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for adams_torch-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 6f0588925f4ef113dd831f83644551025619b26dcd5add08a4f11ac30cc015ab
MD5 cb34663ef77f4283348c5eb8c2088a3f
BLAKE2b-256 b7f12486f865bcd1abf5a25acf3d066d960692278b700c92c099ec4e4dbdc0da

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page