A Python library for reinforcement learning algorithms
Project description
🎯 NeatRL
A clean, modern Python library for reinforcement learning algorithms
NeatRL provides high-quality implementations of popular RL algorithms with a focus on simplicity, performance, and ease of use. Built with PyTorch and designed for both research and production use.
✨ Features
- 📊 Experiment Tracking: Built-in support for Weights & Biases logging
- 🎮 Gymnasium Compatible: Works with Gymnasium environments and adding many more!
- 🎯 Atari Support: Full support for Atari games with automatic CNN architectures
- ⚡ Parallel Training: Vectorized environments for faster data collection
- 🔧 Easy to Extend: Modular design for adding new algorithms
- 📈 State-of-the-Art: Implements modern RL techniques and best practices
- 🎥 Video Recording: Automatic video capture and WandB integration
- 📉 Advanced Logging: Per-layer gradient monitoring and comprehensive metrics
🏗️ Supported Algorithms
Current Implementations
-
DQN (Deep Q-Network) - Classic value-based RL algorithm
- Support for discrete action spaces
- Experience replay and target networks
- Atari preprocessing and frame stacking
-
Dueling DQN - Enhanced DQN with separate value and advantage streams
- Improved learning stability
- Better performance on complex environments
-
REINFORCE - Policy gradient method for discrete and continuous action spaces
- ✨ NEW: Atari game support with automatic CNN architecture
- ✨ NEW: Parallel environment training (
n_envssupport) - ✨ NEW: Continuous action space support
- ✨ NEW: Per-layer gradient logging
- Episode-based Monte Carlo returns
- Variance reduction through baseline subtraction
-
PPO-RND (Proximal Policy Optimization with Random Network Distillation) - State-of-the-art exploration method
- ✨ NEW: Intrinsic motivation through novelty detection
- ✨ NEW: Combined extrinsic and intrinsic rewards for better exploration
- ✨ NEW: Support for both discrete and continuous action spaces
- ✨ NEW: Automatic render mode handling for video recording
- ✨ NEW: Comprehensive WandB logging with global step tracking
- PPO with clipped surrogate objective
- Vectorized environments for parallel training
- Intrinsic reward normalization and advantage calculation
-
More algorithms coming soon...
📦 Installation
python -m venv neatrl-env
source neatrl-env/bin/activate # On Windows use `neatrl-env\Scripts
pip install neatrl"[classic,box2d,atari]""[classic,box2d,atari]"
🚀 Quick Start
Train DQN on CartPole
from neatrl import train_dqn
model = train_dqn(
env_id="CartPole-v1",
total_timesteps=10000,
seed=42
)
Train REINFORCE on Atari
from neatrl import train_reinforce
model = train_reinforce(
env_id="BreakoutNoFrameskip-v4",
total_steps=2000,
atari_wrapper=True, # Automatic Atari preprocessing
n_envs=4, # Parallel environments
use_wandb=True, # Track with WandB
seed=42
)
Train REINFORCE with Continuous Actions
from neatrl import train_reinforce
import torch.nn as nn
# Custom policy for continuous actions
class ContinuousPolicyNet(nn.Module):
def __init__(self, state_space, action_space):
super().__init__()
self.fc1 = nn.Linear(state_space, 32)
self.fc2 = nn.Linear(32, 16)
self.mean = nn.Linear(16, action_space)
self.logstd = nn.Linear(16, action_space)
def forward(self, x):
x = torch.relu(self.fc2(torch.relu(self.fc1(x))))
return self.mean(x), torch.exp(self.logstd(x))
def get_action(self, x):
mean, std = self.forward(x)
dist = torch.distributions.Normal(mean, std)
action = dist.sample()
return action, dist.log_prob(action).sum(dim=-1)
model = train_reinforce(
env_id="Pendulum-v1",
total_steps=2000,
custom_agent=ContinuousPolicyNet(3, 1),
seed=42
)
📚 Documentation
The docs include:
- Detailed usage examples
- Hyperparameter tuning guides
- Environment compatibility
- Experiment tracking setup
- Troubleshooting tips
🤝 Contributing
Contributions are welcome! Please feel free to submit a Pull Request.
Development Setup
git clone https://github.com/YuvrajSingh-mist/NeatRL.git
cd NeatRL
pip install -e .[dev]
📋 Changelog
[0.2.1] - 2025-12-17
- Added: REINFORCE Atari support with automatic CNN architecture
- Added: Parallel environment training (
n_envsparameter) - Added: Continuous action space support for REINFORCE
- Added: Advanced gradient logging (per-layer norms, clip ratios)
- Changed: REINFORCE parameter
episodes→total_steps - Fixed: Multi-environment action handling for vectorized training
[0.2.0] - 2025-12-14
- Added: Grid environment support with automatic one-hot encoding
- Changed: Renamed
recordtocapture_videofor consistency
[0.1.4] - 2025-12-13
- Added: Custom agent support for DQN training
- Added: Network architecture display using torchinfo
- Improved: Error handling for custom agent constructors
- Changed: Agent parameter now accepts nn.Module subclasses
[0.1.3] - 2025-12-01
- Initial release with DQN implementation
- Weights & Biases integration
- Video recording capabilities
- Comprehensive documentation
For the complete changelog, see CHANGELOG.md.
📄 License
This project is licensed under the MIT License - see the LICENSE file for details.
Made with ❤️ for the RL community
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file neatrl-0.4.0.tar.gz.
File metadata
- Download URL: neatrl-0.4.0.tar.gz
- Upload date:
- Size: 23.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
77c153a4d043e5b13f5d48abf801b46d4449f5f60633499fa1661d9cf29df587
|
|
| MD5 |
661a44b776036a41881c0f9b405d4314
|
|
| BLAKE2b-256 |
4e7ed977a132d75722072c59b82dbf8dee4e871f4b49bc346436e319c8814e22
|
File details
Details for the file neatrl-0.4.0-py3-none-any.whl.
File metadata
- Download URL: neatrl-0.4.0-py3-none-any.whl
- Upload date:
- Size: 31.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.2
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
9bf6eb340eb5350ae5b0e4acc5e87bda18348fab39c81e1b5af7cd2eaa6dcf88
|
|
| MD5 |
d5bb1fd39afebd4e872fc626cc076e84
|
|
| BLAKE2b-256 |
885b87d78387b1144bd13d8a51ec2571f22b427d6c81d1c8a52716dc3a79ee44
|