Implementation of Reinforcement Learning from Human Feedback (RLHF)
Project description
InstructGoose
Paper: InstructGPT - Training language models to follow instructions with human feedback
Install
Install from PipPy
pip install instruct-goose
Install directly from the source code
git clone https://github.com/xrsrke/instructGOOSE.git
cd instructGOOSE
pip install -e .
Train the RL-based language model
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import torch
from torch.utils.data import DataLoader
from torch import optim
from instruct_goose import Agent, RewardModel, RLHFTrainer, RLHFConfig, create_reference_model
Step 1: Load dataset
dataset = load_dataset("imdb", split="train")
train_dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
Step 2: Load the pre-trained model and tokenizer
model_base = AutoModelForCausalLM.from_pretrained("gpt2")
reward_model = RewardModel("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
eos_token_id = tokenizer.eos_token_id
tokenizer.pad_token = tokenizer.eos_token
Step 3: Create the RL-based language model agent and the reference model
model = Agent(model_base)
ref_model = create_reference_model(model)
Step 4: Train it
max_new_tokens = 20
generation_kwargs = {
"min_length":-1,
"top_k": 0.0,
"top_p": 1.0,
"do_sample": True,
"pad_token_id": tokenizer.eos_token_id,
"max_new_tokens": max_new_tokens
}
config = RLHFConfig()
N_EPOCH = 100
trainer = RLHFTrainer(model, ref_model, config)
optimizer = optim.SGD(model.parameters(), lr=1e-3)
for epoch in range(N_EPOCH):
for batch in train_dataloader:
inputs = tokenizer(batch["text"], padding=True, truncation=True, return_tensors="pt")
responses = model.generate(
inputs["input_ids"], attention_mask=inputs["attention_mask"],
**generation_kwargs
)
# extract the generated text
responses = responses[:, -max_new_tokens:]
# evaluate from the reward model
with torch.no_grad():
text_input_ids = torch.stack([torch.concat([q, r]) for q, r in zip(inputs["input_ids"], responses)], dim=0)
texts = tokenizer.batch_decode(text_input_ids, skip_special_tokens=True)
rewards = reward_model(texts)
# calculate PPO loss
loss = trainer.compute_loss(inputs["input_ids"], responses, rewards)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"loss={loss}")
TODO
- Add support custom reward function
- Add support custom value function
- Add support non-transformer models
Resources
I implemented this using these resources
- Copied the
load_yaml
function from https://github.com/Dahoas/reward-modeling - How to build a dataset to train reward model: https://wandb.ai/carperai/summarize_RLHF/reports/Implementing-RLHF-Learning-to-Summarize-with-trlX–VmlldzozMzAwODM2
- How to add value head in PPO agent: https://github.com/lvwerra/trl
- How to calculate the loss of PPO agent: https://github.com/lvwerra/trl/blob/main/trl/trainer/ppo_trainer.py
- How to use PPO to train RLHF agent: https://github.com/voidful/TextRL
- How PPO works: https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo.py
- Copied the compute
advantages
andreturns
fromTLR
: https://github.com/lvwerra/trl/blob/d2e8bcf8373726fb92d2110c500f7df6d0bd566d/trl/trainer/ppo_trainer.py#L686
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
instruct_goose-0.0.5.tar.gz
(10.8 kB
view details)
Built Distribution
File details
Details for the file instruct_goose-0.0.5.tar.gz
.
File metadata
- Download URL: instruct_goose-0.0.5.tar.gz
- Upload date:
- Size: 10.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6fd46d16a890309d3943b09707c1ee5dfd176af17205aa57499d848df78cd8b6 |
|
MD5 | a89698cefdf9cdfb6df401c50a7a4876 |
|
BLAKE2b-256 | 7ba14a6952d33462698e8b1cafe0b492587094df5c5d1517202c2ffc361c297e |
File details
Details for the file instruct_goose-0.0.5-py3-none-any.whl
.
File metadata
- Download URL: instruct_goose-0.0.5-py3-none-any.whl
- Upload date:
- Size: 12.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | a05ba3a12e13289f1e3a5fa5a0d9c7c7bba368df4a82a013d895f814db7238ab |
|
MD5 | fb84a14d6587741ca20bf58089ec9076 |
|
BLAKE2b-256 | c8e52ec76949a4fd4e6dba455c5524390ca5b823a52e08f9744bbec14d777176 |