Skip to main content

Fine-tune LLM agents with online reinforcement learning

Project description

Llama Gym

Fine-tune LLM agents with online reinforcement learning

Python Version

🔗 Agents for Web Data Extraction   •   🐦 Twitter

LlamaGym

"Agents" originated in reinforcement learning, where they learn by interacting with an environment and receiving a reward signal. However, LLM-based agents today do not learn online (i.e. continuously in real time) via reinforcement.

OpenAI created Gym to standardize and simplify RL environments, but if you try dropping an LLM-based agent into a Gym environment for training, you'd find it's still quite a bit of code to handle LLM conversation context, episode batches, reward assignment, PPO setup, and more.

LlamaGym seeks to simplify fine-tuning LLM agents with RL. Right now, it's a single Agent abstract class that handles all the issues mentioned above, letting you quickly iterate and experiment with agent prompting & hyperparameters across any Gym environment.

Usage

Fine-tuning an LLM-based agent to play in a Gym-style environment with RL has never been easier! Once you install LlamaGym...

pip install llamagym

First, implement 3 abstract methods on the Agent class:

from llamagym import Agent

class BlackjackAgent(Agent):
    def get_system_prompt(self) -> str:
        return "You are an expert blackjack player."

    def format_observation(self, observation) -> str:
        return f"Your current total is {observation[0]}"

    def extract_action(self, response: str):
        return 0 if "stick" in response else 1

Then, define your base LLM (as you would for any fine-tuning job) and instantiate your agent:

model = AutoModelForCausalLMWithValueHead.from_pretrained("Llama-2-7b").to(device)
tokenizer = AutoTokenizer.from_pretrained("Llama-2-7b")
agent = BlackjackAgent(model, tokenizer, device)

Finally, write your RL loop as usual and simply call your agent to act, reward, and terminate:

env = gym.make("Blackjack-v1")

for episode in trange(5000):
    observation, info = env.reset()
    done = False

    while not done:
        action = agent.act(observation) # act based on observation
        observation, reward, terminated, truncated, info = env.step(action)
        agent.assign_reward(reward) # provide reward to agent
        done = terminated or truncated

    train_stats = agent.terminate_episode() # trains if batch is full

Some reminders:

  • above code snippets are mildly simplified above but a fully working example is available in examples/blackjack.py
  • getting online RL to converge is notoriously difficult so you'll have to mess with hyperparameters to see improvement
    • your model may also benefit from a supervised fine-tuning stage on sampled trajectories before running RL (we may add this feature in the future)
  • our implementation values simplicity so is not as compute efficient as e.g. Lamorel, but easier to start playing around with
  • LlamaGym is a weekend project and still a WIP, but we love contributions!

Relevant Work

Citation

bibtex
@misc{pandey2024llamagym,
  title        = {LlamaGym: Fine-tune LLM agents with Online Reinforcement Learning},
  author       = {Rohan Pandey},
  year         = {2024},
  howpublished = {GitHub},
  url          = {https://github.com/KhoomeiK/LlamaGym}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

llamagym-0.1.1.tar.gz (5.2 kB view hashes)

Uploaded Source

Built Distribution

llamagym-0.1.1-py3-none-any.whl (5.7 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page