Skip to main content

Fine-tuning LLM agents with RL

Project description

Llama Gym

Fine-tune LLM agents with online reinforcement learning

🐦 Twitter

LlamaGym

LLM agents are called agents—they should be trainable with RL in classic Gym-style environments. But if you try, you'd find it's quite a bit of code to handle LLM conversation context, episode batches, reward assignment, PPO setup, and more.

LlamaGym seeks to simplify fine-tuning LLM agents with RL. Right now, it's a single Agent abstract class that handles all the issues mentioned above, letting you quickly iterate and experiment with agent prompting & hyperparameters across any Gym environment.

Usage

Fine-tuning an LLM-based agent to play in a Gym-style environment with RL has never been easier!

First, implement 3 abstract methods on the Agent class:

class BlackjackAgent(Agent):
    def get_system_prompt(self) -> str:
        return "You are an expert blackjack player."

    def format_observation(self, observation) -> str:
        return f"Your current total is {observation[0]}"

    def extract_action(self, response: str):
        return 0 if "stick" in response else 1

Then, define your base LLM (as you would for any fine-tuning job) and instantiate your agent:

model = AutoModelForCausalLMWithValueHead.from_pretrained("Llama-2-7b").to(device)
tokenizer = AutoTokenizer.from_pretrained("Llama-2-7b")
agent = BlackjackAgent(model, tokenizer, device)

Finally, write your RL loop as usual and simply call your agent to act, reward, and terminate:

env = gym.make("Blackjack-v1")

for episode in trange(5000):
    observation, info = env.reset()
    done = False

    while not done:
        action = agent.act(observation) # act based on observation
        observation, reward, terminated, truncated, info = env.step(action)
        agent.assign_reward(reward) # provide reward to agent
        done = terminated or truncated

    train_stats = agent.terminate_episode() # trains if batch is full

Note: the above code snippets are mildly simplified but a fully working example is available in examples/blackjack.py.

TODO

  • set up logging on examples
  • understand the PPO logs and fix hyperparams
  • run wandb hyperparam sweep

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

llamagym-0.1.0.tar.gz (4.4 kB view hashes)

Uploaded Source

Built Distribution

llamagym-0.1.0-py3-none-any.whl (5.0 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page