A gradient processing and optimisation library in JAX.
Project description
Optax
Introduction
Optax is a gradient processing and optimization library for JAX.
Optax is designed to facilitate research by providing building blocks that can be easily recombined in custom ways.
Our goals are to:
- provide simple, well-tested, efficient implementations of core components,
- improve research productivity by enabling to easily combine low level ingredients into custom optimiser (or other gradient processing components).
- accelerate adoption of new ideas by making it easy for anyone to contribute.
We favour focusing on small composable building blocks that can be effectively combined into custom solutions. Others may build upon these basic components more complicated abstractions. Whenever reasonable, implementations prioritise readability and structuring code to match standard equations, over code reuse.
An initial prototype of this library was made available in JAX's experimental
folder as jax.experimental.optix
. Given the wide adoption across DeepMind
of optix
, and after a few iterations on the API, optix
was eventually moved
out of experimental
as a standalone open-source library, renamed optax
.
Installation
Chex can be installed with pip directly from github, with the following command:
pip install git+git://github.com/deepmind/optax.git
Components
Gradient Transformations (transform.py)
One of the key building blocks of optax
is a GradientTransformation
.
Each transformation is defined two functions:
- state = init(params)
- grads, state = update(grads, state, params=None)
The init
function initializes a (possibly empty) set of statistics (aka state)
and the update
function transforms a candidate gradient given some statistics,
and (optionally) the current value of the parameters.
For example:
tx = scale_by_rms()
state = tx.init(params) # init stats
grads, state = tx.update(grads, state, params) # transform & update stats.
Composing Gradient Transformations (combine.py)
The fact that transformations take candidate gradients as input and return processed gradients as output (in contrast to returning the updated parameters) is critical to allow to combine arbitrary transformations into a custom optimiser / gradient processor, and also allows to combine transformations for different gradients that operate on a shared set of variables.
For instance, chain
combines them sequentially, and returns a
new GradientTransformation
that applies several transformations in sequence.
For example:
my_optimiser = chain(
clip_by_global_norm(max_norm),
scale_by_adam(eps=1e-4),
scale(-learning_rate))
Schedules (schedule.py)
Many popular transformations use time dependent components, e.g. to anneal
some hyper-parameter (e.g. the learning rate). Optax provides for this purpose
schedules
that can be used to decay scalars as a function of a step
count.
For example:
def polynomial_schedule(
init_value, end_value, power, transition_steps, transition_begin):
def schedule(step_count):
count = jnp.clip(
step_count - transition_begin, 0, transition_steps)
frac = 1 - count / transition_steps
return (init_value - end_value) * (frac**power) + end_value
return schedule
Popular optimisers (alias.py)
In addition to the low level building blocks we also provide aliases for popular
optimisers built using these components (e.g. RMSProp, Adam, AdamW, etc, ...).
These are all still instances of a GradientTransformation
, and can therefore
be further combined with any of the individual building blocks.
For example:
def adamw(learning_rate, b1, b2, eps, weight_decay):
return chain(
scale_by_adam(b1=b1, b2=b2, eps=eps),
scale_and_decay(-learning_rate, weight_decay=weight_decay))
Applying updates (update.py)
An apply_updates
function can be used to eventually apply the
transformed gradients to the set of parameters of interest.
Separating gradient transformations from the parameter update allows to flexibly chain a sequence of transformations of the same gradients, as well as combine multiple updates to the same parameters (e.g. in multi-task settings where the different tasks may benefit from different sets of gradient transformations).
Second Order (second_order.py)
Computing the Hessian or Fisher information matrices for neural networks is typically intractible due to the quadratic memory requirements. Solving for the diagonals of these matrices is often a better solution. The library offers functions for computing these diagonals with sub-quadratic memory requirements.
Stochastic gradient estimators (stochastic_gradient_estimators.py)
Stochastic gradient estimators compute Monte Carlo estimates of gradients of the expectation of a function under a distribution with respect to the distribution's parameters.
Unbiased estimators such as the score function estimator (REINFORCE),
pathwise estimator (reparametrization trick) or measure valued estimator
are implemented: score_function_jacobians
, pathwise_jacobians
and measure_valued_jacobians
. Their applicability (both in terms of functions and
distributions) is discussed in their respective documentation.
Stochastic gradient estimators can be combined with common control variates for
variance reduction via control_variates_jacobians
. For provided control
variates see delta
and moving_avg_baseline
.
The result of a gradient estimator or control_variates_jacobians
contains the
Jacobians of the function with respect to the samples from the input
distribution. These can then be used to update distributional parameters, or
to assess gradient variance.
Example of how to use the pathwise_jacobians
estimator:
dist_params = [mean, log_scale]
function = lambda x: jnp.sum(x * weights)
jacobians = pathwise_jacobians(
function, dist_params,
utils.multi_normal, rng, num_samples)
mean_grads = jnp.mean(jacobians[0], axis=0)
log_scale_grads = jnp.mean(jacobians[1], axis=0)
grads = [mean_grads, log_scale_grads]
optim_update, optim_state = optim.update(grads, optim_state)
updated_dist_params = optax.apply_updates(dist_params, optim_update)
where optim
is an optax optimizer.
Citing Optax
To cite this repository:
@software{optax2020github,
author = {Matteo Hessel and David Budden and Fabio Viola and Mihaela Rosca
and Tom Hennigan},
title = {Optax: composable gradient transformation and optimisation, in JAX!},
url = {http://github.com/deepmind/optax},
version = {0.0.1},
year = {2020},
}
In this bibtex entry, the version number is intended to be from optax/__init__.py, and the year corresponds to the project's open-source release.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file optax-0.0.1.tar.gz
.
File metadata
- Download URL: optax-0.0.1.tar.gz
- Upload date:
- Size: 34.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/49.6.0 requests-toolbelt/0.9.1 tqdm/4.48.2 CPython/3.8.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | ed3dddb3973ee65ad12048ed589d384efc0fa6b096bed2c75ed008aef4ad720c |
|
MD5 | 36b060e9cf4dee69ff217759633eb02a |
|
BLAKE2b-256 | 96219c7e30191613d2504a17410255fe1bfc9c37e5d3b1b4a87ee80a4a7a952a |
Provenance
File details
Details for the file optax-0.0.1-py3-none-any.whl
.
File metadata
- Download URL: optax-0.0.1-py3-none-any.whl
- Upload date:
- Size: 49.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/49.6.0 requests-toolbelt/0.9.1 tqdm/4.48.2 CPython/3.8.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 037e67da96c2ca586422819feda06e4291a31e1fef7adad07d7946851309de44 |
|
MD5 | b3cdd3948e17ce0c68a275420198a770 |
|
BLAKE2b-256 | 04a2e00bcb7ec4cb3dd9d3973a279628f419e97f29d0efe140ee15a9ae53db6e |