Stochastic Weight Averaging for Optax
Project description
SWAG in Optax
This package implements SWAG as an Optax transform to allow usage with JAX.
Installation
For now, the only available mode of installation is directly from source as
pip install git+https://github.com/activatedgeek/optax-swag.git
TODO: A PyPI package will be available soon.
Usage
To start updating the iterate statistics, use chaining as
import optax
from optax_swag import swag
optimizer = optax.chain(
... ## Other optimizer and transform config.
swag(freq, rank) ## Always add as the last transform.
)
The SWAGState object can be accessed from the optimizer state list for downstream usage.
Sampling
A reference code to generate samples from the collected statistics is provided below.
import jax
import jax.numpy as jnp
from optax_swag import sample_swag
swa_opt_state = # Reference to a SWAGState object from the optimizer.
n_samples = 10
rng = jax.random.PRNGKey(42)
rng, *samples_rng = jax.random.split(rng, 1 + n_samples)
swag_sample_params = jax.vmap(sample_swag, in_axes=(0, None))(
jnp.array(samples_rng), swa_opt_state)
The resulting swag_sample_params
can now be used for downstream evaluation.
NOTE: Make sure to update non-parameter variables (e.g. BatchNorm running statistics) for each generated sample.
License
Apache 2.0
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for optax_swag-0.1.0b1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 39e20c56f9e8a450779f8dca610cbb00a9da4f2be1e0ee3f05568528d4d82308 |
|
MD5 | 800405ed13c8ec9177160801bbd0d434 |
|
BLAKE2b-256 | 993683b067248d08ad7b2ec8cb295062e0f56b11e74874c25b53d4f3bad072bb |