Group Relative Policy Optimization for Efficient RL Training
Project description
🚀 OptimRL: Group Relative Policy Optimization
OptimRL is a high-performance reinforcement learning library that introduces a groundbreaking algorithm, Group Relative Policy Optimization (GRPO). Designed to streamline the training of RL agents, GRPO eliminates the need for a critic network while ensuring robust performance with group-based advantage estimation and KL regularization. Whether you're building an AI to play games, optimize logistics, or manage resources, OptimRL provides state-of-the-art efficiency and stability.
🏅 Badges
🌟 Features
Why Choose OptimRL?
-
🚫 Critic-Free Learning
Traditional RL methods require training both an actor and a critic network. GRPO eliminates this dual-network requirement, cutting model complexity by 50% while retaining top-tier performance. -
👥 Group-Based Advantage Estimation
GRPO introduces a novel way to normalize rewards within groups of experiences. This ensures:- Stable training across diverse reward scales.
- Adaptive behavior for varying tasks and environments.
-
📏 KL Regularization
Prevent policy collapse with GRPO's built-in KL divergence regularization, ensuring:- Smoothed updates for policies.
- Reliable and stable learning in any domain.
-
⚡ Vectorized NumPy Operations with PyTorch Tensor Integration
OptimRL leverages NumPy's vectorized operations and PyTorch's tensor computations with GPU acceleration for maximum performance. This hybrid implementation provides:- 10-100x speedups over pure Python through optimized array programming
- Seamless CPU/GPU execution via PyTorch backend
- Native integration with deep learning workflows
- Full automatic differentiation support
-
🔄 Experience Replay Buffer
Improve sample efficiency with built-in experience replay:- Learn from past experiences multiple times
- Reduce correlation between consecutive samples
- Configurable buffer capacity and batch sizes
-
🔄 Continuous Action Space Support
Train agents in environments with continuous control:- Gaussian policy implementation for continuous actions
- Configurable action bounds
- Adaptive standard deviation for exploration
🛠️ Installation
For End Users
Simply install from PyPI:
pip install optimrl
For Developers
Clone the repository and set up a development environment:
git clone https://github.com/subaashnair/optimrl.git
cd optimrl
pip install -e '.[dev]'
⚡ Quick Start
Discrete Action Space Example (CartPole)
import torch
import torch.nn as nn
import torch.optim as optim
import gym
from optimrl import create_agent
# Define a simple policy network
class PolicyNetwork(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.network = nn.Sequential(
nn.Linear(input_dim, 64),
nn.ReLU(),
nn.Linear(64, output_dim),
nn.LogSoftmax(dim=-1)
)
def forward(self, x):
return self.network(x)
# Create environment and network
env = gym.make('CartPole-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
policy = PolicyNetwork(state_dim, action_dim)
# Create GRPO agent
agent = create_agent(
"grpo",
policy_network=policy,
optimizer_class=optim.Adam,
learning_rate=0.001,
gamma=0.99,
grpo_params={"epsilon": 0.2, "beta": 0.01},
buffer_capacity=10000,
batch_size=32
)
# Training loop
state, _ = env.reset()
for step in range(1000):
action = agent.act(state)
next_state, reward, done, truncated, _ = env.step(action)
agent.store_experience(reward, done)
if done or truncated:
state, _ = env.reset()
agent.update() # Update policy after episode ends
else:
state = next_state
Complete CartPole Implementation
For a complete implementation of CartPole with OptimRL, check out our examples in the simple_test directory:
cartpole_simple.py: Basic implementation with GRPOcartpole_improved.py: Improved implementation with tuned parameterscartpole_final.py: Final implementation with optimized performancecartpole_tuned.py: Enhanced implementation with advanced featurescartpole_simple_pg.py: Vanilla Policy Gradient implementation for comparison
The vanilla policy gradient implementation (cartpole_simple_pg.py) achieves excellent performance on CartPole-v1, reaching the maximum reward of 500 consistently. It serves as a useful baseline for comparing against the GRPO implementations.
Continuous Action Space Example (Pendulum)
import torch
import torch.nn as nn
import torch.optim as optim
import gym
from optimrl import create_agent
# Define a continuous policy network
class ContinuousPolicyNetwork(nn.Module):
def __init__(self, input_dim, action_dim):
super().__init__()
self.shared_layers = nn.Sequential(
nn.Linear(input_dim, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU()
)
# Output both mean and log_std for each action dimension
self.output_layer = nn.Linear(64, action_dim * 2)
def forward(self, x):
x = self.shared_layers(x)
return self.output_layer(x)
# Create environment and network
env = gym.make('Pendulum-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
action_bounds = (env.action_space.low[0], env.action_space.high[0])
policy = ContinuousPolicyNetwork(state_dim, action_dim)
# Create Continuous GRPO agent
agent = create_agent(
"continuous_grpo",
policy_network=policy,
optimizer_class=optim.Adam,
action_dim=action_dim,
learning_rate=0.0005,
gamma=0.99,
grpo_params={"epsilon": 0.2, "beta": 0.01},
buffer_capacity=10000,
batch_size=64,
min_std=0.01,
action_bounds=action_bounds
)
# Training loop
state, _ = env.reset()
for step in range(1000):
action = agent.act(state)
next_state, reward, done, truncated, _ = env.step(action)
agent.store_experience(reward, done)
if done or truncated:
state, _ = env.reset()
agent.update() # Update policy after episode ends
else:
state = next_state
📊 Performance Comparison
Our simple policy gradient implementation consistently solves the CartPole-v1 environment in under 1000 episodes, achieving the maximum reward of 500. The GRPO implementations offer competitive performance with additional benefits:
- Lower variance: More stable learning across different random seeds
- Improved sample efficiency: Learns from fewer interactions with the environment
- Better regularization: Prevents policy collapse during training
Kaggle Notebook
You can view the "OptimRL Trading Experiment" notebook on Kaggle:
Alternatively, you can open the notebook locally as an .ipynb file:
Open the OptimRL Trading Experiment Notebook (.ipynb)
🤝 Contributing
We're excited to have you onboard! Here's how you can help improve OptimRL:
- Fork the repo.
- Create a feature branch:
git checkout -b feature/AmazingFeature
- Commit your changes:
git commit -m 'Add some AmazingFeature'
- Push to the branch:
git push origin feature/AmazingFeature
- Open a Pull Request.
Before submitting, make sure you run all tests:
pytest tests/
📜 License
This project is licensed under the MIT License. See the LICENSE file for details.
📚 Citation
If you use OptimRL in your research, please cite:
@software{optimrl2024,
title={OptimRL: Group Relative Policy Optimization},
author={Subashan Nair},
year={2024},
url={https://github.com/subaashnair/optimrl}
}
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file optimrl-1.0.1.tar.gz.
File metadata
- Download URL: optimrl-1.0.1.tar.gz
- Upload date:
- Size: 37.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7e2911b446865d6429e33c7ba3408833b135ac57e1e088da8bac779ebd396d6f
|
|
| MD5 |
67baad39f7fd646fe905cf99e6912377
|
|
| BLAKE2b-256 |
4a4ffbf4a15589f7855f826a2154e6acf06b9028a68192e18252e25b1845e05f
|
File details
Details for the file optimrl-1.0.1-py3-none-any.whl.
File metadata
- Download URL: optimrl-1.0.1-py3-none-any.whl
- Upload date:
- Size: 13.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1539f8bcdb0a9760b741a572e17197bfcd3d7ecc22a35297724e6c061573e3b6
|
|
| MD5 |
4bab6009c5cea2289c9b2ad908104898
|
|
| BLAKE2b-256 |
10e2a554a15836ee88db89531f593c1725b3295c3ce8451fa5e63db56dea2a3e
|