Skip to main content

A native-PyTorch library for LLM fine-tuning

Project description

torchtune

Unit Test Recipe Integration Test

Introduction | Installation | Get Started | Documentation | Community | License

[!IMPORTANT] Update September 25, 2024: torchtune has support for Llama 3.2 11B Vision, Llama 3.2 3B, and Llama 3.2 1B models! Try them out by following our installation instructions here, then run any of the text configs here or vision configs here.

 

Introduction

torchtune is a PyTorch library for easily authoring, finetuning and experimenting with LLMs.

torchtune provides:

  • PyTorch implementations of popular LLMs from Llama, Gemma, Mistral, Phi, and Qwen model families
  • Hackable training recipes for full finetuning, LoRA, QLoRA, DPO, PPO, QAT, knowledge distillation, and more
  • Out-of-the-box memory efficiency, performance improvements, and scaling with the latest PyTorch APIs
  • YAML configs for easily configuring training, evaluation, quantization or inference recipes
  • Built-in support for many popular dataset formats and prompt templates

 

Models

torchtune currently supports the following models.

Model Sizes
Llama3.2-Vision 11B [models, configs]
Llama3.2 1B, 3B [models, configs]
Llama3.1 8B, 70B, 405B [models, configs]
Llama3 8B, 70B [models, configs]
Llama2 7B, 13B, 70B [models, configs]
Code-Llama2 7B, 13B, 70B [models, configs]
Mistral 7B [models, configs]
Gemma 2B, 7B [models, configs]
Microsoft Phi3 Mini [models, configs]
Qwen2 0.5B, 1.5B, 7B [models, configs]

We're always adding new models, but feel free to file an issue if there's a new one you would like to see in torchtune.

 

Finetuning recipes

torchtune provides the following finetuning recipes for training on one or more devices.

Finetuning Method Devices Recipe Example Config(s)
Full Finetuning 1-8 full_finetune_single_device
full_finetune_distributed
Llama3.1 8B single-device
Llama 3.1 70B distributed
LoRA Finetuning 1-8 lora_finetune_single_device
lora_finetune_distributed
Qwen2 0.5B single-device
Gemma 7B distributed
QLoRA Finetuning 1-8 lora_finetune_single_device
lora_finetune_distributed
Phi3 Mini single-device
Llama 3.1 405B distributed
DoRA/QDoRA Finetuning 1-8 lora_finetune_single_device
lora_finetune_distributed
Llama3 8B QDoRA single-device
Llama3 8B DoRA distributed
Quantization-Aware Training 4-8 qat_distributed Llama3 8B QAT
Direct Preference Optimization 1-8 lora_dpo_single_device
lora_dpo_distributed
Llama2 7B single-device
Llama2 7B distributed
Proximal Policy Optimization 1 ppo_full_finetune_single_device Mistral 7B
Knowledge Distillation 1 knowledge_distillation_single_device Qwen2 1.5B -> 0.5B

The above configs are just examples to get you started. If you see a model above not listed here, we likely still support it. If you're unsure whether something is supported, please open an issue on the repo.

 

Memory and training speed

Below is an example of the memory requirements and training speed for different Llama 3.1 models.

[!NOTE] For ease of comparison, all the below numbers are provided for batch size 2 (without gradient accumulation), a dataset packed to sequence length 2048, and torch compile enabled.

If you are interested in running on different hardware or with different models, check out our documentation on memory optimizations here to find the right setup for you.

Model Finetuning Method Runnable On Peak Memory per GPU Tokens/sec *
Llama 3.1 8B Full finetune 1x 4090 18.9 GiB 1650
Llama 3.1 8B Full finetune 1x A6000 37.4 GiB 2579
Llama 3.1 8B LoRA 1x 4090 16.2 GiB 3083
Llama 3.1 8B LoRA 1x A6000 30.3 GiB 4699
Llama 3.1 8B QLoRA 1x 4090 7.4 GiB 2413
Llama 3.1 70B Full finetune 8x A100 13.9 GiB ** 1568
Llama 3.1 70B LoRA 8x A100 27.6 GiB 3497
Llama 3.1 405B QLoRA 8x A100 44.8 GB 653

*= Measured over one full training epoch

**= Uses CPU offload with fused optimizer

 

Installation

torchtune is tested with the latest stable PyTorch release as well as the preview nightly version. torchtune leverages torchvision for finetuning multimodal LLMs and torchao for the latest in quantization techniques; you should install these as well.

 

Install stable release

# Install stable PyTorch, torchvision, torchao stable releases
pip install torch torchvision torchao
pip install torchtune

 

Install nightly release

# Install PyTorch, torchvision, torchao nightlies
pip install --pre --upgrade torch torchvision torchao --index-url https://download.pytorch.org/whl/nightly/cu121 # full options are cpu/cu118/cu121/cu124
pip install --pre --upgrade torchtune --extra-index-url https://download.pytorch.org/whl/nightly/cpu

You can also check out our install documentation for more information, including installing torchtune from source.

 

To confirm that the package is installed correctly, you can run the following command:

tune --help

And should see the following output:

usage: tune [-h] {ls,cp,download,run,validate} ...

Welcome to the torchtune CLI!

options:
  -h, --help            show this help message and exit

...

 

Get Started

To get started with torchtune, see our First Finetune Tutorial. Our End-to-End Workflow Tutorial will show you how to evaluate, quantize and run inference with a Llama model. The rest of this section will provide a quick overview of these steps with Llama3.1.

Downloading a model

Follow the instructions on the official meta-llama repository to ensure you have access to the official Llama model weights. Once you have confirmed access, you can run the following command to download the weights to your local machine. This will also download the tokenizer model and a responsible use guide.

To download Llama3.1, you can run:

tune download meta-llama/Meta-Llama-3.1-8B-Instruct \
--output-dir /tmp/Meta-Llama-3.1-8B-Instruct \
--hf-token <HF_TOKEN> \

[!Tip] Set your environment variable HF_TOKEN or pass in --hf-token to the command in order to validate your access. You can find your token at https://huggingface.co/settings/tokens

 

Running finetuning recipes

You can finetune Llama3.1 8B with LoRA on a single GPU using the following command:

tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device

For distributed training, tune CLI integrates with torchrun. To run a full finetune of Llama3.1 8B on two GPUs:

tune run --nproc_per_node 2 full_finetune_distributed --config llama3_1/8B_full

[!Tip] Make sure to place any torchrun commands before the recipe specification. Any CLI args after this will override the config and not impact distributed training.

 

Modify Configs

There are two ways in which you can modify configs:

Config Overrides

You can directly overwrite config fields from the command line:

tune run lora_finetune_single_device \
--config llama2/7B_lora_single_device \
batch_size=8 \
enable_activation_checkpointing=True \
max_steps_per_epoch=128

Update a Local Copy

You can also copy the config to your local directory and modify the contents directly:

tune cp llama3_1/8B_full ./my_custom_config.yaml
Copied to ./my_custom_config.yaml

Then, you can run your custom recipe by directing the tune run command to your local files:

tune run full_finetune_distributed --config ./my_custom_config.yaml

 

Check out tune --help for all possible CLI commands and options. For more information on using and updating configs, take a look at our config deep-dive.

 

Custom Datasets

torchtune supports finetuning on a variety of different datasets, including instruct-style, chat-style, preference datasets, and more. If you want to learn more about how to apply these components to finetune on your own custom dataset, please check out the provided links along with our API docs.

 

Community

torchtune focuses on integrating with popular tools and libraries from the ecosystem. These are just a few examples, with more under development:

 

Community Contributions

We really value our community and the contributions made by our wonderful users. We'll use this section to call out some of these contributions. If you'd like to help out as well, please see the CONTRIBUTING guide.

 

Acknowledgements

The Llama2 code in this repository is inspired by the original Llama2 code.

We want to give a huge shout-out to EleutherAI, Hugging Face and Weights & Biases for being wonderful collaborators and for working with us on some of these integrations within torchtune.

We also want to acknowledge some awesome libraries and tools from the ecosystem:

  • gpt-fast for performant LLM inference techniques which we've adopted out-of-the-box
  • llama recipes for spring-boarding the llama2 community
  • bitsandbytes for bringing several memory and performance based techniques to the PyTorch ecosystem
  • @winglian and axolotl for early feedback and brainstorming on torchtune's design and feature set.
  • lit-gpt for pushing the LLM finetuning community forward.
  • HF TRL for making reward modeling more accessible to the PyTorch community.

 

License

torchtune is released under the BSD 3 license. However you may have other legal obligations that govern your use of other content, such as the terms of service for third-party models.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

torchtune-0.3.1-py3-none-any.whl (596.6 kB view details)

Uploaded Python 3

File details

Details for the file torchtune-0.3.1-py3-none-any.whl.

File metadata

  • Download URL: torchtune-0.3.1-py3-none-any.whl
  • Upload date:
  • Size: 596.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.11.7

File hashes

Hashes for torchtune-0.3.1-py3-none-any.whl
Algorithm Hash digest
SHA256 b409f2bb180ed444d50e2600c854654a8954dbfab9d655cf1461eeedf6a74094
MD5 7b49937abfd6d2665ca86c6b5cec06a9
BLAKE2b-256 645e09722b84ce4ee8eaec0d9b06a653a3c6ad0cb6efdbd23f1cbdeff834b567

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page