Skip to main content

A stateful optimizer library for JAX

Project description

opax

opax is an optimizer library for JAX. It is a reimplementation of optax using PAX's stateful module.

Installation

To install the latest version:

pip3 install git+https://github.com/ntt123/opax.git

Getting started

To create an optimizer:

import opax
optimizer = opax.adam(1e-4)(parameters)

Note: parameters is a pytree of trainable parameters.

To update parameters:

updates, optimizer = opax.transfrom_gradients(gradients, optimizer, parameters)
parameters = opax.apply_updates(parameters, updates)

Note: gradients has the same treedef as parameters.

The opax.chain function

opax follows optax's philosophy in combining GradientTransformation's together with opax.chain function:

optimizer = opax.chain(
    opax.clip_by_global_norm(1.0),
    opax.scale_by_adam(),
    opax.scale(1e-4),
)(parameters)

Learning rate schedule

opax supports learning rate scheduling with opax.scale_by_schedule.

def staircase_schedule_fn(step: jnp.ndarray):
    p = jnp.floor(step.astype(jnp.float32) / 1000)
    return jnp.exp2(-p)

optimizer = opax.chain(
    opax.clip_by_global_norm(1.0),
    opax.scale_by_adam(),
    opax.scale_by_schedule(staircase_schedule_fn),
)(parameters)

Utilities

To get the current global norm:

print(optimizer[0].global_norm)

Note: global_norm is a property of ClipByGlobalNorm class.

To get the current learning rate:

print(optimizer[-1].learning_rate)

Note: learning_rate is a property of ScaleBySchedule class.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

opax-0.2.3rc1.tar.gz (6.3 kB view details)

Uploaded Source

Built Distribution

opax-0.2.3rc1-py3-none-any.whl (7.3 kB view details)

Uploaded Python 3

File details

Details for the file opax-0.2.3rc1.tar.gz.

File metadata

  • Download URL: opax-0.2.3rc1.tar.gz
  • Upload date:
  • Size: 6.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.2

File hashes

Hashes for opax-0.2.3rc1.tar.gz
Algorithm Hash digest
SHA256 111eefd3d2fa51186d558feb1da2e9eb6edf5a364c4131dfdd9be31e0849f751
MD5 d957fbf328a73d8c2e888e372d0cd16b
BLAKE2b-256 f79c5b66f16cecf8c294a2c10ade731be8733a9ced0bd62d4cba17035ff019b5

See more details on using hashes here.

File details

Details for the file opax-0.2.3rc1-py3-none-any.whl.

File metadata

  • Download URL: opax-0.2.3rc1-py3-none-any.whl
  • Upload date:
  • Size: 7.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.2

File hashes

Hashes for opax-0.2.3rc1-py3-none-any.whl
Algorithm Hash digest
SHA256 0efb4bb1dc708600a21c8d8c38503a39206e75a900b45de09beb325a0bd4b1ab
MD5 6925da08dfcdce475e25b311dae32add
BLAKE2b-256 e6ab4679c1e6e47a50a2ddaa9c1edd11ac92ae26217b562aa59dbbc353efb74a

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page