Skip to main content

Compile-friendly, world-size-aware unit scaling

Project description

dd-unit-scaling

A production-ready, thin wrapper around graphcore-research/unit-scaling that makes u-μP work reliably with torch.compile, FSDP2, and distributed training at scale. Used to train Toto 2.0.

Background

Why u-μP?

Toto 2.0 uses u-μP (Unit-Scaled Maximal Update Parameterization) rather than standard μP. Standard μP requires running a base model to compute reference scales, then transferring those scales to the target model size. u-μP eliminates this — scaling factors are derived directly from layer fan-in/fan-out, so there is no base model metadata to manage. u-μP has also been shown to converge better on decoder-only transformers.

Why a thin wrapper?

The upstream unit_scaling library implements u-μP correctly in principle, but has accumulated small issues that prevent it from working with torch.compile and distributed training at scale. Rather than fork the entire library, dd-unit-scaling re-exports everything from upstream and overrides only the broken pieces. This keeps the surface area minimal and makes it easy to drop if upstream fixes land.

Installation

pip install dd-unit-scaling

For Muon-family optimizers, also install dion:

pip install git+https://github.com/microsoft/dion.git

Usage

For a conceptual guide on applying u-μP to your model, see the unit-scaling demo notebook.

Drop-in replacement for unit_scaling:

import dd_unit_scaling as uu
from dd_unit_scaling import functional as U

out = U.linear(x, weight, bias)
normed = U.rms_norm(x, normalized_shape=(dim,), weight=w)

Optimizers:

# AdamW for bias/norm/output params
opt_adam = uu.AdamW(bias_params, lr=1e-3)

# Muon-family for weight params
opt_muon = uu.Muon(weight_params, lr=0.02)
opt_dion2 = uu.Dion2(weight_params, lr=0.02)
opt_normuon = uu.NorMuon(weight_params, lr=0.02, use_polar_express=True)

Features

torch.compile fixes

Upstream unit_scaling has a collection of issues that break torch.compileisinstance checks on fx.proxy.Proxy, ints where dynamo expects floats, and other tracing-unfriendly patterns. This package fixes all of them. Core primitives (scale_fwd, scale_bwd) are rewritten using PyTorch 2.x's setup_context pattern, and residual ops and activations are reimplemented on top of them.

World-size-aware scaling

In DDP/FSDP, gradients are averaged across workers, so batch-dependent scale factors need to account for the full effective batch: local_batch × world_size × grad_accumulation_steps. Call init_world_size_cache() before torch.compile to cache the world size as a plain int (avoiding graph breaks from process group calls).

u-μP optimizers with FSDP2 support

Provides AdamW, Muon, Dion2, and NorMuon wrappers that apply per-parameter u-μP LR scaling. Metadata (fan-in, fan-out, mup_type) is cached by parameter name before FSDP wrapping, since FSDP2 replaces parameter tensors with DTensors. Includes Polar Express orthogonalization as an alternative to Newton-Schulz for NorMuon (also used in Karpathy's nanochat).

Distributed training setup

When training with DDP/FSDP, three caching calls must happen in a specific order:

import dd_unit_scaling as uu

# 1. Build model (param shapes are still intact here)
model = MyModel()

# 2. Cache μP metadata BEFORE FSDP wrapping
uu.cache_fan_values(model.named_parameters())

# 3. Apply FSDP/DDP wrapping (replaces param tensors with DTensors)
fully_shard(model)

# 4. Set world size and grad accumulation BEFORE torch.compile
#    (these become plain ints that dynamo can trace without graph breaks)
uu.init_world_size_cache(world_size=dist.get_world_size())
uu.set_grad_accumulation_steps(accum_steps)

# 5. Compile
model = torch.compile(model)

# 6. Create optimizers (they read the cached metadata)
opt = uu.AdamW(bias_params, lr=1e-3)
opt_muon = uu.Dion2(weight_params, lr=0.02)
  • cache_fan_values must happen before FSDP wrapping because FSDP replaces parameter tensors with DTensors, changing their shapes. The cache stores fan-in/fan-out/mup_type by parameter name so the optimizers can look them up later.
  • init_world_size_cache and set_grad_accumulation_steps must happen before torch.compile so the values are baked in as plain ints. If called after compile, they cause graph breaks.

Setting world_size by parallelism strategy

The goal is to make each GPU's local input.numel() reflect the global batch — as if all data were on a single device. The effective global batch is local_batch × world_size × grad_accumulation_steps.

world_size should be the product of all parallelism dimensions where ranks process different data:

Dimension Different data? Counts toward world_size?
DP (data parallel, incl. FSDP) Yes — each rank sees a different batch Yes
CP (context parallel) Yes — each rank sees different sequence chunks Yes
TP (tensor parallel) No — ranks split the same input across heads/features No
SP (sequence parallel) No — activation sharding within TP groups No
PP (pipeline parallel) No — ranks process different stages of the same micro-batch No
EP (expert parallel) No — same batch, different experts No

world_size = dp × cp (where dp is the total data-parallel degree, whether DDP, FSDP, or HSDP).

# Examples
uu.init_world_size_cache(dp)                # DDP or FSDP only
uu.init_world_size_cache(dp * cp)           # DP + CP
uu.init_world_size_cache(dp * cp)           # DP + CP + TP + PP (TP/PP don't count)
# Single GPU — no call needed (defaults to 1)

If you use gradient accumulation, also set:

uu.set_grad_accumulation_steps(accum_steps)  # defaults to 1

Loss normalization in distributed training

When computing weighted loss averages in DDP/FSDP, apply a world-size correction to maintain hyperparameter stability across different world sizes. After all-reducing the total weight across ranks, divide it by world_size:

loss = loss.sum() / (total_weight / world_size)  # equivalent to: loss * world_size / total_weight

This compensates for DDP's gradient averaging and ensures world-size invariance — your learning rate and other hyperparameters remain stable regardless of how many GPUs you train on.

Design note: sequence-length-independent scaling in Toto 2.0

This isn't a feature of dd-unit-scaling itself, but a design choice in Toto 2.0 worth calling out. Toto 2.0 uses unscaled F.scaled_dot_product_attention (PyTorch native SDPA) instead of unit-scaled SDPA, so no scale factors depend on sequence length — making the model compatible with KV-cache inference. The resulting attn/MLP variance imbalance is compensated via residual_attn_ratio = sqrt(S / log(S)) (where S = context_length / patch_size), which adjusts the residual tau values so that attention branches get proportionally more weight. residual_mult is set to 0.75 (the unit_scaling default is 1.0).

Requirements

  • Python >= 3.12
  • PyTorch >= 2.4.0
  • unit-scaling >= 0.2.0
  • dion (optional, for Muon/Dion2/NorMuon)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dd_unit_scaling-0.1.0.tar.gz (15.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

dd_unit_scaling-0.1.0-py3-none-any.whl (14.9 kB view details)

Uploaded Python 3

File details

Details for the file dd_unit_scaling-0.1.0.tar.gz.

File metadata

  • Download URL: dd_unit_scaling-0.1.0.tar.gz
  • Upload date:
  • Size: 15.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for dd_unit_scaling-0.1.0.tar.gz
Algorithm Hash digest
SHA256 126b6f740f10636aaabb4d237b2d868ac3a131f62d3ddecb589347d2fa6b1970
MD5 a09f32acd1ce68aeda1989e900bcb8a7
BLAKE2b-256 af8cb269c756abf79a37458a87aba450b2d8779ac8c3d00495481e3a812a392f

See more details on using hashes here.

Provenance

The following attestation bundles were made for dd_unit_scaling-0.1.0.tar.gz:

Publisher: pypi-publish-dd-unit-scaling.yml on DataDog/toto

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file dd_unit_scaling-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for dd_unit_scaling-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 701b1fa8c5602e51e3668fc45d4e27b50b34aebf8189a772eede463d3cc1bc3a
MD5 3c8c0af8e89375231080f90ecfe54f3f
BLAKE2b-256 66f20f19c1ffcf9fc2e192ed0b2186a652956705fb60b026dfab6ea23f5b9ea2

See more details on using hashes here.

Provenance

The following attestation bundles were made for dd_unit_scaling-0.1.0-py3-none-any.whl:

Publisher: pypi-publish-dd-unit-scaling.yml on DataDog/toto

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page