Pytorch optimizer factory with modern init technique
Project description
OptimFactory
Small utilities to make PyTorch optimizer setup and µParam/µP‑style initialization easier.
This repo currently provides:
- µP initialization helpers
mup_init(parameters): init all non‑bias tensors with std1/sqrt(fan_in).mup_init_output(weight): output layer init with std1/fan_in.
- µP parameter‑group factory
mup_param_group(parameters, base_lr, base_dim=256, weight_decay=1e-3, weight_decay_scale=True): builds param groups where learning rate and weight decay are scaled by fan‑in.
- Optional param‑group splitting
muon_param_group_split(param_groups, dim_threshold=64): split groups for separate optimizers (e.g. Muon vs AdamW) based on tensor shape/fan‑in.
The code is intentionally lightweight and pure‑PyTorch.
Install
From source:
pip install -e .
This package requires Python ≥3.10 and torch (and torchvision only if you run MNIST examples).
Quick start
import torch
import torch.nn as nn
import torch.optim as optim
from optimfactory import mup_init, mup_init_output, mup_param_group
model = nn.Sequential(
nn.Linear(128, 512),
nn.ReLU(),
nn.Linear(512, 10),
)
# µP init: skip 1D bias tensors automatically
mup_init(model.parameters())
# output layer often uses a different scale
mup_init_output(model[-1].weight)
param_groups = mup_param_group(
model.parameters(),
base_lr=1e-3,
base_dim=256,
weight_decay=0.1,
weight_decay_scale=True,
)
optimizer = optim.AdamW(param_groups, betas=(0.9, 0.98))
API details
mup_init(params)
Initializes each parameter tensor in params:
- if
param.ndim == 1(bias / norm weight), leave untouched - otherwise compute
fan_in = prod(param.shape[1:]) - sample
N(0, 1/sqrt(fan_in))
mup_init_output(param)
Like mup_init, but uses std 1/fan_in. Useful for final classifiers/heads.
mup_param_group(params, base_lr, base_dim=256, weight_decay=1e-3, weight_decay_scale=True)
Builds param groups keyed by (fan_in, ndim) so same‑shaped tensors share hyper‑params.
Scaling rules:
- For 1D tensors:
lr_scale = 1 - For others:
lr_scale = base_dim / fan_in - Group LR:
base_lr * lr_scale - Group WD:
- if
weight_decay_scale=True:weight_decay / lr_scale - else: fixed
weight_decay
- if
Returned value is a list of dicts suitable for any PyTorch optimizer.
muon_param_group_split(param_groups, dim_threshold=64)
Given param groups (typically from mup_param_group), split into:
muon_group: 2D tensors wherefan_in >= dim_thresholdadam_group: everything else
This is a convenience when you want to use a special optimizer for large matrices.
optimfactory does not ship an optimizer named “Muon”; if you use one, it’s from elsewhere.
ComboOptimizer(optimizers) / ComboLRScheduler(schedulers)
Lightweight wrappers to treat multiple optimizers or LR schedulers as one object.
ComboOptimizer.step()/.zero_grad()forward to each child optimizer.ComboOptimizeraccepts optionalclip_grad_normandgrad_scaler(torch.amp.GradScaler) for global clipping and AMP.ComboLRScheduler.step()forwards to each child scheduler.- Both support
.state_dict()and.load_state_dict()by storing child state dicts in a list.
Examples
example/mnist.py: MNIST CNN/MLP hybrid with µP init and µP‑scaled param groups.- It references
optim.Muonandanyschedule.AnySchedule, which are external. - If you don’t have them installed, set
USE_MUON=Falseor use the basic example below.
- It references
example/basic_usage.py: minimal MLP training loop showing only optimfactory usage.
Running examples:
python example/basic_usage.py
python example/mnist.py
Notes / roadmap
- The project is small; PRs for more init schemes, group rules, or example notebooks are welcome.
- If you want more µP theory background, search for “μParametrization / µP” papers and guides.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file optimfactory-0.0.1-py3-none-any.whl.
File metadata
- Download URL: optimfactory-0.0.1-py3-none-any.whl
- Upload date:
- Size: 12.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0be97c220349a4e86b0809d62e0642a976071214291972dec6985df020696674
|
|
| MD5 |
625b30af8962d63fc87122993962f361
|
|
| BLAKE2b-256 |
88c99bbe78ba5086c2af71f4413dbbe1856a21aa3e6058f486d63ac906fb2369
|