Skip to main content

PyTorch Reinforcement Learning Framework for Researchers

Project description


Build Status

Cherry is a reinforcement learning framework for researchers built on top of PyTorch.

Unlike other reinforcement learning implementations, cherry doesn't implement a single monolithic interface to existing algorithms. Instead, it provides you with low-level, common tools to write your own algorithms. Drawing from the UNIX philosophy, each tool strives to be as independent from the rest of the framework as possible. So if you don't like a specific tool, you don’t need to use it.

Features

Cherry extends PyTorch with only a handful of new core concepts.

To learn more about the tools and philosophy behind cherry, check out our Getting Started tutorial.

Overview and Examples

The following snippet showcases a few of the tools offered by cherry. Many more high-quality examples are available in the examples/ folder.

Defining a cherry.nn.Policy

class VisionPolicy(cherry.nn.Policy):  # inherits from torch.nn.Module

   def __init__(self, feature_extractor, actor):
      super(VisionGaussianPolicy, self).__init__()
      self.feature_extractor = feature_extractor
      self.actor = actor

   def forward(self, obs):
      mean = self.actor(self.feature_extractor(obs))
      std = 0.1 * torch.ones_like(mean)
      return cherry.distributions.TanhNormal(mean, std)  # policies always return a distribution

policy = VisionPolicy(MyResnetExtractor(), MyMLPActor())
action = policy.act(obs)  # sampled from policy's distribution
deterministic_action = policy.act(obs, deterministic=True)  # distribution's mode
action_distribution = policy(obs)  # work with the policy's distribution

Building a cherry.ExperienceReplay of cherry.Transition

# building the replay
replay = cherry.ExperienceReplay()
state = env.reset()
for t in range(1000):
   action = policy.act(state)
   next_state, reward, done, info = env.step(action)
   replay.append(state, action, reward, next_state, done)
   next_state = state

# manipulating the replay
replay = replay[-256:]  # indexes like a list
batch = replay.sample(32, contiguous=True)  # sample transitions into a replay
batch = batch.to('cuda') # move replay to device
for transition in reversed(batch): # iterate over a replay
   transition.reward *= 0.99

# get all states, actions, and rewards as PyTorch tensors.
reinforce_loss = - torch.sum(policy(batch.state()).log_prob(batch.action()) * batch.reward())

Designing algorithms with cherry.td, cherry.pg, and cherry.algorithms

# defining a new algorithm
@dataclasses.dataclass
class MyA2C:
   discount: float = 0.99
   
   def update(self, replay, policy, state_value, optimizer):
      # discount rewards
      values = state_value(replay.action())
      discounted_rewards = cherry.td.discount(
         self.discount, replay.reward(), replay.done(), bootstrap=values[-1].detach()
      )

      # Compute losses
      policy_loss = cherry.algorithms.A2C.policy_loss(
         log_probs=policy(replay.state()).log_prob(replay.action()),
         advantages=discounted_rewards - values.detach(),
      )
      value_loss = cherry.algorithms.A2C.state_value_loss(values, discounted_rewards)

      # Optimization step
      optimizer.zero_grad()
      (policy_loss + value_loss).backward()
      optimizer.step()
      return {'a2c/policy_loss': policy_loss, 'a2c/value_loss': value_loss}

# using MyA2C
my_a2c = MyA2C(discount=0.95)
my_policy = MyPolicy()
linear_value = cherry.models.LinearValue(128)
adam = torch.optim.Adam(policy.parameters())
for step in range(1000):
   replay = collect_experience(policy)
   my_a2c.update(replay, my_policy, linear_value, adam)

Install

pip install cherry-rl

Changelog

A human-readable changelog is available in the CHANGELOG.md file.

Documentation

Documentation and tutorials are available on cherry’s website: http://cherry-rl.net.

Contributing

Here are a couple of guidelines we strive to follow.

  • It's always a good idea to open an issue first, where we can discuss how to best proceed.
  • If you want to contribute a new example using cherry, it would preferably stand in a single file.
  • If you would like to contribute a new feature to the core library, we suggest to first implement an example showcasing your new functionality. Doing so is quite useful:
    • it allows for automatic testing,
    • it ensures that the functionality is correctly implemented,
    • it shows users how to use your functionality, and
    • it gives a concrete example when discussing the best way to merge your implementation.

Acknowledgements

Cherry draws inspiration from many reinforcement learning implementations, including

Why 'cherry' ?

Because it's the sweetest part of the cake.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cherry-rl-0.2.0.tar.gz (3.5 MB view details)

Uploaded Source

File details

Details for the file cherry-rl-0.2.0.tar.gz.

File metadata

  • Download URL: cherry-rl-0.2.0.tar.gz
  • Upload date:
  • Size: 3.5 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/32.0 requests/2.25.1 requests-toolbelt/0.9.1 urllib3/1.26.6 tqdm/4.61.2 importlib-metadata/4.11.1 keyring/23.5.0 rfc3986/1.5.0 colorama/0.4.4 CPython/3.9.5

File hashes

Hashes for cherry-rl-0.2.0.tar.gz
Algorithm Hash digest
SHA256 c53ace32ec902b2af4f08b7963a3bff4245939cc1577dc48b0dd0e4521d8f942
MD5 a473f2cd43320dd2339e84c5eb27a713
BLAKE2b-256 d0d56136ee97d27a47a11b17e049c7aee3a5070e4bf9557057b75c90f7887ebd

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page