Skip to main content

A stateful optimizer library for JAX

Project description

opax

opax is an optimizer library for JAX. It is a reimplementation of optax using PAX's stateful module.

Installation

To install the latest version:

pip3 install git+https://github.com/ntt123/opax.git

Getting started

To create an optimizer:

import opax
optimizer = opax.adam(1e-4)(parameters)

Note: parameters is a pytree of trainable parameters.

To update parameters:

updates, optimizer = opax.transform_gradients(gradients, optimizer, parameters)
parameters = opax.apply_updates(parameters, updates)

Note: gradients has the same treedef as parameters.

The opax.chain function

opax follows optax's philosophy in combining GradientTransformation's together with opax.chain function:

optimizer = opax.chain(
    opax.clip_by_global_norm(1.0),
    opax.scale_by_adam(),
    opax.scale(1e-4),
)(parameters)

Learning rate schedule

opax supports learning rate scheduling with opax.scale_by_schedule.

def staircase_schedule_fn(step: jnp.ndarray):
    p = jnp.floor(step.astype(jnp.float32) / 1000)
    return jnp.exp2(-p)

optimizer = opax.chain(
    opax.clip_by_global_norm(1.0),
    opax.scale_by_adam(),
    opax.scale_by_schedule(staircase_schedule_fn),
)(parameters)

Utilities

To get the current global norm:

print(optimizer[0].global_norm)

Note: global_norm is a property of ClipByGlobalNorm class.

To get the current learning rate:

print(optimizer[-1].learning_rate)

Note: learning_rate is a property of ScaleBySchedule class.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

opax-0.2.5.tar.gz (6.5 kB view details)

Uploaded Source

Built Distribution

opax-0.2.5-py3-none-any.whl (7.4 kB view details)

Uploaded Python 3

File details

Details for the file opax-0.2.5.tar.gz.

File metadata

  • Download URL: opax-0.2.5.tar.gz
  • Upload date:
  • Size: 6.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.6.0 importlib_metadata/4.8.2 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.10.0

File hashes

Hashes for opax-0.2.5.tar.gz
Algorithm Hash digest
SHA256 3f287b2b040570ea4ad982a56832b74df9dd9d291effaf16937c320920c6f28c
MD5 8e732a6e89125cd1ce1a37375d498d9e
BLAKE2b-256 e643b38b222ef4f004e391f7927338cc164e680dc5445301dae67c9f0614b609

See more details on using hashes here.

File details

Details for the file opax-0.2.5-py3-none-any.whl.

File metadata

  • Download URL: opax-0.2.5-py3-none-any.whl
  • Upload date:
  • Size: 7.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.6.0 importlib_metadata/4.8.2 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.10.0

File hashes

Hashes for opax-0.2.5-py3-none-any.whl
Algorithm Hash digest
SHA256 b17589acecb50b6bd116d5e50db74ec9db23d9a7d4d75f3baf65bbfdb159ee6c
MD5 b69a9d7509af17566c9d6128b4379de4
BLAKE2b-256 e48725ec5cb400d4a99baa87af5703ca057d321996f319185337b62cbb1d583d

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page