Skip to main content

Muon opimizer

Project description

Muon: An optimizer for the hidden layers of neural networks

This repo contains an implementation of the Muon optimizer originally described in this thread and this writeup.

Installation

pip install git+https://github.com/KellerJordan/Muon

Usage

Muon is intended to optimize only the internal ≥2D parameters of a network. Embeddings, classifier heads, and internal gains/biases should be optimized using AdamW.

# optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, betas=(0.90, 0.95), weight_decay=0.01)

from muon import MuonWithAuxAdam
# Find ≥2D parameters in the body of the network -- these should be optimized by Muon
hidden_weights = [p for p in model.body.parameters() if p.ndim >= 2]
# Find everything else -- these should be optimized by AdamW
hidden_gains_biases = [p for p in model.body.parameters() if p.ndim < 2]
exterior_weights = [*model.head.parameters(), *model.embed.parameters()])
# Create the optimizer
# Note: you can also use multiple groups of each type with different hparams if you want.
muon_group = dict(params=hidden_weights, lr=0.02, weight_decay=0.01, use_muon=True)
adam_group = dict(params=hidden_gains_biases+exterior_weights, lr=3e-4,
                  betas=(0.9, 0.95), weight_decay=0.01, use_muon=False)
optimizer = MuonWithAuxAdam([muon_group, adam_group])

You'll have to replace model.body, model.head, and model.embed with whatever subset is appropriate for your model. E.g., for a ConvNet, Muon should optimize all the convolutional filters except the first one, and AdamW should optimize everything else.

Example usage

Example use in the NanoGPT speedrun

Example use in the CIFAR-10 speedrun

Hyperparameter tuning

Typically, the default values of momentum (0.95), nesterov (True), and ns_steps (5) work well. The only hyperparameter which must be tuned is the learning rate. It should have constant muP scaling, that is, as you scale up the model size, you shouldn't need to retune the learning rate.

Benchmarks

For a comparison between AdamW, Shampoo, SOAP, and Muon for training a 124M-parameter transformer, see here.

Accomplishments

More learning resources and results about Muon

Citation

@misc{jordan2024muon,
  author       = {Keller Jordan and Yuchen Jin and Vlado Boza and You Jiacheng and
                  Franz Cesista and Laker Newhouse and Jeremy Bernstein},
  title        = {Muon: An optimizer for hidden layers in neural networks},
  year         = {2024},
  url          = {https://kellerjordan.github.io/posts/muon/}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

muon_optimizer-0.1.0.tar.gz (6.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

muon_optimizer-0.1.0-py3-none-any.whl (7.1 kB view details)

Uploaded Python 3

File details

Details for the file muon_optimizer-0.1.0.tar.gz.

File metadata

  • Download URL: muon_optimizer-0.1.0.tar.gz
  • Upload date:
  • Size: 6.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.3

File hashes

Hashes for muon_optimizer-0.1.0.tar.gz
Algorithm Hash digest
SHA256 65c50441f29b7248e586383d371262eb9062653e820b11143c81a7a10ba48da7
MD5 c102e651eb96a06038b2bb403f6c47b6
BLAKE2b-256 028e753080860c9c0f5333bfe20ebb194f293af77274cfaf3f2bd999511ecf6e

See more details on using hashes here.

File details

Details for the file muon_optimizer-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: muon_optimizer-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 7.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.3

File hashes

Hashes for muon_optimizer-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 4b8e22e7c1a8e73c254c8f6941128ad490af1aba883d4eec0c7f0b8e2f330735
MD5 08d0c5c2d3f0a22a394aceb0a97ac055
BLAKE2b-256 e8e7ae0b654bb12f7dfeb688852695e3c54d9ce0ac51e6fdcbeebe8b8d95c0f1

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page