Skip to main content

AQT: Accurate Quantized Training

Project description

AQT : Accurate Quantized Training

AQT is a quantization library designed to allow utilization of low-bit and high-performance numerics of contemporary ML hardware accelerators. AQT supports both research and production[^research-vs-prod], but focuses on the latter.

[^research-vs-prod]: The support for research is exemplified by having a state of the art quantization quality on standard models such as ResNet and Transformer. The production aspect is defined as high performance and robust out-of-the-box working results with good defaults.

Vision

We believe that in the long term, AQT library is the best way to deliver quantization capabilities to the model owners. Without a solution like AQT shared across Google and perhaps across whole TF and Jax ecosystems, every team will develop their own quantization solution. This is unsatisfactory because quantization is a difficult topic and 'the devil is in the details and papers'. AQT codifies this knowledge into an implementation hidden behind a config and best practices described in a documentation.

For a wide application, a quantization solution needs to be easy to use. That's why the central goal of AQTp is robustness to model changes and high-performance implementation (as opposed to achieving paper state-of-the-art quantization results on a particular model).

Supported machine learning frameworks

AQTp supports two frameworks:

AQT legacy

There is also AQT legacy (https://github.com/google/aqt/tree/main/aqt/jax_legacy) directory. AQT legacy was first prototype of AQT focused only on researching quantization quality, robustness and ease-of-use without any regard for performance. AQT legacy is intended to be obsoleted by AQTp Jax, but about 10-20% of the AQT legacy features are not ported yet. The differences between AQTp and AQT legacy are described below.

Usage

Low-level API

On typical ML accelerators, only the tensor operations enjoy quantization acceleration (as opposed to pointwise operation such as ReLu, normalization, residual addition, etc):

  • matmul (dense layer),
  • all kinds of convolutions,
  • and einsum.

At a low-level, AQTp is Python-level drop-in replacement for these tensor operations. AQTp variants of these ops, apart from the standard arguments, also take AQTp config as an argument. For instance JAX matmul operation:

def dot(
    lhs: jnp.ndarray,
    rhs: jnp.ndarray) -> jnp.ndarray: ...

has an AQTp counterpart:

def dot(
    lhs: jnp.ndarray,
    rhs: jnp.ndarray,
    lhs_quantizer: aqt_tensor.TensorQuantizer,
    rhs_quantizer: aqt_tensor.TensorQuantizer) -> jnp.ndarray:

where TensorQuantizers can be created with a call to a following constructor:

TensorQuantizer(
  data_shape: List[Optional[int]],
  config: aqt_config.AqtScheduleConfig,
)

AQTp config specifies what hardware numerics, what quantization kind and which calibration algorithms should be used. Internally, AQTp replaces say float tf.matmul-typed with int8-typed tf.matmul surrounded by appropriate conversion ops.

Framework-level API

The strategy for internal Google customers is to integrate AQT into training frameworks (not individual models) that production is using. For example: Smartass for Ads, Babelfish for Translate and ASR (Assistant), ImageFactory for VSS, Pax and P5X for Large Language Models, Flax in general. Some of these integrations are finished e.g. in Smartass one can just change config proto and AQT is enabled, and some are not started, e.g. ImageFactory. When framework integration is done, we (AQT owners) support the model owners with our expertise. We also build a growing 'best-practices' documentation on how to get the most value out of AQTp.

AQTp workflow

The default workflow for retrainable models is as follows:

  1. Modify the model with AQTp on Python level, either use directly low-level ops or framework-level configuration.
  2. Train with AQTp enabled.
  3. Benchmark model with xprof and observe speedups.

With that, the quantization is available throughout all of the primary training and there is no additional 'quantization phase' such as fine-tuning or PTQ is needed. This is a simplification that we welcome. Importantly, thanks to QAT, there is no quantization-induced training-serving bias.

However, for larger models such as Large Language Models (LLMs) (e.g Palm, Primer, ULM), AQT supports enabling quantization only in the (already pre-existing) fine-tuning phase. This allows to delay the quantization decisions after the expensive main training. It also enables quantization of already trained models.

In both cases, using AQTp with int8 is not expected to increase training length, which could add additional training cost. However, when using strong quantization such as int4, it is observed that additional training is indeed beneficial for model quality. This is a good tradeoff in both of these cases:

  • We can afford to train longer in order for the final model to serve (run inference) faster.
  • When using quantized training (acceleration also of the backpropagation, not QAT), the wall-clock time would be still better thanks to int4 acceleration.

Shared config system {#config}

AQTp TF and Jax implementations share the config system (https://github.com/google/aqt/tree/main/aqt/common).

Example quantized models

Expected speedups

  • HBM bound / PCIe bound
  • Smartass measurement
  • Babelfish measurement

Focus on quantization aware training

AQT does focus on Quantization Aware Training (QAT, anagram accidental). We chose this is because the quality of QAT compared to Post-Training Quantization (PTQ) is much better. QAT buys us several things:

  • With QAT, int8 works for almost all models out of the box, PTQ works sometimes.
  • With QAT, int4 works for most layers on most models, PTQ works very rarely for weights only.
  • With QAT there is much less quantization-tuning (tuning the model so that it achieves desired quality when quantized) needed. This is especially important in Google, where TPUv1 (SeaStar) induced quantization fatigue (frustration) across many PAs.

The cost of doing QAT as opposed to PTQ:

  • Quantization process needs to be done before the training.
  • Model owners have to be already aware of the quantization aspect.

We believe that with AQTp, the costs are not large, especially using framework-level APIs.

Best practices

Install

Publishing pip package.

  • Make sure newest commits are in.
  • Update version number in setup.py.
  • Build package:
python3 -m pip install --upgrade build
python3 -m build
  • Upload the package.
python3 -m pip install --upgrade twine
python3 -m twine upload dist/*

Supported numerical formats {#snf}

AQT supports research-level experiments with int8, int4, fp8, binary and many other formats. However as of today (Aug 2022) only int8 and only on TPU is practical in production setting.

AQTp and AQT legacy differences {#aqt_legacy}

AQT legacy is intended to be obsoleted by AQTp Jax, but about 20% of the AQT legacy features are not yet implemented in the main AQTp. Currently (July 2022) the following differences exist between AQT legacy and AQTp Jax:

  • Structured sparsity research is still being done in AQT legacy.
  • Per-input channel quantization is implemented only in AQT but not AQTp. However, per-example-in-batch and per-output-channel quantizations are implemented in both.
  • AQT legacy has an implemented support for FP8 and other FP formats, AQTp not yet.
  • AQT legacy has a high quality Transformer example while AQTp has not.
  • AQT legacy has a support for accelerated unsigned type uint8, while AQTp has a hardware support only signed int8 and unsigned types can be only emulated.
  • AQT legacy ResNet is higher quality than AQTp Resnet due to the missing features listed above.
  • We expect AQTp to catch up in all the differences except sparsity by early Q4 2022.

AQTp has features that AQT does not:

  • High performance on actual hardware
  • Maintainability: significantly simpler code.
  • Both TF and Jax implementations and shared config system and tests.
  • Design compatible with backprop quantization.
  • Per-example quantization.
  • others (todo lew@)

The path not taken: graph rewriter

We have found that python-level model code modifications are not as expensive in terms of engineering time as one could expect. In productionisation efforts, the biggest cost is in quantization-debugging. Having a flexible (fine-grained) API facilitates robust, selective application of quantization. Semi-automated rewriting is still possible with python level monkey-patching, but so far this has not been important enough to be implemented.

There are other complications with rewriter approach:

  • How to deliver per-layer quantization configuration?
  • How to inject state (calibration statistics) into the graph in presence of control structures?
  • How to save said state into a checkpoint.

FAQ

This section contains questions and answers that might be in future converted to its own sections above.

Q: It looks like AQTp always computes stats during training. I think stats have to be disabled for a case when I use dynamic quantization(it will save memory): it makes no sense to keep stats in the case when it is not used.

Yes, it will not only save some (but very little) training memory, but also save stats updating computation (important in Ads and LLM cases, but not others). There was an idea to disable the stats collection when using dynamic quantization and also after the calibration. Stats will not be collected in the inference graph (train=false) anyway.

Q: AQTp supports per example quantization: it computes stats per example in batch, this approach will not work during inference as soon as inference and training use different batch sizes.

Per-example quantization makes sense only in the condition where we have ema_update_count=1, i.e. stats are always updated based on 1 example. This results in dynamic quantization where per-example is the only thing that makes sense. Current config of AQT is flexible and allows cases that don't make sense in NNs, because at its core AQT is quantizing arbitrary matmuls, and is not specialized for NNs. In particular we have the same configuration options on LHS and RHS of matmul.

How AQT allows us to configure channel-wise quantization?

In AQT we have a lot of flexibility in channel-wise configuration. We inform TensorQuantizer through StatsConfig which axes will share scales (and statistics).

For example if we have a matmul sized: [B, I] * [I, F] => [B,F] where B - Batch, I - Input channels, F - output channels. Note that the axis I disappears, but axes B and F do not. It means we have to share scales for axis I, but we can choose for B and F.

In that case we can have separate calibration scales for every element of the B axis and separate scales for every element of the F axis. We cannot have separate scales for every element of I (input channels) axis because matmul summation (dot products) is done along that axis and when you add elements, they need to share a scale so that this holds:

$$((s\cdot x_1\cdot w_1) + \ldots + (s\cdot x_I) \cdot w_I) \cdot \frac{1}{s} = (x_1 \cdot w_1 + \ldots + \cdot x_I \cdot w_I)$$

$$s$$ - is a scale of activations $$x_i$$ and we can divide by it after the matmul only if it is shared.

If we used different $$s$$ for every channel in the sum, we could not "remove it". For contrast, we can have a separate scale for every output channel, because if we multiply weights in channel $$f$$ by $$s_f$$, this axis does not disappear in the output shape [B,F] and we can divide each column $$f$$ by $$s_f$$. This is exactly what AQT does. In general, we can have separate scales for the axis whenever we do not add together different elements along that axis. For instance if we have 3x1 Conv2D (3 rows, 1 column, vertical convolution), in principle we could have separate scales for every column of pixels, every output channel and every example in a batch, but we need to share scales for input channels and different rows.

Q: What happens if we don't put any axes in share_stats_axes?

Take for 1x1 convolution, kernel.shape = [1,1,i,o]. If we don't put any share_stats_axes, it would ask to quantize each element separately in kernel. Is it not mathematically technically not feasible. AQT will return an error as the contraction dimension must share scales.

Does AQT support static and dynamic quantization? Static quantization = scales are fixed during the inference. Dynamic quantization = scales are recomputed during the inference.

Dynamic quantization is more expensive during the inference. For weights though, they both could be (with a smart compiler) more or less the same, because dynamic recomputing scales for weights will always return the same scales.

TODO(lew): Check inference graph - can statistics/scales be updated during inference??? I think Not.

On default AQT will provide static quantization graph for inference. The scales though will be found through statistic gathering either through a fixed period of training (scale freezing option in config), or throughout whole training.

In order to use dynamic quantization you need to not freeze scales and set up EMA parameter to 1 batch only (stats and scales are calculated only from last batch).

Static quantization is better than dynamic for two reasons:

  • We observed consistently better quality with static quantization
  • Static is cheaper at inference time.
  • Static can be cheaper at training time (not implemented yet), because when you freeze scale you don't have to gather stats anymore.

Q: How many batches do you need to observe?

Depends on the batch size and whether you share scale along batch dimension. We do recommend sharing the scale along Batch axis, because all examples have the same distribution. We think that few 1000s example (one to few batches with a large batch) are sufficient for good stats.

Q: What is the overhead of dynamic quantization?

In some cases dynamic quantization overhead might be small. This is in case the number of channels are large. However, the operation of reduction along the input channels is not cheap. In some cases dynamic quantization might yield better quality, especially when it is hard to find right calibration coefficients in presence of outliers with Post-Training quantization (not supported by AQT). With Quantization-Aware training, clipping outliers is not harmful as the model can adapt during the training.

Citing AQT

Please use a following bibtex entry:

@software{aqt2022github,
  author = {Lew, Lukasz and Feinberg, Vlad and Agrawal, Shivani and Lee, Jihwan and Malmaud, Jonathan and Wang, Lisa and  Dormiani, Pouya and Pope, Reiner },
  title = {AQT: Accurate Quantized Training)},
  url = {http://github.com/google/aqt},
  year = {2022},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

aqtp-0.1.1.tar.gz (289.9 kB view hashes)

Uploaded Source

Built Distribution

aqtp-0.1.1-py3-none-any.whl (405.5 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page