Skip to main content

Fine-tune LLM agents with online reinforcement learning

Project description

Llama Gym

Fine-tune LLM agents with online reinforcement learning

Python Version

🔗 Agents for Web Data Extraction   •   🐦 Twitter

LlamaGym

"Agents" originated in reinforcement learning, where they learn by interacting with an environment and receiving a reward signal. However, LLM-based agents today do not learn online (i.e. continuously in real time) via reinforcement.

OpenAI created Gym to standardize and simplify RL environments, but if you try dropping an LLM-based agent into a Gym environment for training, you'd find it's still quite a bit of code to handle LLM conversation context, episode batches, reward assignment, PPO setup, and more.

LlamaGym seeks to simplify fine-tuning LLM agents with RL. Right now, it's a single Agent abstract class that handles all the issues mentioned above, letting you quickly iterate and experiment with agent prompting & hyperparameters across any Gym environment.

Usage

Fine-tuning an LLM-based agent to play in a Gym-style environment with RL has never been easier! Once you install LlamaGym...

pip install llamagym

First, implement 3 abstract methods on the Agent class:

from llamagym import Agent

class BlackjackAgent(Agent):
    def get_system_prompt(self) -> str:
        return "You are an expert blackjack player."

    def format_observation(self, observation) -> str:
        return f"Your current total is {observation[0]}"

    def extract_action(self, response: str):
        return 0 if "stick" in response else 1

Then, define your base LLM (as you would for any fine-tuning job) and instantiate your agent:

model = AutoModelForCausalLMWithValueHead.from_pretrained("Llama-2-7b").to(device)
tokenizer = AutoTokenizer.from_pretrained("Llama-2-7b")
agent = BlackjackAgent(model, tokenizer, device)

Finally, write your RL loop as usual and simply call your agent to act, reward, and terminate:

env = gym.make("Blackjack-v1")

for episode in trange(5000):
    observation, info = env.reset()
    done = False

    while not done:
        action = agent.act(observation) # act based on observation
        observation, reward, terminated, truncated, info = env.step(action)
        agent.assign_reward(reward) # provide reward to agent
        done = terminated or truncated

    train_stats = agent.terminate_episode() # trains if batch is full

Some reminders:

  • above code snippets are mildly simplified above but a fully working example is available in examples/blackjack.py
  • getting online RL to converge is notoriously difficult so you'll have to mess with hyperparameters to see improvement
    • your model may also benefit from a supervised fine-tuning stage on sampled trajectories before running RL (we may add this feature in the future)
  • our implementation values simplicity so is not as compute efficient as e.g. Lamorel, but easier to start playing around with
  • LlamaGym is a weekend project and still a WIP, but we love contributions!

Relevant Work

Citation

bibtex
@misc{pandey2024llamagym,
  title        = {LlamaGym: Fine-tune LLM agents with Online Reinforcement Learning},
  author       = {Rohan Pandey},
  year         = {2024},
  howpublished = {GitHub},
  url          = {https://github.com/KhoomeiK/LlamaGym}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

llamagym-0.1.1.tar.gz (5.2 kB view details)

Uploaded Source

Built Distribution

llamagym-0.1.1-py3-none-any.whl (5.7 kB view details)

Uploaded Python 3

File details

Details for the file llamagym-0.1.1.tar.gz.

File metadata

  • Download URL: llamagym-0.1.1.tar.gz
  • Upload date:
  • Size: 5.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.9.18

File hashes

Hashes for llamagym-0.1.1.tar.gz
Algorithm Hash digest
SHA256 360a0b6315261018bf06aa6779bdb1c7205d07b20bda2ab4622076795109fc31
MD5 0c3ddf2d9665d743013cba3fbafed8ff
BLAKE2b-256 b3f3c634acf96c26fffde64cfcaddb92b8e5102be67cd2f14a024cfe835cec5b

See more details on using hashes here.

File details

Details for the file llamagym-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: llamagym-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 5.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.9.18

File hashes

Hashes for llamagym-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 24cf73effc11cc9d1fd5bef8d5d59816982c39ef8147a95b203f04094f04f046
MD5 a47312d15e969c981f0655176711536e
BLAKE2b-256 f72fe680605d96e1159ead0b06a82868477ad139acb5f8b95f72f5cab0b3d111

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page