Skip to main content

Package for applying ao techniques to GPU models

Project description

torchao: PyTorch Architecture Optimization

Note: This repository is currently under heavy development - if you have suggestions on the API or use-cases you'd like to be covered, please open an github issue

The torchao package allows you to quantize and prune your models using native PyTorch.

The repo hosts both

  1. lower precision dtypes such as nf4, uint4
  2. Quantization algorithms such as dynamic quant, smoothquant
  3. Sparsity algorithms such as Wanda

Success stories

Our kernels have has been used to achieve SOTA inference performance on

  1. Image segmentation modelss with sam-fast
  2. Language models with gpt-fast
  3. Diffusion models with sd-fast

Installation

Note: this library makes liberal use of several new features in pytorch, its recommended to use it with the current pytorch nightly if you want full feature coverage. If not, the subclass APIs may not work, though the module swap api's will still work.

  1. From PyPI:
pip install torchao
  1. From Source:
git clone https://github.com/pytorch-labs/ao
cd ao
pip install -e .

Examples

Typically quantization algorithms will have different schemes for how the activation and weights are quantized so A16W8 for instance means the activations are quantized to 16 bits wheras the weights are quantized to 8 bits. Trying out different quantization schemes in torchao is generally a 1 line change.

A8W8 Dynamic Quantization

import torch
from torchao.quantization import quant_api

# Fuse the int8*int8 -> int32 matmul and subsequent mul op avoiding materialization of the int32 intermediary tensor
torch._inductor.config.force_fuse_int_mm_with_mul = True

# Plug in your model and example input
model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16)
input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda')

# convert linear modules to quantized linear modules
quant_api.change_linear_weights_to_int8_dqtensors(model)

# compile the model to improve performance
model = torch.compile(model, mode='max-autotune')
model(input)

A16W8 WeightOnly Quantization

quant_api.change_linear_weights_to_int8_woqtensors(model)

This technique works best when the torch._inductor.config.use_mixed_mm option is enabled. This avoids dequantizing the weight tensor before the matmul, instead fusing the dequantization into the matmul, thereby avoiding materialization of a large floating point weight tensor.

A16W4 WeightOnly Quantization

quant_api.change_linear_weights_to_int4_woqtensors(model)

Note: The quantization error incurred by applying int4 quantization to your model can be fairly significant, so using external techniques like GPTQ may be necessary to obtain a usable model.

A8W8 Dynamic Quantization with Smoothquant

We've also implemented a version of smoothquant with the same GEMM format as above. Due to requiring calibration, the API is more complicated.

Example

import torch
from torchao.quantization.smoothquant import swap_linear_with_smooth_fq_linear, smooth_fq_linear_to_inference

# Fuse the int8*int8 -> int32 matmul and subsequent mul op avoiding materialization of the int32 intermediary tensor
torch._inductor.config.force_fuse_int_mm_with_mul = True

# plug in your model
model = get_model()

# convert linear modules to smoothquant
# linear module in calibration mode
swap_linear_with_smooth_fq_linear(model)

# Create a data loader for calibration
calibration_data = get_calibration_data()
calibration_dataset = MyDataset(calibration_data)
calibration_loader = DataLoader(calibration_dataset, batch_size=32, shuffle=True)

# Calibrate the model
model.train()
for batch in calibration_loader:
    inputs = batch
    model(inputs)

# set it to inference mode
smooth_fq_linear_to_inference(model)

# compile the model to improve performance
model = torch.compile(model, mode='max-autotune')
model(input)

Sharp edges

  1. While these techniques are designed to improve model performance, in some cases the opposite can occur. This is because quantization adds additional overhead to the model that is hopefully made up for by faster matmuls (dynamic quantization) or loading weights faster (weight-only quantization). If your matmuls are small enough or your non-quantized perf isn't bottlenecked by weight load time, these techniques may reduce performance.
  2. Use the PyTorch nightlies so you can leverage tensor subclasses which is preferred over older module swap based methods because it doesn't modify the graph and is generally more composable and flexible.

License

torchao is released under the BSD 3 license.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchao-0.1.tar.gz (48.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchao-0.1-py3-none-any.whl (54.3 kB view details)

Uploaded Python 3

File details

Details for the file torchao-0.1.tar.gz.

File metadata

  • Download URL: torchao-0.1.tar.gz
  • Upload date:
  • Size: 48.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.9.0

File hashes

Hashes for torchao-0.1.tar.gz
Algorithm Hash digest
SHA256 323cd04e6bf8ad7b284f9de033a967c5e699806ad27a8e151e5594926b53b7bd
MD5 7433091bb83763e9682e84cebc4d9b3d
BLAKE2b-256 c403685d640b6c11f14f347184bf5f257f3a4b51cc27244653767b9b532e35de

See more details on using hashes here.

File details

Details for the file torchao-0.1-py3-none-any.whl.

File metadata

  • Download URL: torchao-0.1-py3-none-any.whl
  • Upload date:
  • Size: 54.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.9.0

File hashes

Hashes for torchao-0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 711b358c81af34d75e4839bd76150a7d4578603027c2fc5bf022ca57c91c99e2
MD5 21011b29682eed8854ab763e4307ed86
BLAKE2b-256 554bd0f904a6e514fe70bee587223cd66ba726083eef8511e150f5354adf3580

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page