TextRL - use reinforcement learning to adjust text generation results.
Project description
TextRL
Text generation with reinforcement learning using huggingface's transformer.
RLHF (Reinforcement Learning with Human Feedback)
Implementation of ChatGPT for human interaction to improve generation model with reinforcement learning.
Introduction
This project is trying to use reinforcement learning to adjust text generation results. It is based on any text-generation model on huggingaface's transformer with PFRL and OpenAI GYM.
Example 1
Run on 7B multi-lingual bloom: bigscience/bloomz-7b1-mt
import pfrl
from textrl import TextRLEnv, TextRLActor
from transformers import AutoModelForCausalLM, AutoTokenizer
checkpoint = "bigscience/bloomz-7b1-mt"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype="auto", device_map="auto")
model = model.cuda()
class MyRLEnv(TextRLEnv):
def get_reward(self, input_item, predicted_list, finish): # predicted will be the list of predicted token
reward = 0
if finish:
reward = len(predicted_list)
return reward
observaton_list = [["explain how attention work in seq2seq model"]]
env = MyRLEnv(model, tokenizer, observation_input=observaton_list)
actor = TextRLActor(env, model, tokenizer,act_deterministically=True)
agent = actor.agent_ppo(update_interval=2, minibatch_size=2, epochs=10)
print(actor.predict(observaton_list[0]))
pfrl.experiments.train_agent_batch_with_evaluation(
agent,
env,
steps=100,
eval_n_steps=None,
eval_n_episodes=1,
eval_interval=2,
outdir='bloom—test',
)
print(actor.predict(observaton_list[0]))
Example 2
Controllable generation via RL to let Elon Musk speak ill of DOGE
before: i think dogecoin is a great idea.
after: i think dogecoin is a great idea, but I think it is a little overused.
Installation
pip install
pip install pfrl@git+https://github.com/voidful/pfrl.git
pip install textrl
Build from source
git clone and cd into this project.
pip install -e .
Usage
init agent and environment
import torch
from textrl import TextRLEnv, TextRLActor
from transformers import AutoModelForCausalLM, AutoTokenizer
checkpoint = "bigscience/bloomz-7b1-mt"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype="auto", device_map="auto")
model = model.cuda()
setup reward function for environment
- predicted(list[str]): will be the list of predicted token
- finish(bool): it met the end of sentence or not
class MyRLEnv(TextRLEnv):
def get_reward(self, input_item, predicted_list, finish): # predicted will be the list of predicted token
if "[UNK]" in predicted_list:
reward = -1
else:
reward = 1
return reward
prepare for training
- observaton_list should be a list of all possible input string for model training
eg:
observaton_list = [['testing sent 1'],['testing sent 2']]
env = MyRLEnv(model, tokenizer, observation_input=observaton_list)
actor = TextRLActor(env, model, tokenizer)
agent = actor.agent_ppo(update_interval=10, minibatch_size=2000, epochs=20)
Train
n_episodes = 1000
max_episode_len = 200 # max sentence length
for i in range(1, n_episodes + 1):
obs = env.reset()
R = 0
t = 0
while True:
action = agent.act(obs)
obs, reward, done, pred = env.step(action)
R += reward
t += 1
reset = t == max_episode_len
agent.observe(obs, reward, done, reset)
if done or reset:
break
if i % 10 == 0:
print('episode:', i, 'R:', R)
if i % 50 == 0:
print('statistics:', agent.get_statistics())
print('Finished.')
another way to train
import logging
import sys
logging.basicConfig(level=logging.INFO, stream=sys.stdout, format='')
pfrl.experiments.train_agent_with_evaluation(
agent,
env,
steps=1000,
eval_n_steps=None,
eval_n_episodes=1500,
train_max_episode_len=50,
eval_interval=10000,
outdir='somewhere',
)
prediction
agent.load("somewhere/best") # loading the best model
actor.predict("input text")
dump trained model to huggingface's model
textrl-dump --model ./model_path_before_rl --rl ./rl_path --dump ./output_dir
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.