Enables pruning of Keras DNNs using "lottery ticket" pruning
Project description
Lottery Ticket Pruner
Deep Neural Networks (DNNs) can often benefit from "pruning" some weights in the network, turning dense matrices of weights into sparse matrices of weights with little or no loss in accuracy of the overall model.
This is a keras implementation of the most relevant pruning strategies outlined in two papers:
- The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks
- Deconstructing Lottery Tickets: Zeros, Signs, and the Supermask
The pruning strategies implemented in this package can reduce the number of non-zero weights of CNNs, DNNs by 40-98% with negligible losses in accuracy of the final model. Various techniques like MorphNet can then be applied to further optimize these now sparse models to decrease model size and/or inference times.
Installation
pip install lottery-ticket-pruner
Usage
A typical use of the code in this repo looks something like this:
from lottery_ticket_pruner import LotteryTicketPruner, PrunerCallback
model = <create model with random initial weights>
# Save the initial weights of the model so we can start pruning training from them later
initial_weights = model.get_weights()
# Initialize pruner so it knows the starting initial (random) weights
pruner = LotteryTicketPruner(model)
...
# Train the model
model.fit(X, y)
...
pruner.set_pretrained_weights(model)
...
# Revert model so it has random initial weights
model.set_weights(initial_weights)
# Now train the model using pruning
pruner.calc_prune_mask(model, 0.5, 'large_final')
untrained_loss, untrained_accuracy = model.evaluate(x_test, y_test)
model.fit(X, y, callbacks=[PrunerCallback(pruner)])
trained_loss, trained_accuracy = model.evaluate(x_test, y_test)
For a full working example that computes the accuracy for an untrained model that has been pruned, as well as training a model from scratch using lotttery ticket pruning, see the example code. This example code uses the MNIST and CIFAR10 datasets.
Results
examples/example.sh was run to see the effects of pruning at 20%, 55.78%, 89.6%, 99.3% using the 3 supported pruning strategies across the MNIST and CIFAR10 datasets. Training was capped at 100 epochs to help control AWS expenses.
The results averaged across 3 iterations:
MNIST (100 epochs)
|Prune Percentage| |Dataset| |Prune Strategy| |Avg Accuracy|
|:---| |:---| |:---:| |:---:|
|n/a| |mnist| |n/a| |0.937|
|20%| |mnist| |large_final| |0.935|
|20%| |mnist| |smallest_weights| |0.936|
|20%| |mnist| |smallest_weights_global| |0.939|
|55.78%| |mnist| |large_final| |0.936|
|55.78%| |mnist| |smallest_weights| |0.936|
|55.78%| |mnist| |smallest_weights_global| |0.939|
|89.6%| |mnist| |large_final| |0.936|
|89.6%| |mnist| |smallest_weights| |0.937|
|89.6%| |mnist| |smallest_weights_global| |0.939|
|99.33%| |mnist| |large_final| |0.936|
|99.33%| |mnist| |smallest_weights| |0.937|
|99.33%| |mnist| |smallest_weights_global| |0.939|
CIFAR (100 epochs)
|Prune Percentage| |Dataset| |Prune Strategy| |Avg Accuracy|
|:---| |:---| |:---:| |:---:|
|n/a| |cifar10| |n/a| |0.427|
|20%| |cifar10| |large_final| |0.298|
|20%| |cifar10| |smallest_weights| |0.427|
|20%| |cifar10| |smallest_weights_global| |0.423|
|55.78%| |cifar10| |large_final| |0.294|
|55.78%| |cifar10| |smallest_weights| |0.427|
|55.78%| |cifar10| |smallest_weights_global| |0.424|
|89.6%| |cifar10| |large_final| |0.289|
|89.6%| |cifar10| |smallest_weights| |0.427|
|89.6%| |cifar10| |smallest_weights_global| |0.424|
|99.33%| |cifar10| |large_final| |0.288|
|99.33%| |cifar10| |smallest_weights| |0.428|
|99.33%| |cifar10| |smallest_weights_global| |0.425|
CIFAR (500 epochs)
|Prune Percentage| |Dataset| |Prune Strategy| |Avg Accuracy|
|:---| |:---| |:---:| |:---:|
|n/a| |cifar10| |n/a| |0.550|
|20%| |cifar10| |smallest_weights_global| |0.550|
|55.78%| |cifar10| |smallest_weights_global| |0.552|
|89.6%| |cifar10| |smallest_weights_global| |0.554|
|99.33%| |cifar10| |smallest_weights_global| |0.554|
Pruning the initial model weights with no training
One of the surprising findings of these papers is that if we simply do inference on the model using the original weights, with no training, but applying pruning the resulting models perform far (far!) better than a random guess. Here are the results of inference done after applying pruning to the random initial weights of the model without any training. The initial model, used as an input to the pruning logic, was trained for 100 epochs.
MNIST
|Prune Percentage| |Dataset| |Prune Strategy| |Avg Accuracy|
|:---| |:---| |:---:| |:---:|
|n/a| |mnist| |no pruning done - random weights| |0.121|
|n/a| |mnist| |source model trained for 100 epochs| |0.936|
|20%| |mnist| |large_final| |0.760|
|20%| |mnist| |smallest_weights| |0.737|
|20%| |mnist| |smallest_weights_global| |0.722|
|55.78%| |mnist| |large_final| |0.911|
|55.78%| |mnist| |smallest_weights| |0.899|
|55.78%| |mnist| |smallest_weights_global| |0.920|
|89.6%| |mnist| |large_final| |0.744|
|89.6%| |mnist| |smallest_weights| |0.703|
|89.6%| |mnist| |smallest_weights_global| |0.925|
|99.33%| |mnist| |large_final| |0.176|
|99.33%| |mnist| |smallest_weights| |0.164|
|99.33%| |mnist| |smallest_weights_global| |0.098|
CIFAR
|Prune Percentage| |Dataset| |Prune Strategy| |Avg Accuracy|
|:---| |:---| |:---:| |:---:|
|n/a| |cifar10| |no pruning done - random weights| |0.094|
|n/a| |mnist| |source model trained for 100 epochs| |0.424|
|20%| |cifar10| |large_final| |0.232|
|20%| |cifar10| |smallest_weights| |0.180|
|20%| |cifar10| |smallest_weights_global| |0.201|
|55.78%| |cifar10| |large_final| |0.192|
|55.78%| |cifar10| |smallest_weights| |0.240|
|55.78%| |cifar10| |smallest_weights_global| |0.251|
|89.6%| |cifar10| |large_final| |0.101|
|89.6%| |cifar10| |smallest_weights| |0.102|
|89.6%| |cifar10| |smallest_weights_global| |0.240|
|99.33%| |cifar10| |large_final| |0.100|
|99.33%| |cifar10| |smallest_weights| |0.099|
|99.33%| |cifar10| |smallest_weights_global| |0.100|
Working In This Repo
The information in this section is only needed if you need to modify this package.
This repo uses Github Actions to perform Continuous Integration checks, tests for each push, pull request.
Likewise, when a new release is tagged a new version of the package is automatically built and uploaded to pypi.
Local Testing
Running unit tests locally is done via tox. This automatically generates a code coverage report too.
tox
FAQ
Q: The two papers cited above refer to more pruning strategies than are implemented here. When will you support the XXX pruning strategy?
A: The goal of this repo is to provide an implementation of the more effective strategies described by the two papers. If other effective strategies are developed then pull requests implementing those strategies are welcomed.
Q: Why isn't python 3.5 supported?
A: keras>=2.1.0, pandas>=1.0 don't support python 3.5. Hence this package does not either.
Contributing
Pull requests to this repo are always welcome.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file lottery-ticket-pruner-0.8.1.tar.gz
.
File metadata
- Download URL: lottery-ticket-pruner-0.8.1.tar.gz
- Upload date:
- Size: 23.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.1.3 requests-toolbelt/0.9.1 tqdm/4.46.1 CPython/3.6.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | fc7a35e3a8a43aa635c5919bd21a21866bb172345ce717e3df2a9bb59e181b03 |
|
MD5 | eec73e583e2adb82f7679f3545c59165 |
|
BLAKE2b-256 | c1d963e465c20f1388ec5b584412b6dc5ee96251dd80c0c2fe1fbefec5cb1283 |
File details
Details for the file lottery_ticket_pruner-0.8.1-py2.py3-none-any.whl
.
File metadata
- Download URL: lottery_ticket_pruner-0.8.1-py2.py3-none-any.whl
- Upload date:
- Size: 12.5 kB
- Tags: Python 2, Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.1.3 requests-toolbelt/0.9.1 tqdm/4.46.1 CPython/3.6.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 580c3567e16a54276eb1fdd9e968fb72f762dedd9e24a6d3fd85aef39520319a |
|
MD5 | 8d1db5b4c17deee05ce891c698ae08e2 |
|
BLAKE2b-256 | c3e417b65edd9e97ffad31bfbe0bfca71789be49fbcd9ab442e923d6c2d99f7b |