Skip to main content

Implementation of Reinforcement Learning from Human Feedback (RLHF)

Project description

InstructGoose - 🚧 WORK IN PROGRESS 🚧

Paper: InstructGPT - Training language models to follow instructions with human feedback

Install

Install from PipPy

pip install instruct-goose

Install directly from the source code

git clone https://github.com/xrsrke/instructGOOSE.git
cd instructGOOSE
pip install -e .

Train the RL-based language model

from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset

import torch
from torch.utils.data import DataLoader
from torch import optim

from instruct_goose import Agent, RewardModel, RLHFTrainer, RLHFConfig, create_reference_model

Step 1: Load dataset

dataset = load_dataset("imdb", split="train")
train_dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

Step 2: Load the pre-trained model and tokenizer

model_base = AutoModelForCausalLM.from_pretrained("gpt2")
reward_model = RewardModel("gpt2")

tokenizer = AutoTokenizer.from_pretrained("gpt2")
eos_token_id = tokenizer.eos_token_id
tokenizer.pad_token = tokenizer.eos_token

Step 3: Create the RL-based language model agent and the reference model

model = Agent(model_base)
ref_model = create_reference_model(model)

Step 4: Train it

max_new_tokens = 20
generation_kwargs = {
    "min_length":-1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
    "max_new_tokens": max_new_tokens
}

config = RLHFConfig()
N_EPOCH = 100
trainer = RLHFTrainer(model, ref_model, config)
optimizer = optim.SGD(model.parameters(), lr=1e-3)
for epoch in range(N_EPOCH):
    for batch in train_dataloader:
        inputs = tokenizer(batch["text"], padding=True, truncation=True, return_tensors="pt")
        responses = model.generate(
            inputs["input_ids"], attention_mask=inputs["attention_mask"],
            **generation_kwargs
        )
        
        # extract the generated text
        responses = responses[:, -max_new_tokens:]
        
        # evaluate from the reward model
        with torch.no_grad():
            text_input_ids = torch.stack([torch.concat([q, r]) for q, r in zip(inputs["input_ids"], responses)], dim=0)
            texts = tokenizer.batch_decode(text_input_ids, skip_special_tokens=True)
            rewards = reward_model(texts)
        
        # calculate PPO loss
        loss = trainer.compute_loss(inputs["input_ids"], responses, rewards)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(f"loss={loss}")

TODO

  • Add support custom reward function
  • Add support custom value function
  • Add support non-transformer models

Questions

  • In the context of RLHF, how to calculate the $L_t^{V F}(\theta)$,
    • Like it’s a function of the PPO agent uses to predict how much reward it gets if generates the sequence?
  • Does the RL model and the SFT model use the same tokenizer? Yes
  • I don’t know how to returns the logit of the generation model
  • Does the PPO Agent (Language Model) has a value network just like the regular PPO Agent?
  • I don’t understand how to calculate the advantage in PPO

Resources

I used these resources to implement this

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

instruct_goose-0.0.3.tar.gz (13.0 kB view details)

Uploaded Source

Built Distribution

instruct_goose-0.0.3-py3-none-any.whl (13.2 kB view details)

Uploaded Python 3

File details

Details for the file instruct_goose-0.0.3.tar.gz.

File metadata

  • Download URL: instruct_goose-0.0.3.tar.gz
  • Upload date:
  • Size: 13.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.6

File hashes

Hashes for instruct_goose-0.0.3.tar.gz
Algorithm Hash digest
SHA256 cfba80a01ef34f3c194b48f30a66181aa0336ee1a58344dfd604e4db32197fa6
MD5 0800e7294545e4fa79f99f9a3da4072d
BLAKE2b-256 c536cd107ba26166841c2037aea13d08d756125aceb959ce3583f38f7e229768

See more details on using hashes here.

File details

Details for the file instruct_goose-0.0.3-py3-none-any.whl.

File metadata

File hashes

Hashes for instruct_goose-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 24809cb154bac441d37ae97b23f723969815d5bb9c1599ce671b08b5ddbfd3ab
MD5 89d7c5853d126c11387ee07dfc558f53
BLAKE2b-256 636847e3fecbfd0e0a4753de990427b4f43efb7ff262b9692ed0cfc3ff42dedc

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page