DiffJPEG implemented in JAX
Project description
DiffJPEG: A Jax Implementation
This is a Jax implementation of the differentiable JPEG compression algorithm, based on the PyTorch implementation and some of the modifications found in this repository to improve quality at high compression rates.
Requirements
- JAX
Installation
Can be installed with pip:
pip install diffjpeg_jax
Usage
Unlike the PyTorch version, this is ML library agnostic, so it simply is implemented as a function. Inputs should be in the range [0, 255]
and in the format (H, W, C)
.
from diffjpeg_jax import diff_jpeg
img = ... # (H, W, C)
jpeg = diff_jpeg(img, quality=75)
Note: The implementation is not wrapped in JIT, so make sure to do that if you want to. For batch processing just use vmap.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
diffjpeg_jax-1.0.0.tar.gz
(4.0 kB
view hashes)
Built Distribution
Close
Hashes for diffjpeg_jax-1.0.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1915318f08693e5f87bc2f02a863f5fec49a65d5da6d4beab8184c5084674354 |
|
MD5 | ddd1b708181593bf900105ddab647f18 |
|
BLAKE2b-256 | 36632dca458376b1ad329be977e82d9793b6089033ae010574206bf507a2a8f7 |