torchao

PyTorch native quantization and sparsity for training and inference

https://github.com/pytorch/ao

Science Score: 64.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
    Links to: arxiv.org
  • Committers with academic emails
    2 of 114 committers (1.8%) from academic institutions
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (9.5%) to scientific vocabulary

Keywords

brrr cuda dtypes float8 inference llama mx offloading optimizer pytorch quantization sparsity training transformer

Keywords from Contributors

cryptocurrency cryptography jax distribution deep-neural-networks interpretability language-model dl mesh evaluation-framework
Last synced: 4 months ago · JSON representation ·

Repository

PyTorch native quantization and sparsity for training and inference

Basic Info
Statistics
  • Stars: 2,291
  • Watchers: 45
  • Forks: 322
  • Open Issues: 482
  • Releases: 13
Topics
brrr cuda dtypes float8 inference llama mx offloading optimizer pytorch quantization sparsity training transformer
Created about 2 years ago · Last pushed 4 months ago
Metadata Files
Readme Contributing License Code of conduct Citation Codeowners

README.md

# TorchAO

PyTorch-Native Training-to-Serving Model Optimization

  • Pre-train Llama-3.1-70B 1.5x faster with float8 training
  • Recover 77% of quantized perplexity degradation on Llama-3.2-3B with QAT
  • Quantize Llama-3-8B to int4 for 1.89x faster inference with 58% less memory
[![](https://img.shields.io/badge/CodeML_%40_ICML-2025-blue)](https://openreview.net/attachment?id=HpqH0JakHf&name=pdf) [![](https://dcbadge.vercel.app/api/server/gpumode?style=flat&label=TorchAO%20in%20GPU%20Mode)](https://discord.com/channels/1189498204333543425/1205223658021458100) [![](https://img.shields.io/github/contributors-anon/pytorch/ao?color=yellow&style=flat-square)](https://github.com/pytorch/ao/graphs/contributors) [![](https://img.shields.io/badge/torchao-documentation-blue?color=DE3412)](https://docs.pytorch.org/ao/stable/index.html) [![license](https://img.shields.io/badge/license-BSD_3--Clause-lightgrey.svg)](./LICENSE) [Latest News](#-latest-news) | [Overview](#-overview) | [Quick Start](#-quick-start) | [Installation](#-installation) | [Integrations](#-integrations) | [Inference](#-inference) | [Training](#-training) | [Videos](#-videos) | [Citation](#-citation)

Latest News

Older news - [Nov 24] We achieved [1.43-1.51x faster pre-training](https://pytorch.org/blog/training-using-float8-fsdp2/) on Llama-3.1-70B and 405B using float8 training - [Oct 24] TorchAO is added as a quantization backend to HF Transformers! - [Sep 24] We officially launched TorchAO. Check out our blog [here](https://pytorch.org/blog/pytorch-native-architecture-optimization/)! - [Jul 24] QAT [recovered up to 96% accuracy degradation](https://pytorch.org/blog/quantization-aware-training/) from quantization on Llama-3-8B - [Jun 24] Semi-structured 2:4 sparsity [achieved 1.1x inference speedup and 1.3x training speedup](https://pytorch.org/blog/accelerating-neural-network-training/) on the SAM and ViT models respectively - [Jun 24] Block sparsity [achieved 1.46x training speeedup](https://pytorch.org/blog/speeding-up-vits/) on the ViT model with <2% drop in accuracy

Overview

TorchAO is a PyTorch-native model optimization framework leveraging quantization and sparsity to provide an end-to-end, training-to-serving workflow for AI models. TorchAO works out-of-the-box with torch.compile() and FSDP2 across most HuggingFace PyTorch models. Key features include: * Float8 training and inference for speedups without compromising accuracy * MX training and inference, provides MX tensor formats based on native PyTorch MX dtypes (prototype) * Quantization-Aware Training (QAT) for mitigating quantization degradation * Post-Training Quantization (PTQ) for int4, int8, fp6 etc, with matching kernels targeting a variety of backends including CUDA, ARM CPU, and XNNPACK * Sparsity, includes different techniques such as 2:4 sparsity and block sparsity

Check out our docs for more details!

From the team that brought you the fast series: * 9.5x inference speedups for Image segmentation models with sam-fast * 10x inference speedups for Language models with gpt-fast * 3x inference speedup for Diffusion models with sd-fast (new: flux-fast) * 2.7x inference speedup for FAIRs Seamless M4T-v2 model with seamlessv2-fast

Quick Start

First, install TorchAO. We recommend installing the latest stable version: pip install torchao

Quantize your model weights to int4! from torchao.quantization import Int4WeightOnlyConfig, quantize_ quantize_(model, Int4WeightOnlyConfig(group_size=32)) Compared to a torch.compiled bf16 baseline, your quantized model should be significantly smaller and faster on a single A100 GPU: ``` int4 model size: 1.25 MB bfloat16 model size: 4.00 MB compression ratio: 3.2

bf16 mean time: 30.393 ms int4 mean time: 4.410 ms speedup: 6.9x ``` For the full model setup and benchmark details, check out our quick start guide. Alternatively, try quantizing your favorite model using our HuggingFace space!

Installation

To install the latest stable version: pip install torchao

Other installation options ``` # Nightly pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu126 # Different CUDA versions pip install torchao --index-url https://download.pytorch.org/whl/cu126 # CUDA 12.6 pip install torchao --index-url https://download.pytorch.org/whl/cpu # CPU only # For developers USE_CUDA=1 python setup.py develop USE_CPP=0 python setup.py develop ```

Please see the torchao compability table for version requirements for dependencies.

Integrations

TorchAO is integrated into some of the leading open-source libraries including:

Inference

TorchAO delivers substantial performance gains with minimal code changes:

Quantize any model with nn.Linear layers in just one line (Option 1), or load the quantized model directly from HuggingFace using our integration with HuggingFace transformers (Option 2):

Option 1: Direct TorchAO API

python from torchao.quantization.quant_api import quantize_, Int4WeightOnlyConfig quantize_(model, Int4WeightOnlyConfig(group_size=128, use_hqq=True))

Option 2: HuggingFace Integration

```python from transformers import TorchAoConfig, AutoModelForCausalLM from torchao.quantization.quant_api import Int4WeightOnlyConfig

Create quantization configuration

quantizationconfig = TorchAoConfig(quanttype=Int4WeightOnlyConfig(groupsize=128, usehqq=True))

Load and automatically quantize

quantizedmodel = AutoModelForCausalLM.frompretrained( "microsoft/Phi-4-mini-instruct", torchdtype="auto", devicemap="auto", quantizationconfig=quantizationconfig ) ```

Deploy quantized models in vLLM with one command:

shell vllm serve pytorch/Phi-4-mini-instruct-int4wo-hqq --tokenizer microsoft/Phi-4-mini-instruct -O3

With this quantization flow, we achieve 67% VRAM reduction and 12-20% speedup on A100 GPUs while maintaining model quality. For more detail, see this step-by-step quantization guide. We also release some pre-quantized models here.

Training

Quantization-Aware Training

Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization-Aware Training (QAT) to overcome this limitation, especially for lower bit-width dtypes such as int4. In collaboration with TorchTune, we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering 96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext for Llama3 compared to post-training quantization (PTQ). For more details, please refer to the QAT README and the original blog:

```python from torchao.quantization import quantize_, Int8DynamicActivationInt4WeightConfig from torchao.quantization.qat import QATConfig

prepare

baseconfig = Int8DynamicActivationInt4WeightConfig(groupsize=32) quantize(mymodel, QATConfig(base_config, step="prepare"))

train model (not shown)

convert

quantize(mymodel, QATConfig(base_config, step="convert")) ```

Users can also combine LoRA + QAT to speed up training by 1.89x compared to vanilla QAT using this fine-tuning recipe.

Float8

torchao.float8 implements training recipes with the scaled float8 dtypes, as laid out in https://arxiv.org/abs/2209.05433. With torch.compile on, current results show throughput speedups of up to 1.5x on up to 512 GPU / 405B parameter count scale (details):

python from torchao.float8 import convert_to_float8_training convert_to_float8_training(m)

Our float8 training is integrated into TorchTitan's pre-training flows so users can easily try it out. For more details, check out these blog posts about our float8 training support: * Accelerating Large Scale Training and Convergence with PyTorch Float8 Rowwise on Crusoe 2K H200s * Supercharging Training using float8 and FSDP2 * Efficient Pre-training of Llama 3-like model architectures using torchtitan on Amazon SageMaker * Float8 in PyTorch

Sparse Training

We've added support for semi-structured 2:4 sparsity with 6% end-to-end speedups on ViT-L. Full blog here. The code change is a 1 liner with the full example available here:

python from torchao.sparsity.training import SemiSparseLinear, swap_linear_with_semi_sparse_linear swap_linear_with_semi_sparse_linear(model, {"seq.0": SemiSparseLinear})

Memory-efficient optimizers

Optimizers like ADAM can consume substantial GPU memory - 2x as much as the model parameters themselves. TorchAO provides two approaches to reduce this overhead:

1. Quantized optimizers: Reduce optimizer state memory by 2-4x by quantizing to lower precision

python from torchao.optim import AdamW8bit, AdamW4bit, AdamWFp8 optim = AdamW8bit(model.parameters()) # replace with Adam4bit and AdamFp8 for the 4 / fp8 versions Our quantized optimizers are implemented in just a few hundred lines of PyTorch code and compiled for efficiency. While slightly slower than specialized kernels, they offer an excellent balance of memory savings and performance. See detailed benchmarks here.

2. CPU offloading: Move optimizer state and gradients to CPU memory

For maximum memory savings, we support single GPU CPU offloading that efficiently moves both gradients and optimizer state to CPU memory. This approach can reduce your VRAM requirements by 60% with minimal impact on training speed:

python optim = CPUOffloadOptimizer(model.parameters(), torch.optim.AdamW, fused=True) optim.load_state_dict(ckpt["optim"])

Videos

Citation

If you find the torchao library useful, please cite it in your work as below.

bibtex @software{torchao, title={TorchAO: PyTorch-Native Training-to-Serving Model Optimization}, author={torchao}, url={https://github.com/pytorch/ao}, license={BSD-3-Clause}, month={oct}, year={2024} }

Owner

  • Name: pytorch
  • Login: pytorch
  • Kind: organization
  • Location: where the eigens are valued

Citation (CITATION.cff)

cff-version: 1.2.0
title: "torchao: PyTorch native quantization and sparsity for training and inference"
message: "If you use this software, please cite it as below."
type: software
authors:
  - given-names: "torchao maintainers and contributors"
url: "https//github.com/pytorch/ao"
license: "BSD-3-Clause"
date-released: "2024-10-25"

Committers

Last synced: 9 months ago

All Time
  • Total Commits: 1,171
  • Total Committers: 114
  • Avg Commits per committer: 10.272
  • Development Distribution Score (DDS): 0.853
Past Year
  • Commits: 1,073
  • Committers: 111
  • Avg Commits per committer: 9.667
  • Development Distribution Score (DDS): 0.857
Top Committers
Name Email Commits
Jerry Zhang j****8@g****m 172
Vasiliy Kuznetsov v****o 120
Mark Saroufim m****m@m****m 95
Apurva Jain a****a@g****m 85
HDCharles 3****s 74
Scott Roy 1****y 73
cpuhrsch c****h@g****m 62
andrewor14 a****4@g****m 56
Thien Tran g****t@y****g 50
Driss Guessous 3****g 49
Jesse Cai j****i@m****m 44
Daniel Vega-Myhre d****m@m****m 38
Manuel Candales 4****s 14
Kimish Patel k****l@m****m 9
Wei (Will) Feng 1****y 9
Aleksandar Samardžić 1****c 9
Svetlana Karslioglu s****s@m****m 8
Jane (Yuan) Xu 3****9 8
Andrey Talman a****n@f****m 8
jeromeku j****u@g****m 8
Yi Liu y****u@i****m 7
supriyar s****r@f****m 7
Masaki Kozuki m****i@n****m 7
Tobias van der Werff 3****f 6
Pawan Jayakumar 1****s@g****m 6
Peter Yeh p****x 6
gmagogsfm g****m 6
Hanxian97 h****g@m****m 6
Huy Do h****n@g****m 5
y-sq 5****q 5
and 84 more...
Committer Domains (Top 20 + Academic)

Issues and Pull Requests

Last synced: 4 months ago

All Time
  • Total issues: 534
  • Total pull requests: 3,224
  • Average time to close issues: 22 days
  • Average time to close pull requests: 8 days
  • Total issue authors: 190
  • Total pull request authors: 206
  • Average comments per issue: 1.82
  • Average comments per pull request: 2.17
  • Merged pull requests: 2,043
  • Bot issues: 0
  • Bot pull requests: 0
Past Year
  • Issues: 380
  • Pull requests: 2,628
  • Average time to close issues: 16 days
  • Average time to close pull requests: 7 days
  • Issue authors: 166
  • Pull request authors: 182
  • Average comments per issue: 1.65
  • Average comments per pull request: 2.18
  • Merged pull requests: 1,654
  • Bot issues: 0
  • Bot pull requests: 0
Top Authors
Issue Authors
  • msaroufim (56)
  • vkuzo (53)
  • jerryzh168 (35)
  • drisspg (24)
  • gau-nernst (24)
  • andrewor14 (14)
  • danielvegamyhre (14)
  • HDCharles (11)
  • metascroy (10)
  • jcaip (9)
  • felipemello1 (7)
  • goldhuang (7)
  • ebsmothers (6)
  • jainapurva (5)
  • zigzagcai (5)
Pull Request Authors
  • jerryzh168 (385)
  • jainapurva (314)
  • vkuzo (312)
  • danielvegamyhre (216)
  • metascroy (211)
  • msaroufim (179)
  • drisspg (169)
  • andrewor14 (148)
  • HDCharles (119)
  • jcaip (92)
  • gau-nernst (84)
  • cpuhrsch (67)
  • petrex (45)
  • kimishpatel (43)
  • szyszyzys (35)
Top Labels
Issue Labels
float8 (29) triaged (26) good first issue (23) bug (18) ci (13) question (12) quantize (10) optimizer (8) tracker (7) binaries (7) enhancement (7) multibackend (7) performance (5) CLA Signed (5) autoquant (4) topic: documentation (4) mx (4) module: rocm (4) cpu (3) rfc (3) build (3) distributed (2) qat (2) reproduction needed (2) pt2e_quant (2) topic: bug fix (1) ciflow/rocm (1) topic: improvement (1) documentation (1) dependencies (1)
Pull Request Labels
CLA Signed (2,718) topic: not user facing (850) fb-exported (285) topic: improvement (205) topic: bug fix (137) topic: new feature (129) topic: documentation (95) module: rocm (76) topic: for developers (66) topic: performance (64) ciflow/rocm (59) topic: bc-breaking (42) ci-no-td (36) float8 (34) mx (32) ci (30) Merged (16) topic: deprecation (15) build (15) sparsity (14) benchmark (9) quantize (9) cpu (8) merging (7) pt2e_quant (6) ciflow/tutorials (6) ciflow/binaries/all (5) bug (5) performance (4) ciflow/benchmark (4)

Packages

  • Total packages: 2
  • Total downloads:
    • pypi 539,890 last-month
  • Total dependent packages: 1
    (may contain duplicates)
  • Total dependent repositories: 0
    (may contain duplicates)
  • Total versions: 64
  • Total maintainers: 7
proxy.golang.org: github.com/pytorch/ao
  • Versions: 46
  • Dependent Packages: 0
  • Dependent Repositories: 0
Rankings
Dependent packages count: 6.5%
Average: 6.7%
Dependent repos count: 7.0%
Last synced: 4 months ago
pypi.org: torchao

Package for applying ao techniques to GPU models

  • Versions: 18
  • Dependent Packages: 1
  • Dependent Repositories: 0
  • Downloads: 539,890 Last month
Rankings
Dependent packages count: 9.9%
Forks count: 29.8%
Average: 36.6%
Stargazers count: 38.9%
Dependent repos count: 67.8%
Last synced: 4 months ago

Dependencies

.github/workflows/test_install.yml actions
  • actions/checkout v2 composite
  • actions/setup-python v2 composite
setup.py pypi
  • torch *