A simple port of μP to Haiku/JAX.
Project description
MUP for Haiku
This is a (very preliminary) port of Yang and Hu et al.'s μP repo to Haiku and JAX. It's not feature complete, and I'm very open to suggestions on improving the usability.
Learning rate demo
These plots show the evolution of the optimal learning rate for a 3-hidden-layer MLP on MNIST, trained for 10 epochs (5 trials per lr/width combination).
With standard parameterization, the learning rate optimum continues changing as the width increases:
With μP, the learning rate optimum stabilizes as width increases:
Usage
from functools import partial
import jax
import jax.numpy as jnp
import haiku as hk
from optax import adam, chain
from hk_mup import apply_mup, Mup, Readout
class MyModel(hk.Module):
def __init__(self, width, n_classes=10):
super().__init__(name='model')
self.width = width
self.n_classes = n_classes
def __call__(self, x):
x = hk.Linear(self.width)(x)
x = jax.nn.relu(x)
return Readout(2)(x) # 1. Replace output layer with Readout layer
def fn(x, width=100):
with apply_mup(): # 2. Modify parameter creation with apply_mup()
return MyModel(width)(x)
mup = Mup()
init_input = jnp.zeros(123)
base_model = hk.transform(partial(fn, width=1))
with mup.init_base(): # 3. Use this context manager when initializing the base model
hk.init(fn, jax.random.PRNGKey(0), init_input)
model = hk.transform(fn)
with mup.init_target(): # 4. Use this context manager when initializng the target model
params = model.init(jax.random.PRNGKey(0), init_input)
model = mup.wrap_model(model) # 5. Modify your model with Mup
optimizer = optax.adam(3e-4)
optimizer = mup.wrap_optimizer(optimizer, adam=True) # 6. Use wrap_optimizer to get layer specific learning rates
# Now the model can be trained as normal
Summary
- Replace output layers with
Readout
layers - Modify parameter creation with the
apply_mup()
context manager - Initialize a base model inside a
Mup.init_base()
context - Initialize the target model inside a
Mup.init_target()
context - Wrap the model with
Mup.wrap_model
- Wrap optimizer with
Mup.wrap_optimizer
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
haiku-mup-0.1.0.tar.gz
(5.7 kB
view details)
Built Distribution
File details
Details for the file haiku-mup-0.1.0.tar.gz
.
File metadata
- Download URL: haiku-mup-0.1.0.tar.gz
- Upload date:
- Size: 5.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.1.13 CPython/3.10.4 Linux/5.17.2-arch3-1
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1955d5a6c1a9c5c497653ca23033008dea2c373686daa3f0d5313b09c3ef473b |
|
MD5 | 2e0282fa2566387b3d22be5bd1a83cfe |
|
BLAKE2b-256 | 169050807c1c9b6806c97628f034f7be790bfb0b1c2e1171f55869b669e0095f |
File details
Details for the file haiku_mup-0.1.0-py3-none-any.whl
.
File metadata
- Download URL: haiku_mup-0.1.0-py3-none-any.whl
- Upload date:
- Size: 6.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.1.13 CPython/3.10.4 Linux/5.17.2-arch3-1
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | b83b5ebc1f2b65d566c63771926cf63d94ab5b55f2164ed9e99e167dd2c71d28 |
|
MD5 | 15113bac1ed7eb74f68f18706f3fcb0f |
|
BLAKE2b-256 | d0d1ab50050996d3439da3bfb2e92a6c39a269b1c0680224cfe8c6c58e368109 |