Skip to main content

A gradient processing and optimisation library in JAX.

Project description

Optax

CI status Documentation Status pypi

Introduction

Optax is a gradient processing and optimization library for JAX.

Optax is designed to facilitate research by providing building blocks that can be easily recombined in custom ways.

Our goals are to

  • Provide simple, well-tested, efficient implementations of core components.
  • Improve research productivity by enabling to easily combine low-level ingredients into custom optimisers (or other gradient processing components).
  • Accelerate adoption of new ideas by making it easy for anyone to contribute.

We favour focusing on small composable building blocks that can be effectively combined into custom solutions. Others may build upon these basic components in more complicated abstractions. Whenever reasonable, implementations prioritise readability and structuring code to match standard equations, over code reuse.

An initial prototype of this library was made available in JAX's experimental folder as jax.experimental.optix. Given the wide adoption across DeepMind of optix, and after a few iterations on the API, optix was eventually moved out of experimental as a standalone open-source library, and renamed optax.

Documentation on Optax can be found at optax.readthedocs.io.

Installation

You can install the latest released version of Optax from PyPI via:

pip install optax

or you can install the latest development version from GitHub:

pip install git+https://github.com/google-deepmind/optax.git

Quickstart

Optax contains implementations of many popular optimizers and loss functions. For example, the following code snippet uses the Adam optimizer from optax.adam and the mean squared error from optax.l2_loss. We initialize the optimizer state using the init function and params of the model.

optimizer = optax.adam(learning_rate)
# Obtain the `opt_state` that contains statistics for the optimizer.
params = {'w': jnp.ones((num_weights,))}
opt_state = optimizer.init(params)

To write the update loop we need a loss function that can be differentiated by Jax (with jax.grad in this example) to obtain the gradients.

compute_loss = lambda params, x, y: optax.l2_loss(params['w'].dot(x), y)
grads = jax.grad(compute_loss)(params, xs, ys)

The gradients are then converted via optimizer.update to obtain the updates that should be applied to the current parameters to obtain the new ones. optax.apply_updates is a convenience utility to do this.

updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)

You can continue the quick start in the Optax 101 notebook.

Development

We welcome new contributors.

Source code

You can check the latest sources with the following command.

git clone https://github.com/google-deepmind/optax.git

Testing

To run the tests, please execute the following script.

sh ./test.sh

Documentation

To build the documentation, first ensure that all the dependencies are installed.

pip install -e ".[docs]"

Then, execute the following.

cd docs/
make html

Benchmarks

If you feel lost in the crowd of available optimizers for deep learning, there exist some extensive benchmarks:

Benchmarking Neural Network Training Algorithms, Dahl G. et al, 2023,

Descending through a Crowded Valley — Benchmarking Deep Learning Optimizers, Schmidt R. et al, 2021.

If you are interested in developing your own benchmark for some tasks, consider the following framework

Benchopt: Reproducible, efficient and collaborative optimization benchmarks, Moreau T. et al, 2022.

Finally, if you are searching for some recommendations on tuning optimizers, consider taking a look at

Deep Learning Tuning Playbook, Godbole V. et al, 2023.

Citing Optax

This repository is part of the DeepMind JAX Ecosystem, to cite Optax please use the citation:

@software{deepmind2020jax,
  title = {The {D}eep{M}ind {JAX} {E}cosystem},
  author = {DeepMind and Babuschkin, Igor and Baumli, Kate and Bell, Alison and Bhupatiraju, Surya and Bruce, Jake and Buchlovsky, Peter and Budden, David and Cai, Trevor and Clark, Aidan and Danihelka, Ivo and Dedieu, Antoine and Fantacci, Claudio and Godwin, Jonathan and Jones, Chris and Hemsley, Ross and Hennigan, Tom and Hessel, Matteo and Hou, Shaobo and Kapturowski, Steven and Keck, Thomas and Kemaev, Iurii and King, Michael and Kunesch, Markus and Martens, Lena and Merzic, Hamza and Mikulik, Vladimir and Norman, Tamara and Papamakarios, George and Quan, John and Ring, Roman and Ruiz, Francisco and Sanchez, Alvaro and Sartran, Laurent and Schneider, Rosalia and Sezener, Eren and Spencer, Stephen and Srinivasan, Srivatsan and Stanojevi\'{c}, Milo\v{s} and Stokowiec, Wojciech and Wang, Luyu and Zhou, Guangyao and Viola, Fabio},
  url = {http://github.com/google-deepmind},
  year = {2020},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

optax-0.2.2.tar.gz (160.4 kB view details)

Uploaded Source

Built Distribution

optax-0.2.2-py3-none-any.whl (223.7 kB view details)

Uploaded Python 3

File details

Details for the file optax-0.2.2.tar.gz.

File metadata

  • Download URL: optax-0.2.2.tar.gz
  • Upload date:
  • Size: 160.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.12.2

File hashes

Hashes for optax-0.2.2.tar.gz
Algorithm Hash digest
SHA256 f09bf790ef4b09fb9c35f79a07594c6196a719919985f542dc84b0bf97812e0e
MD5 87465ef021d24a1e5a9576b860dc140d
BLAKE2b-256 880520c29c0a1d391d669098a49fee30325c12c92e16a3df607bb882a2c630cc

See more details on using hashes here.

Provenance

File details

Details for the file optax-0.2.2-py3-none-any.whl.

File metadata

  • Download URL: optax-0.2.2-py3-none-any.whl
  • Upload date:
  • Size: 223.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.12.2

File hashes

Hashes for optax-0.2.2-py3-none-any.whl
Algorithm Hash digest
SHA256 411c414a76aae259f4191a60b712663968741a5163ca92fc250b5d5c7d36fb57
MD5 711cf28deab970cce394dc9b976c51e1
BLAKE2b-256 160474ec9cf76c9e3d222251ac38de67404a41b3b673d9227611c9f5aecccb18

See more details on using hashes here.

Provenance

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page