PyTorch implementation of low-rank adaptation (LoRA) and Adamix, a parameter-efficient approach to adapt a large pre-trained deep learning model which obtains performance better than full fine-tuning.
Project description
Adapting GPT-2 using LoRA, Adapters and Adamix
This folder contains the implementation of LoRA, Adamix Adapter and Adamix LoRA in GPT-2 using the modiifed Python package lora
and steps to replicate the results in our recent paper
This repo reproduces our experiments on GPT-2 Medium.
Repository Overview
Our implementation is based on the fine-tuning code for GPT-2 in Hugging Face. There are several directories in this repo:
- src/ contains the source code used for data processing, training, and decoding.
- eval/ contains the code for task-specific evaluation scripts.
- data/ contains the raw data we used in our experiments.
- vocab/ contains the GPT-2 vocabulary files.
Getting Started
- You can start with the following docker image:
nvcr.io/nvidia/pytorch:20.03-py3
on a GPU-capable machine, but any generic PyTorch image should work.
docker run -it nvcr.io/nvidia/pytorch:20.03-py3
- Clone the repo and install dependencies using the provided setup script in a virtual environment (remove sudo wherever necessary if running in docker container):
bash setup.sh
Now we are ready to replicate the results in our paper.
Replicating Our Results on GPT-2 Medium
(see our paper for hyperparameters for GPT-2 Medium)
- Train GPT-2 Medium on E2E NLG Challenge dataset
-
LoRA
python -m torch.distributed.launch --nproc_per_node=1 src/gpt2_ft.py \ --train_data ./data/e2e/train.jsonl \ --valid_data ./data/e2e/valid.jsonl \ --train_batch_size 8 \ --grad_acc 1 \ --valid_batch_size 4 \ --seq_len 512 \ --model_card gpt2.md \ --init_checkpoint ./pretrained_checkpoints/gpt2-medium-pytorch_model.bin \ --platform local \ --clip 0.0 \ --lr 0.0002 \ --weight_decay 0.01 \ --correct_bias \ --adam_beta2 0.999 \ --scheduler linear \ --warmup_step 500 \ --max_epoch 5 \ --save_interval 1000 \ --lora_dim 4 \ --lora_alpha 32 \ --lora_dropout 0.1 \ --label_smooth 0.1 \ --work_dir ./trained_models/GPT2_M/e2e/lora_only \ --random_seed 110 \ --lora_only 1
-
Adapter with Adamix
python -m torch.distributed.launch --nproc_per_node=1 src/gpt2_ft.py \ --train_data ./data/e2e/train.jsonl \ --valid_data ./data/e2e/valid.jsonl \ --train_batch_size 8 \ --grad_acc 1 \ --valid_batch_size 4 \ --seq_len 512 \ --model_card gpt2.md \ --init_checkpoint ./pretrained_checkpoints/gpt2-medium-pytorch_model.bin \ --platform local \ --clip 0.0 \ --lr 0.0002 \ --weight_decay 0.01 \ --correct_bias \ --adam_beta2 0.999 \ --scheduler linear \ --warmup_step 2000 \ --max_epoch 20 \ --eval_interval 5000 \ --save_interval 5000 \ --lora_dim 4 \ --lora_alpha 32 \ --lora_dropout 0.1 \ --label_smooth 0.1 \ --work_dir ./trained_models/GPT2_M/e2e/adapter_adamix \ --random_seed 110 \ --adamix_only 1 \ --n_experts 8 \ --share_A 0 \ --share_B 1
-
LoRA with Adamix
python -m torch.distributed.launch --nproc_per_node=1 src/gpt2_ft.py \ --train_data ./data/e2e/train.jsonl \ --valid_data ./data/e2e/valid.jsonl \ --train_batch_size 8 \ --grad_acc 1 \ --valid_batch_size 4 \ --seq_len 512 \ --model_card gpt2.md \ --init_checkpoint ./pretrained_checkpoints/gpt2-medium-pytorch_model.bin \ --platform local \ --clip 0.0 \ --lr 0.0002 \ --weight_decay 0.01 \ --correct_bias \ --adam_beta2 0.999 \ --scheduler linear \ --warmup_step 2000 \ --max_epoch 20 \ --eval_interval 5000 \ --save_interval 5000 \ --lora_dim 4 \ --lora_alpha 32 \ --lora_dropout 0.1 \ --label_smooth 0.1 \ --work_dir ./trained_models/GPT2_M/e2e/lora_adamix \ --random_seed 110 \ --n_experts 8 \ --share_A 0 \ --share_B 1
-
Generate outputs from the trained model using beam search (LoRA with Adamix):
python -m torch.distributed.launch --nproc_per_node=1 src/gpt2_beam.py \ --data ./data/e2e/test.jsonl \ --batch_size 1 \ --seq_len 128 \ --eval_len 64 \ --model_card gpt2.md \ --init_checkpoint ./trained_models/GPT2_M/e2e/lora_adamix/model.final.pt \ --platform local \ --lora_dim 4 \ --lora_alpha 32 \ --beam 10 \ --length_penalty 0.8 \ --no_repeat_ngram_size 4 \ --repetition_penalty 1.0 \ --eos_token_id 628 \ --work_dir ./trained_models/GPT2_M/e2e/lora_adamix \ --output_file predict.jsonl \ --n_experts 8 \ --share_A 0 \ --share_B 1
-
Decode outputs from step (2)
python src/gpt2_decode.py \ --vocab ./vocab \ --sample_file ./trained_models/GPT2_M/e2e/lora_adamix/predict.jsonl \ --input_file ./data/e2e/test_formatted.jsonl \ --output_ref_file e2e_ref.txt \ --output_pred_file e2e_pred.txt
-
Run evaluation on E2E test set
python eval/e2e/measure_scores.py e2e_ref.txt e2e_pred.txt -p
Replicating Our Result on WebNLG
-
Follow steps 1 and 2 from E2E pipeline by replacing references to E2E with webnlg (see our paper for hyperparameters)
-
Decode outputs from beam search (step 2 above)
python src/gpt2_decode.py \ --vocab ./vocab \ --sample_file ./trained_models/GPT2_M/webnlg/lora_adamix/predict.jsonl \ --input_file ./data/webnlg_challenge_2017/test_formatted.jsonl \ --ref_type webnlg \ --ref_num 6 \ --output_ref_file eval/GenerationEval/data/references_webnlg \ --output_pred_file eval/GenerationEval/data/hypothesis_webnlg \ --tokenize --lower
-
Run evaluation on WebNLG test set
cd ./eval/GenerationEval/ python eval.py \ -R data/references_webnlg/reference \ -H data/hypothesis_webnlg \ -nr 6 \ -m bleu,meteor,ter cd ../..
Replicating Our Result on DART
-
Follow steps 1 and 2 from E2E pipeline by replacing references to E2E with dart (see our paper for hyperparameters)
-
Decode outputs from beam search (step 2 above)
python src/gpt2_decode.py \ --vocab ./vocab \ --sample_file ./trained_models/GPT2_M/dart/lora_adamix/predict.jsonl \ --input_file ./data/dart/test_formatted.jsonl \ --ref_type dart \ --ref_num 6 \ --output_ref_file eval/GenerationEval/data/references_dart \ --output_pred_file eval/GenerationEval/data/hypothesis_dart \ --tokenize --lower
-
Run evaluation on Dart test set
cd ./eval/GenerationEval/ python eval.py \ -R data/references_dart/reference \ -H data/hypothesis_dart \ -nr 6 \ -m bleu,meteor,ter cd ../..
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file adamix_gpt2-0.0.2.tar.gz
.
File metadata
- Download URL: adamix_gpt2-0.0.2.tar.gz
- Upload date:
- Size: 31.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.8.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | b7d4db5ee496b3ba60dd951d03e1b9cba4cac175152202994df1b349baab9ab5 |
|
MD5 | ab60efd5a06831ba82f937a9e84d227b |
|
BLAKE2b-256 | ab5c625b20fc436c8e978bf0f215078814bb94cb40981a7a99c5fcf5930d07ea |
File details
Details for the file adamix_gpt2-0.0.2-py3-none-any.whl
.
File metadata
- Download URL: adamix_gpt2-0.0.2-py3-none-any.whl
- Upload date:
- Size: 37.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.8.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 9fa4b6a30be905d5b0fd4dc733d78a5cddc7d112a579610a61befbf14eafb4f4 |
|
MD5 | 0c462b0f65f4801603da9d4f3fff53e7 |
|
BLAKE2b-256 | e7f2f29456a339e6f1bd2d57b2afb808e3cd28a79b26d1d7badeeeeb69e5e4fe |