Skip to main content

A stateful optimizer library for JAX

Project description

opax

opax is an optimizer library for JAX. It is a reimplementation of optax using PAX's stateful module.

Installation

To install the latest version:

pip3 install git+https://github.com/ntt123/opax.git

Getting started

To create an optimizer:

import opax
optimizer = opax.adam(1e-4).init(parameters)

Note: parameters is a pytree of trainable parameters.

To update parameters:

updates, optimizer = opax.transform_gradients(gradients, optimizer, parameters)
parameters = opax.apply_updates(parameters, updates)

Note: gradients has the same treedef as parameters.

The opax.chain function

opax follows optax's philosophy in combining GradientTransformation's together with opax.chain function:

optimizer = opax.chain(
    opax.clip_by_global_norm(1.0),
    opax.scale_by_adam(),
    opax.scale(1e-4),
).init(parameters)

Learning rate schedule

opax supports learning rate scheduling with opax.scale_by_schedule.

def staircase_schedule_fn(step: jnp.ndarray):
    p = jnp.floor(step.astype(jnp.float32) / 1000)
    return jnp.exp2(-p)

optimizer = opax.chain(
    opax.clip_by_global_norm(1.0),
    opax.scale_by_adam(),
    opax.scale_by_schedule(staircase_schedule_fn),
).init(parameters)

Utilities

To get the current global norm:

print(optimizer[0].global_norm)

Note: global_norm is a property of ClipByGlobalNorm class.

To get the current learning rate:

print(optimizer[-1].learning_rate)

Note: learning_rate is a property of ScaleBySchedule class.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

opax-0.2.7.tar.gz (6.4 kB view details)

Uploaded Source

Built Distribution

opax-0.2.7-py3-none-any.whl (7.3 kB view details)

Uploaded Python 3

File details

Details for the file opax-0.2.7.tar.gz.

File metadata

  • Download URL: opax-0.2.7.tar.gz
  • Upload date:
  • Size: 6.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.6.0 importlib_metadata/4.8.2 pkginfo/1.8.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.10.0

File hashes

Hashes for opax-0.2.7.tar.gz
Algorithm Hash digest
SHA256 e8fd91c3b10dea955d3007f52931266f1fb6e61113d75ad79c9d632d8c5cd856
MD5 cbb16a35d128868d52d1bd929be124ec
BLAKE2b-256 77b1ca80284684b2fa2b8c2617171c6ad67368fe8818d18d001ad072dee7a433

See more details on using hashes here.

File details

Details for the file opax-0.2.7-py3-none-any.whl.

File metadata

  • Download URL: opax-0.2.7-py3-none-any.whl
  • Upload date:
  • Size: 7.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.6.0 importlib_metadata/4.8.2 pkginfo/1.8.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.10.0

File hashes

Hashes for opax-0.2.7-py3-none-any.whl
Algorithm Hash digest
SHA256 72a12f51692d074b3f8b277a8fe0dd9dea456345cb7ff3d41cb79ba3f5be37a7
MD5 9474d3624caf68c267268e5fe594f3bb
BLAKE2b-256 2b44c93d0fbe0334d8a1d2d8ea52c66f00d13bdb987a27700103c660efe86312

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page