Skip to main content

Stochastic Weight Averaging for Optax

Project description

SWAG in Optax

PyPI version

This package implements SWAG as an Optax transform to allow usage with JAX.

Installation

Install from pip as:

pip install optax-swag

To install the latest directly from source, run

pip install git+https://github.com/activatedgeek/optax-swag.git

Usage

To start updating the iterate statistics, use chaining as

import optax
from optax_swag import swag

optimizer = optax.chain(
    ...  ## Other optimizer and transform config.
    swag(freq, rank)  ## Always add as the last transform.
)

The SWAGState object can be accessed from the optimizer state list for downstream usage.

Sampling

A reference code to generate samples from the collected statistics is provided below.

import jax
import jax.numpy as jnp

from optax_swag import sample_swag

swa_opt_state = # Reference to a SWAGState object from the optimizer.
n_samples = 10

rng = jax.random.PRNGKey(42)
rng, *samples_rng = jax.random.split(rng, 1 + n_samples)

swag_sample_params = jax.vmap(sample_swag, in_axes=(0, None))(
    jnp.array(samples_rng), swa_opt_state)

The resulting swag_sample_params can now be used for downstream evaluation.

NOTE: Make sure to update non-parameter variables (e.g. BatchNorm running statistics) for each generated sample.

License

Apache 2.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

optax-swag-0.1.0.tar.gz (7.8 kB view details)

Uploaded Source

Built Distribution

optax_swag-0.1.0-py3-none-any.whl (8.6 kB view details)

Uploaded Python 3

File details

Details for the file optax-swag-0.1.0.tar.gz.

File metadata

  • Download URL: optax-swag-0.1.0.tar.gz
  • Upload date:
  • Size: 7.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.11.2

File hashes

Hashes for optax-swag-0.1.0.tar.gz
Algorithm Hash digest
SHA256 bf8826a07314c25b917c091e65acf75db09f179658c4fe1cbf0f2625127d9ed5
MD5 5c0c716dfe702c86e58322ee25a25474
BLAKE2b-256 31de4fc302bf37918613fbc4242840f9478672c0c8ae4f303d3fb962d0b362dd

See more details on using hashes here.

File details

Details for the file optax_swag-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: optax_swag-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 8.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.11.2

File hashes

Hashes for optax_swag-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 399ccc8afd93979765a604e440fa5aa036e956d4f342b17a489c2f643bf2fc30
MD5 b6831121aa74b7a2ca578ba20ff99e15
BLAKE2b-256 9783f15fbd3d7ad164c976b798395cebb8bbd27e5003403fce6adb5230f16bd5

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page