Skip to main content

Fast Arithmetic Coding for PyTorch

Project description

torchac: Fast Arithmetic Coding for PyTorch

About

This is a simplified version of the arithmetic coder we used in the

neural compression paper "Practical Full Resolution Learned Lossless Image

Compression", which

lives in the L3C-Pytorch repo.

In particular, we removed the L3C-specific parts, which relied on CUDA

compliations and were tricky to get going.

The implementation is based on this blog post,

meaning that we implement arithmetic coding.

While it could be further optimized, it is already much faster than doing the equivalent thing in pure-Python (because of all the

bit-shifts etc.). In L3C, Encoding an entire 512 x 512 image happens in 0.202s (see Appendix A in the paper).

What torchac is

  • A simple library to encode a stream of symbols into a bitstream given

    the cumulative distribution of the symbols.

  • The number of possible symbols must be finite.

What torchac is not

  • We do not provide classes to learn or represent probability/cumulative

    distributions. These have to be provided by you.

HowTo

Set up conda environment

This library has been tested with

  • PyTorch 1.7

  • Python 3.8

And that's all you need. Other versions also may work.

If you don't have an environment setup, you can make one with conda:

# We use Python 3.8, other version may be supported.

conda create --name <YOUR_ENV_NAME> python==3.8



conda activate <YOUR_ENV_NAME>



# Installing pytorch

Find conda command for your system: https://pytorch.org

Test installation

To (optionally) test your installation, you need pytest:

# If you don't have pytest

pip install pytest



# Run tests

python -m pytest test.py -s

Output should end in something like:

===== 5 passed, 2 warnings in 0.95s =========

Example

The examples/ folder contains an example for training an auto-encoder on MNIST.

Output of the example script. First two columns show training

set, second two columns show testing set.

FAQ

1. Output is not equal to the input

Either normalization gone wrong or you encoded a symbol that is >Lp.

Important Implementation Details

How we represent probability distributions.

The probabilities are specified as CDFs.

For each possible symbol,

we need 2 CDF values. This means that if there are L possible symbols

{0, ..., L-1}, the CDF must specified the value for L+1 symbols.

Example:


Let's say we have L = 3 possible symbols. We need a CDF with 4 values

to specify the symbols distribution:



symbol:        0     1     2

cdf:       C_0   C_1   C_2   C_3



This corresponds to the 3 probabilities



P(0) = C_1 - C_0

P(1) = C_2 - C_1

P(2) = C_3 - C_2



NOTE: The arithmetic coder assumes that C_3 == 1. 

Important:

  • If you have L possible symbols, you need to pass a CDF that

    specifies L + 1 values. Since this is a common number, we call it

    Lp = L + 1 throught the code (the "p" stands for prime, i.e., L').

  • The last value of the CDF should be 1. Note that the arithmetic coder

    in torchac.cpp will just assume it's 1 regardless of what is passed, so not having a CDF

    that ends in 1 will mean you will estimate bitrates wrongly. More details below.

  • Note that even though the CDF specifies Lp values, symbols are only allowed

to be in {0, ..., Lp-2}. In the above example, Lp == 4, but the

max symbols is Lp-2 == 2. Bigger values will yield wrong outputs

Expected input shapes

We allow any shapes for the inputs, but the spatial dimensions of the

input CDF and the input symbols must match. In particular, we expect:

  • CDF must have shape (N1, ..., Nm, Lp), where N1, ..., Nm are the

m spatial dimensions, and Lp is as described above.

  • Symbols must have shape (N1, ..., Nm), i.e., same spatial dimensions

as the CDF.

For example, in a typical CNN, you might have a CDF of shape

(batch, channels, height, width, Lp).

Normalized vs. Unnormalized / Floating Point vs. Integer CDFs

The library differentiates between "normalized" and "unnormalized" CDFs,

and between "floating point" and "integer" CDFs. What do these mean?

  • A proper CDF is strictly monotonically increasing, and we call this a

"normalized" CDF.

  • However, since we work with finite precision (16 bits to

be precise in this implementation), it may be that you have a CDF that

is strictly monotonically increasing in float32 space, but not when

it is converted to 16 bit precision. An "unnormalized" CDF is what we call

a CDF that has the same value for at least two subsequent elements.

  • "floating point" CDFs are CDFs that are specified as float32 and need

to be converted to 16 bit precision

  • "integer" CDFs are CDFs specified as int16 - BUT are then interpreted

as uint16 on the C++ side. See "int16 vs uint16" below.

Examples:

float_unnormalized_cdf = [0.1, 0.2, 0.2, 0.3, ..., 1.]

float_normalized_cdf = [0.1, 0.2, 0.20001, 0.3, ..., 1.]

integer_unnormalized_cdf = [10, 20, 20, 30, ..., 0]  # See below for why last is 0.

integer_normalized_cdf = [10, 20, 21, 30, ..., 0]    # See below for why last is 0.

There are two APIs:

  • encode_float_cdf and decode_float_cdf is to be used for floating point

CDFs. These functions have a flag needs_normalization that specifies

whether the input is assumed to be normalized. You can set

need_normalization=False if you have CDFs that you know are normalized, e.g.,

Gaussian distributions with a large enough sigma. This would then speedup

encoding and decoding large tensors somewhat, and will make bitrate

estimation from the CDF more precise.

  • encode_int16_normalized_cdf and decode_int16_normalized_cdf is to be

used for integer CDFs that are already normalized.

int16 vs uint16 - it gets confusing!

One big source of confusion can be that PyTorch does not support uint16.

Yet, that's exactly what we need. So what we do is we just represent

integer CDFs with int16 in the Python side, and interpret/cast them to uint16

on the C++ side. This means that if you were to look at the int16 CDFs

you would see confusing things:

# Python

cdf_float = [0., 1/3, 2/3, 1.]  # A uniform distribution for L=3 symbols.

cdf_int = [0, 21845, -21845, 0]



# C++

uint16* cdf_int = [0, 21845, 43690, 0]

Note:

  1. In the python cdf_int numbers bigger than 2**16/2 are negative

  2. The final value is actually 0. This is then handled in torchac.cpp which

just assums cdf[..., -1] == 2**16, which cannot be represented as a uint16.

Fun stuff!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchac-0.8.2.tar.gz (12.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchac-0.8.2-py3-none-any.whl (21.8 kB view details)

Uploaded Python 3

File details

Details for the file torchac-0.8.2.tar.gz.

File metadata

  • Download URL: torchac-0.8.2.tar.gz
  • Upload date:
  • Size: 12.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.6.1 requests/2.25.0 setuptools/50.3.2 requests-toolbelt/0.9.1 tqdm/4.51.0 CPython/3.8.0

File hashes

Hashes for torchac-0.8.2.tar.gz
Algorithm Hash digest
SHA256 70ba0a3b50ccf26adafb06180e6412f0e94547379993f93a810fd521d6e9faf5
MD5 7039140e158507f7047c71d14c9dd09d
BLAKE2b-256 cb82a305a34e9eb053c31fee599d64bc4e8667c3637f6562c420a22f1bc78b00

See more details on using hashes here.

File details

Details for the file torchac-0.8.2-py3-none-any.whl.

File metadata

  • Download URL: torchac-0.8.2-py3-none-any.whl
  • Upload date:
  • Size: 21.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.6.1 requests/2.25.0 setuptools/50.3.2 requests-toolbelt/0.9.1 tqdm/4.51.0 CPython/3.8.0

File hashes

Hashes for torchac-0.8.2-py3-none-any.whl
Algorithm Hash digest
SHA256 420040fd6b62b804fdfdb76ee5fe70db7fa62a356db29b0f97ebdf18fc93eee0
MD5 9d8934362770d55a60c3164196c3ad34
BLAKE2b-256 47584c76f7c57600299a41d63cfbb83f2ba43940580256fcaa3a3ac3160d8ea6

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page