Skip to main content

PyTorch Reinforcement Learning Framework for Researchers

Project description


Build Status

Cherry is a reinforcement learning framework for researchers built on top of PyTorch.

Unlike other reinforcement learning implementations, cherry doesn't implement a single monolithic interface to existing algorithms. Instead, it provides you with low-level, common tools to write your own algorithms. Drawing from the UNIX philosophy, each tool strives to be as independent from the rest of the framework as possible. So if you don't like a specific tool, you don’t need to use it.

Features

Cherry extends PyTorch with only a handful of new core concepts.

  • PyTorch Modules for reinforcement learning:
    • cherry.nn.Policy:
    • cherry.nn.ActionValue:
    • cherry.nn.StateValue:
  • Data structures for reinforcement learning:
    • cherry.Transition:
    • cherry.ExperienceReplay:

Cherry also includes additional features, to help implement existing and new RL algorithms.

  • Pythonic and low-level interface à la Pytorch.
  • Support for tabular (!) and function approximation algorithms.
  • Various OpenAI Gym environment wrappers.
  • Helper functions for popular algorithms. (e.g. A2C, DDPG, TRPO, PPO, SAC)
  • Logging, visualization, and debugging tools.
  • Painless and efficient distributed training on CPUs and GPUs.
  • Unit, integration, and regression tested, continuously integrated.

To learn more about the tools and philosophy behind cherry, check out our Getting Started tutorial.

Example

The following snippet showcases some of the tools offered by cherry.

import cherry as ch

# Wrap environments
env = gym.make('CartPole-v0')
env = ch.envs.Logger(env, interval=1000)
env = ch.envs.Torch(env)

policy = PolicyNet()
optimizer = optim.Adam(policy.parameters(), lr=1e-2)
replay = ch.ExperienceReplay()  # Manage transitions

for step in range(1000):
    state = env.reset()
    while True:
        mass = Categorical(policy(state))
        action = mass.sample()
        log_prob = mass.log_prob(action)
        next_state, reward, done, _ = env.step(action)

        # Build the ExperienceReplay
        replay.append(state, action, reward, next_state, done, log_prob=log_prob)
        if done:
            break
        else:
            state = next_state

    # Discounting and normalizing rewards
    rewards = ch.td.discount(0.99, replay.reward(), replay.done())
    rewards = ch.normalize(rewards)

    loss = -th.sum(replay.log_prob() * rewards)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    replay.empty()

Many more high-quality examples are available in the examples/ folder.

Installation

Note Cherry is considered in early alpha release. Stuff might break.

pip install cherry-rl

Changelog

A human-readable changelog is available in the CHANGELOG.md file.

Documentation

Documentation and tutorials are available on cherry’s website: http://cherry-rl.net.

Contributing

First, thanks for your consideration in contributing to cherry. Here are a couple of guidelines we strive to follow.

  • It's always a good idea to open an issue first, where we can discuss how to best proceed.
  • If you want to contribute a new example using cherry, it would preferably stand in a single file.
  • If you would like to contribute a new feature to the core library, we suggest to first implement an example showcasing your new functionality. Doing so is quite useful:
    • it allows for automatic testing,
    • it ensures that the functionality is correctly implemented,
    • it shows users how to use your functionality, and
    • it gives a concrete example when discussing the best way to merge your implementation.

We don't have forums, but are happy to discuss with you on slack. Make sure to send an email to smr.arnold@gmail.com to get an invite.

Acknowledgements

Cherry draws inspiration from many reinforcement learning implementations, including

Why 'cherry' ?

Because it's the sweetest part of the cake.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cerise-0.2.0.tar.gz (516.5 kB view details)

Uploaded Source

File details

Details for the file cerise-0.2.0.tar.gz.

File metadata

  • Download URL: cerise-0.2.0.tar.gz
  • Upload date:
  • Size: 516.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/32.0 requests/2.25.1 requests-toolbelt/0.9.1 urllib3/1.26.6 tqdm/4.61.2 importlib-metadata/4.11.1 keyring/23.5.0 rfc3986/1.5.0 colorama/0.4.4 CPython/3.9.5

File hashes

Hashes for cerise-0.2.0.tar.gz
Algorithm Hash digest
SHA256 0f10e3e0f9d93183a4d8c51dd0eca3b936b227718c2a5ebf7c187ab1a2a56f77
MD5 93b7958604c13134df7545f9075a6e31
BLAKE2b-256 f1a3b9e30e6c51ca9b16546b426e03303c84c2568a9d8f2d856908cd8adad418

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page