Skip to main content

A stateful optimizer library for JAX

Project description

opax

opax is an optimizer library for JAX. It is a reimplementation of optax using PAX's stateful module.

Installation

To install the latest version:

pip3 install git+https://github.com/ntt123/opax.git

Getting started

To create an optimizer:

import opax
optimizer = opax.adam(1e-4).init(parameters)

Note: parameters is a pytree of trainable parameters.

To update parameters:

updates, optimizer = opax.transform_gradients(gradients, optimizer, parameters)
parameters = opax.apply_updates(parameters, updates)

Note: gradients has the same treedef as parameters.

The opax.chain function

opax follows optax's philosophy in combining GradientTransformation's together with opax.chain function:

optimizer = opax.chain(
    opax.clip_by_global_norm(1.0),
    opax.scale_by_adam(),
    opax.scale(1e-4),
).init(parameters)

Learning rate schedule

opax supports learning rate scheduling with opax.scale_by_schedule.

def staircase_schedule_fn(step: jnp.ndarray):
    p = jnp.floor(step.astype(jnp.float32) / 1000)
    return jnp.exp2(-p)

optimizer = opax.chain(
    opax.clip_by_global_norm(1.0),
    opax.scale_by_adam(),
    opax.scale_by_schedule(staircase_schedule_fn),
).init(parameters)

Utilities

To get the current global norm:

print(optimizer[0].global_norm)

Note: global_norm is a property of ClipByGlobalNorm class.

To get the current learning rate:

print(optimizer[-1].learning_rate)

Note: learning_rate is a property of ScaleBySchedule class.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

opax-0.2.11.tar.gz (7.6 kB view details)

Uploaded Source

Built Distribution

opax-0.2.11-py3-none-any.whl (8.1 kB view details)

Uploaded Python 3

File details

Details for the file opax-0.2.11.tar.gz.

File metadata

  • Download URL: opax-0.2.11.tar.gz
  • Upload date:
  • Size: 7.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.10.5

File hashes

Hashes for opax-0.2.11.tar.gz
Algorithm Hash digest
SHA256 e7d9a90dc95e7c1bd91945e1448b2eef549a751136df323933b7f2864034e1bf
MD5 49ac462579eb5b36e7c57dae317ad6cf
BLAKE2b-256 2159ce6da3441cc7b53c06c516b80bd6f657b110ce981a87a13980ccbc66d60b

See more details on using hashes here.

File details

Details for the file opax-0.2.11-py3-none-any.whl.

File metadata

  • Download URL: opax-0.2.11-py3-none-any.whl
  • Upload date:
  • Size: 8.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.10.5

File hashes

Hashes for opax-0.2.11-py3-none-any.whl
Algorithm Hash digest
SHA256 efb65dfef1ace7fdad205a4310db0401277958071c047663dc0efdd42fd3319d
MD5 4a6ab85133f5f5be2a0ef43c8b647257
BLAKE2b-256 37bcc5e94f4ba25a0dfd8ab328371bd1175e26d33342827347d5ecd3d94da0f7

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page