Skip to main content

A stateful optimizer library for JAX

Project description

opax

opax is an optimizer library for JAX. It is a reimplementation of optax using PAX's stateful module.

Installation

To install the latest version:

pip3 install git+https://github.com/ntt123/opax.git

Getting started

To create an optimizer:

import opax
optimizer = opax.adam(1e-4).init(parameters)

Note: parameters is a pytree of trainable parameters.

To update parameters:

updates, optimizer = opax.transform_gradients(gradients, optimizer, parameters)
parameters = opax.apply_updates(parameters, updates)

Note: gradients has the same treedef as parameters.

The opax.chain function

opax follows optax's philosophy in combining GradientTransformation's together with opax.chain function:

optimizer = opax.chain(
    opax.clip_by_global_norm(1.0),
    opax.scale_by_adam(),
    opax.scale(1e-4),
).init(parameters)

Learning rate schedule

opax supports learning rate scheduling with opax.scale_by_schedule.

def staircase_schedule_fn(step: jnp.ndarray):
    p = jnp.floor(step.astype(jnp.float32) / 1000)
    return jnp.exp2(-p)

optimizer = opax.chain(
    opax.clip_by_global_norm(1.0),
    opax.scale_by_adam(),
    opax.scale_by_schedule(staircase_schedule_fn),
).init(parameters)

Utilities

To get the current global norm:

print(optimizer[0].global_norm)

Note: global_norm is a property of ClipByGlobalNorm class.

To get the current learning rate:

print(optimizer[-1].learning_rate)

Note: learning_rate is a property of ScaleBySchedule class.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

opax-0.2.8.tar.gz (6.5 kB view details)

Uploaded Source

Built Distribution

opax-0.2.8-py3-none-any.whl (7.5 kB view details)

Uploaded Python 3

File details

Details for the file opax-0.2.8.tar.gz.

File metadata

  • Download URL: opax-0.2.8.tar.gz
  • Upload date:
  • Size: 6.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.6.0 importlib_metadata/4.8.2 pkginfo/1.8.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.10.0

File hashes

Hashes for opax-0.2.8.tar.gz
Algorithm Hash digest
SHA256 11d5359ed05c183736149b3057d44936babcc18f90455dffaa032c3a24d0b235
MD5 8effe5a02378799a03c44779496f4ff8
BLAKE2b-256 fe653149cc30f097c5ee27f2096230ba1335587db74a2c0279f54422951369f9

See more details on using hashes here.

File details

Details for the file opax-0.2.8-py3-none-any.whl.

File metadata

  • Download URL: opax-0.2.8-py3-none-any.whl
  • Upload date:
  • Size: 7.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.6.0 importlib_metadata/4.8.2 pkginfo/1.8.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.10.0

File hashes

Hashes for opax-0.2.8-py3-none-any.whl
Algorithm Hash digest
SHA256 1d3a0f15e0702b557a3ab28dfe96119aa6381c75b931888b8c1025670894ccef
MD5 70c483db33986515afb81b9fc499a467
BLAKE2b-256 bbed947a3a7956a70c54bab5766b7c4a11323141fd13b627d53df86baeb12bf8

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page