Skip to main content

A Python library for reinforcement learning algorithms

Project description

🎯 NeatRL

A clean, modern Python library for reinforcement learning algorithms

NeatRL provides high-quality implementations of popular RL algorithms with a focus on simplicity, performance, and ease of use. Built with PyTorch and designed for both research and production use.

✨ Features

  • 📊 Experiment Tracking: Built-in support for Weights & Biases logging
  • 🎮 Gymnasium Compatible: Works with Gymnasium environments and adding many more!
  • 🎯 Atari Support: Full support for Atari games with automatic CNN architectures
  • Parallel Training: Vectorized environments for faster data collection
  • 🔧 Easy to Extend: Modular design for adding new algorithms
  • 📈 State-of-the-Art: Implements modern RL techniques and best practices
  • 🎥 Video Recording: Automatic video capture and WandB integration
  • 📉 Advanced Logging: Per-layer gradient monitoring and comprehensive metrics

🏗️ Supported Algorithms

Current Implementations

  • DQN (Deep Q-Network) - Classic value-based RL algorithm

    • Support for discrete action spaces
    • Experience replay and target networks
    • Atari preprocessing and frame stacking
  • Dueling DQN - Enhanced DQN with separate value and advantage streams

    • Improved learning stability
    • Better performance on complex environments
  • REINFORCE - Policy gradient method for discrete and continuous action spaces

    • NEW: Atari game support with automatic CNN architecture
    • NEW: Parallel environment training (n_envs support)
    • NEW: Continuous action space support
    • NEW: Per-layer gradient logging
    • Episode-based Monte Carlo returns
    • Variance reduction through baseline subtraction
  • PPO (Proximal Policy Optimization) - State-of-the-art policy gradient method with GAE

    • NEW: Full PPO implementation with Generalized Advantage Estimation (GAE)
    • NEW: Support for both discrete and continuous action spaces
    • NEW: Atari game support with automatic CNN architecture (train_ppo_cnn)
    • NEW: Clipped surrogate objective for stable policy updates
    • NEW: Value function clipping and entropy regularization
    • NEW: Vectorized environments for parallel training
    • NEW: Comprehensive WandB logging with advantage distributions
    • NEW: Per-layer gradient monitoring and video recording
    • Generalized Advantage Estimation with configurable lambda
    • Flexible network architecture with custom actor/critic classes
  • PPO-RND (Proximal Policy Optimization with Random Network Distillation) - State-of-the-art exploration method

    • NEW: Intrinsic motivation through novelty detection
    • NEW: Combined extrinsic and intrinsic rewards for better exploration
    • NEW: Support for both discrete and continuous action spaces
    • NEW: Automatic render mode handling for video recording
    • NEW: Comprehensive WandB logging with global step tracking
    • PPO with clipped surrogate objective
    • Vectorized environments for parallel training
    • Intrinsic reward normalization and advantage calculation
  • More algorithms coming soon...

📦 Installation

python -m venv neatrl-env
source neatrl-env/bin/activate 

pip install neatrl"[classic,box2d,atari]"

🚀 Quick Start

Train DQN on CartPole

from neatrl import train_dqn

model = train_dqn(
    env_id="CartPole-v1",
    total_timesteps=10000,
    seed=42
)

Train REINFORCE on Atari

from neatrl import train_reinforce

model = train_reinforce(
    env_id="BreakoutNoFrameskip-v4",
    total_steps=2000,
    atari_wrapper=True,  # Automatic Atari preprocessing
    n_envs=4,            # Parallel environments
    use_wandb=True,      # Track with WandB
    seed=42
)

Train REINFORCE with Continuous Actions

from neatrl import train_reinforce
import torch.nn as nn

# Custom policy for continuous actions
class ContinuousPolicyNet(nn.Module):
    def __init__(self, state_space, action_space):
        super().__init__()
        self.fc1 = nn.Linear(state_space, 32)
        self.fc2 = nn.Linear(32, 16)
        self.mean = nn.Linear(16, action_space)
        self.logstd = nn.Linear(16, action_space)
    
    def forward(self, x):
        x = torch.relu(self.fc2(torch.relu(self.fc1(x))))
        return self.mean(x), torch.exp(self.logstd(x))
    
    def get_action(self, x):
        mean, std = self.forward(x)
        dist = torch.distributions.Normal(mean, std)
        action = dist.sample()
        return action, dist.log_prob(action).sum(dim=-1)

model = train_reinforce(
    env_id="Pendulum-v1",
    total_steps=2000,
    custom_agent=ContinuousPolicyNet(3, 1),
    seed=42
)

Train PPO on Classic Control

from neatrl import train_ppo

model = train_ppo(
    env_id="CartPole-v1",
    total_timesteps=50000,
    n_envs=4,           # Parallel environments
    GAE=0.95,           # Generalized Advantage Estimation lambda
    clip_value=0.2,     # PPO clipping parameter
    use_wandb=True,     # Track with WandB
    seed=42
)

Train PPO on Atari

from neatrl import train_ppo_cnn

model = train_ppo_cnn(
    env_id="BreakoutNoFrameskip-v4",
    total_timesteps=100000,
    n_envs=8,           # More parallel environments for Atari
    atari_wrapper=True, # Automatic Atari preprocessing
    use_wandb=True,     # Track with WandB
    seed=42
)

📚 Documentation

📖 Complete Documentation

The docs include:

  • Detailed usage examples
  • Hyperparameter tuning guides
  • Environment compatibility
  • Experiment tracking setup
  • Troubleshooting tips

🤝 Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

Development Setup

git clone https://github.com/YuvrajSingh-mist/NeatRL.git
cd NeatRL
pip install -e .[dev]

📋 Changelog

[0.2.1] - 2025-12-17

  • Added: REINFORCE Atari support with automatic CNN architecture
  • Added: Parallel environment training (n_envs parameter)
  • Added: Continuous action space support for REINFORCE
  • Added: Advanced gradient logging (per-layer norms, clip ratios)
  • Changed: REINFORCE parameter episodestotal_steps
  • Fixed: Multi-environment action handling for vectorized training

[0.2.0] - 2025-12-14

  • Added: Grid environment support with automatic one-hot encoding
  • Changed: Renamed record to capture_video for consistency

[0.1.4] - 2025-12-13

  • Added: Custom agent support for DQN training
  • Added: Network architecture display using torchinfo
  • Improved: Error handling for custom agent constructors
  • Changed: Agent parameter now accepts nn.Module subclasses

[0.1.3] - 2025-12-01

  • Initial release with DQN implementation
  • Weights & Biases integration
  • Video recording capabilities
  • Comprehensive documentation

For the complete changelog, see CHANGELOG.md.

📄 License

This project is licensed under the MIT License - see the LICENSE file for details.


Made with ❤️ for the RL community

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

neatrl-0.5.0.tar.gz (42.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

neatrl-0.5.0-py3-none-any.whl (49.7 kB view details)

Uploaded Python 3

File details

Details for the file neatrl-0.5.0.tar.gz.

File metadata

  • Download URL: neatrl-0.5.0.tar.gz
  • Upload date:
  • Size: 42.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.2

File hashes

Hashes for neatrl-0.5.0.tar.gz
Algorithm Hash digest
SHA256 a3daf236118be77336a28dc04f9fe024fc5c7bc27f8054f80b858cb7556859ef
MD5 b102d9ba3276ac6e368954d094a627d7
BLAKE2b-256 4aa6a016d9ff7d5345134acd4c7b191a9b977e7be90894fea06d0f719c5a039d

See more details on using hashes here.

File details

Details for the file neatrl-0.5.0-py3-none-any.whl.

File metadata

  • Download URL: neatrl-0.5.0-py3-none-any.whl
  • Upload date:
  • Size: 49.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.2

File hashes

Hashes for neatrl-0.5.0-py3-none-any.whl
Algorithm Hash digest
SHA256 0d9cbb452b3dfcdf72d66f4d853e28b31fec3ddf92d63399fcb574366454b43e
MD5 45130047d60824994d12fc7e104170df
BLAKE2b-256 5dbd6af843c02e92125a73fd86deac3a3b749d5f6dbcf04a81aaec4f32234c21

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page