Skip to main content

Easy-to-use fine-tuning framework using PEFT

Project description

LLaMA Efficient Tuning

GitHub Repo stars GitHub Code License GitHub last commit PyPI GitHub pull request

👋 Join our WeChat.

[ English | 中文 ]

Changelog

[23/09/10] Now we support using FlashAttention for the LLaMA models. Try --flash_attn argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs (experimental feature).

[23/08/18] Now we support resuming training, upgrade transformers to 4.31.0 to enjoy this feature.

[23/08/12] Now we support RoPE scaling to extend the context length of the LLaMA models. Try --rope_scaling linear argument in training and --rope_scaling dynamic argument at inference to extrapolate the position embeddings.

[23/08/11] Now we support DPO training for instruction-tuned models. See this example to train your models.

[23/07/31] Now we support dataset streaming. Try --streaming and --max_steps 10000 arguments to load your dataset in streaming mode.

[23/07/29] We release two instruction-tuned 13B models at Hugging Face. See these Hugging Face Repos (LLaMA-2 / Baichuan) for details.

[23/07/18] Now we develop an all-in-one Web UI for training, evaluation and inference. Try train_web.py to fine-tune models in your Web browser. Thank @KanadeSiina and @codemayq for their efforts in the development.

[23/07/09] Now we release FastEdit ⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow FastEdit if you are interested.

[23/06/29] We provide a reproducible example of training a chat model using instruction-following datasets, see Baichuan-7B-sft for details.

[23/06/22] Now we align the demo API with the OpenAI's format where you can insert the fine-tuned model in arbitrary ChatGPT-based applications.

[23/06/03] Now we support quantized training and inference (aka QLoRA). Try --quantization_bit 4/8 argument to work with quantized models.

Supported Models

Model Model size Default module Template
LLaMA 7B/13B/33B/65B q_proj,v_proj -
LLaMA-2 7B/13B/70B q_proj,v_proj llama2
BLOOM 560M/1.1B/1.7B/3B/7.1B/176B query_key_value -
BLOOMZ 560M/1.1B/1.7B/3B/7.1B/176B query_key_value -
Falcon 7B/40B query_key_value -
Baichuan 7B/13B W_pack baichuan
Baichuan2 7B/13B W_pack baichuan2
InternLM 7B q_proj,v_proj intern
Qwen 7B c_attn chatml
XVERSE 13B q_proj,v_proj xverse
ChatGLM2 6B query_key_value chatglm2

[!NOTE] Default module is used for the --lora_target argument, you can use --lora_target all to specify all the available modules.

For the "base" models, the --template argument can be chosen from default, alpaca, vicuna etc. But make sure to use the corresponding template for the "chat" models.

Supported Training Approaches

Approach Full-parameter Partial-parameter LoRA QLoRA
Pre-Training :white_check_mark: :white_check_mark: :white_check_mark: :white_check_mark:
Supervised Fine-Tuning :white_check_mark: :white_check_mark: :white_check_mark: :white_check_mark:
Reward Modeling :white_check_mark: :white_check_mark:
PPO Training :white_check_mark: :white_check_mark:
DPO Training :white_check_mark: :white_check_mark: :white_check_mark:

[!NOTE] Use --quantization_bit 4/8 argument to enable QLoRA.

Provided Datasets

Please refer to data/README.md for details.

Some datasets require confirmation before using them, so we recommend logging in with your Hugging Face account using these commands.

pip install --upgrade huggingface_hub
huggingface-cli login

Requirement

  • Python 3.8+ and PyTorch 1.13.1+
  • 🤗Transformers, Datasets, Accelerate, PEFT and TRL
  • sentencepiece, protobuf and tiktoken
  • jieba, rouge-chinese and nltk (used at evaluation)
  • gradio and matplotlib (used in web_demo.py)
  • uvicorn, fastapi and sse-starlette (used in api_demo.py)

And powerful GPUs!

Getting Started

Data Preparation (optional)

Please refer to data/example_dataset for checking the details about the format of dataset files. You can either use a single .json file or a dataset loading script with multiple files to create a custom dataset.

[!NOTE] Please update data/dataset_info.json to use your custom dataset. About the format of this file, please refer to data/README.md.

Dependence Installation (optional)

git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git
conda create -n llama_etuning python=3.10
conda activate llama_etuning
cd LLaMA-Efficient-Tuning
pip install -r requirements.txt

If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you will be required to install a pre-built version of bitsandbytes library, which supports CUDA 11.1 to 12.1.

pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl

All-in-one Web UI

CUDA_VISIBLE_DEVICES=0 python src/train_web.py

We strongly recommend using the all-in-one Web UI for newcomers since it can also generate training scripts automatically.

[!WARNING] Currently the web UI only supports training on a single GPU.

Train on a single GPU

[!IMPORTANT] If you want to train models on multiple GPUs, please refer to Distributed Training.

Pre-Training

CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
    --stage pt \
    --model_name_or_path path_to_llama_model \
    --do_train \
    --dataset wiki_demo \
    --finetuning_type lora \
    --lora_target q_proj,v_proj \
    --output_dir path_to_pt_checkpoint \
    --overwrite_cache \
    --per_device_train_batch_size 4 \
    --gradient_accumulation_steps 4 \
    --lr_scheduler_type cosine \
    --logging_steps 10 \
    --save_steps 1000 \
    --learning_rate 5e-5 \
    --num_train_epochs 3.0 \
    --plot_loss \
    --fp16

Supervised Fine-Tuning

CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
    --stage sft \
    --model_name_or_path path_to_llama_model \
    --do_train \
    --dataset alpaca_gpt4_en \
    --template default \
    --finetuning_type lora \
    --lora_target q_proj,v_proj \
    --output_dir path_to_sft_checkpoint \
    --overwrite_cache \
    --per_device_train_batch_size 4 \
    --gradient_accumulation_steps 4 \
    --lr_scheduler_type cosine \
    --logging_steps 10 \
    --save_steps 1000 \
    --learning_rate 5e-5 \
    --num_train_epochs 3.0 \
    --plot_loss \
    --fp16

Reward Modeling

CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
    --stage rm \
    --model_name_or_path path_to_llama_model \
    --do_train \
    --dataset comparison_gpt4_en \
    --template default \
    --finetuning_type lora \
    --lora_target q_proj,v_proj \
    --resume_lora_training False \
    --checkpoint_dir path_to_sft_checkpoint \
    --output_dir path_to_rm_checkpoint \
    --per_device_train_batch_size 2 \
    --gradient_accumulation_steps 4 \
    --lr_scheduler_type cosine \
    --logging_steps 10 \
    --save_steps 1000 \
    --learning_rate 1e-6 \
    --num_train_epochs 1.0 \
    --plot_loss \
    --fp16

PPO Training

CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
    --stage ppo \
    --model_name_or_path path_to_llama_model \
    --do_train \
    --dataset alpaca_gpt4_en \
    --template default \
    --finetuning_type lora \
    --lora_target q_proj,v_proj \
    --resume_lora_training False \
    --checkpoint_dir path_to_sft_checkpoint \
    --reward_model path_to_rm_checkpoint \
    --output_dir path_to_ppo_checkpoint \
    --per_device_train_batch_size 2 \
    --gradient_accumulation_steps 4 \
    --lr_scheduler_type cosine \
    --logging_steps 10 \
    --save_steps 1000 \
    --learning_rate 1e-5 \
    --num_train_epochs 1.0 \
    --plot_loss \
    --fp16

DPO Training

CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
    --stage dpo \
    --model_name_or_path path_to_llama_model \
    --do_train \
    --dataset comparison_gpt4_en \
    --template default \
    --finetuning_type lora \
    --lora_target q_proj,v_proj \
    --resume_lora_training False \
    --checkpoint_dir path_to_sft_checkpoint \
    --output_dir path_to_dpo_checkpoint \
    --per_device_train_batch_size 2 \
    --gradient_accumulation_steps 4 \
    --lr_scheduler_type cosine \
    --logging_steps 10 \
    --save_steps 1000 \
    --learning_rate 1e-5 \
    --num_train_epochs 1.0 \
    --plot_loss \
    --fp16

Distributed Training

Use Huggingface Accelerate

accelerate config # configure the environment
accelerate launch src/train_bash.py # arguments (same as above)
Example config for LoRA training
compute_environment: LOCAL_MACHINE
distributed_type: MULTI_GPU
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Use DeepSpeed

deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
    --deepspeed ds_config.json \
    ... # arguments (same as above)
Example config for full-parameter training with DeepSpeed ZeRO-2
{
  "train_batch_size": "auto",
  "train_micro_batch_size_per_gpu": "auto",
  "gradient_accumulation_steps": "auto",
  "gradient_clipping": "auto",
  "zero_allow_untested_optimizer": true,
  "fp16": {
    "enabled": "auto",
    "loss_scale": 0,
    "initial_scale_power": 16,
    "loss_scale_window": 1000,
    "hysteresis": 2,
    "min_loss_scale": 1
  },  
  "zero_optimization": {
    "stage": 2,
    "allgather_partitions": true,
    "allgather_bucket_size": 5e8,
    "reduce_scatter": true,
    "reduce_bucket_size": 5e8,
    "overlap_comm": false,
    "contiguous_gradients": true
  }
}

Export model

python src/export_model.py \
    --model_name_or_path path_to_llama_model \
    --template default \
    --finetuning_type lora \
    --checkpoint_dir path_to_checkpoint \
    --output_dir path_to_export

API Demo

python src/api_demo.py \
    --model_name_or_path path_to_llama_model \
    --template default \
    --finetuning_type lora \
    --checkpoint_dir path_to_checkpoint

[!NOTE] Visit http://localhost:8000/docs for API documentation.

CLI Demo

python src/cli_demo.py \
    --model_name_or_path path_to_llama_model \
    --template default \
    --finetuning_type lora \
    --checkpoint_dir path_to_checkpoint

Web Demo

python src/web_demo.py \
    --model_name_or_path path_to_llama_model \
    --template default \
    --finetuning_type lora \
    --checkpoint_dir path_to_checkpoint

Evaluation (BLEU and ROUGE_CHINESE)

CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
    --stage sft \
    --model_name_or_path path_to_llama_model \
    --do_eval \
    --dataset alpaca_gpt4_en \
    --template default \
    --finetuning_type lora \
    --checkpoint_dir path_to_checkpoint \
    --output_dir path_to_eval_result \
    --per_device_eval_batch_size 8 \
    --max_samples 100 \
    --predict_with_generate

[!NOTE] We recommend using --per_device_eval_batch_size=1 and --max_target_length 128 at 4/8-bit evaluation.

Predict

CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
    --stage sft \
    --model_name_or_path path_to_llama_model \
    --do_predict \
    --dataset alpaca_gpt4_en \
    --template default \
    --finetuning_type lora \
    --checkpoint_dir path_to_checkpoint \
    --output_dir path_to_predict_result \
    --per_device_eval_batch_size 8 \
    --max_samples 100 \
    --predict_with_generate

License

This repository is licensed under the Apache-2.0 License.

Please follow the model licenses to use the corresponding model weights:

Citation

If this work is helpful, please kindly cite as:

@Misc{llama-efficient-tuning,
  title = {LLaMA Efficient Tuning},
  author = {hiyouga},
  howpublished = {\url{https://github.com/hiyouga/LLaMA-Efficient-Tuning}},
  year = {2023}
}

Acknowledgement

This repo benefits from PEFT, QLoRA and OpenChatKit. Thanks for their wonderful works.

Star History

Star History Chart

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

llmtuner-0.1.8.tar.gz (69.9 kB view hashes)

Uploaded Source

Built Distribution

llmtuner-0.1.8-py3-none-any.whl (86.5 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page