Skip to main content

Momentum-Aligned Gradient Masking — block-wise stochastic masking wrapper for PyTorch optimizers

Project description

Magma

Momentum-Aligned Gradient Masking for Adaptive Optimizers

Magma is a lightweight wrapper that applies block-wise stochastic masking to any PyTorch optimizer, modulated by the alignment between gradient momentum and the current gradient. It is an implementation of the algorithm described in "On Surprising Effectiveness of Masking Updates in Adaptive Optimizers"(arXiv 2602.15322).

The core insight is deceptively simple. At each step, a per-parameter Bernoulli coin flip decides whether to keep or discard the update. Updates that survive are further scaled by a smoothed cosine similarity score between the gradient and its exponential moving average. The base optimizer's internal states i.e Adam's running means or RMSProp's squared gradients are always updated. Only the parameter itself is masked.

This acts as a form of implicit regularization, particularly effective under the heterogeneous curvature and heavy-tailed gradient noise characteristic of transformer training.

Installation

pip install magma-optimizer

Or directly from source:

pip install git+https://github.com/andrijdavid/magma-optimizer.git

Usage

Magma wraps any instantiated PyTorch optimizer. The interface mirrors what you already know.

from magma import Magma
import torch

model = ...  # your model
base = torch.optim.Adam(model.parameters(), lr=1e-3)

optimizer = Magma(
    base,
    mask_prob=0.5,        # prob of keeping an update
    tau=2.0,              # temperature for the alignment sigmoid
    momentum_beta=0.9,    # EMA coefficient for gradient momentum
    alignment_ema=0.9,    # EMA coefficient for smoothing the alignment score
    exclude=set(model.embed.parameters()),  # skip masking on embeddings
)

for x, y in dataloader:
    optimizer.zero_grad()
    loss = criterion(model(x), y)
    loss.backward()
    optimizer.step()

The exclude parameter accepts a set of tensors that should bypass masking entirely. The paper recommends excluding embedding layers, as their update dynamics differ from attention and MLP blocks.

Algorithm

The procedure, applied at each step for each non-excluded parameter:

  1. Update momentum EMA: μ = β·μ + (1−β)·g
  2. Compute alignment: s̃ = sigmoid(cosine_similarity(μ, g) / τ)
  3. Smooth alignment: s = 0.9·s_prev + 0.1·s̃
  4. Run the base optimizer step (all internal states update normally)
  5. Sample mask: m ~ Bernoulli(p)
  6. Apply: θ = (s·m)·θ_new + (1 − s·m)·θ_old

When the mask is zero, the parameter reverts to its pre-step value. When the mask is one, the update is scaled by the alignment score. The base optimizer sees every gradient regardless.

Citation

@article{joo2026magma,
  title={On Surprising Effectiveness of Masking Updates in Adaptive Optimizers},
  author={Joo, Taejong and Xia, Wenhan and Kim, Cheolmin and Zhang, Ming and Ie, Eugene},
  journal={arXiv preprint arXiv:2602.15322},
  year={2026}
}

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

magma_optimizer-0.1.1.tar.gz (962.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

magma_optimizer-0.1.1-py3-none-any.whl (5.8 kB view details)

Uploaded Python 3

File details

Details for the file magma_optimizer-0.1.1.tar.gz.

File metadata

  • Download URL: magma_optimizer-0.1.1.tar.gz
  • Upload date:
  • Size: 962.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for magma_optimizer-0.1.1.tar.gz
Algorithm Hash digest
SHA256 2d7b53a001ccc5d7cc53a57dbcfe87feb148c40ca4ff196d2260d29bbe151a4d
MD5 5bfb9e7e3bcfcd7363ea31f60a181428
BLAKE2b-256 544fa440795e44cee0189db60e799553bd1b4f9ed82d438d1ecd64eafcac8128

See more details on using hashes here.

Provenance

The following attestation bundles were made for magma_optimizer-0.1.1.tar.gz:

Publisher: publish.yml on andrijdavid/magma-optimizer

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file magma_optimizer-0.1.1-py3-none-any.whl.

File metadata

File hashes

Hashes for magma_optimizer-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 4832a144335dd1c732902f94d7dea678eaec553036ec37de679555abc0f6bf72
MD5 56fc453bf27ffa9a1663f1be2cd1bd41
BLAKE2b-256 c34e317411bea664ab04e553fb8a2b9976891217eede59906f48dabd4fa06170

See more details on using hashes here.

Provenance

The following attestation bundles were made for magma_optimizer-0.1.1-py3-none-any.whl:

Publisher: publish.yml on andrijdavid/magma-optimizer

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page