Skip to main content

A stateful optimizer library for JAX

Project description

opax

opax is an optimizer library for JAX. It is a reimplementation of optax using PAX's stateful module.

Installation

To install the latest version:

pip3 install git+https://github.com/ntt123/opax.git

Getting started

To create an optimizer:

import opax
optimizer = opax.adam(1e-4).init(parameters)

Note: parameters is a pytree of trainable parameters.

To update parameters:

updates, optimizer = opax.transform_gradients(gradients, optimizer, parameters)
parameters = opax.apply_updates(parameters, updates)

Note: gradients has the same treedef as parameters.

The opax.chain function

opax follows optax's philosophy in combining GradientTransformation's together with opax.chain function:

optimizer = opax.chain(
    opax.clip_by_global_norm(1.0),
    opax.scale_by_adam(),
    opax.scale(1e-4),
).init(parameters)

Learning rate schedule

opax supports learning rate scheduling with opax.scale_by_schedule.

def staircase_schedule_fn(step: jnp.ndarray):
    p = jnp.floor(step.astype(jnp.float32) / 1000)
    return jnp.exp2(-p)

optimizer = opax.chain(
    opax.clip_by_global_norm(1.0),
    opax.scale_by_adam(),
    opax.scale_by_schedule(staircase_schedule_fn),
).init(parameters)

Utilities

To get the current global norm:

print(optimizer[0].global_norm)

Note: global_norm is a property of ClipByGlobalNorm class.

To get the current learning rate:

print(optimizer[-1].learning_rate)

Note: learning_rate is a property of ScaleBySchedule class.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

opax-0.2.10.tar.gz (7.5 kB view details)

Uploaded Source

Built Distribution

opax-0.2.10-py3-none-any.whl (8.0 kB view details)

Uploaded Python 3

File details

Details for the file opax-0.2.10.tar.gz.

File metadata

  • Download URL: opax-0.2.10.tar.gz
  • Upload date:
  • Size: 7.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.10.5

File hashes

Hashes for opax-0.2.10.tar.gz
Algorithm Hash digest
SHA256 13a0b253e751c02cc675572f1754161e68b0db4e95cc4a6b192da8ce5465e556
MD5 3b84bbf0c2ded9ac9034cf3fa7149c9a
BLAKE2b-256 017101bbea8dc9cb672c7899ff73092b1359a8f1caab4ea028b652031e431681

See more details on using hashes here.

File details

Details for the file opax-0.2.10-py3-none-any.whl.

File metadata

  • Download URL: opax-0.2.10-py3-none-any.whl
  • Upload date:
  • Size: 8.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.10.5

File hashes

Hashes for opax-0.2.10-py3-none-any.whl
Algorithm Hash digest
SHA256 268162bf1f859864d5ad37903ad306f126d737ff6d75d0ee021127b4b6b5f6c8
MD5 12fcb483c30103d24fc3026c108dfe30
BLAKE2b-256 169c119d947b0f8d209627983a6ef590720189afbd3c403619b148d186bf5c0c

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page