Skip to main content

Accelerate LLM preference finetuning with a single line of code

Project description

Flash Preference

PyPI License: MIT

Accelerate LLM preference tuning via prefix sharing with a single line of code. Applicable to Direct Preference Optimization (DPO), Reward Modeling (RM), Group Relative Policy Optimization (GRPO), etc.

Getting Started

Install the stable version from PyPI:

pip install flash-pref

Or install the latest version from GitHub.

pip install git+https://github.com/li-plus/flash-preference.git@main

All you have to do is to add a shared_prefix context wrapping the model forward and backward passes. The common prefixes of the input sequences will be automatically detected and shared, reducing computation and memory footprint without loss of accuracy.

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from flash_pref import shared_prefix

model_id = "Qwen/Qwen2.5-7B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="right")
model = AutoModelForCausalLM.from_pretrained(
    model_id, attn_implementation="flash_attention_2", use_cache=False, torch_dtype=torch.bfloat16, device_map="cuda"
)

prompt = "What is the next 10 numbers of this sequence: " + ", ".join(str(x) for x in range(500))
chosen_response = ", ".join(str(x) for x in range(500, 500 + 10))
rejected_response = ", ".join(str(x) for x in range(500, 500 + 10, 2))

conversations = [
    [{"role": "user", "content": prompt}, {"role": "assistant", "content": chosen_response}],
    [{"role": "user", "content": prompt}, {"role": "assistant", "content": rejected_response}],
]
inputs = tokenizer.apply_chat_template(
    conversations, tokenize=True, padding=True, return_tensors="pt", return_dict=True
).to("cuda")

# ===== MAGIC HERE =====
with shared_prefix(model, input_ids=inputs.input_ids, attention_mask=inputs.attention_mask):
    output = model(**inputs)
    output.logits.backward(torch.randn_like(output.logits))

Benchmark

The performance speedup and memory saved relative to the baseline:

Benchmark settings is as below. Please refer to tests/benchmark.py for more details.

  • Model: Qwen/Qwen2.5-7B-Instruct with gradient checkpointing, Liger-Kernel and FlashAttention-2 enabled.
  • Data: mocked pairwise preference data where prompt and response lengths vary from 64 to 16k.
  • Computation: 1 forward pass and then 1 backward pass.
  • Hardware: 1x NVIDIA A800-SXM4-80GB GPU.

Developing

Unit Tests

Currently tested for LLaMA, Gemma, Gemma2, Qwen2, Qwen2VL and Qwen2.5VL architectures. At least 2 GPUs are required for unit tests. To run the unit tests, type:

make test

Code Format

To format the code, type:

make lint

License

This project is under MIT License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

flash_pref-0.1.0.tar.gz (12.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

flash_pref-0.1.0-py3-none-any.whl (8.6 kB view details)

Uploaded Python 3

File details

Details for the file flash_pref-0.1.0.tar.gz.

File metadata

  • Download URL: flash_pref-0.1.0.tar.gz
  • Upload date:
  • Size: 12.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for flash_pref-0.1.0.tar.gz
Algorithm Hash digest
SHA256 689660a328ebe1acbd612941f3f2b392e81bef0ae5d88c59e04e6c04ac0b36d4
MD5 d3c7a5937ce923ccde5f53c33af403b1
BLAKE2b-256 01fe17ca5a4d445a7391f8ce79a59fa591b78f678dcd2d84b58b4be7e392a180

See more details on using hashes here.

Provenance

The following attestation bundles were made for flash_pref-0.1.0.tar.gz:

Publisher: python-publish.yml on li-plus/flash-preference

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file flash_pref-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: flash_pref-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 8.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for flash_pref-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 640b4cc77b2e5eee0386e6561949d402def5d3f8b85b72c36ec1d1a7c169d060
MD5 909b4a49f30e05cbaba8bdd16ea9eafc
BLAKE2b-256 5b86bc5650e960a4563202fd1a584979542e7417f412ed05f028ae3728f2fa6b

See more details on using hashes here.

Provenance

The following attestation bundles were made for flash_pref-0.1.0-py3-none-any.whl:

Publisher: python-publish.yml on li-plus/flash-preference

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page