Skip to main content

A gradient processing and optimisation library in JAX.

Project description

Optax

Introduction

Optax is a gradient processing and optimization library for JAX.

Optax is designed to facilitate research by providing building blocks that can be easily recombined in custom ways.

Our goals are to:

  • provide simple, well-tested, efficient implementations of core components,
  • improve research productivity by enabling to easily combine low level ingredients into custom optimiser (or other gradient processing components).
  • accelerate adoption of new ideas by making it easy for anyone to contribute.

We favour focusing on small composable building blocks that can be effectively combined into custom solutions. Others may build upon these basic components more complicated abstractions. Whenever reasonable, implementations prioritise readability and structuring code to match standard equations, over code reuse.

An initial prototype of this library was made available in JAX's experimental folder as jax.experimental.optix. Given the wide adoption across DeepMind of optix, and after a few iterations on the API, optix was eventually moved out of experimental as a standalone open-source library, renamed optax.

Installation

Chex can be installed with pip directly from github, with the following command:

pip install git+git://github.com/deepmind/optax.git

Components

Gradient Transformations (transform.py)

One of the key building blocks of optax is a GradientTransformation.

Each transformation is defined two functions:

  • state = init(params)
  • grads, state = update(grads, state, params=None)

The init function initializes a (possibly empty) set of statistics (aka state) and the update function transforms a candidate gradient given some statistics, and (optionally) the current value of the parameters.

For example:

tx = scale_by_rms()
state = tx.init(params)  # init stats
grads, state = tx.update(grads, state, params)  # transform & update stats.

Composing Gradient Transformations (combine.py)

The fact that transformations take candidate gradients as input and return processed gradients as output (in contrast to returning the updated parameters) is critical to allow to combine arbitrary transformations into a custom optimiser / gradient processor, and also allows to combine transformations for different gradients that operate on a shared set of variables.

For instance, chain combines them sequentially, and returns a new GradientTransformation that applies several transformations in sequence.

For example:

my_optimiser = chain(
    clip_by_global_norm(max_norm),
    scale_by_adam(eps=1e-4),
    scale(-learning_rate))

Schedules (schedule.py)

Many popular transformations use time dependent components, e.g. to anneal some hyper-parameter (e.g. the learning rate). Optax provides for this purpose schedules that can be used to decay scalars as a function of a step count.

For example:

def polynomial_schedule(
    init_value, end_value, power, transition_steps, transition_begin):
    def schedule(step_count):
      count = jnp.clip(
          step_count - transition_begin, 0, transition_steps)
      frac = 1 - count / transition_steps
      return (init_value - end_value) * (frac**power) + end_value
    return schedule

Popular optimisers (alias.py)

In addition to the low level building blocks we also provide aliases for popular optimisers built using these components (e.g. RMSProp, Adam, AdamW, etc, ...). These are all still instances of a GradientTransformation, and can therefore be further combined with any of the individual building blocks.

For example:

def adamw(learning_rate, b1, b2, eps, weight_decay):
  return chain(
      scale_by_adam(b1=b1, b2=b2, eps=eps),
      scale_and_decay(-learning_rate, weight_decay=weight_decay))

Applying updates (update.py)

An apply_updates function can be used to eventually apply the transformed gradients to the set of parameters of interest.

Separating gradient transformations from the parameter update allows to flexibly chain a sequence of transformations of the same gradients, as well as combine multiple updates to the same parameters (e.g. in multi-task settings where the different tasks may benefit from different sets of gradient transformations).

Second Order (second_order.py)

Computing the Hessian or Fisher information matrices for neural networks is typically intractible due to the quadratic memory requirements. Solving for the diagonals of these matrices is often a better solution. The library offers functions for computing these diagonals with sub-quadratic memory requirements.

Stochastic gradient estimators (stochastic_gradient_estimators.py)

Stochastic gradient estimators compute Monte Carlo estimates of gradients of the expectation of a function under a distribution with respect to the distribution's parameters.

Unbiased estimators such as the score function estimator (REINFORCE), pathwise estimator (reparametrization trick) or measure valued estimator are implemented: score_function_jacobians, pathwise_jacobians and measure_valued_jacobians. Their applicability (both in terms of functions and distributions) is discussed in their respective documentation.

Stochastic gradient estimators can be combined with common control variates for variance reduction via control_variates_jacobians. For provided control variates see delta and moving_avg_baseline.

The result of a gradient estimator or control_variates_jacobians contains the Jacobians of the function with respect to the samples from the input distribution. These can then be used to update distributional parameters, or to assess gradient variance.

Example of how to use the pathwise_jacobians estimator:

  dist_params = [mean, log_scale]
  function = lambda x: jnp.sum(x * weights)
  jacobians = pathwise_jacobians(
        function, dist_params,
        utils.multi_normal, rng, num_samples)

  mean_grads = jnp.mean(jacobians[0], axis=0)
  log_scale_grads = jnp.mean(jacobians[1], axis=0)
  grads = [mean_grads, log_scale_grads]
  optim_update, optim_state = optim.update(grads, optim_state)
  updated_dist_params = optax.apply_updates(dist_params, optim_update)

where optim is an optax optimizer.

Citing Optax

To cite this repository:

@software{optax2020github,
  author = {Matteo Hessel and David Budden and Fabio Viola and Mihaela Rosca
            and Tom Hennigan},
  title = {Optax: composable gradient transformation and optimisation, in JAX!},
  url = {http://github.com/deepmind/optax},
  version = {0.0.1},
  year = {2020},
}

In this bibtex entry, the version number is intended to be from optax/__init__.py, and the year corresponds to the project's open-source release.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

optax-0.0.1.tar.gz (34.5 kB view details)

Uploaded Source

Built Distribution

optax-0.0.1-py3-none-any.whl (49.4 kB view details)

Uploaded Python 3

File details

Details for the file optax-0.0.1.tar.gz.

File metadata

  • Download URL: optax-0.0.1.tar.gz
  • Upload date:
  • Size: 34.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/49.6.0 requests-toolbelt/0.9.1 tqdm/4.48.2 CPython/3.8.5

File hashes

Hashes for optax-0.0.1.tar.gz
Algorithm Hash digest
SHA256 ed3dddb3973ee65ad12048ed589d384efc0fa6b096bed2c75ed008aef4ad720c
MD5 36b060e9cf4dee69ff217759633eb02a
BLAKE2b-256 96219c7e30191613d2504a17410255fe1bfc9c37e5d3b1b4a87ee80a4a7a952a

See more details on using hashes here.

Provenance

File details

Details for the file optax-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: optax-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 49.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/49.6.0 requests-toolbelt/0.9.1 tqdm/4.48.2 CPython/3.8.5

File hashes

Hashes for optax-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 037e67da96c2ca586422819feda06e4291a31e1fef7adad07d7946851309de44
MD5 b3cdd3948e17ce0c68a275420198a770
BLAKE2b-256 04a2e00bcb7ec4cb3dd9d3973a279628f419e97f29d0efe140ee15a9ae53db6e

See more details on using hashes here.

Provenance

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page