Skip to main content

Qwix is a Jax quantization library.

Project description

Qwix: a quantization library for Jax.

Qwix is a Jax quantization library supporting Quantization-Aware Training (QAT) and Post-Training Quantization (PTQ) for both XLA targets (CPU/GPU/TPU) and ODML targets (LiteRT).

Features

  • Supported schemas:
    • Weight-only quantization.
    • Dynamic-range quantization.
    • Static-range quantization.
  • Supported modes:
    • QAT: this mode emulates quantized behavior during serving with fake quantization.
    • PTQ: this mode achieves the best serving performance on XLA devices such as TPU and GPU.
    • ODML: this mode adds proper annotation to the model so that the LiteRT converter could produce full integer models.
    • LoRA/QLoRA: this mode enables LoRA and QLoRA on a model.
  • Supported numerics:
    • Native: int4, int8, fp8.
    • Emulated: int1 to int7, nf4.
  • Supported array calibration methods:
    • absmax: symmetric quantization using maximum absolute value.
    • minmax: asymmetric quantization using minimum and maximum values.
    • rms: symmetric quantization using root mean square.
    • fixed: fixed range.
  • Supported Jax ops and their quantization granularity:
    • XLA:
      • conv_general_dilated: per-channel.
      • dot_general and einsum: per-channel and sub-channel.
    • LiteRT:
      • conv, matmul, and fully_connected: per-channel.
      • Other ops available in LiteRT: per-tensor.
  • Integration with any Flax Linen or NNX models via a single function call.

Usage

Qwix doesn't provide a PyPI package yet. To use Qwix, you need to install from GitHub directly.

pip install git+https://github.com/google/qwix

Model definition

We're going to use a simple MLP model in the example. Qwix integrates with models without need to modify their code, so any model can be used below.

import jax
from flax import linen as nn

class MLP(nn.Module):

  dhidden: int
  dout: int

  @nn.compact
  def __call__(self, x):
    x = nn.Dense(self.dhidden, use_bias=False)(x)
    x = nn.relu(x)
    x = nn.Dense(self.dout, use_bias=False)(x)
    return x

model = MLP(64, 16)
model_input = jax.random.uniform(jax.random.key(0), (8, 16))

Quantization config

Qwix uses a regex-based configuration system to instruct how to quantize a Jax model. Configurations are defined as a list of QuantizationRule. Each rule consists of a key that matches Flax modules, and a set of values that control quantization behavior.

For example, to quantize the above model in int8 (w8a8), we need to define the rules as below.

import qwix

rules = [
    qwix.QuantizationRule(
        module_path='.*',  # this rule matches all modules.
        weight_qtype='int8',  # quantizes weights in int8.
        act_qtype='int8',  # quantizes activations in int8.
    )
]

Unlike some other libraries that provides limited number of quantization recipes, Qwix doesn't have a list of presets. Instead, different quantization schemas are achieved by combinations of quantization configs.

Post-Training Quantization

To apply PTQ to the above model, we only need to call qwix.quantize_model.

ptq_model = qwix.quantize_model(model, qwix.PtqProvider(rules))

Now the ptq_model will contain quantized weights. We could verify that.

>>> jax.eval_shape(ptq_model.init, jax.random.key(0), model_input)['params']
{
  'Dense_0': {
    'kernel': WithAux(
        array=QArray(
            qvalue=ShapeDtypeStruct(shape=(16, 64), dtype=int8),
            scale=ShapeDtypeStruct(shape=(1, 64), dtype=float32),
            ...
        ),
        ...
    )
  },
  'Dense_1': {
    'kernel': WithAux(
        array=QArray(
            qvalue=ShapeDtypeStruct(shape=(64, 16), dtype=int8),
            scale=ShapeDtypeStruct(shape=(1, 16), dtype=float32),
            ...
        ),
        ...
    )
  }
}

Weight quantization

Since Flax Linen modules are pure-functional, weights quantization are separate from model quantization. To quantize weights for the above ptq_model, we need to call qwix.quantize_params.

# Floating-point params, usually loaded from checkpoints.
fp_params = ...

# Abstract quantized params, which serve as a template for quantize_params.
abs_ptq_params = jax.eval_shape(ptq_model.init, jax.random.key(0), model_input)['params']

# Weight quantization.
ptq_params = qwix.quantize_params(fp_params, abs_ptq_params)

# ptq_params contains the quantized weights and can be consumed by ptq_model.
quantized_model_output = ptq_model.apply({'params': ptq_params}, model_input)

Relation with AQT

The design of Qwix was inspired by AQT and borrowed many great ideas from it. Here's a brief list of the similarities and the differences.

  • Qwix's QArray is similar to AQT's QTensor, both supporting sub-channel quantization.
  • AQT has quantized training support (quantized forwards and quantized backwards), while Qwix's QAT is based on fake quantization, which doesn't improve the training performance.
  • AQT provides drop-in replacements for einsum and dot_general, each of these having to be configured separately. Qwix provides addtional mechanisms to integrate with a whole model implicitly.
  • Applying static-range quantization is easier in Qwix as it has more in-depth support with Flax.

Contributing

Please refer to CONTRIBUTING.md for more information.

Citing Qwix

To cite Qwix please use the citation:

@software{Qwix,
  title = {Qwix: A Quantization Library for Jax},
  author={Dangyi Liu, Jiwon Shin, et al.},
  year = {2024},
  howpublished = {\url{https://github.com/google/qwix}},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

qwix-0.1.6.tar.gz (114.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

qwix-0.1.6-py3-none-any.whl (143.9 kB view details)

Uploaded Python 3

File details

Details for the file qwix-0.1.6.tar.gz.

File metadata

  • Download URL: qwix-0.1.6.tar.gz
  • Upload date:
  • Size: 114.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.12

File hashes

Hashes for qwix-0.1.6.tar.gz
Algorithm Hash digest
SHA256 7ccd4968be320b19eb9349cd953e689a1fd6e97bbde4ac788d8a2280f0f24acc
MD5 d28541b2f05738fc5f8c9c065ce13afb
BLAKE2b-256 34062bcfa17d78b2433943186356730540e53d89fb35e4a8a5629dc996cddafb

See more details on using hashes here.

File details

Details for the file qwix-0.1.6-py3-none-any.whl.

File metadata

  • Download URL: qwix-0.1.6-py3-none-any.whl
  • Upload date:
  • Size: 143.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.12

File hashes

Hashes for qwix-0.1.6-py3-none-any.whl
Algorithm Hash digest
SHA256 e7aae2a781fb1459da8a46523c2fae193682b8285b0ace0b39838e998486c394
MD5 00033612fdcf268710d044fa2b3c49ad
BLAKE2b-256 d8b2664a923f9c9a9cb589a80a95e3ae0f536982e047cd89d65066ee8bc0f33e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page