Skip to main content

Pytorch optimizer factory with modern init technique

Project description

OptimFactory

Small utilities to make PyTorch optimizer setup and µParam/µP‑style initialization easier.

This repo currently provides:

  • µP initialization helpers
    • mup_init(parameters): init all non‑bias tensors with std 1/sqrt(fan_in).
    • mup_init_output(weight): output layer init with std 1/fan_in.
  • µP parameter‑group factory
    • mup_param_group(parameters, base_lr, base_dim=256, weight_decay=1e-3, weight_decay_scale=True): builds param groups where learning rate and weight decay are scaled by fan‑in.
  • Optional param‑group splitting
    • muon_param_group_split(param_groups, dim_threshold=64): split groups for separate optimizers (e.g. Muon vs AdamW) based on tensor shape/fan‑in.

The code is intentionally lightweight and pure‑PyTorch.

Install

From source:

pip install -e .

This package requires Python ≥3.10 and torch (and torchvision only if you run MNIST examples).

Quick start

import torch
import torch.nn as nn
import torch.optim as optim

from optimfactory import mup_init, mup_init_output, mup_param_group

model = nn.Sequential(
    nn.Linear(128, 512),
    nn.ReLU(),
    nn.Linear(512, 10),
)

# µP init: skip 1D bias tensors automatically
mup_init(model.parameters())
# output layer often uses a different scale
mup_init_output(model[-1].weight)

param_groups = mup_param_group(
    model.parameters(),
    base_lr=1e-3,
    base_dim=256,
    weight_decay=0.1,
    weight_decay_scale=True,
)

optimizer = optim.AdamW(param_groups, betas=(0.9, 0.98))

API details

mup_init(params)

Initializes each parameter tensor in params:

  • if param.ndim == 1 (bias / norm weight), leave untouched
  • otherwise compute fan_in = prod(param.shape[1:])
  • sample N(0, 1/sqrt(fan_in))

mup_init_output(param)

Like mup_init, but uses std 1/fan_in. Useful for final classifiers/heads.

mup_param_group(params, base_lr, base_dim=256, weight_decay=1e-3, weight_decay_scale=True)

Builds param groups keyed by (fan_in, ndim) so same‑shaped tensors share hyper‑params.

Scaling rules:

  • For 1D tensors: lr_scale = 1
  • For others: lr_scale = base_dim / fan_in
  • Group LR: base_lr * lr_scale
  • Group WD:
    • if weight_decay_scale=True: weight_decay / lr_scale
    • else: fixed weight_decay

Returned value is a list of dicts suitable for any PyTorch optimizer.

muon_param_group_split(param_groups, dim_threshold=64)

Given param groups (typically from mup_param_group), split into:

  • muon_group: 2D tensors where fan_in >= dim_threshold
  • adam_group: everything else

This is a convenience when you want to use a special optimizer for large matrices. optimfactory does not ship an optimizer named “Muon”; if you use one, it’s from elsewhere.

ComboOptimizer(optimizers) / ComboLRScheduler(schedulers)

Lightweight wrappers to treat multiple optimizers or LR schedulers as one object.

  • ComboOptimizer.step() / .zero_grad() forward to each child optimizer.
  • ComboOptimizer accepts optional clip_grad_norm and grad_scaler (torch.amp.GradScaler) for global clipping and AMP.
  • ComboLRScheduler.step() forwards to each child scheduler.
  • Both support .state_dict() and .load_state_dict() by storing child state dicts in a list.

Examples

  • example/mnist.py: MNIST CNN/MLP hybrid with µP init and µP‑scaled param groups.
    • It references optim.Muon and anyschedule.AnySchedule, which are external.
    • If you don’t have them installed, set USE_MUON=False or use the basic example below.
  • example/basic_usage.py: minimal MLP training loop showing only optimfactory usage.

Running examples:

python example/basic_usage.py
python example/mnist.py

Notes / roadmap

  • The project is small; PRs for more init schemes, group rules, or example notebooks are welcome.
  • If you want more µP theory background, search for “μParametrization / µP” papers and guides.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

optimfactory-0.0.1-py3-none-any.whl (12.5 kB view details)

Uploaded Python 3

File details

Details for the file optimfactory-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: optimfactory-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 12.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.11

File hashes

Hashes for optimfactory-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 0be97c220349a4e86b0809d62e0642a976071214291972dec6985df020696674
MD5 625b30af8962d63fc87122993962f361
BLAKE2b-256 88c99bbe78ba5086c2af71f4413dbbe1856a21aa3e6058f486d63ac906fb2369

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page