Skip to main content

A smaller, faster, cleaner WeightWatcher for FLAX/JAX with XLA

Project description

Flax-WeightWatcher

Unfortunately, CalculatedContent's WeightWatcher does not support FLAX models, and could use accelerated linear algebra (XLA) frameworks for greater speed.

Since I found the process of making a PR to the original WeightWatcher repository too tedious, I just wrote my own one for FLAX models, because that's what I'm working with at the moment. It also helps that JAX uses XLA.

Flax-WeightWatcher is not meant to be a one-one match with the original, yet it is designed to be extensible. I welcome any interest in contributing to it to extend its functionality to perhaps match or exceed the original.

Installation

A simple pip install flax-weightwatcher will install this tool as a Python library. It will also install the following dependancies:

jax
flax
numpy
pandas
powerlaw
jaxtyping

Usage

The usage is intended to match that of the original WeightWatcher but has some minor changes.

from flax_weightwatcher import FlaxWeightWatcher

model = nnx.Sequential(*[nnx.Linear(28*28, 128, rngs=nnx.Rngs(0)), nnx.Linear(128, 10, rngs=nnx.Rngs(0))])
watcher = FlaxWeightWatcher(model=model, details_format="df") # can also be "dict" to return the details in a dictionary instead of a pandas DataFrame
details = watcher.analyze()
details.head()

This should print something like this:

   layer_index layer_name weight_shape     alpha  num_eigenvals_fit  num_eigenvals  stable_ranks  effective_ranks  ranks
0           13   layers.0      784,128  6.143005                 38            128      8.157706       117.944939    128
1           25   layers.1       128,10  5.302257                  8             10      2.552803         9.536791     10

Features to be added

  • metrics computed in WeightWatcher
  • ESD plotting utilities

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

flax_weightwatcher-0.1.0-py3-final-any.whl (4.1 kB view details)

Uploaded Python 3

File details

Details for the file flax_weightwatcher-0.1.0-py3-final-any.whl.

File metadata

File hashes

Hashes for flax_weightwatcher-0.1.0-py3-final-any.whl
Algorithm Hash digest
SHA256 08efac0f3025a9e3412656d6d4c3af067b729ad53974ae2f0f4d39bd01c566a7
MD5 1591665f34dd0c445b6ca85f38f79067
BLAKE2b-256 70d47cfc90a3d5c147b470b0b293ea931c8ac3246ffc28404594b35eed8a2f72

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page