Stochastic Weight Averaging for Optax
Project description
SWAG in Optax
This package implements SWAG as an Optax transform to allow usage with JAX.
Installation
Install from pip
as:
pip install optax-swag
To install the latest directly from source, run
pip install git+https://github.com/activatedgeek/optax-swag.git
Usage
To start updating the iterate statistics, use chaining as
import optax
from optax_swag import swag
optimizer = optax.chain(
... ## Other optimizer and transform config.
swag(freq, rank) ## Always add as the last transform.
)
The SWAGState object can be accessed from the optimizer state list for downstream usage.
Sampling
A reference code to generate samples from the collected statistics is provided below.
import jax
import jax.numpy as jnp
from optax_swag import sample_swag
swa_opt_state = # Reference to a SWAGState object from the optimizer.
n_samples = 10
rng = jax.random.PRNGKey(42)
rng, *samples_rng = jax.random.split(rng, 1 + n_samples)
swag_sample_params = jax.vmap(sample_swag, in_axes=(0, None))(
jnp.array(samples_rng), swa_opt_state)
The resulting swag_sample_params
can now be used for downstream evaluation.
NOTE: Make sure to update non-parameter variables (e.g. BatchNorm running statistics) for each generated sample.
License
Apache 2.0
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
optax-swag-0.1.0.tar.gz
(7.8 kB
view hashes)
Built Distribution
Close
Hashes for optax_swag-0.1.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 399ccc8afd93979765a604e440fa5aa036e956d4f342b17a489c2f643bf2fc30 |
|
MD5 | b6831121aa74b7a2ca578ba20ff99e15 |
|
BLAKE2b-256 | 9783f15fbd3d7ad164c976b798395cebb8bbd27e5003403fce6adb5230f16bd5 |