Reinforcement learning environments for fine-tuning language models for reasoning tasks.
Project description
aigym
Self-supervised reinforcement learning environments for LLM fine-tuning
aigym is a library that provides a suite of novel reinforcement learning (RL) environments
for the purpose of fine-tuning pre-trained language models for various reasoning tasks.
Built on top of the gymnasium API, the objective of this project is to expose a light-weight and extensible environments to fine-tune language models with techniques like PPO and GRPO.
It is designed to complement training frameworks like trl, transformers, pytorch, and pytorch lightning.
See the project roadmap here
Installation
pip install aigym
Development Installation
Install uv:
pip install uv
Create a virtual environment:
uv venv --python 3.12
Activate the virtual environment:
source .venv/bin/activate
Install the package:
uv sync --extra ollama --group dev
Install ollama to run a local model: https://ollama.com/download
Quickstart
from typing import Generator
import ollama
from aigym.agent import Agent
from aigym.env import WikipediaGymEnv
# define a policy function for the agent using Ollama
def policy(prompt: str) -> Generator[str, None, None]:
for chunk in ollama.generate(
model="gemma3:1b",
prompt=prompt,
stream=True,
):
yield chunk.response
# initialize the agent with the policy function in streaming mode
agent = Agent(policy=policy, stream=True)
# initialize the wikipedia maze environment
env = WikipediaGymEnv(n_hops=2)
# create a travel path between two pages that are two hops away
observation, info = env.reset()
# allow the agent to take 10 steps to try to find the target page
for step in range(10):
# generate an action
action = agent.act(observation)
if action.action is None:
print(f"No valid action taken at step {step}")
continue
# take a step in the environment
observation, reward, terminated, truncated, info = env.step(action)
# break early if the episode is terminated
if terminated or truncated:
print(f"Episode terminated or truncated at step {step}")
break
Usage
The examples directory contains examples on how to use the aigym environments.
Run an ollama-based agent on the Wikipedia maze environment:
Basic example: inference only
This example uses ollama to run a local model and performs rollouts of the
Wikipedia maze environment.
python examples/ollama_agent.py
Training example
This example uses the examples/agent_training.py script to train a small model
on the Wikipedia maze environment.
python examples/agent_training.py --model_id google/gemma-3-270m-it
[!NOTE] Because the model is low capacity, it may take some time for it to generate any valid actions at all, since the action space requires outputting correctly formatted
<think>and<answer>tags, where the<answer>contains valid json.
Training on Flyte
Flyte is an AI orchestration platform that provides an easy way to run workloads on the cloud, including data processing, model training, model inference, and agentic pipelines.
You can train an agent on a Flyte cluster using the examples/agent_training_flyte.py example:
Install flyte:
uv pip install '.[flyte]'
Then create a configuration:
flyte create config \
--endpoint demo.hosted.unionai.cloud \
--builder remote \
--project aigym \
--domain development
[!NOTE] Modify the
--endpointflag to point to your Flyte cluster.
This will create a config.yaml file in the current directory.
Basic example:
This is the easiest difficulty setting that goes 1 hop away from the start url.
PYTHONPATH=. python examples/agent_training_flyte.py \
--n_hops 1 \
--model_id google/gemma-3-12b-it \
--enable_gradient_checkpointing
Increased difficulty setting: five hops away
PYTHONPATH=. python examples/agent_training_flyte.py \
--model_id google/gemma-3-12b-it \
--enable_gradient_checkpointing \
--n_episodes 100 \
--lora_r 64 \
--n_hops 5 \
--n_tries_per_hop 4 \
--rollout_min_new_tokens 256 \
--rollout_max_new_tokens 512 \
--group_size 4 \
--wandb_project aigym-agent-training \
--attn_implementation eager
Anchor the start url to the "Mammal" page
PYTHONPATH=. python examples/agent_training_flyte.py \
--model_id google/gemma-3-12b-it \
--start_url_anchors '["https://en.wikipedia.org/wiki/Mammal"]' \
--enable_gradient_checkpointing \
--n_episodes 1000 \
--lr 1e-3 \
--max_grad_norm 4.0 \
--lora_r 64 \
--n_hops 2 \
--n_tries_per_hop 2 \
--static_env \
--rollout_min_new_tokens 256 \
--rollout_max_new_tokens 512 \
--group_size 4 \
--wandb_project aigym-agent-training \
--attn_implementation eager
Sweep with different number of hops
PYTHONPATH=. python examples/agent_training_flyte_sweep.py \
--model_id google/gemma-3-12b-it \
--enable_gradient_checkpointing \
--n_episodes 100 \
--n_hops_list "[1, 2, 3, 4, 5]" \
--n_tries_per_hop 1 \
--rollout_min_new_tokens 256 \
--rollout_max_new_tokens 1024 \
--group_size 4 \
--wandb_project aigym-agent-training \
--attn_implementation eager
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file aigym-0.0.3.tar.gz.
File metadata
- Download URL: aigym-0.0.3.tar.gz
- Upload date:
- Size: 6.1 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
96bb108c28fb8a709b934f741760f7a7338dc15849c483814b0eff32b0fc855a
|
|
| MD5 |
0fb737ab52c90b8f248d081cfdba13c8
|
|
| BLAKE2b-256 |
409a4bb8942f25c947289f08be219d6037c89323a873d23b0592859800d83d30
|
File details
Details for the file aigym-0.0.3-py3-none-any.whl.
File metadata
- Download URL: aigym-0.0.3-py3-none-any.whl
- Upload date:
- Size: 21.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ac9c5fef1066dd525560c3e43e37ed22eaa5920770cfe3d6f3b04c034860c2ec
|
|
| MD5 |
7a473b1bdee9840edb59643859211976
|
|
| BLAKE2b-256 |
94889692dd8115fbac042c2ec98d28adba6f0a14e65d7066446d69b868ca16ad
|