Skip to main content

A simple port of μP to Haiku/JAX.

Project description

MUP for Haiku

This is a (very preliminary) port of Yang and Hu et al.'s μP repo to Haiku and JAX. It's not feature complete, and I'm very open to suggestions on improving the usability.

Learning rate demo

These plots show the evolution of the optimal learning rate for a 3-hidden-layer MLP on MNIST, trained for 10 epochs (5 trials per lr/width combination).

With standard parameterization, the learning rate optimum continues changing as the width increases:

With μP, the learning rate optimum stabilizes as width increases:

Usage

from functools import partial

import jax
import jax.numpy as jnp
import haiku as hk
from optax import adam, chain

from hk_mup import apply_mup, Mup, Readout

class MyModel(hk.Module):
    def __init__(self, width, n_classes=10):
        super().__init__(name='model')
        self.width = width
        self.n_classes = n_classes

    def __call__(self, x):
        x = hk.Linear(self.width)(x)
        x = jax.nn.relu(x)
        return Readout(2)(x) # 1. Replace output layer with Readout layer

def fn(x, width=100):
    with apply_mup(): # 2. Modify parameter creation with apply_mup()
        return MyModel(width)(x)

mup = Mup()

init_input = jnp.zeros(123)
base_model = hk.transform(partial(fn, width=1))

with mup.init_base(): # 3. Use this context manager when initializing the base model
    hk.init(fn, jax.random.PRNGKey(0), init_input) 

model = hk.transform(fn)

with mup.init_target(): # 4. Use this context manager when initializng the target model
    params = model.init(jax.random.PRNGKey(0), init_input)

model = mup.wrap_model(model) # 5. Modify your model with Mup

optimizer = optax.adam(3e-4)
optimizer = mup.wrap_optimizer(optimizer, adam=True) # 6. Use wrap_optimizer to get layer specific learning rates

# Now the model can be trained as normal

Summary

  1. Replace output layers with Readout layers
  2. Modify parameter creation with the apply_mup() context manager
  3. Initialize a base model inside a Mup.init_base() context
  4. Initialize the target model inside a Mup.init_target() context
  5. Wrap the model with Mup.wrap_model
  6. Wrap optimizer with Mup.wrap_optimizer

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

haiku-mup-0.1.0.tar.gz (5.7 kB view details)

Uploaded Source

Built Distribution

haiku_mup-0.1.0-py3-none-any.whl (6.5 kB view details)

Uploaded Python 3

File details

Details for the file haiku-mup-0.1.0.tar.gz.

File metadata

  • Download URL: haiku-mup-0.1.0.tar.gz
  • Upload date:
  • Size: 5.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.1.13 CPython/3.10.4 Linux/5.17.2-arch3-1

File hashes

Hashes for haiku-mup-0.1.0.tar.gz
Algorithm Hash digest
SHA256 1955d5a6c1a9c5c497653ca23033008dea2c373686daa3f0d5313b09c3ef473b
MD5 2e0282fa2566387b3d22be5bd1a83cfe
BLAKE2b-256 169050807c1c9b6806c97628f034f7be790bfb0b1c2e1171f55869b669e0095f

See more details on using hashes here.

File details

Details for the file haiku_mup-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: haiku_mup-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 6.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.1.13 CPython/3.10.4 Linux/5.17.2-arch3-1

File hashes

Hashes for haiku_mup-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 b83b5ebc1f2b65d566c63771926cf63d94ab5b55f2164ed9e99e167dd2c71d28
MD5 15113bac1ed7eb74f68f18706f3fcb0f
BLAKE2b-256 d0d1ab50050996d3439da3bfb2e92a6c39a269b1c0680224cfe8c6c58e368109

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page