Skip to main content

Implementation of Reinforcement Learning from Human Feedback (RLHF)

Project description

InstructGoose

Paper: InstructGPT - Training language models to follow instructions with human feedback

Install

Install from PipPy

pip install instruct-goose

Install directly from the source code

git clone https://github.com/xrsrke/instructGOOSE.git
cd instructGOOSE
pip install -e .

Train the RL-based language model

from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset

import torch
from torch.utils.data import DataLoader
from torch import optim

from instruct_goose import Agent, RewardModel, RLHFTrainer, RLHFConfig, create_reference_model

Step 1: Load dataset

dataset = load_dataset("imdb", split="train")
train_dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

Step 2: Load the pre-trained model and tokenizer

model_base = AutoModelForCausalLM.from_pretrained("gpt2")
reward_model = RewardModel("gpt2")

tokenizer = AutoTokenizer.from_pretrained("gpt2")
eos_token_id = tokenizer.eos_token_id
tokenizer.pad_token = tokenizer.eos_token

Step 3: Create the RL-based language model agent and the reference model

model = Agent(model_base)
ref_model = create_reference_model(model)

Step 4: Train it

max_new_tokens = 20
generation_kwargs = {
    "min_length":-1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
    "max_new_tokens": max_new_tokens
}

config = RLHFConfig()
N_EPOCH = 100
trainer = RLHFTrainer(model, ref_model, config)
optimizer = optim.SGD(model.parameters(), lr=1e-3)
for epoch in range(N_EPOCH):
    for batch in train_dataloader:
        inputs = tokenizer(batch["text"], padding=True, truncation=True, return_tensors="pt")
        responses = model.generate(
            inputs["input_ids"], attention_mask=inputs["attention_mask"],
            **generation_kwargs
        )
        
        # extract the generated text
        responses = responses[:, -max_new_tokens:]
        
        # evaluate from the reward model
        with torch.no_grad():
            text_input_ids = torch.stack([torch.concat([q, r]) for q, r in zip(inputs["input_ids"], responses)], dim=0)
            texts = tokenizer.batch_decode(text_input_ids, skip_special_tokens=True)
            rewards = reward_model(texts)
        
        # calculate PPO loss
        loss = trainer.compute_loss(inputs["input_ids"], responses, rewards)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(f"loss={loss}")

TODO

  • Add support custom reward function
  • Add support custom value function
  • Add support non-transformer models

Resources

I implemented this using these resources

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

instruct_goose-0.0.5.tar.gz (10.8 kB view details)

Uploaded Source

Built Distribution

instruct_goose-0.0.5-py3-none-any.whl (12.2 kB view details)

Uploaded Python 3

File details

Details for the file instruct_goose-0.0.5.tar.gz.

File metadata

  • Download URL: instruct_goose-0.0.5.tar.gz
  • Upload date:
  • Size: 10.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.6

File hashes

Hashes for instruct_goose-0.0.5.tar.gz
Algorithm Hash digest
SHA256 6fd46d16a890309d3943b09707c1ee5dfd176af17205aa57499d848df78cd8b6
MD5 a89698cefdf9cdfb6df401c50a7a4876
BLAKE2b-256 7ba14a6952d33462698e8b1cafe0b492587094df5c5d1517202c2ffc361c297e

See more details on using hashes here.

File details

Details for the file instruct_goose-0.0.5-py3-none-any.whl.

File metadata

File hashes

Hashes for instruct_goose-0.0.5-py3-none-any.whl
Algorithm Hash digest
SHA256 a05ba3a12e13289f1e3a5fa5a0d9c7c7bba368df4a82a013d895f814db7238ab
MD5 fb84a14d6587741ca20bf58089ec9076
BLAKE2b-256 c8e52ec76949a4fd4e6dba455c5524390ca5b823a52e08f9744bbec14d777176

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page