Skip to main content

Stochastic Weight Averaging for Optax

Project description

SWAG in Optax

PyPI version

This package implements SWAG as an Optax transform to allow usage with JAX.

Installation

Install from pip as:

pip install optax-swag

To install the latest directly from source, run

pip install git+https://github.com/activatedgeek/optax-swag.git

Usage

To start updating the iterate statistics, use chaining as

import optax
from optax_swag import swag

optimizer = optax.chain(
    ...  ## Other optimizer and transform config.
    swag(freq, rank)  ## Always add as the last transform.
)

The SWAGState object can be accessed from the optimizer state list for downstream usage.

Sampling

A reference code to generate samples from the collected statistics is provided below.

import jax
import jax.numpy as jnp

from optax_swag import sample_swag

swa_opt_state = # Reference to a SWAGState object from the optimizer.
n_samples = 10

rng = jax.random.PRNGKey(42)
rng, *samples_rng = jax.random.split(rng, 1 + n_samples)

swag_sample_params = jax.vmap(sample_swag, in_axes=(0, None))(
    jnp.array(samples_rng), swa_opt_state)

The resulting swag_sample_params can now be used for downstream evaluation.

NOTE: Make sure to update non-parameter variables (e.g. BatchNorm running statistics) for each generated sample.

License

Apache 2.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

optax-swag-0.1.0.tar.gz (7.8 kB view hashes)

Uploaded Source

Built Distribution

optax_swag-0.1.0-py3-none-any.whl (8.6 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page