A stateful optimizer library for JAX
Project description
opax
opax
is an optimizer library for JAX. It is a reimplementation of optax using PAX's stateful module.
Installation
To install the latest version:
pip3 install git+https://github.com/ntt123/opax.git
Getting started
To create an optimizer:
import opax
optimizer = opax.adam(1e-4).init(parameters)
Note: parameters
is a pytree of trainable parameters.
To update parameters:
updates, optimizer = opax.transform_gradients(gradients, optimizer, parameters)
parameters = opax.apply_updates(parameters, updates)
Note: gradients
has the same treedef
as parameters
.
The opax.chain
function
opax
follows optax
's philosophy in combining GradientTransformation
's together with opax.chain
function:
optimizer = opax.chain(
opax.clip_by_global_norm(1.0),
opax.scale_by_adam(),
opax.scale(1e-4),
).init(parameters)
Learning rate schedule
opax
supports learning rate scheduling with opax.scale_by_schedule
.
def staircase_schedule_fn(step: jnp.ndarray):
p = jnp.floor(step.astype(jnp.float32) / 1000)
return jnp.exp2(-p)
optimizer = opax.chain(
opax.clip_by_global_norm(1.0),
opax.scale_by_adam(),
opax.scale_by_schedule(staircase_schedule_fn),
).init(parameters)
Utilities
To get the current global norm:
print(optimizer[0].global_norm)
Note: global_norm
is a property of ClipByGlobalNorm
class.
To get the current learning rate:
print(optimizer[-1].learning_rate)
Note: learning_rate
is a property of ScaleBySchedule
class.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file opax-0.2.10.tar.gz
.
File metadata
- Download URL: opax-0.2.10.tar.gz
- Upload date:
- Size: 7.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.10.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 13a0b253e751c02cc675572f1754161e68b0db4e95cc4a6b192da8ce5465e556 |
|
MD5 | 3b84bbf0c2ded9ac9034cf3fa7149c9a |
|
BLAKE2b-256 | 017101bbea8dc9cb672c7899ff73092b1359a8f1caab4ea028b652031e431681 |
File details
Details for the file opax-0.2.10-py3-none-any.whl
.
File metadata
- Download URL: opax-0.2.10-py3-none-any.whl
- Upload date:
- Size: 8.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.10.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 268162bf1f859864d5ad37903ad306f126d737ff6d75d0ee021127b4b6b5f6c8 |
|
MD5 | 12fcb483c30103d24fc3026c108dfe30 |
|
BLAKE2b-256 | 169c119d947b0f8d209627983a6ef590720189afbd3c403619b148d186bf5c0c |