Implementation of Reinforcement Learning from Human Feedback (RLHF)
Project description
InstructGoose - 🚧 WORK IN PROGRESS 🚧
Paper: InstructGPT - Training language models to follow instructions with human feedback
Install
Install from PipPy
pip install instruct-goose
Install directly from the source code
git clone https://github.com/xrsrke/instructGOOSE.git
cd instructGOOSE
pip install -e .
Train the RL-based language model
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import torch
from torch.utils.data import DataLoader
from torch import optim
from instruct_goose import Agent, RewardModel, RLHFTrainer, RLHFConfig, create_reference_model
Step 1: Load dataset
dataset = load_dataset("imdb", split="train")
train_dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
Step 2: Load the pre-trained model and tokenizer
model_base = AutoModelForCausalLM.from_pretrained("gpt2")
reward_model = RewardModel("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
eos_token_id = tokenizer.eos_token_id
tokenizer.pad_token = tokenizer.eos_token
Step 3: Create the RL-based language model agent and the reference model
model = Agent(model_base)
ref_model = create_reference_model(model)
Step 4: Train it
max_new_tokens = 20
generation_kwargs = {
"min_length":-1,
"top_k": 0.0,
"top_p": 1.0,
"do_sample": True,
"pad_token_id": tokenizer.eos_token_id,
"max_new_tokens": max_new_tokens
}
config = RLHFConfig()
N_EPOCH = 100
trainer = RLHFTrainer(model, ref_model, config)
optimizer = optim.SGD(model.parameters(), lr=1e-3)
for epoch in range(N_EPOCH):
for batch in train_dataloader:
inputs = tokenizer(batch["text"], padding=True, truncation=True, return_tensors="pt")
responses = model.generate(
inputs["input_ids"], attention_mask=inputs["attention_mask"],
**generation_kwargs
)
# extract the generated text
responses = responses[:, -max_new_tokens:]
# evaluate from the reward model
with torch.no_grad():
text_input_ids = torch.stack([torch.concat([q, r]) for q, r in zip(inputs["input_ids"], responses)], dim=0)
texts = tokenizer.batch_decode(text_input_ids, skip_special_tokens=True)
rewards = reward_model(texts)
# calculate PPO loss
loss = trainer.compute_loss(inputs["input_ids"], responses, rewards)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"loss={loss}")
TODO
- Add support custom reward function
- Add support custom value function
- Add support non-transformer models
Questions
- In the context of RLHF, how to calculate the $L_t^{V F}(\theta)$,
- Like it’s a function of the PPO agent uses to predict how much reward it gets if generates the sequence?
Does the RL model and the SFT model use the same tokenizer? YesI don’t know how to returns the logit of the generation model- Does the PPO Agent (Language Model) has a value network just like the regular PPO Agent?
- I don’t understand how to calculate the advantage in PPO
Resources
I used these resources to implement this
- Copied the
load_yaml
function from https://github.com/Dahoas/reward-modeling - How to build a dataset to train reward model: https://wandb.ai/carperai/summarize_RLHF/reports/Implementing-RLHF-Learning-to-Summarize-with-trlX–VmlldzozMzAwODM2
- How to add value head in PPO agent: https://github.com/lvwerra/trl
- How to calculate the loss of PPO agent: https://github.com/lvwerra/trl/blob/main/trl/trainer/ppo_trainer.py
- How to use PPO to train RLHF agent: https://github.com/voidful/TextRL
- How PPO works: https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo.py
- Copied the compute
advantages
andreturns
fromTLR
: https://github.com/lvwerra/trl/blob/d2e8bcf8373726fb92d2110c500f7df6d0bd566d/trl/trainer/ppo_trainer.py#L686
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
instruct_goose-0.0.3.tar.gz
(13.0 kB
view details)
Built Distribution
File details
Details for the file instruct_goose-0.0.3.tar.gz
.
File metadata
- Download URL: instruct_goose-0.0.3.tar.gz
- Upload date:
- Size: 13.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | cfba80a01ef34f3c194b48f30a66181aa0336ee1a58344dfd604e4db32197fa6 |
|
MD5 | 0800e7294545e4fa79f99f9a3da4072d |
|
BLAKE2b-256 | c536cd107ba26166841c2037aea13d08d756125aceb959ce3583f38f7e229768 |
File details
Details for the file instruct_goose-0.0.3-py3-none-any.whl
.
File metadata
- Download URL: instruct_goose-0.0.3-py3-none-any.whl
- Upload date:
- Size: 13.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 24809cb154bac441d37ae97b23f723969815d5bb9c1599ce671b08b5ddbfd3ab |
|
MD5 | 89d7c5853d126c11387ee07dfc558f53 |
|
BLAKE2b-256 | 636847e3fecbfd0e0a4753de990427b4f43efb7ff262b9692ed0cfc3ff42dedc |