Fine-tune LLM agents with online reinforcement learning
Project description
Fine-tune LLM agents with online reinforcement learning
🔗 Agents for Web Data Extraction • 🐦 Twitter
LlamaGym
"Agents" originated in reinforcement learning, where they learn by interacting with an environment and receiving a reward signal. However, LLM-based agents today do not learn online (i.e. continuously in real time) via reinforcement.
OpenAI created Gym to standardize and simplify RL environments, but if you try dropping an LLM-based agent into a Gym environment for training, you'd find it's still quite a bit of code to handle LLM conversation context, episode batches, reward assignment, PPO setup, and more.
LlamaGym seeks to simplify fine-tuning LLM agents with RL. Right now, it's a single Agent
abstract class that handles all the issues mentioned above, letting you quickly iterate and experiment with agent prompting & hyperparameters across any Gym environment.
Usage
Fine-tuning an LLM-based agent to play in a Gym-style environment with RL has never been easier! Once you install LlamaGym...
pip install llamagym
First, implement 3 abstract methods on the Agent class:
from llamagym import Agent
class BlackjackAgent(Agent):
def get_system_prompt(self) -> str:
return "You are an expert blackjack player."
def format_observation(self, observation) -> str:
return f"Your current total is {observation[0]}"
def extract_action(self, response: str):
return 0 if "stick" in response else 1
Then, define your base LLM (as you would for any fine-tuning job) and instantiate your agent:
model = AutoModelForCausalLMWithValueHead.from_pretrained("Llama-2-7b").to(device)
tokenizer = AutoTokenizer.from_pretrained("Llama-2-7b")
agent = BlackjackAgent(model, tokenizer, device)
Finally, write your RL loop as usual and simply call your agent to act, reward, and terminate:
env = gym.make("Blackjack-v1")
for episode in trange(5000):
observation, info = env.reset()
done = False
while not done:
action = agent.act(observation) # act based on observation
observation, reward, terminated, truncated, info = env.step(action)
agent.assign_reward(reward) # provide reward to agent
done = terminated or truncated
train_stats = agent.terminate_episode() # trains if batch is full
Some reminders:
- above code snippets are mildly simplified above but a fully working example is available in
examples/blackjack.py
- getting online RL to converge is notoriously difficult so you'll have to mess with hyperparameters to see improvement
- your model may also benefit from a supervised fine-tuning stage on sampled trajectories before running RL (we may add this feature in the future)
- our implementation values simplicity so is not as compute efficient as e.g. Lamorel, but easier to start playing around with
- LlamaGym is a weekend project and still a WIP, but we love contributions!
Relevant Work
- Grounding Large Language Models with Online Reinforcement Learning
- True Knowledge Comes from Practice: Aligning LLMs with Embodied Environments via Reinforcement Learning
Citation
bibtex
@misc{pandey2024llamagym,
title = {LlamaGym: Fine-tune LLM agents with Online Reinforcement Learning},
author = {Rohan Pandey},
year = {2024},
howpublished = {GitHub},
url = {https://github.com/KhoomeiK/LlamaGym}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file llamagym-0.1.1.tar.gz
.
File metadata
- Download URL: llamagym-0.1.1.tar.gz
- Upload date:
- Size: 5.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.9.18
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 360a0b6315261018bf06aa6779bdb1c7205d07b20bda2ab4622076795109fc31 |
|
MD5 | 0c3ddf2d9665d743013cba3fbafed8ff |
|
BLAKE2b-256 | b3f3c634acf96c26fffde64cfcaddb92b8e5102be67cd2f14a024cfe835cec5b |
File details
Details for the file llamagym-0.1.1-py3-none-any.whl
.
File metadata
- Download URL: llamagym-0.1.1-py3-none-any.whl
- Upload date:
- Size: 5.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.9.18
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 24cf73effc11cc9d1fd5bef8d5d59816982c39ef8147a95b203f04094f04f046 |
|
MD5 | a47312d15e969c981f0655176711536e |
|
BLAKE2b-256 | f72fe680605d96e1159ead0b06a82868477ad139acb5f8b95f72f5cab0b3d111 |