Skip to main content

Enables pruning of Keras DNNs using "lottery ticket" pruning

Project description

Lottery Ticket Pruner

Deep Neural Networks (DNNs) can often benefit from "pruning" some weights in the network, turning dense matrices of weights into sparse matrices of weights with little or no loss in accuracy of the overall model.

This is a keras implementation of the most relevant pruning strategies outlined in two papers:

The pruning strategies implemented in this package can reduce the number of non-zero weights of CNNs, DNNs by 40-98% with negligible losses in accuracy of the final model. Various techniques like MorphNet can then be applied to further optimize these now sparse models to decrease model size and/or inference times.

Installation

pip install lottery-ticket-pruner

Usage

A typical use of the code in this repo looks something like this:

from lottery_ticket_pruner import LotteryTicketPruner, PrunerCallback

model = <create model with random initial weights>
# Save the initial weights of the model so we can start pruning training from them later
initial_weights = model.get_weights()
# Initialize pruner so it knows the starting initial (random) weights
pruner = LotteryTicketPruner(model)
...
# Train the model
model.fit(X, y)
...
pruner.set_pretrained_weights(model)
...
# Revert model so it has random initial weights
model.set_weights(initial_weights)
# Now train the model using pruning
pruner.calc_prune_mask(model, 0.5, 'large_final')
untrained_loss, untrained_accuracy = model.evaluate(x_test, y_test)
model.fit(X, y, callbacks=[PrunerCallback(pruner)])
trained_loss, trained_accuracy = model.evaluate(x_test, y_test)

For a full working example that computes the accuracy for an untrained model that has been pruned, as well as training a model from scratch using lotttery ticket pruning, see the example code. This example code uses the MNIST and CIFAR10 datasets.

Results

examples/example.sh was run to see the effects of pruning at 20%, 55.78%, 89.6%, 99.3% using the 3 supported pruning strategies across the MNIST and CIFAR10 datasets. Training was capped at 100 epochs to help control AWS expenses.

The results averaged across 3 iterations:

MNIST (100 epochs)

|Prune Percentage|  |Dataset|   |Prune Strategy|            |Avg Accuracy|
|:---|              |:---|      |:---:|                     |:---:|
|n/a|               |mnist|     |n/a|                       |0.937|
|20%|               |mnist|     |large_final|               |0.935|
|20%|               |mnist|     |smallest_weights|          |0.936|
|20%|               |mnist|     |smallest_weights_global|   |0.939|
|55.78%|            |mnist|     |large_final|               |0.936|
|55.78%|            |mnist|     |smallest_weights|          |0.936|
|55.78%|            |mnist|     |smallest_weights_global|   |0.939|
|89.6%|             |mnist|     |large_final|               |0.936|
|89.6%|             |mnist|     |smallest_weights|          |0.937|
|89.6%|             |mnist|     |smallest_weights_global|   |0.939|
|99.33%|            |mnist|     |large_final|               |0.936|
|99.33%|            |mnist|     |smallest_weights|          |0.937|
|99.33%|            |mnist|     |smallest_weights_global|   |0.939|

CIFAR (100 epochs)

|Prune Percentage|  |Dataset|   |Prune Strategy|            |Avg Accuracy|
|:---|              |:---|      |:---:|                     |:---:|
|n/a|               |cifar10|   |n/a|                       |0.427|
|20%|               |cifar10|   |large_final|               |0.298|
|20%|               |cifar10|   |smallest_weights|          |0.427|
|20%|               |cifar10|   |smallest_weights_global|   |0.423|
|55.78%|            |cifar10|   |large_final|               |0.294|
|55.78%|            |cifar10|   |smallest_weights|          |0.427|
|55.78%|            |cifar10|   |smallest_weights_global|   |0.424|
|89.6%|             |cifar10|   |large_final|               |0.289|
|89.6%|             |cifar10|   |smallest_weights|          |0.427|
|89.6%|             |cifar10|   |smallest_weights_global|   |0.424|
|99.33%|            |cifar10|   |large_final|               |0.288|
|99.33%|            |cifar10|   |smallest_weights|          |0.428|
|99.33%|            |cifar10|   |smallest_weights_global|   |0.425|

CIFAR (500 epochs)

|Prune Percentage|  |Dataset|   |Prune Strategy|            |Avg Accuracy|
|:---|              |:---|      |:---:|                     |:---:|
|n/a|               |cifar10|   |n/a|                       |0.550|
|20%|               |cifar10|   |smallest_weights_global|   |0.550|
|55.78%|            |cifar10|   |smallest_weights_global|   |0.552|
|89.6%|             |cifar10|   |smallest_weights_global|   |0.554|
|99.33%|            |cifar10|   |smallest_weights_global|   |0.554|

Pruning the initial model weights with no training

One of the surprising findings of these papers is that if we simply do inference on the model using the original weights, with no training, but applying pruning the resulting models perform far (far!) better than a random guess. Here are the results of inference done after applying pruning to the random initial weights of the model without any training. The initial model, used as an input to the pruning logic, was trained for 100 epochs.

MNIST

|Prune Percentage|  |Dataset|   |Prune Strategy|                        |Avg Accuracy|
|:---|              |:---|      |:---:|                                 |:---:|
|n/a|               |mnist|     |no pruning done - random weights|      |0.121|
|n/a|               |mnist|     |source model trained for 100 epochs|   |0.936|
|20%|               |mnist|     |large_final|                           |0.760|
|20%|               |mnist|     |smallest_weights|                      |0.737|
|20%|               |mnist|     |smallest_weights_global|               |0.722|
|55.78%|            |mnist|     |large_final|                           |0.911|
|55.78%|            |mnist|     |smallest_weights|                      |0.899|
|55.78%|            |mnist|     |smallest_weights_global|               |0.920|
|89.6%|             |mnist|     |large_final|                           |0.744|
|89.6%|             |mnist|     |smallest_weights|                      |0.703|
|89.6%|             |mnist|     |smallest_weights_global|               |0.925|
|99.33%|            |mnist|     |large_final|                           |0.176|
|99.33%|            |mnist|     |smallest_weights|                      |0.164|
|99.33%|            |mnist|     |smallest_weights_global|               |0.098|

CIFAR

|Prune Percentage|  |Dataset|   |Prune Strategy|                        |Avg Accuracy|
|:---|              |:---|      |:---:|                                 |:---:|
|n/a|               |cifar10|   |no pruning done - random weights|      |0.094|
|n/a|               |mnist|     |source model trained for 100 epochs|   |0.424|
|20%|               |cifar10|   |large_final|                           |0.232|
|20%|               |cifar10|   |smallest_weights|                      |0.180|
|20%|               |cifar10|   |smallest_weights_global|               |0.201|
|55.78%|            |cifar10|   |large_final|                           |0.192|
|55.78%|            |cifar10|   |smallest_weights|                      |0.240|
|55.78%|            |cifar10|   |smallest_weights_global|               |0.251|
|89.6%|             |cifar10|   |large_final|                           |0.101|
|89.6%|             |cifar10|   |smallest_weights|                      |0.102|
|89.6%|             |cifar10|   |smallest_weights_global|               |0.240|
|99.33%|            |cifar10|   |large_final|                           |0.100|
|99.33%|            |cifar10|   |smallest_weights|                      |0.099|
|99.33%|            |cifar10|   |smallest_weights_global|               |0.100|

Working In This Repo

The information in this section is only needed if you need to modify this package.

This repo uses Github Actions to perform Continuous Integration checks, tests for each push, pull request.

Likewise, when a new release is tagged a new version of the package is automatically built and uploaded to pypi.

Local Testing

Running unit tests locally is done via tox. This automatically generates a code coverage report too.

tox

FAQ

Q: The two papers cited above refer to more pruning strategies than are implemented here. When will you support the XXX pruning strategy?

A: The goal of this repo is to provide an implementation of the more effective strategies described by the two papers. If other effective strategies are developed then pull requests implementing those strategies are welcomed.

Q: Why isn't python 3.5 supported?

A: keras>=2.1.0, pandas>=1.0 don't support python 3.5. Hence this package does not either.

Contributing

Pull requests to this repo are always welcome.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lottery-ticket-pruner-0.8.1.tar.gz (23.8 kB view details)

Uploaded Source

Built Distribution

lottery_ticket_pruner-0.8.1-py2.py3-none-any.whl (12.5 kB view details)

Uploaded Python 2 Python 3

File details

Details for the file lottery-ticket-pruner-0.8.1.tar.gz.

File metadata

  • Download URL: lottery-ticket-pruner-0.8.1.tar.gz
  • Upload date:
  • Size: 23.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.1.3 requests-toolbelt/0.9.1 tqdm/4.46.1 CPython/3.6.8

File hashes

Hashes for lottery-ticket-pruner-0.8.1.tar.gz
Algorithm Hash digest
SHA256 fc7a35e3a8a43aa635c5919bd21a21866bb172345ce717e3df2a9bb59e181b03
MD5 eec73e583e2adb82f7679f3545c59165
BLAKE2b-256 c1d963e465c20f1388ec5b584412b6dc5ee96251dd80c0c2fe1fbefec5cb1283

See more details on using hashes here.

File details

Details for the file lottery_ticket_pruner-0.8.1-py2.py3-none-any.whl.

File metadata

  • Download URL: lottery_ticket_pruner-0.8.1-py2.py3-none-any.whl
  • Upload date:
  • Size: 12.5 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.1.3 requests-toolbelt/0.9.1 tqdm/4.46.1 CPython/3.6.8

File hashes

Hashes for lottery_ticket_pruner-0.8.1-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 580c3567e16a54276eb1fdd9e968fb72f762dedd9e24a6d3fd85aef39520319a
MD5 8d1db5b4c17deee05ce891c698ae08e2
BLAKE2b-256 c3e417b65edd9e97ffad31bfbe0bfca71789be49fbcd9ab442e923d6c2d99f7b

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page