Skip to main content

Cardio RL. In development...

Project description

:running: Cardio: Runners for Deep Reinforcement Learning in Gym Environments :running:

Ruff License Docformatter

Pythonver

Motivation | Installation | Usage | Under the hood | Development | Contributing

So many reinforcement learning libraries, what makes Cardio different?

  • Easy and readable: Focus on the agent and leave the boilerplate code to Cardio
  • Extensible: Easy progression from simple algorithms all the way up to Rainbow and beyond
  • Research friendly: Cardio was designed to be a whiteboard for your RL research

Cardio aims to make new algorithm implementations easy to do, readable and framework agnostic by providing a collection of modular environment interaction loops for the research and implementation of deep reinforcement learning (RL) algorithms in Gymnasium environments. Out of the box these loops are capable of more complex experience collection approaches such as n-step transitions, trajectories, and storing of auxiliary values to a replay buffer. Accompanying these core components are helpful utilities (such as replay buffers and data transformations), and single-file reference implementations for state-of-the-art algorithms.

Motivation

In the spectrum of RL libraries, Cardio lies in-between large complete packages such as stable-baselines3 (lacks modularity/extensibility) that deliver complete implementations of algorithms, and more research-friendly repositories like CleanRL (repeating boilerplate code), in a similar design paradigm to Google’s Dopamine and Acme.

To achieve the desired structure and API, Cardio makes some concessions with the first of which being speed. There's no competing against end-to-end jitted implementations, but going down this direction greatly hinders the modularity and application of implementations to arbitrary environments. If you are interested in lightning quick training of agents on established baselines then please look towards the likes of Stoix.

Secondly, taking a modular approach leaves us less immediately extensible than the likes of CleanRL, despite the features in place to make the environment loops transparent, there is inevitably going to be edge cases where Cardio is not the best choice.

Installation

NOTE: Jax is a major requirement for runner internally, the installation process will be updated soon to make a better distinction between setting up Cardio using Jax for GPU's, CPU's or TPU's.

Via pip

pip install cardio-rl

Or from github:

git clone https://github.com/mmcaulif/Cardio.git
cd cardio
poetry install

Usage

Below is a simple example leveraging Cardio's off-policy runner to help write a simple implementation of a core deep RL algorithm, Deep Q-Networks, for the Cartpole environment.

It will be assumed that you have an beginners understanding of deep RL and this section just serves to demonstrate how different algorithm might fit into Cardio.

DQN

In this algorithm our agent performs a fixed number of environment steps (aka a rollout) and saves the transitions experienced in a replay buffer for performing update steps. Once the rollout is done, we sample from the replay buffer and pass the sampled transitions to the agents update method. To implement our agent we will use the provided Cardio Agent class and override the init, update and step methods:

class DQN(crl.Agent):
    def __init__(
        self,
        env: gym.Env,
        critic: nn.Module,
        gamma: float = 0.99,
        targ_freq: int = 1_000,
        optim_kwargs: dict = {"lr": 1e-4},
        init_eps: float = 0.9,
        min_eps: float = 0.05,
        schedule_len: int = 5000,
        use_rmsprop: bool = False,
    ):
        self.env = env
        self.critic = critic
        self.targ_critic = copy.deepcopy(critic)
        self.gamma = gamma
        self.targ_freq = targ_freq
        self.update_count = 0

        if not use_rmsprop:
            self.optimizer = th.optim.Adam(self.critic.parameters(), **optim_kwargs)
        else:
            # TODO: fix mypy crying about return type
            self.optimizer = th.optim.RMSprop(self.critic.parameters(), **optim_kwargs)

        self.eps = init_eps
        self.min_eps = min_eps
        self.ann_coeff = self.min_eps ** (1 / schedule_len)

    def update(self, batches):
        data = jax.tree.map(th.from_numpy, batches)
        s, a, r, s_p, d = data["s"], data["a"], data["r"], data["s_p"], data["d"]

        q = self.critic(s).gather(-1, a)
        q_p = self.targ_critic(s_p).max(dim=-1, keepdim=True).values
        y = r + self.gamma * q_p * ~d

        loss = F.mse_loss(q, y.detach())
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        self.update_count += 1
        if self.update_count % self.targ_freq == 0:
            self.targ_critic.load_state_dict(self.critic.state_dict())

        return {}

    def step(self, state):
        if np.random.rand() > self.eps:
            th_state = th.from_numpy(state)
            action = self.critic(th_state).argmax().numpy(force=True)
        else:
            action = self.env.action_space.sample()

        self.eps = max(self.min_eps, self.eps * self.ann_coeff)
        return action, {}

Next we instantiate our runner. When we instantiate a runner we will pass it our environment, our agent, rollout length, and the keyword agrs for the buffer (in this case, the batch size).

env = gym.make("CartPole-v1")
runner = crl.Runner.off_policy(
    env=env,
    agent=DQN(env, Q_critic(4, 2)),
    rollout_len=4,
    buffer_kwargs={"batch_size": 32}
)

And finally, to run 50,000 rollouts (in this case, 50,000 x 4 = 200,000 environment steps) and perform an agent update after each one, we just use the run method:

runner.run(rollouts=50_000, eval_freq=1_250)

Under the hood

Below we'll go over the inner workings of Cardio. The intention was to make Cardio quite minimal and easy to parse, akin to Dopamine, but I hope it is interesting to practitioners and I'm eager to hear any feedback/opinions on the design paradigm. This section also serves to highlight a couple of the nuances of Cardio's components.

Diagram pending creation

Transition

Borrowing an idea from TorchRL, the core building block that Cardio centers around is a dictionary that represents an MDP transition. By default the transition dict has the following keys: s, a, r, s_p, d corresponding to state, action, reward, state' (state prime or next state) and done. Two important concepts to be aware of are:

  1. A Cardio Transition dictionary does not neccessarily correspond to a a single environment step. For example, in the case of n-step transitions s will correspond to s_t but s_p will correspnd to s_(t+n) with the reward key having n number of entries. Furthermore, the replay buffer stores data as a transition dictionary with keys pointing to multiple states, actions rewards etc.
  2. The done value used in Cardio is the result of the OR between the terminal and truncated values used in gymnasium. Empiraclly, decoupling termination and truncation has been shown to have a negligible affect. However, this is a trivial feature to change and its possible that leaving up to the user is best.

By using dictionaries, new entries are easy to add and thus the storing of user-defined variables (such as intrinsic reward or policy probabilities) is built in to the framework, whereas this would be nontrivial to implement in more abstract libraries like stable-baselines3.

Agent

Much like Acme the Cardio agent class is very minimal, simply defining some base methods that are used by the environment interaction loops. The most important thing to know is when they are called, what data is provided, and which component is calling it. The most important of which are the step (given a state, return an action and any extras), view (given a step transition, return any extras) and update methods (given a batch of transitions).

Gatherer

The gatherer is the primary component in Cardio and serves the purpose of stepping through the environment directly with a provided agent, or a random policy. The gatherer has two buffers that are used to package the transitions for the Runner in the desired manner. The step buffer collects transitions optained from singular environment steps and has a capacity equal to n. When the step buffer is full, it transforms its elements into one n-step transition and adds that transition to the transition buffer. Some rough pseudocode is provided below.

Gatherer pseudocode

The step buffer is emptied after terminal states to prevent transitions overlapping across episodes. When n > 1, the step buffer needs to be "flushed", i.e. create transitions from steps that would otherwise be thrown away. Please refer to the example below provided by my esteemed colleage, ChatGPT:

If you are collecting 3-step transitions, here's how you handle the transitions where s_3 is a terminal state:

  1. Transition from s_0: (s_0, a_0, [r_0, r_1, r_2], s_3)
  2. Transition from s_1: (s_1, a_1, [r_1, r_2], s_3)
  3. Transition from s_2: (s_2, a_2, r_2, s_3)

The transition buffer is even simpler, just containing the processed transitions from the step buffer. The transition buffer starts empty when the gatherer's step method is called and also maintains its data across terminal steps. Both of these characteristics are opposite to the step buffer which persists across gatherer.step calls but not across terminal steps.

Due to the nature of n-step transitions, sometimes the gatherer's transition buffer will have less transitions than environment steps taken (as the step buffer gets filled) and other times it will have more (when the step buffer gets flushed) but at any given time there will be a rough one-to-one mapping between environment steps taken and transitions collected. Lastly, rollout lengths can be less than n.

Runner

The runner is the high level orchestrator that deals with the different components and data, it contains a gatherer, your agent and any replay buffer you might have. The runner step function calls the gatherer's step function as part its own step function, or as part of its built in warmup (for collecting a large amount of initial data with your agent) and burnin (for randomly stepping through an environment, not collecting data, such as for initialising normalisation values) methods. The runner can either be used via its run method (which iteratively calls the runner.step and the agent.update methods) or just with its step method if you'd like more finegrained control.

Development

The main development goal for Cardio will be to make it as fast, easy to use, and extensible as possible. The aim is not to include many RL features or to cater to every domain. Far down the line I could imagine trying to incorporate async runners but that can get messy quickly. However, if you notice any bugs, or have any suggestions or feature requests, user input is greatly appreciated!

Some tentative tasks right now are:

  • Integrated loggers (WandB, Neptune, Tensorboard etc.)
  • Implement seeding for reproducability.
  • Widespread and rigorous testing!
  • Asynchronous features

A wider goal is to perform profiling and squash any immediate performance bottlenecks. Wrapping an environment in a Cardio runner should introduce as little overhead as possible.

Any RL components (like neural network layers) are likely to be better suited to Cardio's sibling repo, Sprinter.

Contributing

Cat pull request image

Jokes aside, given the roadmap described above for Cardio, PR's related to bugs and performance are the main interest. If you would like a new feature, please create an issue first and we can discuss.

License

This repository is licensed under the Apache 2.0 License

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cardio_rl-0.1.2.tar.gz (33.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

cardio_rl-0.1.2-py3-none-any.whl (44.9 kB view details)

Uploaded Python 3

File details

Details for the file cardio_rl-0.1.2.tar.gz.

File metadata

  • Download URL: cardio_rl-0.1.2.tar.gz
  • Upload date:
  • Size: 33.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for cardio_rl-0.1.2.tar.gz
Algorithm Hash digest
SHA256 0ab285bb40f5d71c669787ec0c0214f078c69d63355f5e1285fa933221f71e1d
MD5 1a6dcf0c37e294eb181d22fb16e3da99
BLAKE2b-256 7b5b13f3d43862ba641fec87d4b63d0e7aefc6c23577d1594b62ecf456673c47

See more details on using hashes here.

Provenance

The following attestation bundles were made for cardio_rl-0.1.2.tar.gz:

Publisher: release.yml on mmcaulif/Cardio

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file cardio_rl-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: cardio_rl-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 44.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for cardio_rl-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 f2b9fef0ad21d82b036ddc05f030a1e5e7426bf2dd91beece03c6de6c5b62ae6
MD5 44af8ee58c95827bc31b68a01b0dae0d
BLAKE2b-256 c39b2660930d1a3f74be0bba6ccf5ac8b17e7d3a0cf8e69114008f3960f1b9ec

See more details on using hashes here.

Provenance

The following attestation bundles were made for cardio_rl-0.1.2-py3-none-any.whl:

Publisher: release.yml on mmcaulif/Cardio

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page