Skip to main content

Implementation of Reinforcement Learning from Human Feedback (RLHF)

Project description

InstructGoose

Paper: InstructGPT - Training language models to follow instructions with human feedback

image.png

Install

Install from PipPy

pip install instruct-goose

Install directly from the source code

git clone https://github.com/xrsrke/instructGOOSE.git
cd instructGOOSE
pip install -e .

Train the RL-based language model

from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset

import torch
from torch.utils.data import DataLoader, random_split
from torch import optim

from instruct_goose import Agent, RewardModel, RLHFTrainer, RLHFConfig, create_reference_model

Step 1: Load dataset

dataset = load_dataset("imdb", split="train")
dataset, _ = random_split(dataset, lengths=[10, len(dataset) - 10]) # for demenstration purposes
train_dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
Found cached dataset imdb (/Users/education/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)

Step 2: Load the pre-trained model and tokenizer

model_base = AutoModelForCausalLM.from_pretrained("gpt2") # for demonstration purposes
reward_model = RewardModel("gpt2")

tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left")
eos_token_id = tokenizer.eos_token_id
tokenizer.pad_token = tokenizer.eos_token

Step 3: Create the RL-based language model agent and the reference model

model = Agent(model_base)
ref_model = create_reference_model(model)

Step 4: Train it

max_new_tokens = 20
generation_kwargs = {
    "min_length":-1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
    "max_new_tokens": max_new_tokens
}

config = RLHFConfig()
N_EPOCH = 1 # for demonstration purposes
trainer = RLHFTrainer(model, ref_model, config)
optimizer = optim.SGD(model.parameters(), lr=1e-3)
for epoch in range(N_EPOCH):
    for batch in train_dataloader:
        inputs = tokenizer(batch["text"], padding=True, truncation=True, return_tensors="pt")
        response_ids = model.generate(
            inputs["input_ids"], attention_mask=inputs["attention_mask"],
            **generation_kwargs
        )
        
        # extract the generated text
        response_ids = response_ids[:, -max_new_tokens:]
        response_attention_mask = torch.ones_like(response_ids)
        
        # evaluate from the reward model
        with torch.no_grad():
            text_input_ids = torch.stack([torch.concat([q, r]) for q, r in zip(inputs["input_ids"], response_ids)], dim=0)
            rewards = reward_model(text_input_ids)
        
        # calculate PPO loss
        loss = trainer.compute_loss(
            query_ids=inputs["input_ids"],
            query_attention_mask=inputs["attention_mask"],
            response_ids=response_ids,
            response_attention_mask=response_attention_mask,
            rewards=rewards
        )
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(f"loss={loss}")
loss=-824.6560668945312
loss=0.030958056449890137
loss=4.284017562866211

TODO

  • Add support custom reward function
  • Add support custom value function
  • Add support non-transformer models
  • Write config class

Resources

I implemented this using these resources

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

instruct_goose-0.0.7.tar.gz (11.5 kB view details)

Uploaded Source

Built Distribution

instruct_goose-0.0.7-py3-none-any.whl (12.7 kB view details)

Uploaded Python 3

File details

Details for the file instruct_goose-0.0.7.tar.gz.

File metadata

  • Download URL: instruct_goose-0.0.7.tar.gz
  • Upload date:
  • Size: 11.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.6

File hashes

Hashes for instruct_goose-0.0.7.tar.gz
Algorithm Hash digest
SHA256 532aa9676e27e9e8c570d5663bb6e2e55de1765fcda2c3c0f1b666cfb0c05877
MD5 e8349b2024d7777732fef99094f6545d
BLAKE2b-256 9b4e9bd9eafab6ba2a564f741645513210b87caec3ffccfae56500b8a7f29e27

See more details on using hashes here.

File details

Details for the file instruct_goose-0.0.7-py3-none-any.whl.

File metadata

File hashes

Hashes for instruct_goose-0.0.7-py3-none-any.whl
Algorithm Hash digest
SHA256 d04fd839a81f82ca03272d94d74c5f34ec52c9458292a55e64d1a85a51ec0bcd
MD5 e7fc7e0ad43baaf91da421d8d14ed41a
BLAKE2b-256 14f027ef27a25ed93d747363926cb4516b146317889f408fc18291a1e8228459

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page