Skip to main content

Qwix is a Jax quantization library.

Reason this release was yanked:

Please use 0.1.4

Project description

Qwix: a quantization library for Jax.

Qwix is a Jax quantization library supporting Quantization-Aware Training (QAT) and Post-Training Quantization (PTQ) for both XLA targets (CPU/GPU/TPU) and ODML targets (LiteRT).

Features

  • Supported schemas:
    • Weight-only quantization.
    • Dynamic-range quantization.
    • Static-range quantization.
  • Supported modes:
    • QAT: this mode emulates quantized behavior during serving with fake quantization.
    • PTQ: this mode achieves the best serving performance on XLA devices such as TPU and GPU.
    • ODML: this mode adds proper annotation to the model so that the LiteRT converter could produce full integer models.
    • LoRA/QLoRA: this mode enables LoRA and QLoRA on a model.
  • Supported numerics:
    • Native: int4, int8, fp8.
    • Emulated: int1 to int7, nf4.
  • Supported array calibration methods:
    • absmax: symmetric quantization using maximum absolute value.
    • minmax: asymmetric quantization using minimum and maximum values.
    • rms: symmetric quantization using root mean square.
    • fixed: fixed range.
  • Supported Jax ops and their quantization granularity:
    • XLA:
      • conv_general_dilated: per-channel.
      • dot_general and einsum: per-channel and sub-channel.
    • LiteRT:
      • conv, matmul, and fully_connected: per-channel.
      • Other ops available in LiteRT: per-tensor.
  • Integration with any Flax Linen or NNX models via a single function call.

Usage

Qwix doesn't provide a PyPI package yet. To use Qwix, you need to install from GitHub directly.

pip install git+https://github.com/google/qwix

Model definition

We're going to use a simple MLP model in the example. Qwix integrates with models without need to modify their code, so any model can be used below.

import jax
from flax import linen as nn

class MLP(nn.Module):

  dhidden: int
  dout: int

  @nn.compact
  def __call__(self, x):
    x = nn.Dense(self.dhidden, use_bias=False)(x)
    x = nn.relu(x)
    x = nn.Dense(self.dout, use_bias=False)(x)
    return x

model = MLP(64, 16)
model_input = jax.random.uniform(jax.random.key(0), (8, 16))

Quantization config

Qwix uses a regex-based configuration system to instruct how to quantize a Jax model. Configurations are defined as a list of QuantizationRule. Each rule consists of a key that matches Flax modules, and a set of values that control quantization behavior.

For example, to quantize the above model in int8 (w8a8), we need to define the rules as below.

import qwix

rules = [
    qwix.QuantizationRule(
        module_path='.*',  # this rule matches all modules.
        weight_qtype='int8',  # quantizes weights in int8.
        act_qtype='int8',  # quantizes activations in int8.
    )
]

Unlike some other libraries that provides limited number of quantization recipes, Qwix doesn't have a list of presets. Instead, different quantization schemas are achieved by combinations of quantization configs.

Post-Training Quantization

To apply PTQ to the above model, we only need to call qwix.quantize_model.

ptq_model = qwix.quantize_model(model, qwix.PtqProvider(rules))

Now the ptq_model will contain quantized weights. We could verify that.

>>> jax.eval_shape(ptq_model.init, jax.random.key(0), model_input)['params']
{
  'Dense_0': {
    'kernel': WithAux(
        array=QArray(
            qvalue=ShapeDtypeStruct(shape=(16, 64), dtype=int8),
            scale=ShapeDtypeStruct(shape=(1, 64), dtype=float32),
            ...
        ),
        ...
    )
  },
  'Dense_1': {
    'kernel': WithAux(
        array=QArray(
            qvalue=ShapeDtypeStruct(shape=(64, 16), dtype=int8),
            scale=ShapeDtypeStruct(shape=(1, 16), dtype=float32),
            ...
        ),
        ...
    )
  }
}

Weight quantization

Since Flax Linen modules are pure-functional, weights quantization are separate from model quantization. To quantize weights for the above ptq_model, we need to call qwix.quantize_params.

# Floating-point params, usually loaded from checkpoints.
fp_params = ...

# Abstract quantized params, which serve as a template for quantize_params.
abs_ptq_params = jax.eval_shape(ptq_model.init, jax.random.key(0), model_input)['params']

# Weight quantization.
ptq_params = qwix.quantize_params(fp_params, abs_ptq_params)

# ptq_params contains the quantized weights and can be consumed by ptq_model.
quantized_model_output = ptq_model.apply({'params': ptq_params}, model_input)

Relation with AQT

The design of Qwix was inspired by AQT and borrowed many great ideas from it. Here's a brief list of the similarities and the differences.

  • Qwix's QArray is similar to AQT's QTensor, both supporting sub-channel quantization.
  • AQT has quantized training support (quantized forwards and quantized backwards), while Qwix's QAT is based on fake quantization, which doesn't improve the training performance.
  • AQT provides drop-in replacements for einsum and dot_general, each of these having to be configured separately. Qwix provides addtional mechanisms to integrate with a whole model implicitly.
  • Applying static-range quantization is easier in Qwix as it has more in-depth support with Flax.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

qwix-0.1.3.tar.gz (70.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

qwix-0.1.3-py3-none-any.whl (91.6 kB view details)

Uploaded Python 3

File details

Details for the file qwix-0.1.3.tar.gz.

File metadata

  • Download URL: qwix-0.1.3.tar.gz
  • Upload date:
  • Size: 70.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.7

File hashes

Hashes for qwix-0.1.3.tar.gz
Algorithm Hash digest
SHA256 b83408a1be41129bd06112a845cd60a961cc713d297f7ac0ca7485ec04d409e7
MD5 7d8a769d6db0288f25c3531ffbe578e7
BLAKE2b-256 3076d8b64c6ca116e43fec88c9d60438513274e08e76dc4e0d23a9035e87400d

See more details on using hashes here.

File details

Details for the file qwix-0.1.3-py3-none-any.whl.

File metadata

  • Download URL: qwix-0.1.3-py3-none-any.whl
  • Upload date:
  • Size: 91.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.7

File hashes

Hashes for qwix-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 c7e9ea278f33d24720044f043a6b9d014812025def40a608b4cea630bbe61f48
MD5 c2335494824995bbf1a70b33b09e4e26
BLAKE2b-256 3a787271e5be3222c713b6fa93b7bf323e84529f6856ed5b6b22e6a1069500f6

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page