Skip to main content

A stateful optimizer library for JAX

Project description

opax

opax is an optimizer library for JAX. It is a reimplementation of optax using PAX's stateful module.

Installation

To install the latest version:

pip3 install git+https://github.com/ntt123/opax.git

Getting started

To create an optimizer:

import opax
optimizer = opax.adam(1e-4)(parameters)

Note: parameters is a pytree of trainable parameters.

To update parameters:

updates, optimizer = opax.transfrom_gradients(gradients, optimizer, parameters)
parameters = opax.apply_updates(parameters, updates)

Note: gradients has the same treedef as parameters.

The opax.chain function

opax follows optax's philosophy in combining GradientTransformation's together with opax.chain function:

optimizer = opax.chain(
    opax.clip_by_global_norm(1.0),
    opax.scale_by_adam(),
    opax.scale(1e-4),
)(parameters)

Learning rate schedule

opax supports learning rate scheduling with opax.scale_by_schedule.

def staircase_schedule_fn(step: jnp.ndarray):
    p = jnp.floor(step.astype(jnp.float32) / 1000)
    return jnp.exp2(-p)

optimizer = opax.chain(
    opax.clip_by_global_norm(1.0),
    opax.scale_by_adam(),
    opax.scale_by_schedule(staircase_schedule_fn),
)(parameters)

Utilities

To get the current global norm:

print(optimizer[0].global_norm)

Note: global_norm is a property of ClipByGlobalNorm class.

To get the current learning rate:

print(optimizer[-1].learning_rate)

Note: learning_rate is a property of ScaleBySchedule class.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

opax-0.2.3.tar.gz (6.4 kB view details)

Uploaded Source

Built Distribution

opax-0.2.3-py3-none-any.whl (7.3 kB view details)

Uploaded Python 3

File details

Details for the file opax-0.2.3.tar.gz.

File metadata

  • Download URL: opax-0.2.3.tar.gz
  • Upload date:
  • Size: 6.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.7

File hashes

Hashes for opax-0.2.3.tar.gz
Algorithm Hash digest
SHA256 9a7ceaabc6f92770ed60743f5ea600c27d059ec8b49127b07de6a1391f1a435b
MD5 9cc0333db348fac777aaa90628fed9c0
BLAKE2b-256 e746b5bad08a23a5e0cab38a12c21ee2a5685d95cfcfac488cd9c7ec650e99b0

See more details on using hashes here.

File details

Details for the file opax-0.2.3-py3-none-any.whl.

File metadata

  • Download URL: opax-0.2.3-py3-none-any.whl
  • Upload date:
  • Size: 7.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.7

File hashes

Hashes for opax-0.2.3-py3-none-any.whl
Algorithm Hash digest
SHA256 76ea8de69d81e2cb58308802db5c177ef40a5d76d1c022cfb2657a30b87cb1db
MD5 734e3735f7798ec792efae0bbe0cf6fd
BLAKE2b-256 80d73a60d9f30aea9d00dfffad80ae9703754b2b9c9c889dee068a7d67337315

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page