Skip to main content

Accelerate LLM preference finetuning with a single line of code

Project description

Flash Preference

PyPI License: MIT

Accelerate LLM preference tuning via prefix sharing with a single line of code. Applicable to Direct Preference Optimization (DPO), Reward Modeling (RM), Group Relative Policy Optimization (GRPO), etc.

Getting Started

Install the stable version from PyPI:

pip install flash-pref

Or install the latest version from GitHub:

pip install git+https://github.com/li-plus/flash-preference.git@main

All you have to do is to add a shared_prefix context wrapping the model forward and backward passes. The common prefixes of the input sequences will be automatically detected and shared, reducing computation and memory footprint without loss of accuracy.

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from flash_pref import shared_prefix

model_id = "Qwen/Qwen2.5-7B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="right")
model = AutoModelForCausalLM.from_pretrained(
    model_id, attn_implementation="flash_attention_2", use_cache=False, torch_dtype=torch.bfloat16, device_map="cuda"
)

prompt = "What are the next 10 numbers of this sequence: " + ", ".join(str(x) for x in range(500))
chosen_response = ", ".join(str(x) for x in range(500, 500 + 10))
rejected_response = ", ".join(str(x) for x in range(500, 500 + 10, 2))

conversations = [
    [{"role": "user", "content": prompt}, {"role": "assistant", "content": chosen_response}],
    [{"role": "user", "content": prompt}, {"role": "assistant", "content": rejected_response}],
]
inputs = tokenizer.apply_chat_template(
    conversations, tokenize=True, padding=True, return_tensors="pt", return_dict=True
).to("cuda")

# ===== MAGIC HERE =====
with shared_prefix(model, input_ids=inputs.input_ids, attention_mask=inputs.attention_mask):
    output = model(**inputs)
    output.logits.backward(torch.randn_like(output.logits))

For huggingface/trl users, a drop-in replacement for trl trainer is also available. Check out the end-to-end training examples below.

Algorithm Original Trainer Accelerated Trainer with Prefix Sharing Example
Direct Preference Optimization from trl import DPOTrainer from flash_pref import FlashDPOTrainer examples/dpo_trl.py
Reward Modeling from trl import RewardTrainer from flash_pref import FlashRewardTrainer examples/reward_modeling_trl.py

Benchmark

The performance speedup and memory saved relative to the baseline:

Benchmark settings are as below. Please refer to tests/benchmark.py for more details.

  • Model: Qwen/Qwen2.5-7B-Instruct with gradient checkpointing, Liger-Kernel and FlashAttention-2 enabled.
  • Data: mocked pairwise preference data where prompt and response lengths vary from 64 to 16k.
  • Computation: 1 forward pass and then 1 backward pass.
  • Hardware: 1x NVIDIA A800-SXM4-80GB GPU.

Developing

Unit Tests

Currently tested for LLaMA, Gemma, Gemma2, Qwen2, Qwen2VL and Qwen2.5VL architectures. At least 2 GPUs are required for unit tests. To run the unit tests, type:

make test

Code Format

To format the code, type:

make lint

License

This project is under MIT License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

flash_pref-0.1.1.tar.gz (16.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

flash_pref-0.1.1-py3-none-any.whl (11.3 kB view details)

Uploaded Python 3

File details

Details for the file flash_pref-0.1.1.tar.gz.

File metadata

  • Download URL: flash_pref-0.1.1.tar.gz
  • Upload date:
  • Size: 16.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for flash_pref-0.1.1.tar.gz
Algorithm Hash digest
SHA256 1f77446ccd905b5ff4fc7edd99c2743c7df63f2148a06f8e530fde2849187054
MD5 c0a2a1c5b676c1e58f3c1f0d68eee61d
BLAKE2b-256 c22fb7d9457c97c640a021fac9ef1700a2a156467625f1ba9d92ebdcb0e1854d

See more details on using hashes here.

Provenance

The following attestation bundles were made for flash_pref-0.1.1.tar.gz:

Publisher: python-publish.yml on li-plus/flash-preference

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file flash_pref-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: flash_pref-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 11.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for flash_pref-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 601d62ebb0d084984f0cebe5450318bc7c220b6ccb62d5cff6240ff456fe03ec
MD5 3d8de52fc4213bdf837ac9a42a2287cc
BLAKE2b-256 ef6e6f612a151c9b08814d3033be278acdf63fb592a0ca7d6bda57e4777f7334

See more details on using hashes here.

Provenance

The following attestation bundles were made for flash_pref-0.1.1-py3-none-any.whl:

Publisher: python-publish.yml on li-plus/flash-preference

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page