Accelerate LLM preference finetuning with a single line of code
Project description
Flash Preference
Accelerate LLM preference tuning via prefix sharing with a single line of code. Applicable to Direct Preference Optimization (DPO), Reward Modeling (RM), Group Relative Policy Optimization (GRPO), etc.
Getting Started
Install the stable version from PyPI:
pip install flash-pref
Or install the latest version from GitHub:
pip install git+https://github.com/li-plus/flash-preference.git@main
All you have to do is to add a shared_prefix context wrapping the model forward and backward passes. The common prefixes of the input sequences will be automatically detected and shared, reducing computation and memory footprint without loss of accuracy.
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from flash_pref import shared_prefix
model_id = "Qwen/Qwen2.5-7B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="right")
model = AutoModelForCausalLM.from_pretrained(
model_id, attn_implementation="flash_attention_2", use_cache=False, torch_dtype=torch.bfloat16, device_map="cuda"
)
prompt = "What are the next 10 numbers of this sequence: " + ", ".join(str(x) for x in range(500))
chosen_response = ", ".join(str(x) for x in range(500, 500 + 10))
rejected_response = ", ".join(str(x) for x in range(500, 500 + 10, 2))
conversations = [
[{"role": "user", "content": prompt}, {"role": "assistant", "content": chosen_response}],
[{"role": "user", "content": prompt}, {"role": "assistant", "content": rejected_response}],
]
inputs = tokenizer.apply_chat_template(
conversations, tokenize=True, padding=True, return_tensors="pt", return_dict=True
).to("cuda")
# ===== MAGIC HERE =====
with shared_prefix(model, input_ids=inputs.input_ids, attention_mask=inputs.attention_mask):
output = model(**inputs)
output.logits.backward(torch.randn_like(output.logits))
For huggingface/trl users, a drop-in replacement for trl trainer is also available. Check out the end-to-end training examples below.
| Algorithm | Original Trainer | Accelerated Trainer with Prefix Sharing | Example |
|---|---|---|---|
| Direct Preference Optimization | from trl import DPOTrainer |
from flash_pref import FlashDPOTrainer |
examples/dpo_trl.py |
| Reward Modeling | from trl import RewardTrainer |
from flash_pref import FlashRewardTrainer |
examples/reward_modeling_trl.py |
Benchmark
The performance speedup and memory saved relative to the baseline:
Benchmark settings are as below. Please refer to tests/benchmark.py for more details.
- Model: Qwen/Qwen2.5-7B-Instruct with gradient checkpointing, Liger-Kernel and FlashAttention-2 enabled.
- Data: mocked pairwise preference data where prompt and response lengths vary from 64 to 16k.
- Computation: 1 forward pass and then 1 backward pass.
- Hardware: 1x NVIDIA A800-SXM4-80GB GPU.
Developing
Unit Tests
Currently tested for LLaMA, Gemma, Gemma2, Qwen2, Qwen2VL and Qwen2.5VL architectures. At least 2 GPUs are required for unit tests. To run the unit tests, type:
make test
Code Format
To format the code, type:
make lint
License
This project is under MIT License.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file flash_pref-0.1.1.tar.gz.
File metadata
- Download URL: flash_pref-0.1.1.tar.gz
- Upload date:
- Size: 16.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1f77446ccd905b5ff4fc7edd99c2743c7df63f2148a06f8e530fde2849187054
|
|
| MD5 |
c0a2a1c5b676c1e58f3c1f0d68eee61d
|
|
| BLAKE2b-256 |
c22fb7d9457c97c640a021fac9ef1700a2a156467625f1ba9d92ebdcb0e1854d
|
Provenance
The following attestation bundles were made for flash_pref-0.1.1.tar.gz:
Publisher:
python-publish.yml on li-plus/flash-preference
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
flash_pref-0.1.1.tar.gz -
Subject digest:
1f77446ccd905b5ff4fc7edd99c2743c7df63f2148a06f8e530fde2849187054 - Sigstore transparency entry: 262495531
- Sigstore integration time:
-
Permalink:
li-plus/flash-preference@8748225ef294a003ff971d9f7fddea55070989ee -
Branch / Tag:
refs/heads/main - Owner: https://github.com/li-plus
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
python-publish.yml@8748225ef294a003ff971d9f7fddea55070989ee -
Trigger Event:
workflow_dispatch
-
Statement type:
File details
Details for the file flash_pref-0.1.1-py3-none-any.whl.
File metadata
- Download URL: flash_pref-0.1.1-py3-none-any.whl
- Upload date:
- Size: 11.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
601d62ebb0d084984f0cebe5450318bc7c220b6ccb62d5cff6240ff456fe03ec
|
|
| MD5 |
3d8de52fc4213bdf837ac9a42a2287cc
|
|
| BLAKE2b-256 |
ef6e6f612a151c9b08814d3033be278acdf63fb592a0ca7d6bda57e4777f7334
|
Provenance
The following attestation bundles were made for flash_pref-0.1.1-py3-none-any.whl:
Publisher:
python-publish.yml on li-plus/flash-preference
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
flash_pref-0.1.1-py3-none-any.whl -
Subject digest:
601d62ebb0d084984f0cebe5450318bc7c220b6ccb62d5cff6240ff456fe03ec - Sigstore transparency entry: 262495532
- Sigstore integration time:
-
Permalink:
li-plus/flash-preference@8748225ef294a003ff971d9f7fddea55070989ee -
Branch / Tag:
refs/heads/main - Owner: https://github.com/li-plus
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
python-publish.yml@8748225ef294a003ff971d9f7fddea55070989ee -
Trigger Event:
workflow_dispatch
-
Statement type: