Skip to main content

Compile-friendly, world-size-aware unit scaling

Project description

dd-unit-scaling

A production-ready, thin wrapper around graphcore-research/unit-scaling that makes u-μP work reliably with torch.compile, FSDP2, and distributed training at scale. Used to train Toto 2.0.

Background

Why u-μP?

Toto 2.0 uses u-μP (Unit-Scaled Maximal Update Parameterization) rather than standard μP. Standard μP requires running a base model to compute reference scales, then transferring those scales to the target model size. u-μP eliminates this — scaling factors are derived directly from layer fan-in/fan-out, so there is no base model metadata to manage. u-μP has also been shown to converge better on decoder-only transformers.

Why a thin wrapper?

The upstream unit_scaling library implements u-μP correctly in principle, but has accumulated small issues that prevent it from working with torch.compile and distributed training at scale. Rather than fork the entire library, dd-unit-scaling re-exports everything from upstream and overrides only the broken pieces. This keeps the surface area minimal and makes it easy to drop if upstream fixes land.

Installation

pip install dd-unit-scaling

For Muon-family optimizers, also install dion:

pip install git+https://github.com/microsoft/dion.git

Usage

For a conceptual guide on applying u-μP to your model, see the unit-scaling demo notebook.

Drop-in replacement for unit_scaling:

import dd_unit_scaling as uu
from dd_unit_scaling import functional as U

out = U.linear(x, weight, bias)
normed = U.rms_norm(x, normalized_shape=(dim,), weight=w)

Optimizers:

# AdamW for bias/norm/output params
opt_adam = uu.AdamW(bias_params, lr=1e-3)

# Muon-family for weight params
opt_muon = uu.Muon(weight_params, lr=0.02)
opt_dion2 = uu.Dion2(weight_params, lr=0.02)
opt_normuon = uu.NorMuon(weight_params, lr=0.02, use_polar_express=True)

Features

torch.compile fixes

Upstream unit_scaling has a collection of issues that break torch.compileisinstance checks on fx.proxy.Proxy, ints where dynamo expects floats, and other tracing-unfriendly patterns. This package fixes all of them. Core primitives (scale_fwd, scale_bwd) are rewritten using PyTorch 2.x's setup_context pattern, and residual ops and activations are reimplemented on top of them.

World-size-aware scaling

In DDP/FSDP, gradients are averaged across workers, so batch-dependent scale factors need to account for the full effective batch: local_batch × world_size × grad_accumulation_steps. Call init_world_size_cache() before torch.compile to cache the world size as a plain int (avoiding graph breaks from process group calls).

u-μP optimizers with FSDP2 support

Provides AdamW, Muon, Dion2, and NorMuon wrappers that apply per-parameter u-μP LR scaling. Metadata (fan-in, fan-out, mup_type) is cached by parameter name before FSDP wrapping, since FSDP2 replaces parameter tensors with DTensors. Includes Polar Express orthogonalization as an alternative to Newton-Schulz for NorMuon (also used in Karpathy's nanochat).

Distributed training setup

When training with DDP/FSDP, three caching calls must happen in a specific order:

import dd_unit_scaling as uu

# 1. Build model (param shapes are still intact here)
model = MyModel()

# 2. Cache μP metadata BEFORE FSDP wrapping
uu.cache_fan_values(model.named_parameters())

# 3. Apply FSDP/DDP wrapping (replaces param tensors with DTensors)
fully_shard(model)

# 4. Set world size and grad accumulation BEFORE torch.compile
#    (these become plain ints that dynamo can trace without graph breaks)
uu.init_world_size_cache(world_size=dist.get_world_size())
uu.set_grad_accumulation_steps(accum_steps)

# 5. Compile
model = torch.compile(model)

# 6. Create optimizers (they read the cached metadata)
opt = uu.AdamW(bias_params, lr=1e-3)
opt_muon = uu.Dion2(weight_params, lr=0.02)
  • cache_fan_values must happen before FSDP wrapping because FSDP replaces parameter tensors with DTensors, changing their shapes. The cache stores fan-in/fan-out/mup_type by parameter name so the optimizers can look them up later.
  • init_world_size_cache and set_grad_accumulation_steps must happen before torch.compile so the values are baked in as plain ints. If called after compile, they cause graph breaks.

Setting world_size by parallelism strategy

The goal is to make each GPU's local input.numel() reflect the global batch — as if all data were on a single device. The effective global batch is local_batch × world_size × grad_accumulation_steps.

world_size should be the product of all parallelism dimensions where ranks process different data:

Dimension Different data? Counts toward world_size?
DP (data parallel, incl. FSDP) Yes — each rank sees a different batch Yes
CP (context parallel) Yes — each rank sees different sequence chunks Yes
TP (tensor parallel) No — ranks split the same input across heads/features No
SP (sequence parallel) No — activation sharding within TP groups No
PP (pipeline parallel) No — ranks process different stages of the same micro-batch No
EP (expert parallel) No — same batch, different experts No

world_size = dp × cp (where dp is the total data-parallel degree, whether DDP, FSDP, or HSDP).

# Examples
uu.init_world_size_cache(dp)                # DDP or FSDP only
uu.init_world_size_cache(dp * cp)           # DP + CP
uu.init_world_size_cache(dp * cp)           # DP + CP + TP + PP (TP/PP don't count)
# Single GPU — no call needed (defaults to 1)

If you use gradient accumulation, also set:

uu.set_grad_accumulation_steps(accum_steps)  # defaults to 1

Loss normalization in distributed training

When computing weighted loss averages in DDP/FSDP, apply a world-size correction to maintain hyperparameter stability across different world sizes. After all-reducing the total weight across ranks, divide it by world_size:

loss = loss.sum() / (total_weight / world_size)  # equivalent to: loss * world_size / total_weight

This compensates for DDP's gradient averaging and ensures world-size invariance — your learning rate and other hyperparameters remain stable regardless of how many GPUs you train on.

Design note: sequence-length-independent scaling in Toto 2.0

This isn't a feature of dd-unit-scaling itself, but a design choice in Toto 2.0 worth calling out. Toto 2.0 uses unscaled F.scaled_dot_product_attention (PyTorch native SDPA) instead of unit-scaled SDPA, so no scale factors depend on sequence length — making the model compatible with KV-cache inference. The resulting attn/MLP variance imbalance is compensated via residual_attn_ratio = sqrt(S / log(S)) (where S = context_length / patch_size), which adjusts the residual tau values so that attention branches get proportionally more weight. residual_mult is set to 0.75 (the unit_scaling default is 1.0).

Requirements

  • Python >= 3.12
  • PyTorch >= 2.4.0
  • unit-scaling >= 0.2.0
  • dion (optional, for Muon/Dion2/NorMuon)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

timecopilot_dd_unit_scaling-0.1.0.tar.gz (15.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

timecopilot_dd_unit_scaling-0.1.0-py3-none-any.whl (15.0 kB view details)

Uploaded Python 3

File details

Details for the file timecopilot_dd_unit_scaling-0.1.0.tar.gz.

File metadata

File hashes

Hashes for timecopilot_dd_unit_scaling-0.1.0.tar.gz
Algorithm Hash digest
SHA256 249a126ca734ead5a7a7a9c48adf494d6d7e2f481c65efb9f7bd9140051fb44b
MD5 3f5d57229f12cfbd6a6fe4ec2528e61e
BLAKE2b-256 b390465f1d79e76698549679dab143d98730de947f3a8366cdde55d9e6bbeab0

See more details on using hashes here.

File details

Details for the file timecopilot_dd_unit_scaling-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for timecopilot_dd_unit_scaling-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 fd703cec2c07c02bf183e61908b0b8d5386d9bd5c6678bfe49ff162bff9180f5
MD5 fb90aad630f8b407776a7d175837ccd6
BLAKE2b-256 fb5704e5e2ecf3932dfc50008c59e296c240c0f51b6bc13761ac9bb009273d20

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page