Implementation of Reinforcement Learning from Human Feedback (RLHF)
Project description
InstructGoose
Paper: InstructGPT - Training language models to follow instructions with human feedback
Install
Install from PipPy
pip install instruct-goose
Install directly from the source code
git clone https://github.com/xrsrke/instructGOOSE.git
cd instructGOOSE
pip install -e .
Train the RL-based language model
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import torch
from torch.utils.data import DataLoader, random_split
from torch import optim
from instruct_goose import Agent, RewardModel, RLHFTrainer, RLHFConfig, create_reference_model
Step 1: Load dataset
dataset = load_dataset("imdb", split="train")
dataset, _ = random_split(dataset, lengths=[10, len(dataset) - 10]) # for demenstration purposes
train_dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
Found cached dataset imdb (/Users/education/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)
Step 2: Load the pre-trained model and tokenizer
model_base = AutoModelForCausalLM.from_pretrained("gpt2") # for demonstration purposes
reward_model = RewardModel("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left")
eos_token_id = tokenizer.eos_token_id
tokenizer.pad_token = tokenizer.eos_token
Step 3: Create the RL-based language model agent and the reference model
model = Agent(model_base)
ref_model = create_reference_model(model)
Step 4: Train it
max_new_tokens = 20
generation_kwargs = {
"min_length":-1,
"top_k": 0.0,
"top_p": 1.0,
"do_sample": True,
"pad_token_id": tokenizer.eos_token_id,
"max_new_tokens": max_new_tokens
}
config = RLHFConfig()
N_EPOCH = 1 # for demonstration purposes
trainer = RLHFTrainer(model, ref_model, config)
optimizer = optim.SGD(model.parameters(), lr=1e-3)
for epoch in range(N_EPOCH):
for batch in train_dataloader:
inputs = tokenizer(batch["text"], padding=True, truncation=True, return_tensors="pt")
response_ids = model.generate(
inputs["input_ids"], attention_mask=inputs["attention_mask"],
**generation_kwargs
)
# extract the generated text
response_ids = response_ids[:, -max_new_tokens:]
response_attention_mask = torch.ones_like(response_ids)
# evaluate from the reward model
with torch.no_grad():
text_input_ids = torch.stack([torch.concat([q, r]) for q, r in zip(inputs["input_ids"], response_ids)], dim=0)
rewards = reward_model(text_input_ids)
# calculate PPO loss
loss = trainer.compute_loss(
query_ids=inputs["input_ids"],
query_attention_mask=inputs["attention_mask"],
response_ids=response_ids,
response_attention_mask=response_attention_mask,
rewards=rewards
)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"loss={loss}")
loss=-824.6560668945312
loss=0.030958056449890137
loss=4.284017562866211
TODO
- Add support custom reward function
- Add support custom value function
- Add support non-transformer models
- Write config class
Resources
I implemented this using these resources
- Copied the
load_yamlfunction from https://github.com/Dahoas/reward-modeling - How to build a dataset to train reward model: https://wandb.ai/carperai/summarize_RLHF/reports/Implementing-RLHF-Learning-to-Summarize-with-trlX–VmlldzozMzAwODM2
- How to add value head in PPO agent: https://github.com/lvwerra/trl
- How to calculate the loss of PPO agent: https://github.com/lvwerra/trl/blob/main/trl/trainer/ppo_trainer.py
- How to use PPO to train RLHF agent: https://github.com/voidful/TextRL
- How PPO works: https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo.py
- Copied the compute
advantagesandreturnsfromTLR: https://github.com/lvwerra/trl/blob/d2e8bcf8373726fb92d2110c500f7df6d0bd566d/trl/trainer/ppo_trainer.py#L686
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file instruct_goose-0.0.7.tar.gz.
File metadata
- Download URL: instruct_goose-0.0.7.tar.gz
- Upload date:
- Size: 11.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
532aa9676e27e9e8c570d5663bb6e2e55de1765fcda2c3c0f1b666cfb0c05877
|
|
| MD5 |
e8349b2024d7777732fef99094f6545d
|
|
| BLAKE2b-256 |
9b4e9bd9eafab6ba2a564f741645513210b87caec3ffccfae56500b8a7f29e27
|
File details
Details for the file instruct_goose-0.0.7-py3-none-any.whl.
File metadata
- Download URL: instruct_goose-0.0.7-py3-none-any.whl
- Upload date:
- Size: 12.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d04fd839a81f82ca03272d94d74c5f34ec52c9458292a55e64d1a85a51ec0bcd
|
|
| MD5 |
e7fc7e0ad43baaf91da421d8d14ed41a
|
|
| BLAKE2b-256 |
14f027ef27a25ed93d747363926cb4516b146317889f408fc18291a1e8228459
|