Compile-friendly, world-size-aware unit scaling
Project description
dd-unit-scaling
A production-ready, thin wrapper around graphcore-research/unit-scaling that makes u-μP work reliably with torch.compile, FSDP2, and distributed training at scale. Used to train Toto 2.0.
Background
Why u-μP?
Toto 2.0 uses u-μP (Unit-Scaled Maximal Update Parameterization) rather than standard μP. Standard μP requires running a base model to compute reference scales, then transferring those scales to the target model size. u-μP eliminates this — scaling factors are derived directly from layer fan-in/fan-out, so there is no base model metadata to manage. u-μP has also been shown to converge better on decoder-only transformers.
Why a thin wrapper?
The upstream unit_scaling library implements u-μP correctly in principle, but has accumulated small issues that prevent it from working with torch.compile and distributed training at scale. Rather than fork the entire library, dd-unit-scaling re-exports everything from upstream and overrides only the broken pieces. This keeps the surface area minimal and makes it easy to drop if upstream fixes land.
Installation
pip install dd-unit-scaling
For Muon-family optimizers, also install dion:
pip install git+https://github.com/microsoft/dion.git
Usage
For a conceptual guide on applying u-μP to your model, see the unit-scaling demo notebook.
Drop-in replacement for unit_scaling:
import dd_unit_scaling as uu
from dd_unit_scaling import functional as U
out = U.linear(x, weight, bias)
normed = U.rms_norm(x, normalized_shape=(dim,), weight=w)
Optimizers:
# AdamW for bias/norm/output params
opt_adam = uu.AdamW(bias_params, lr=1e-3)
# Muon-family for weight params
opt_muon = uu.Muon(weight_params, lr=0.02)
opt_dion2 = uu.Dion2(weight_params, lr=0.02)
opt_normuon = uu.NorMuon(weight_params, lr=0.02, use_polar_express=True)
Features
torch.compile fixes
Upstream unit_scaling has a collection of issues that break torch.compile — isinstance checks on fx.proxy.Proxy, ints where dynamo expects floats, and other tracing-unfriendly patterns. This package fixes all of them. Core primitives (scale_fwd, scale_bwd) are rewritten using PyTorch 2.x's setup_context pattern, and residual ops and activations are reimplemented on top of them.
World-size-aware scaling
In DDP/FSDP, gradients are averaged across workers, so batch-dependent scale factors need to account for the full effective batch: local_batch × world_size × grad_accumulation_steps. Call init_world_size_cache() before torch.compile to cache the world size as a plain int (avoiding graph breaks from process group calls).
u-μP optimizers with FSDP2 support
Provides AdamW, Muon, Dion2, and NorMuon wrappers that apply per-parameter u-μP LR scaling. Metadata (fan-in, fan-out, mup_type) is cached by parameter name before FSDP wrapping, since FSDP2 replaces parameter tensors with DTensors. Includes Polar Express orthogonalization as an alternative to Newton-Schulz for NorMuon (also used in Karpathy's nanochat).
Distributed training setup
When training with DDP/FSDP, three caching calls must happen in a specific order:
import dd_unit_scaling as uu
# 1. Build model (param shapes are still intact here)
model = MyModel()
# 2. Cache μP metadata BEFORE FSDP wrapping
uu.cache_fan_values(model.named_parameters())
# 3. Apply FSDP/DDP wrapping (replaces param tensors with DTensors)
fully_shard(model)
# 4. Set world size and grad accumulation BEFORE torch.compile
# (these become plain ints that dynamo can trace without graph breaks)
uu.init_world_size_cache(world_size=dist.get_world_size())
uu.set_grad_accumulation_steps(accum_steps)
# 5. Compile
model = torch.compile(model)
# 6. Create optimizers (they read the cached metadata)
opt = uu.AdamW(bias_params, lr=1e-3)
opt_muon = uu.Dion2(weight_params, lr=0.02)
cache_fan_valuesmust happen before FSDP wrapping because FSDP replaces parameter tensors with DTensors, changing their shapes. The cache stores fan-in/fan-out/mup_type by parameter name so the optimizers can look them up later.init_world_size_cacheandset_grad_accumulation_stepsmust happen beforetorch.compileso the values are baked in as plain ints. If called after compile, they cause graph breaks.
Setting world_size by parallelism strategy
The goal is to make each GPU's local input.numel() reflect the global batch — as if all data were on a single device. The effective global batch is local_batch × world_size × grad_accumulation_steps.
world_size should be the product of all parallelism dimensions where ranks process different data:
| Dimension | Different data? | Counts toward world_size? |
|---|---|---|
| DP (data parallel, incl. FSDP) | Yes — each rank sees a different batch | Yes |
| CP (context parallel) | Yes — each rank sees different sequence chunks | Yes |
| TP (tensor parallel) | No — ranks split the same input across heads/features | No |
| SP (sequence parallel) | No — activation sharding within TP groups | No |
| PP (pipeline parallel) | No — ranks process different stages of the same micro-batch | No |
| EP (expert parallel) | No — same batch, different experts | No |
world_size = dp × cp (where dp is the total data-parallel degree, whether DDP, FSDP, or HSDP).
# Examples
uu.init_world_size_cache(dp) # DDP or FSDP only
uu.init_world_size_cache(dp * cp) # DP + CP
uu.init_world_size_cache(dp * cp) # DP + CP + TP + PP (TP/PP don't count)
# Single GPU — no call needed (defaults to 1)
If you use gradient accumulation, also set:
uu.set_grad_accumulation_steps(accum_steps) # defaults to 1
Loss normalization in distributed training
When computing weighted loss averages in DDP/FSDP, apply a world-size correction to maintain hyperparameter stability across different world sizes. After all-reducing the total weight across ranks, divide it by world_size:
loss = loss.sum() / (total_weight / world_size) # equivalent to: loss * world_size / total_weight
This compensates for DDP's gradient averaging and ensures world-size invariance — your learning rate and other hyperparameters remain stable regardless of how many GPUs you train on.
Design note: sequence-length-independent scaling in Toto 2.0
This isn't a feature of dd-unit-scaling itself, but a design choice in Toto 2.0 worth calling out. Toto 2.0 uses unscaled F.scaled_dot_product_attention (PyTorch native SDPA) instead of unit-scaled SDPA, so no scale factors depend on sequence length — making the model compatible with KV-cache inference. The resulting attn/MLP variance imbalance is compensated via residual_attn_ratio = sqrt(S / log(S)) (where S = context_length / patch_size), which adjusts the residual tau values so that attention branches get proportionally more weight. residual_mult is set to 0.75 (the unit_scaling default is 1.0).
Requirements
- Python >= 3.12
- PyTorch >= 2.4.0
- unit-scaling >= 0.2.0
- dion (optional, for Muon/Dion2/NorMuon)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file dd_unit_scaling-0.1.0.tar.gz.
File metadata
- Download URL: dd_unit_scaling-0.1.0.tar.gz
- Upload date:
- Size: 15.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
126b6f740f10636aaabb4d237b2d868ac3a131f62d3ddecb589347d2fa6b1970
|
|
| MD5 |
a09f32acd1ce68aeda1989e900bcb8a7
|
|
| BLAKE2b-256 |
af8cb269c756abf79a37458a87aba450b2d8779ac8c3d00495481e3a812a392f
|
Provenance
The following attestation bundles were made for dd_unit_scaling-0.1.0.tar.gz:
Publisher:
pypi-publish-dd-unit-scaling.yml on DataDog/toto
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
dd_unit_scaling-0.1.0.tar.gz -
Subject digest:
126b6f740f10636aaabb4d237b2d868ac3a131f62d3ddecb589347d2fa6b1970 - Sigstore transparency entry: 1717613742
- Sigstore integration time:
-
Permalink:
DataDog/toto@44ea4e88852228039564aa3e76fac26aafac0803 -
Branch / Tag:
refs/tags/dd-unit-scaling/v0.1.0 - Owner: https://github.com/DataDog
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
pypi-publish-dd-unit-scaling.yml@44ea4e88852228039564aa3e76fac26aafac0803 -
Trigger Event:
release
-
Statement type:
File details
Details for the file dd_unit_scaling-0.1.0-py3-none-any.whl.
File metadata
- Download URL: dd_unit_scaling-0.1.0-py3-none-any.whl
- Upload date:
- Size: 14.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
701b1fa8c5602e51e3668fc45d4e27b50b34aebf8189a772eede463d3cc1bc3a
|
|
| MD5 |
3c8c0af8e89375231080f90ecfe54f3f
|
|
| BLAKE2b-256 |
66f20f19c1ffcf9fc2e192ed0b2186a652956705fb60b026dfab6ea23f5b9ea2
|
Provenance
The following attestation bundles were made for dd_unit_scaling-0.1.0-py3-none-any.whl:
Publisher:
pypi-publish-dd-unit-scaling.yml on DataDog/toto
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
dd_unit_scaling-0.1.0-py3-none-any.whl -
Subject digest:
701b1fa8c5602e51e3668fc45d4e27b50b34aebf8189a772eede463d3cc1bc3a - Sigstore transparency entry: 1717614114
- Sigstore integration time:
-
Permalink:
DataDog/toto@44ea4e88852228039564aa3e76fac26aafac0803 -
Branch / Tag:
refs/tags/dd-unit-scaling/v0.1.0 - Owner: https://github.com/DataDog
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
pypi-publish-dd-unit-scaling.yml@44ea4e88852228039564aa3e76fac26aafac0803 -
Trigger Event:
release
-
Statement type: