Skip to main content

A flexible and efficient implementation of the Proximal Policy Optimization (PPO) algorithm for reinforcement learning.

Project description

nanoPPO

PyPI Changelog Tests Documentation Status License

nanoPPO is a Python package that provides a simple and efficient implementation of the Proximal Policy Optimization (PPO) algorithm for reinforcement learning. It is designed to support both continuous and discrete action spaces, making it suitable for a wide range of applications.

Installation

You can install nanoPPO directly from PyPI using pip:

pip install nanoPPO

Alternatively, you can clone the repository and install from source:

git clone https://github.com/jamesliu/nanoPPO.git
cd nanoPPO
pip install .

Usage

Here are examples of how to use nanoPPO to train an agent.

On the MountaionCarContinuous-v0 environment:

from nanoppo.train_ppo_agent import train_agent
import pickle

ppo, model_file, metrics_file = train_agent(
    env_name="MountainCarContinuous-v0",
    max_episodes=50,
    policy_lr=0.0005,
    value_lr=0.0005,
    vl_coef=0.5,
    checkpoint_dir="checkpoints",
    checkpoint_interval=10,
    log_interval=10,
    wandb_log=False,
)
ppo.load(model_file)
print("Loaded best weights from", model_file)
metrics = pickle.load(open(metrics_file, "rb"))
print("Loaded metrics from", metrics_file)
best_reward = metrics["best_reward"]
episode = metrics["episode"]
print("best_reward", best_reward, "episode", episode)

Use Custom LR Scheduler and Custom Policy

  • Set Cosine Annealing Learning Rate Scheduler
  • Set CausalAttention Policy instead of Linear Policy
from nanoppo.train_ppo_agent import train_agent
from nanoppo.cosine_lr_scheduler import CosineLRScheduler
from nanoppo.policy.actor_critic_causal_attention import ActorCriticCausalAttention

lr_scheduler=CosineLRScheduler(
    learning_rate=config['cosine_lr'], 
    warmup_iters=config['cosine_warmup_iters'], 
    lr_decay_iters=config['cosine_decay_steps'], 
    min_lr=config['cosine_min_lr'])

policy_class = ActorCriticCausalAttention

ppo, model_file, metrics_file = train_agent(
    env_name=env_name,
    env_config = env_config, 
    max_episodes=config['max_episode'],
    stop_reward=config['stop_reward'],
    policy_class = policy_class,
    lr_scheduler=lr_scheduler,
    policy_lr=config['policy_lr'],
    value_lr=config['value_lr'],
    vl_coef=config['vl_coef'],
    betas = config['betas'],
    n_latent_var=config['n_latent_var'],
    gamma=config['gamma'],
    K_epochs=config['K_epochs'],
    eps_clip=config['eps_clip'],
    el_coef=config['el_coef'],
    checkpoint_dir=checkpoint_dir,
    checkpoint_interval=10,
    log_interval=10,
    wandb_log=wandb_log,
    debug=True)

On the CartPole-v1 environment:

from nanoppo.discrete_action_ppo import PPO
import gym

env = gym.make('CartPole-v1')
ppo = PPO(env.observation_space.shape[0], env.action_space.n)

# Training code here...

Examples

See the examples directory for more comprehensive usage examples.

examples/train_mountaincar.sh

python nanoppo/train_ppo_agent.py --env_name=MountainCarContinuous-v0 --policy_lr=0.0005 --value_lr=0.0005 --max_episodes=50 --vl_coef=0.5 --wandb_log

mountaincar

examples/train_pointmass1d.sh

python nanoppo/train_ppo_agent.py --env_name=PointMass1D-v0 --policy_lr=0.0005 --value_lr=0.0005 --max_episodes=50 --vl_coef=0.5 --wandb_log

examples/train_pointmass2d.sh

python nanoppo/train_ppo_agent.py --env_name=PointMass2D-v0 --policy_lr=0.0005 --value_lr=0.0005 --max_episodes=50 --vl_coef=0.5 --wandb_log

Documentation

Full documentation is available here.

Contributing

We welcome contributions to nanoPPO! If you're interested in contributing, please see our contribution guidelines.

License

nanoPPO is licensed under the Apache License 2.0. See the LICENSE file for more details.

Support

For support, questions, or feature requests, please open an issue on our GitHub repository or contact the maintainers.

Changelog

See the releases page for a detailed changelog of each version.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nanoppo-0.15.post2.tar.gz (40.8 kB view details)

Uploaded Source

Built Distribution

nanoppo-0.15.post2-py2.py3-none-any.whl (49.7 kB view details)

Uploaded Python 2 Python 3

File details

Details for the file nanoppo-0.15.post2.tar.gz.

File metadata

  • Download URL: nanoppo-0.15.post2.tar.gz
  • Upload date:
  • Size: 40.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.6.1 CPython/3.10.10 Linux/5.4.0-166-generic

File hashes

Hashes for nanoppo-0.15.post2.tar.gz
Algorithm Hash digest
SHA256 ff563602d32837dcf446c3601d61f51b485d76b505dad16b309adca40b65ea3b
MD5 efa4ac373ec8f18586edc8b6977d38a0
BLAKE2b-256 193a81a5eae816b2257c80294049030b82a56b6ebe083efa87b998b478f92694

See more details on using hashes here.

File details

Details for the file nanoppo-0.15.post2-py2.py3-none-any.whl.

File metadata

  • Download URL: nanoppo-0.15.post2-py2.py3-none-any.whl
  • Upload date:
  • Size: 49.7 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.6.1 CPython/3.10.10 Linux/5.4.0-166-generic

File hashes

Hashes for nanoppo-0.15.post2-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 3748664c4e5cd051ca17a66aa80adf7686c8728943bd537057feba3ac4862e0a
MD5 d561e90c0200c95cd18d5988e398dbc9
BLAKE2b-256 3abe04275859eebe5888bd2b30bd68ec2c355df84535c284ed05c3efda3fe23b

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page