A stateful optimizer library for JAX
Project description
opax
opax is an optimizer library for JAX. It is a reimplementation of optax using PAX's stateful module.
Installation
To install the latest version:
pip3 install git+https://github.com/ntt123/opax.git
Getting started
To create an optimizer:
import opax
optimizer = opax.adam(1e-4).init(parameters)
Note: parameters is a pytree of trainable parameters.
To update parameters:
updates, optimizer = opax.transform_gradients(gradients, optimizer, parameters)
parameters = opax.apply_updates(parameters, updates)
Note: gradients has the same treedef as parameters.
The opax.chain function
opax follows optax's philosophy in combining GradientTransformation's together with opax.chain function:
optimizer = opax.chain(
opax.clip_by_global_norm(1.0),
opax.scale_by_adam(),
opax.scale(1e-4),
).init(parameters)
Learning rate schedule
opax supports learning rate scheduling with opax.scale_by_schedule.
def staircase_schedule_fn(step: jnp.ndarray):
p = jnp.floor(step.astype(jnp.float32) / 1000)
return jnp.exp2(-p)
optimizer = opax.chain(
opax.clip_by_global_norm(1.0),
opax.scale_by_adam(),
opax.scale_by_schedule(staircase_schedule_fn),
).init(parameters)
Utilities
To get the current global norm:
print(optimizer[0].global_norm)
Note: global_norm is a property of ClipByGlobalNorm class.
To get the current learning rate:
print(optimizer[-1].learning_rate)
Note: learning_rate is a property of ScaleBySchedule class.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file opax-0.2.11.tar.gz.
File metadata
- Download URL: opax-0.2.11.tar.gz
- Upload date:
- Size: 7.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.10.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e7d9a90dc95e7c1bd91945e1448b2eef549a751136df323933b7f2864034e1bf
|
|
| MD5 |
49ac462579eb5b36e7c57dae317ad6cf
|
|
| BLAKE2b-256 |
2159ce6da3441cc7b53c06c516b80bd6f657b110ce981a87a13980ccbc66d60b
|
File details
Details for the file opax-0.2.11-py3-none-any.whl.
File metadata
- Download URL: opax-0.2.11-py3-none-any.whl
- Upload date:
- Size: 8.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.10.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
efb65dfef1ace7fdad205a4310db0401277958071c047663dc0efdd42fd3319d
|
|
| MD5 |
4a6ab85133f5f5be2a0ef43c8b647257
|
|
| BLAKE2b-256 |
37bcc5e94f4ba25a0dfd8ab328371bd1175e26d33342827347d5ecd3d94da0f7
|