Skip to main content

Fast Newton-Schulz Algorithm with Kernels

Project description

Gram Newton-Schulz: A Fast, Hardware-Aware Newton-Schulz Algorithm for Muon

Authors: Jack Zhang, Noah Amsel, Berlin Chen, Tri Dao
Blogpost: https://dao-ailab.github.io/blog/2026/gram-newton-schulz/

Achieve up to 2x faster Newton-Schulz with Gram Newton-Schulz and symmetric CuTeDSL GEMM kernels!

What you're probably here for:

  1. Gram Newton-Schulz: https://github.com/Dao-AILab/gram-newton-schulz/blob/main/gram_newton_schulz/gram_newton_schulz.py
  2. Gram Newton-Schulz Restart Autotune: https://github.com/Dao-AILab/gram-newton-schulz/blob/main/gram_newton_schulz/restart_autotune.py
  3. Symmetric GEMMs for Hopper and Blackwell in CuTeDSL: https://github.com/Dao-AILab/quack/blob/main/quack/gemm_symmetric.py

About

Gram Newton-Schulz is a hardware-aware algorithm for polar decomposition that is mathematically equivalent to and faster than Newton-Schulz. Polar decomposition is most commonly used in Muon, and Gram Newton-Schulz serves as a direct drop-in for standard Newton-Schulz with no training accuracy tradeoff.

Instead of iterating on the expensive $X \in \mathbb{R}^{n \times m}$ matrix, Gram Newton-Schulz iterates on the small, square, symmetric Gram matrix $XX^\top \in \mathbb{R}^{n \times n}$, lowering FLOPs and enabling more symmetric GEMM kernels.

Gram Newton-Schulz

Input: $X \in \mathbb{R}^{n \times m}$ with $n \leq m$, coefficients ${(a_t, b_t, c_t)}_{t=1}^5$

  1. $X \gets X / (\|X\|_{F} + \epsilon)$   // Normalize sing vals to $[0, 1]$.   $\epsilon = 10^{-7}$
  2. $X \gets \texttt{float16}(X)$   // Cast to half precision for speed
  3. If $m < n$:   $X \gets X^\top$   // Trick to make $XX^\top$ cheaper
  4. $R_0 \gets XX^\top$
  5. $Q_0 \gets I$
  6. For $t = 1, \ldots, 5$:
    • If $t = 3$:   // Restart to stabilize
      • $X \gets Q_2 X$
      • $R_2 \gets XX^\top$
      • $Q_2 \gets I$
    • $Z_t \gets b_t R_{t-1} + c_t R_{t-1}^2$
    • $Q_t \gets Q_{t-1} Z_t + a_t Q_{t-1}$
    • $RZ_t \gets R_{t-1} Z_t + a_t R_{t-1}$
    • $R_t \gets Z_t (RZ_t) + a_t (RZ_t)$
  7. $X \gets Q_4 X$
  8. If $m < n$:   $X \gets X^\top$   // Undo trick
  9. Return $X$
kimi (2)

Installation

Requirements:

  • NVIDIA Hopper (H100) or Blackwell (B200/B300) GPU
  • PyTorch 2.7.1+
  • CUDA 12.9+

Install PyTorch first, then install from PyPI with pip install gram-newton-schulz --no-build-isolation or from source:

pip install . --no-build-isolation

--no-build-isolation is required so that pip uses your existing CUDA-enabled PyTorch instead of installing torch-cpu in an isolated build environment.

This will install:

  • gram-newton-schulz (this package)
  • nvidia-cutlass-dsl 4.4.2
  • quack-kernels>=0.3.10

Usage

WARNING: torch.compile is known to sometimes have issues with Blackwell, TORCH_COMPILE_DISABLE=1 to disable torch.compile.

Gram Newton-Schulz

from gram_newton_schulz import GramNewtonSchulz, POLAR_EXPRESS_COEFFICIENTS

gram_NS = GramNewtonSchulz(
            ns_coefficients=POLAR_EXPRESS_COEFFICIENTS,
            gram_newton_schulz_reset_iterations=[2]
        )
result = gram_NS(X)

GramNewtonSchulz is a callable function that is initialized with ns_coefficients (List of List of floats) and a list of gram_newton_schulz_reset_iterations immediately after which to restart Gram Newton-Schulz's iterative loop (List of ints) for stability. For example, [2] means a restart occurs after the 2nd iteration and [2,4] means a restart occurs after the 2nd iteration and then after the 4th.

To find the best num-restarts restart location(s) for a set of coefficients, run

python -m gram_newton_schulz.autotune_restarts --num-restarts 1 --coefs "4.0848,-6.8946,2.9270;3.9505,-6.3029,2.6377;3.7418,-5.5913,2.3037;2.8769,-3.1427,1.2046;2.8366,-3.0525,1.2012"

For 5 steps of Newton-Schulz, we recommend num-restarts = 1 for maximum speed while maintaining numerical stability. However, users who experience numerical instability or use more than 5 steps should consider using more restarts.

Muon

The Muon class supports an auxiliary scalar optimizer that updates all non-Muon parameters, custom functions that split model weights for orthogonalization, and Gram Newton-Schulz with autotuned restart locations.

import torch
from torch.optim import AdamW
from gram_newton_schulz import Muon, YOU_COEFFICIENTS

qkv_params = []
regular_2d_params = []
scalar_params = []

for name, param in model.named_parameters():
    if 'qkv_weight' in name:
      qkv_params.append(param)
    elif param.ndim >= 2:
      regular_2d_params.append(param)
    else:
      scalar_params.append(param)

scalar_optimizer = AdamW(
    scalar_params,
    lr=1e-3,
    betas=(0.9, 0.95),
    weight_decay=0.1,
)

def qkv_split_fn(param: torch.Tensor):
    """
    Split Wqkv into [Wq, Wk, Wv].

    Assumes param has shape (3*hidden_dim, hidden_dim) where the first dimension
    is concatenated [Q, K, V] weights.
    """
    hidden_dim = param.size(1)
    Wq = param[:hidden_dim, :]
    Wk = param[hidden_dim:2*hidden_dim, :]
    Wv = param[2*hidden_dim:, :]
    return [Wq, Wk, Wv]

def qkv_recombine_fn(splits):
    """Recombine [Wq, Wk, Wv] back into Wqkv."""
    return torch.cat(splits, dim=0)

muon_param_groups = []

muon_param_groups.append({
    'params': qkv_params,
    'param_split_fn': qkv_split_fn,
    'param_recombine_fn': qkv_recombine_fn,
    'lr': 3e-3,
    'weight_decay': 0.1,
    'momentum': 0.95,
})

muon_param_groups.append({
    'params': regular_2d_params,
    'lr': 3e-3,
    'weight_decay': 0.1,
    'momentum': 0.95,
})

optimizer = Muon(
    params=muon_param_groups,
    scalar_optimizer=scalar_optimizer,
    lr=3e-3,
    weight_decay=0.1,
    momentum=0.95,
    nesterov=True,
    adjust_lr='rms_norm',
    ns_algorithm='gram_newton_schulz',
    ns_use_kernels=True,
    ns_coefficients=YOU_COEFFICIENTS,
    gram_newton_schulz_num_restarts=1,
)

See example.py for a full training example.

Citation

If you use this codebase, or otherwise find our work valuable, please cite Gram Newton-Schulz:

@misc{GramNewtonSchulz,
  title   = {Gram Newton-Schulz},
  author  = {Jack Zhang and Noah Amsel and Berlin Chen and Tri Dao},
  year    = {2026},
  url     = {https://dao-ailab.github.io/blog/2026/gram-newton-schulz/}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gram_newton_schulz-0.1.3.tar.gz (17.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

gram_newton_schulz-0.1.3-py3-none-any.whl (18.3 kB view details)

Uploaded Python 3

File details

Details for the file gram_newton_schulz-0.1.3.tar.gz.

File metadata

  • Download URL: gram_newton_schulz-0.1.3.tar.gz
  • Upload date:
  • Size: 17.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.20

File hashes

Hashes for gram_newton_schulz-0.1.3.tar.gz
Algorithm Hash digest
SHA256 356dd938e9c4bf33b5b307499532ffddcb1b56b3349f8807d8c760d1d0c71b0b
MD5 f7752ba497e4e7ab9712f8c28d07a8dd
BLAKE2b-256 4bba8306e38e4e9e160a756aadd36ebd8b832750bfdaa279f29e7b25129900ee

See more details on using hashes here.

File details

Details for the file gram_newton_schulz-0.1.3-py3-none-any.whl.

File metadata

File hashes

Hashes for gram_newton_schulz-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 11608484c97ffa51be5ec11d50d01572f3dc11eec01f8d27f9bb8f3fd48ae049
MD5 e4cc474c2b4438f18cae9b39e61a986b
BLAKE2b-256 114a38f804b7b8fec9b4de8fce3eee91e02108b005890cc740f51ed76dcc9427

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page