Fast Arithmetic Coding for PyTorch
Project description
torchac: Fast Arithmetic Coding for PyTorch
About
This is a simplified version of the arithmetic coder we used in the neural compression paper "Practical Full Resolution Learned Lossless Image Compression", which lives in the L3C-Pytorch repo. In particular, we removed the L3C-specific parts, which relied on CUDA compliations and were tricky to get going.
The implementation is based on this blog post,
meaning that we implement arithmetic coding.
While it could be further optimized, it is already much faster than doing the equivalent thing in pure-Python (because of all the
bit-shifts etc.). In L3C, Encoding an entire 512 x 512
image happens in 0.202s (see Appendix A in the paper).
What torchac is
- A simple library to encode a stream of symbols into a bitstream given the cumulative distribution of the symbols.
- The number of possible symbols must be finite.
What torchac is not
- We do not provide classes to learn or represent probability/cumulative distributions. These have to be provided by you.
HowTo
Set up conda environment
This library has been tested with
- PyTorch 1.7
- Python 3.8
And that's all you need. Other versions also may work.
If you don't have an environment setup, you can make one with conda
:
# We use Python 3.8, other version may be supported.
conda create --name <YOUR_ENV_NAME> python==3.8
conda activate <YOUR_ENV_NAME>
# Installing pytorch
Find conda command for your system: https://pytorch.org
Test installation
To (optionally) test your installation, you need pytest
:
# If you don't have pytest
pip install pytest
# Run tests
python -m pytest test.py -s
Output should end in something like:
===== 5 passed, 2 warnings in 0.95s =========
Example
The examples/
folder contains an example for training an auto-encoder on MNIST.
Output of the example script. First two columns show training set, second two columns show testing set.
FAQ
1. Output is not equal to the input
Either normalization gone wrong or you encoded a symbol that is >Lp
.
Important Implementation Details
How we represent probability distributions.
The probabilities are specified as CDFs.
For each possible symbol,
we need 2 CDF values. This means that if there are L
possible symbols
{0, ..., L-1}
, the CDF must specified the value for L+1
symbols.
Example:
Let's say we have L = 3 possible symbols. We need a CDF with 4 values
to specify the symbols distribution:
symbol: 0 1 2
cdf: C_0 C_1 C_2 C_3
This corresponds to the 3 probabilities
P(0) = C_1 - C_0
P(1) = C_2 - C_1
P(2) = C_3 - C_2
NOTE: The arithmetic coder assumes that C_3 == 1.
Important:
- If you have
L
possible symbols, you need to pass a CDF that specifiesL + 1
values. Since this is a common number, we call itLp = L + 1
throught the code (the "p" stands for prime, i.e.,L'
). - The last value of the CDF should be
1
. Note that the arithmetic coder intorchac.cpp
will just assume it's1
regardless of what is passed, so not having a CDF that ends in1
will mean you will estimate bitrates wrongly. More details below. - Note that even though the CDF specifies
Lp
values, symbols are only allowed to be in{0, ..., Lp-2}
. In the above example,Lp == 4
, but the max symbols isLp-2 == 2
. Bigger values will yield wrong outputs
Expected input shapes
We allow any shapes for the inputs, but the spatial dimensions of the input CDF and the input symbols must match. In particular, we expect:
- CDF must have shape
(N1, ..., Nm, Lp)
, whereN1, ..., Nm
are them
spatial dimensions, andLp
is as described above. - Symbols must have shape
(N1, ..., Nm)
, i.e., same spatial dimensions as the CDF.
For example, in a typical CNN, you might have a CDF of shape
(batch, channels, height, width, Lp)
.
Normalized vs. Unnormalized / Floating Point vs. Integer CDFs
The library differentiates between "normalized" and "unnormalized" CDFs, and between "floating point" and "integer" CDFs. What do these mean?
- A proper CDF is strictly monotonically increasing, and we call this a "normalized" CDF.
- However, since we work with finite precision (16 bits to
be precise in this implementation), it may be that you have a CDF that
is strictly monotonically increasing in
float32
space, but not when it is converted to 16 bit precision. An "unnormalized" CDF is what we call a CDF that has the same value for at least two subsequent elements. - "floating point" CDFs are CDFs that are specified as
float32
and need to be converted to 16 bit precision - "integer" CDFs are CDFs specified as
int16
- BUT are then interpreted asuint16
on the C++ side. See "int16 vs uint16" below.
Examples:
float_unnormalized_cdf = [0.1, 0.2, 0.2, 0.3, ..., 1.]
float_normalized_cdf = [0.1, 0.2, 0.20001, 0.3, ..., 1.]
integer_unnormalized_cdf = [10, 20, 20, 30, ..., 0] # See below for why last is 0.
integer_normalized_cdf = [10, 20, 21, 30, ..., 0] # See below for why last is 0.
There are two APIs:
encode_float_cdf
anddecode_float_cdf
is to be used for floating point CDFs. These functions have a flagneeds_normalization
that specifies whether the input is assumed to be normalized. You can setneed_normalization=False
if you have CDFs that you know are normalized, e.g., Gaussian distributions with a large enough sigma. This would then speedup encoding and decoding large tensors somewhat, and will make bitrate estimation from the CDF more precise.encode_int16_normalized_cdf
anddecode_int16_normalized_cdf
is to be used for integer CDFs that are already normalized.
int16 vs uint16 - it gets confusing!
One big source of confusion can be that PyTorch does not support uint16
.
Yet, that's exactly what we need. So what we do is we just represent
integer CDFs with int16
in the Python side, and interpret/cast them to uint16
on the C++ side. This means that if you were to look at the int16 CDFs
you would see confusing things:
# Python
cdf_float = [0., 1/3, 2/3, 1.] # A uniform distribution for L=3 symbols.
cdf_int = [0, 21845, -21845, 0]
# C++
uint16* cdf_int = [0, 21845, 43690, 0]
Note:
- In the python
cdf_int
numbers bigger than2**16/2
are negative - The final value is actually 0. This is then handled in
torchac.cpp
which just assumscdf[..., -1] == 2**16
, which cannot be represented as auint16
.
Fun stuff!
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file torchac-0.8.9.tar.gz
.
File metadata
- Download URL: torchac-0.8.9.tar.gz
- Upload date:
- Size: 12.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.2.0 pkginfo/1.6.1 requests/2.25.0 setuptools/50.3.2 requests-toolbelt/0.9.1 tqdm/4.51.0 CPython/3.8.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 71fa23c0c5f3dc028f7ded5d9dda01611f1f9a1cd8a08a8360d45a21104b67da |
|
MD5 | 0bddecbdc8b6feb28c8073cfe0e2f4c3 |
|
BLAKE2b-256 | 7712093bc5ef3927c0e8b94cba97dc0d5f530b2d7729ca1cc9480d51aaf58b72 |
File details
Details for the file torchac-0.8.9-py3-none-any.whl
.
File metadata
- Download URL: torchac-0.8.9-py3-none-any.whl
- Upload date:
- Size: 22.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.2.0 pkginfo/1.6.1 requests/2.25.0 setuptools/50.3.2 requests-toolbelt/0.9.1 tqdm/4.51.0 CPython/3.8.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | ad1d5e306bf153c77c2b88cc2c855df0aaea77ed3b1fd2db036853859b152aaa |
|
MD5 | 009eea4f061cfa052fea8fc62afd865f |
|
BLAKE2b-256 | 16935bea589ee4f6dc5cad6dc15c174d48b626960b06b9d068e5f4d87ede9763 |