Skip to main content

A smaller, faster, cleaner WeightWatcher for FLAX/JAX with XLA

Project description

Flax-WeightWatcher

Unfortunately, CalculatedContent's WeightWatcher does not support FLAX models, and could use accelerated linear algebra (XLA) frameworks for greater speed.

Since I found the process of making a PR to the original WeightWatcher repository too tedious, I just wrote my own one for FLAX models, because that's what I'm working with at the moment. It also helps that JAX uses XLA.

Flax-WeightWatcher is not meant to be a one-one match with the original, yet it is designed to be extensible. I welcome any interest in contributing to it to extend its functionality to perhaps match or exceed the original.

Installation

A simple pip install flax-weightwatcher will install this tool as a Python library. It will also install the following dependancies:

jax
flax
numpy
pandas
powerlaw
jaxtyping

Usage

The usage is intended to match that of the original WeightWatcher but has some minor changes.

from flax_weightwatcher import FlaxWeightWatcher

model = nnx.Sequential(*[nnx.Linear(28*28, 128, rngs=nnx.Rngs(0)), nnx.Linear(128, 10, rngs=nnx.Rngs(0))])
watcher = FlaxWeightWatcher(model=model, details_format="df") # can also be "dict" to return the details in a dictionary instead of a pandas DataFrame
details = watcher.analyze()
details.head()

This should print something like this:

   layer_index layer_name weight_shape     alpha  num_eigenvals_fit  num_eigenvals  stable_ranks  effective_ranks  ranks
0           13   layers.0      784,128  6.143005                 38            128      8.157706       117.944939    128
1           25   layers.1       128,10  5.302257                  8             10      2.552803         9.536791     10

Features to be added

  • metrics computed in WeightWatcher
  • ESD plotting utilities

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

flax_weightwatcher-0.1.1.tar.gz (4.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

flax_weightwatcher-0.1.1-py3-none-any.whl (4.1 kB view details)

Uploaded Python 3

File details

Details for the file flax_weightwatcher-0.1.1.tar.gz.

File metadata

  • Download URL: flax_weightwatcher-0.1.1.tar.gz
  • Upload date:
  • Size: 4.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.11

File hashes

Hashes for flax_weightwatcher-0.1.1.tar.gz
Algorithm Hash digest
SHA256 7a8ad472c36f3e2fe955976b00c028c28598bf5298a921f2a12e1f0f500f2aff
MD5 708f621f7ae7487ad86153159b254ef4
BLAKE2b-256 4c7e52951df71824ff80f35c6e9f7cf02906eab16abf6ed48cc2aee798b92e94

See more details on using hashes here.

File details

Details for the file flax_weightwatcher-0.1.1-py3-none-any.whl.

File metadata

File hashes

Hashes for flax_weightwatcher-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 6cebfb0486aa4db4d99c8f0a56d77d83e173c275361c2c5556d08bcfa86cf3b9
MD5 8a0a7984f331fd4abcfde351cd6d797f
BLAKE2b-256 35839197e2317c857257e3eea4de5eb0188bac292cd77357bcdaac579bb8a3a2

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page