Skip to main content

A stateful optimizer library for JAX

Project description

opax

opax is an optimizer library for JAX. It is a reimplementation of optax using PAX's stateful module.

Installation

To install the latest version:

pip3 install git+https://github.com/ntt123/opax.git

Getting started

To create an optimizer:

import opax
optimizer = opax.adam(1e-4)(parameters)

Note: parameters is a pytree of trainable parameters.

To update parameters:

updates, optimizer = opax.transform_gradients(gradients, optimizer, parameters)
parameters = opax.apply_updates(parameters, updates)

Note: gradients has the same treedef as parameters.

The opax.chain function

opax follows optax's philosophy in combining GradientTransformation's together with opax.chain function:

optimizer = opax.chain(
    opax.clip_by_global_norm(1.0),
    opax.scale_by_adam(),
    opax.scale(1e-4),
)(parameters)

Learning rate schedule

opax supports learning rate scheduling with opax.scale_by_schedule.

def staircase_schedule_fn(step: jnp.ndarray):
    p = jnp.floor(step.astype(jnp.float32) / 1000)
    return jnp.exp2(-p)

optimizer = opax.chain(
    opax.clip_by_global_norm(1.0),
    opax.scale_by_adam(),
    opax.scale_by_schedule(staircase_schedule_fn),
)(parameters)

Utilities

To get the current global norm:

print(optimizer[0].global_norm)

Note: global_norm is a property of ClipByGlobalNorm class.

To get the current learning rate:

print(optimizer[-1].learning_rate)

Note: learning_rate is a property of ScaleBySchedule class.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

opax-0.2.4.tar.gz (6.5 kB view details)

Uploaded Source

Built Distribution

opax-0.2.4-py3-none-any.whl (7.3 kB view details)

Uploaded Python 3

File details

Details for the file opax-0.2.4.tar.gz.

File metadata

  • Download URL: opax-0.2.4.tar.gz
  • Upload date:
  • Size: 6.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.7

File hashes

Hashes for opax-0.2.4.tar.gz
Algorithm Hash digest
SHA256 0bd3d0995c75190a906f772116a643d0dedf6b988c67faca1c390509a318e3e6
MD5 570c60d3c8516966bac7ed1e78e762af
BLAKE2b-256 f61b48a8251036a1552de795b8a623f528f27c55f27bec80c6e802eef55e4651

See more details on using hashes here.

File details

Details for the file opax-0.2.4-py3-none-any.whl.

File metadata

  • Download URL: opax-0.2.4-py3-none-any.whl
  • Upload date:
  • Size: 7.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.7

File hashes

Hashes for opax-0.2.4-py3-none-any.whl
Algorithm Hash digest
SHA256 d7ba07bf5efa86ee5449dcd736bc51fc1a8247dfa2f98ecf7212f7a53f620440
MD5 739e345d1c87495554d44c9dd9d094a9
BLAKE2b-256 e3bc76773bf80a325d9efe6437dd16a570636a59a97a9cb7d5de0aad5ba97e70

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page