Skip to main content

DiffJPEG implemented in JAX

Project description

DiffJPEG: A Jax Implementation

This is a Jax implementation of the differentiable JPEG compression algorithm, based on the PyTorch implementation and some of the modifications found in this repository to improve quality at high compression rates.

Requirements

  • JAX

Installation

Can be installed with pip:

pip install diffjpeg_jax

Usage

Unlike the PyTorch version, this is ML library agnostic, so it simply is implemented as a function. Inputs should be in the range [0, 255] and in the format (H, W, C).

from diffjpeg_jax import diff_jpeg

img = ... # (H, W, C)
jpeg = diff_jpeg(img, quality=75)

Note: The implementation is not wrapped in JIT, so make sure to do that if you want to. For batch processing just use vmap.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

diffjpeg_jax-1.0.0.tar.gz (4.0 kB view hashes)

Uploaded Source

Built Distribution

diffjpeg_jax-1.0.0-py3-none-any.whl (4.7 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page