Skip to main content

Package for applying ao techniques to GPU models

Project description

TorchAO

PyTorch-Native Training-to-Serving Model Optimization

  • Pre-train Llama-3.1-70B 1.5x faster with float8 training
  • Recover 67% of quantized accuracy degradation on Gemma3-4B with QAT
  • Quantize Llama-3-8B to int4 for 1.89x faster inference with 58% less memory

📣 Latest News

Older news

🌅 Overview

TorchAO is an easy to use quantization library for native PyTorch. TorchAO works out-of-the-box with torch.compile() and FSDP2 across most HuggingFace PyTorch models.

For a detailed overview of stable and prototype workflows for different hardware and dtypes, see the Workflows documentation.

Check out our docs for more details!

🚀 Quick Start

First, install TorchAO. We recommend installing the latest stable version:

pip install torchao

Quantize your model weights to int4!

import torch
from torchao.quantization import Int4WeightOnlyConfig, quantize_
if torch.cuda.is_available():
  # quantize on CUDA
  quantize_(model, Int4WeightOnlyConfig(group_size=32, int4_packing_format="tile_packed_to_4d", int4_choose_qparams_algorithm="hqq"))
elif torch.xpu.is_available():
  # quantize on XPU
  quantize_(model, Int4WeightOnlyConfig(group_size=32, int4_packing_format="plain_int32"))

See our quick start guide for more details.

🛠 Installation

To install the latest stable version:

pip install torchao
Other installation options
# Nightly
pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu128

# Different CUDA versions
pip install torchao --index-url https://download.pytorch.org/whl/cu126  # CUDA 12.6
pip install torchao --index-url https://download.pytorch.org/whl/cu129  # CUDA 12.9
pip install torchao --index-url https://download.pytorch.org/whl/xpu    # XPU
pip install torchao --index-url https://download.pytorch.org/whl/cpu    # CPU only

# For developers
# Note: the `--no-build-isolation` flag is required.
USE_CUDA=1 pip install -e . --no-build-isolation
USE_XPU=1 pip install -e . --no-build-isolation
USE_CPP=0 pip install -e . --no-build-isolation

Please see the torchao compability table for version requirements for dependencies.

🔎 Inference

TorchAO delivers substantial performance gains with minimal code changes:

Following is our recommended flow for quantization and deployment:

from transformers import TorchAoConfig, AutoModelForCausalLM
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, PerRow

# Create quantization configuration
quantization_config = TorchAoConfig(quant_type=Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()))

# Load and automatically quantize
quantized_model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen3-32B",
    dtype="auto",
    device_map="auto",
    quantization_config=quantization_config
)

Alternative quantization API to use when the above doesn't work is quantize_ API in quick start guide.

Serving with vllm on 1xH100 machine:

# Server
VLLM_DISABLE_COMPILE_CACHE=1 vllm serve pytorch/Qwen3-32B-FP8 --tokenizer Qwen/Qwen3-32B -O3
# Client
curl http://localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{
  "model": "pytorch/Qwen3-32B-FP8",
  "messages": [
    {"role": "user", "content": "Give me a short introduction to large language models."}
  ],
  "temperature": 0.6,
  "top_p": 0.95,
  "top_k": 20,
  "max_tokens": 32768
}'

For diffusion models, you can quantize using Hugging Face diffusers

import torch
from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConfig
from torchao.quantization import Int8WeightOnlyConfig

pipeline_quant_config = PipelineQuantizationConfig(
    quant_mapping={"transformer": TorchAoConfig(Int8WeightOnlyConfig(group_size=128))}
)
pipeline = DiffusionPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    quantization_config=pipeline_quant_config,
    torch_dtype=torch.bfloat16,
    device_map="cuda"
)

We also support deployment to edge devices through ExecuTorch, for more detail, see quantization and serving guide. We also release pre-quantized models here.

🚅 Training

Quantization-Aware Training

Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization-Aware Training (QAT) to overcome this limitation, especially for lower bit-width dtypes such as int4. In collaboration with TorchTune, we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering 96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext for Llama3 compared to post-training quantization (PTQ). For more details, please refer to the QAT README and the original blog:

import torch
from torchao.quantization import quantize_, Int8DynamicActivationIntxWeightConfig, PerGroup
from torchao.quantization.qat import QATConfig

# prepare
base_config = Int8DynamicActivationIntxWeightConfig(
    weight_dtype=torch.int4,
    weight_granularity=PerGroup(32),
)
quantize_(my_model, QATConfig(base_config, step="prepare"))

# train model (not shown)

# convert
quantize_(my_model, QATConfig(base_config, step="convert"))

Users can also combine LoRA + QAT to speed up training by 1.89x compared to vanilla QAT using this fine-tuning recipe.

Quantized training

torchao.float8 implements training recipes with the scaled float8 dtypes, as laid out in https://arxiv.org/abs/2209.05433. With torch.compile on, current results show throughput speedups of up to 1.5x on up to 512 GPU / 405B parameter count scale (details):

from torchao.float8 import convert_to_float8_training
convert_to_float8_training(m)

Our float8 training is integrated into TorchTitan's pre-training flows so users can easily try it out. For more details, check out these blog posts about our float8 training support:

Other features (sparse training, memory efficient optimizers)

Sparse Training

We've added support for semi-structured 2:4 sparsity with 6% end-to-end speedups on ViT-L. Full blog here. The code change is a 1 liner with the full example available here:

from torchao.sparsity.training import SemiSparseLinear, swap_linear_with_semi_sparse_linear
swap_linear_with_semi_sparse_linear(model, {"seq.0": SemiSparseLinear})

Memory-efficient optimizers

Optimizers like ADAM can consume substantial GPU memory - 2x as much as the model parameters themselves. TorchAO provides two approaches to reduce this overhead:

1. Quantized optimizers: Reduce optimizer state memory by 2-4x by quantizing to lower precision

from torchao.optim import AdamW8bit, AdamW4bit, AdamWFp8
optim = AdamW8bit(model.parameters()) # replace with Adam4bit and AdamFp8 for the 4 / fp8 versions

Our quantized optimizers are implemented in just a few hundred lines of PyTorch code and compiled for efficiency. While slightly slower than specialized kernels, they offer an excellent balance of memory savings and performance. See detailed benchmarks here.

2. CPU offloading: Move optimizer state and gradients to CPU memory

For maximum memory savings, we support single GPU CPU offloading that efficiently moves both gradients and optimizer state to CPU memory. This approach can reduce your VRAM requirements by 60% with minimal impact on training speed:

optim = CPUOffloadOptimizer(model.parameters(), torch.optim.AdamW, fused=True)
optim.load_state_dict(ckpt["optim"])

🔗 Integrations

TorchAO is integrated into some of the leading open-source libraries including:

🎥 Videos

💬 Citation

If you find the torchao library useful, please cite it in your work as below.

@software{torchao,
  title={TorchAO: PyTorch-Native Training-to-Serving Model Optimization},
  author={torchao},
  url={https://github.com/pytorch/ao},
  license={BSD-3-Clause},
  month={oct},
  year={2024}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

torchao-0.16.0-py3-none-any.whl (1.2 MB view details)

Uploaded Python 3

torchao-0.16.0-cp310-abi3-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (3.2 MB view details)

Uploaded CPython 3.10+manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

File details

Details for the file torchao-0.16.0-py3-none-any.whl.

File metadata

  • Download URL: torchao-0.16.0-py3-none-any.whl
  • Upload date:
  • Size: 1.2 MB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for torchao-0.16.0-py3-none-any.whl
Algorithm Hash digest
SHA256 d0a8d773351fd17b95fee81dfbcbf98577b567dcdbec47d221b0ee258432101d
MD5 1eb2ab9f9b344e7986e8542c0dcc220b
BLAKE2b-256 d03d0c5a5833a135a045510e06c06b3d4cf316b06d59415bc21e0b021a000cc8

See more details on using hashes here.

Provenance

The following attestation bundles were made for torchao-0.16.0-py3-none-any.whl:

Publisher: release-pypi.yml on pytorch/test-infra

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file torchao-0.16.0-cp310-abi3-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for torchao-0.16.0-cp310-abi3-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 2d6293a0c57c9dd505efb025a7189459d154965fbed000efd638cf299f9362dd
MD5 56c0efb2f14f0efa17da6f440d1ea4bc
BLAKE2b-256 8d7f0acda8a429ac9cfabd142d30af624d7958bf828c438be5a54ca87bbe16d7

See more details on using hashes here.

Provenance

The following attestation bundles were made for torchao-0.16.0-cp310-abi3-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl:

Publisher: release-pypi.yml on pytorch/test-infra

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page