Skip to main content

Differentiable quantization framework for PyTorch -- fixed for compatibility with Python 3.11+

Project description

Differentiable Model Compression via Pseudo Quantization Noise

linter badge tests badge cov badge

DiffQ performs differentiable quantization using pseudo quantization noise. It can automatically tune the number of bits used per weight or group of weights, in order to achieve a given trade-off between model size and accuracy.

Go read our paper for more details.

What's up?

See the changelog for details on releases.

  • 2022-08-24: v0.2.3: fixed a bug when loading old quantized states.
  • 2021-11-25: version 0.2.2: adding support for torchscript.

Requirements

DiffQ requires Python 3.7, and a reasonably recent version of PyTorch (1.7.1 ideally). To install DiffQ, you can run from the root of the repository:

pip install .

You can also install directly from PyPI with pip install diffq.

Usage

import torch
from torch.nn import functional as F
import diffq
from diffq import DiffQuantizer

model = MyModel()
optim = ...  # The optimizer must be created before the quantizer
quantizer = DiffQuantizer(model)
quantizer.setup_optimizer(optim)

# Distributed data parallel must be created after DiffQuantizer!
dmodel = torch.distributed.DistributedDataParallel(...)

penalty = 1e-3
model.train()  # call model.eval() on eval to automatically use true quantized weights.
for batch in loader:
    ...
    optim.zero_grad()

    # The `penalty` parameter here will control the tradeoff between model size and model accuracy.
    loss = F.mse_loss(x, y) + penalty * quantizer.model_size()
    optim.step()

# To get the true model size with when doing proper bit packing.
print(f"Model is {quantizer.true_model_size():.1f} MB")

# When you want to dump your final model:
torch.save(quantizer.get_quantized_state(), "some_file.th")

# You can later load back the model with
model = MyModel()
diffq.restore_quantized_state(model, torch.load("some_file.th"))

# For DiffQ models, we support exporting the model to Torscript with optimal storage.
# Once loaded, the model will be stored in fp32 in memory (int8 support coming up).
from diffq.ts_export import export
export(quantizer, 'quantized.ts')

Documentation

See the API documentation for detailed documentation. We cover hereafter a few aspects.

Quantizer object

A Quantizer is attached to a model at its creation. All Quantizer objects provide the same basic capabilities:

  • automatically switches to quantized weights on the forward if the model is in eval mode.
  • quantizer-specific code on training forward (e.g. STE for UniformQuantizer with QAT, noise injection for DiffQ).
  • provide access to the quantized model size and state.

Quantized size and state

The method quantizer.model_size() provide a differentiable model size (for DiffQ), while quantizer.true_model_size() provide the true, optimally bit-packed, model size (non differentiable). With quantizer.compressed_model_size() you can get the model size using gzip. This can actually be larger than the true model size, and reveals interesting information on the entropy usage of a specific quantization method.

The bit-packed quantized state is obtained with quantizer.get_quantized_state() , and restored with quantizer.restore_quantized_state(). Bit packing is optimized for speed and can suffer from some overhead (in practice no more than 120B for Uniform and LSQ, and not more than 1kB for DiffQ).

If you do not have access to the original quantizer, for instance at inference time, you can load the state with diffq.restore_quantized_state(model, quantized_state).

Quantizer and optimization

Some quantizer will add extra optimizable parameters (DiffQuantizer and LSQ). Those parameters can require different optimizers or hyper-parameters than the main model weights. Typically, DiffQ bits parameters are always optimized with Adam. For that reason, you should always create the main optimizer before the quantizer. You can then setup the quantizer with this optimizer or another:

model = MyModel(...)
opt = torch.optim.Adam(model.parameters())
quantizer = diffq.DiffQuantizer(model)
quantizer.setup_optimizer(opt, **optim_overrides)

This offers the freedom to use a separate hyper-params. For instance, DiffQuantizer will always deactivate weight_decay for the bits parameters.

If the main optimizer is SGD, it is advised to have a second Adam optimizer for the quantizer.

Warning: you must always wrap your model with DistributedDataParallel after having created the quantizer, otherwise the quantizer parameters won't be optimized!

TorchScript support

At the moment the TorchScript support is experimental. We support saving the model with TorchScript to disk with optimal storage. Once loaded, the model is stored in FP32 in memory. We are working towards adding support for int8 in memory. See the diffq.ts_export.export function in the API.

Examples

We provide three examples in the examples/ folder. One is for CIFAR-10/100, using standard architecture such as Wide-ResNet, ResNet or MobileNet. The second is based on the DeiT visual transformer. The third is a language modeling task on Wikitext-103, using Fairseq

The DeiT and Fairseq examples are provided as a patch on the original codebase at a specific commit. You can initialize the git submodule and apply the patches by running

make examples

For more details on each example, go checkout their specific READMEs:

Installation for development

This will install the dependencies and a diffq in developer mode (changes to the files will directly reflect), along with the dependencies to run unit tests.

pip install -e '.[dev]'

Updating the patch based examples

In order to update the patches, first run make examples to properly initialize the sub repos. Then perform all the changes you want, commit them and run make patches. This will update the patches for each repo. Once this is done, and you checked that all the changes you did are properly included in the new patch files, you can run make reset (this will remove all your changes you did from the submodules, so do check the patch files before calling this) before calling git add -u .; git commit -m "my changes" and pushing.

Test

You can run the unit tests with

make tests

Citation

If you use this code or results in your paper, please cite our work as:

@article{defossez2021differentiable,
  title={Differentiable Model Compression via Pseudo Quantization Noise},
  author={D{\'e}fossez, Alexandre and Adi, Yossi and Synnaeve, Gabriel},
  journal={TMLR},
  year={2022}
}

License

This repository is released under the CC-BY-NC 4.0. license as found in the LICENSE file, except for the following parts that is under the MIT license. The files examples/cifar/src/mobilenet.py and examples/cifar/src/src/resnet.py are taken from kuangliu/pytorch-cifar, released as MIT. The file examples/cifar/src/wide_resnet.py is taken from meliketoy/wide-resnet, released as MIT. See each file headers for the detailed license.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

diffq_fixed-0.2.4.tar.gz (189.7 kB view details)

Uploaded Source

Built Distributions

diffq_fixed-0.2.4-cp312-cp312-win_amd64.whl (106.2 kB view details)

Uploaded CPython 3.12 Windows x86-64

diffq_fixed-0.2.4-cp312-cp312-manylinux_2_28_x86_64.whl (569.2 kB view details)

Uploaded CPython 3.12 manylinux: glibc 2.28+ x86-64

diffq_fixed-0.2.4-cp311-cp311-win_amd64.whl (106.3 kB view details)

Uploaded CPython 3.11 Windows x86-64

diffq_fixed-0.2.4-cp311-cp311-manylinux_2_28_x86_64.whl (575.7 kB view details)

Uploaded CPython 3.11 manylinux: glibc 2.28+ x86-64

diffq_fixed-0.2.4-cp310-cp310-win_amd64.whl (106.2 kB view details)

Uploaded CPython 3.10 Windows x86-64

diffq_fixed-0.2.4-cp310-cp310-manylinux_2_28_x86_64.whl (536.9 kB view details)

Uploaded CPython 3.10 manylinux: glibc 2.28+ x86-64

File details

Details for the file diffq_fixed-0.2.4.tar.gz.

File metadata

  • Download URL: diffq_fixed-0.2.4.tar.gz
  • Upload date:
  • Size: 189.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for diffq_fixed-0.2.4.tar.gz
Algorithm Hash digest
SHA256 cbc906b76fa23d1cf3c0ae517fbab744d9624980a068fe7fbb00dede1d83208d
MD5 a68a7f4c76d5509434d624d764dd5c25
BLAKE2b-256 be968ca5acf5ecfd4108aa6f345cef171f2fda0c081cf0e7430671712586f172

See more details on using hashes here.

File details

Details for the file diffq_fixed-0.2.4-cp312-cp312-win_amd64.whl.

File metadata

File hashes

Hashes for diffq_fixed-0.2.4-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 81087279979db570723ca0911265635d7c3132cc8b74bc52bdea94881440ce96
MD5 82660fbd0f0894fd2a91f396c2b948f0
BLAKE2b-256 331494327d99f551136ac95bdc9697a1738f78ab3bb63afe8ce58b2b9854cd3a

See more details on using hashes here.

File details

Details for the file diffq_fixed-0.2.4-cp312-cp312-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for diffq_fixed-0.2.4-cp312-cp312-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 5aea342125e12b9388f23a4f5bf40cf9ea3fc3888d755e848864546b104443ab
MD5 b6c8c9ace51f74bcab2cabcf2820f211
BLAKE2b-256 f7ea9d831d1fd2c78432282d8e876d6896bb4f098637ce6e48afc2f0ee8775eb

See more details on using hashes here.

File details

Details for the file diffq_fixed-0.2.4-cp311-cp311-win_amd64.whl.

File metadata

File hashes

Hashes for diffq_fixed-0.2.4-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 67798dccc9b4f46196071953d52fafbef0751c5d0acb064e0883c7a8509d84f9
MD5 ab29aca8d87716a57721a41333f98511
BLAKE2b-256 4c70a11838773e1c2bd9fa47d5016a90b7584b4052179407fb934eb89499552e

See more details on using hashes here.

File details

Details for the file diffq_fixed-0.2.4-cp311-cp311-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for diffq_fixed-0.2.4-cp311-cp311-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 6cc6c47820b7823fda855f8793ff48bc33a3d9f0f0e7dee6e430dade25085c9a
MD5 5b137bdb0abdb05d1f0d37f12525b59a
BLAKE2b-256 87d04a763200110bbc04aa4599f6163da7e486e0855f81404b1225839fb515c3

See more details on using hashes here.

File details

Details for the file diffq_fixed-0.2.4-cp310-cp310-win_amd64.whl.

File metadata

File hashes

Hashes for diffq_fixed-0.2.4-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 d6fc81bd6ca17ebc04b8a163f9efae5afedf3274a4eee473f6742471a8d904bc
MD5 4e9573fcfb8fbb88528bf6dc410c71ef
BLAKE2b-256 0fa9b3bf2493447a7492c4208ee9003514017e9ee009798df34abada3666211b

See more details on using hashes here.

File details

Details for the file diffq_fixed-0.2.4-cp310-cp310-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for diffq_fixed-0.2.4-cp310-cp310-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 c911cda65ddebb672aaa8d80d902259ffc4ee469ce231021544093d83372ed15
MD5 bb6849520b249518683161cb9d5a2149
BLAKE2b-256 0a1a4e6981b6ac4b876526e354bae41a27f704508580d74bac93c8be56fe7804

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page