Skip to main content

Package for applying ao techniques to GPU models

Project description

torchao: PyTorch Architecture Optimization

Introduction | Inference | Training | Composability | Custom Kernels | Alpha Features | Installation | Integrations | Videos | License

Introduction

torchao: PyTorch library for custom data types & optimizations. Quantize and sparsify weights, gradients, optimizers & activations for inference and training.

From the team that brought you the fast series

  • 9.5x speedups for Image segmentation models with sam-fast
  • 10x speedups for Language models with gpt-fast
  • 3x speedup for Diffusion models with sd-fast

torchao just works with torch.compile() and FSDP2 over most PyTorch models on Huggingface out of the box.

Inference

Post Training Quantization

Quantizing and Sparsifying your models is a 1 liner that should work on any model with an nn.Linear including your favorite HuggingFace model. You can find a more comprehensive usage instructions here, sparsity here and a HuggingFace inference example here

For inference, we have the option of

  1. Quantize only the weights: works best for memory bound models
  2. Quantize the weights and activations: works best for compute bound models
  3. Quantize the activations and weights and sparsify the weight
from torchao.quantization.quant_api import (
    quantize_,
    int8_dynamic_activation_int8_weight,
    int4_weight_only,
    int8_weight_only
)
quantize_(m, int4_weight_only())

For gpt-fast int4_weight_only() is the best option at bs=1 as it 2x the tok/s and reduces the VRAM requirements by about 65% over a torch.compiled baseline.

If you don't have enough VRAM to quantize your entire model on GPU and you find CPU quantization to be too slow then you can use the device argument like so quantize_(model, int8_weight_only(), device="cuda") which will send and quantize each layer individually to your GPU.

If you see slowdowns with any of these techniques or you're unsure which option to use, consider using autoquant which will automatically profile layers and pick the best way to quantize each layer.

model = torchao.autoquant(torch.compile(model, mode='max-autotune'))

We also provide a developer facing API so you can implement your own quantization algorithms so please use the excellent HQQ algorithm as a motivating example.

KV Cache Quantization

We've added kv cache quantization and other features in order to enable long context length (and necessarily memory efficient) inference.

In practice these features alongside int4 weight only quantization allow us to reduce peak memory by ~55%, meaning we can Llama3.1-8B inference with a 130k context length with only 18.9 GB of peak memory. More details can be found here

Quantization Aware Training

Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization Aware Training (QAT) to overcome this limitation. In collaboration with Torchtune, we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering 96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext for Llama3 compared to post-training quantization (PTQ). And we've provided a full recipe here

from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer

qat_quantizer = Int8DynActInt4WeightQATQuantizer()

# Insert "fake quantize" operations into linear layers.
# These operations simulate quantization numerics
model = qat_quantizer.prepare(model)

# Run Training...

# Convert fake quantize to actual quantize operations
model = qat_quantizer.convert(model)

Training

Float8

torchao.float8 implements training recipes with the scaled float8 dtypes, as laid out in https://arxiv.org/abs/2209.05433.

With torch.compile on, current results show throughput speedups of up to 1.5x on 128 H100 GPU LLaMa 3 70B pretraining jobs (details)

from torchao.float8 import convert_to_float8_training
convert_to_float8_training(m, module_filter_fn=...)

And for an end-to-minimal training recipe of pretraining with float8, you can check out torchtitan

Sparse Training

We've added support for semi-structured 2:4 sparsity with 6% end-to-end speedups on ViT-L. Full blog here

The code change is a 1 liner with the full example available here

swap_linear_with_semi_sparse_linear(model, {"seq.0": SemiSparseLinear})

Memory-efficient optimizers

ADAM takes 2x as much memory as the model params so we can quantize the optimizer state to either 8 or 4 bit effectively reducing the optimizer VRAM requirements by 2x or 4x respectively over an fp16 baseline

from torchao.prototype.low_bit_optim import AdamW8bit, AdamW4bit, AdamWFp8
optim = AdamW8bit(model.parameters()) # replace with Adam4bit and AdamFp8 for the 4 / fp8 versions

In practice, we are a tiny bit slower than expertly written kernels but the implementations for these optimizers were written in a few hundred lines of PyTorch code and compiled so please use them or copy-paste them for your quantized optimizers. Benchmarks here

We also have support for single GPU CPU offloading where both the gradients (same size as weights) and the optimizers will be efficiently sent to the CPU. This alone can reduce your VRAM requirements by 60%

optim = CPUOffloadOptimizer(model.parameters(), torch.optim.AdamW, fused=True)
optim.load_state_dict(ckpt["optim"])

Composability

  1. torch.compile: A key design principle for us is composability as in any new dtype or layout we provide needs to work with our compiler. It shouldn't matter if the kernels are written in pure PyTorch, CUDA, C++, or Triton - things should just work! So we write the dtype, layout, or bit packing logic in pure PyTorch and code-generate efficient kernels.
  2. FSDP2: Historically most quantization has been done for inference, there is now a thriving area of research combining distributed algorithms and quantization.

The best example we have combining the composability of lower bit dtype with compile and fsdp is NF4 which we used to implement the QLoRA algorithm. So if you're doing research at the intersection of this area we'd love to hear from you.

Custom Kernels

We've added support for authoring and releasing custom ops that do not graph break with torch.compile() so if you love writing kernels but hate packaging them so they work all operating systems and cuda versions, we'd love to accept contributions for your custom ops. We have a few examples you can follow

  1. fp6 for 2x faster inference over fp16 with an easy to use API quantize_(model, fpx_weight_only(3, 2))
  2. 2:4 Sparse Marlin GEMM 2x speedups for FP16xINT4 kernels even at batch sizes up to 256
  3. int4 tinygemm unpacker which makes it easier to switch quantized backends for inference

If you believe there's other CUDA kernels we should be taking a closer look at please leave a comment on this issue

Alpha features

Things we're excited about but need more time to cook in the oven

  1. MX training and inference support with tensors using the OCP MX spec data types, which can be described as groupwise scaled float8/float6/float4/int8, with the scales being constrained to powers of two. This work is prototype as the hardware support is not available yet.
  2. Int8 Quantized Training: We're trying out full int8 training. This is easy to use with quantize_(model, int8_weight_only_quantized_training()). This work is prototype as the memory benchmarks are not compelling yet.
  3. IntX: We've managed to support all the ints by doing some clever bitpacking in pure PyTorch and then compiling it. This work is prototype as unfortunately without some more investment in either the compiler or low-bit kernels, int4 is more compelling than any smaller dtype
  4. Bitnet: Mostly this is very cool to people on the team. This is prototype because how useful these kernels are is highly dependent on better hardware and kernel support.

Installation

torchao makes liberal use of several new features in Pytorch, it's recommended to use it with the current nightly or latest stable version of PyTorch.

Stable release from Pypi which will default to CUDA 12.1

pip install torchao

Stable Release from the PyTorch index

pip install torchao --extra-index-url https://download.pytorch.org/whl/cu121 # full options are cpu/cu118/cu121/cu124

Nightly Release

pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu121 # full options are cpu/cu118/cu121/cu124

For most developers you probably want to skip building custom C++/CUDA extensions for faster iteration

USE_CPP=0 pip install -e .

OSS Integrations

We're also fortunate to be integrated into some of the leading open-source libraries including

  1. Hugging Face transformers with a builtin inference backend and low bit optimizers
  2. Hugging Face diffusers best practices with torch.compile and torchao in a standalone repo diffusers-torchao
  3. Mobius HQQ backend leveraged our int4 kernels to get 195 tok/s on a 4090
  4. TorchTune for our QLoRA and QAT recipes
  5. torchchat for post training quantization
  6. SGLang for LLM inference quantization

Videos

License

torchao is released under the BSD 3 license.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

torchao-0.6.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.2 MB view details)

Uploaded CPython 3.12 manylinux: glibc 2.17+ x86-64

torchao-0.6.1-cp312-cp312-macosx_11_0_arm64.whl (813.9 kB view details)

Uploaded CPython 3.12 macOS 11.0+ ARM64

torchao-0.6.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.2 MB view details)

Uploaded CPython 3.11 manylinux: glibc 2.17+ x86-64

torchao-0.6.1-cp311-cp311-macosx_11_0_arm64.whl (815.1 kB view details)

Uploaded CPython 3.11 macOS 11.0+ ARM64

torchao-0.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.2 MB view details)

Uploaded CPython 3.10 manylinux: glibc 2.17+ x86-64

torchao-0.6.1-cp310-cp310-macosx_11_0_arm64.whl (813.8 kB view details)

Uploaded CPython 3.10 macOS 11.0+ ARM64

torchao-0.6.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.2 MB view details)

Uploaded CPython 3.9 manylinux: glibc 2.17+ x86-64

torchao-0.6.1-cp39-cp39-macosx_11_0_arm64.whl (813.8 kB view details)

Uploaded CPython 3.9 macOS 11.0+ ARM64

File details

Details for the file torchao-0.6.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for torchao-0.6.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 c9522762a7da4cd1fc92b3fcc4c1283b0443c287829077ee7854bbeca00f38b8
MD5 6fc6526edab6c3b58aa20f9f2b32ef70
BLAKE2b-256 ecb132ed4c6c1a82ea734f7b46d940a2cd9af9f7425ca0cafdcfaf69108655ab

See more details on using hashes here.

File details

Details for the file torchao-0.6.1-cp312-cp312-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for torchao-0.6.1-cp312-cp312-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 11382b798686eae4260278187ea882ba03a977376a4eee16fca3e73a0f7654ba
MD5 e419a13394b5627f37c26cf2b1089d69
BLAKE2b-256 701194a04984f727e834ccd09d80153f749effc37c8e6ecfa17bbdf62a46c4c8

See more details on using hashes here.

File details

Details for the file torchao-0.6.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for torchao-0.6.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 ffe91fa7e4a5a9dda226f25d54519f28c0b56344801a01df8859cbc91821650f
MD5 ac337539b9e455dc8c08134e2f7e53cf
BLAKE2b-256 90c219661fed3ea15e6a886b7175fdfc517e0b1be69122ca3337b8ac588358c4

See more details on using hashes here.

File details

Details for the file torchao-0.6.1-cp311-cp311-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for torchao-0.6.1-cp311-cp311-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 ec420a0853314044f1895b0f513c37001ea2f685f5db5b4b375a4aa56342bd7e
MD5 3cf91986082be4a79d604f49fb235468
BLAKE2b-256 5d80dfbd92550528b39c33920ec4126d6d91676295970e9512562a9949539de6

See more details on using hashes here.

File details

Details for the file torchao-0.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for torchao-0.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 188fb6fb63595fe90f7b96aff93536e8e0bd131c8cfa39258da790d52affd805
MD5 c7f0cbb2c2a9248cb6d1c442d9d01625
BLAKE2b-256 4ada1b38da2c13dd33fa32656e39f4d2bbd1d34585964c6b2da54771cd1b670c

See more details on using hashes here.

File details

Details for the file torchao-0.6.1-cp310-cp310-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for torchao-0.6.1-cp310-cp310-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 51823f7c993e53fbf36a0e365cfd8ab2493a1d419783b663ef0954d763cea4a3
MD5 51ff881d9a740ffd885629a18a319d93
BLAKE2b-256 9ab9d98bc56b8f428dab47099edcb0efe2c97ade265ddf7ce3b22192c9f0cf3b

See more details on using hashes here.

File details

Details for the file torchao-0.6.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for torchao-0.6.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 8dfef48b95a790127cd467e4ea823a1c85702ff197f405e59bb860c4b85324c7
MD5 d8451e8ecbacd3ee27d5de5d02db1bc9
BLAKE2b-256 250487a2572ba573e9e9e66b8546f97adb353bec6b7eef15c04bed6b3d1f8ad5

See more details on using hashes here.

File details

Details for the file torchao-0.6.1-cp39-cp39-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for torchao-0.6.1-cp39-cp39-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 b5125488512c600de629bd6c81171a85fbd834bab979c4286da5e3c23363939b
MD5 edf1a3cbd68c9bbd354643e6d4c9936e
BLAKE2b-256 204d8a31a7e5194b1a34f0aa511b4440de9714f5a3b368775de1e38f61d736d4

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page