Skip to main content

A reinforcement learning module

Project description

Build Status Coverage Status PyPI version

Reinforcement

The reinforcement package aims to provide simple implementations for basic reinforcement learning algorithms, using Test Driven Development and other principles of Software Engineering in an attempt to minimize defects and improve reproducibility.

Installation

The library can be installed using pip:

pip install reinforcement

Example Implementation

This section demonstrates how to implement a REINFORCE agent and benchmark it on the 'CartPole' gym environment.

You can find the full implementation in examples/reinforce.py. The example folder also contains some additional utility classes and functions that are used in the implementation.

def run_reinforce(config):
    reporter, env, rewards = Reporter(config), gym.make('CartPole-v0'), []
    with tf1.Session() as session:
        agent = _make_agent(config, session, env)
        for episode in range(1, config.episodes + 1):
            reward = _run_episode(env, episode, agent, reporter)
            rewards.append(reward)
            if reporter.should_log(episode):
                logger.info(reporter.report(episode, rewards))
    env.close()

This is the main function setting up the boiler plate code. It creates the tensorflow session, logs the progress, and creats the agent. The Reporter class is just a helper to make logging at a certain frequency more convenient

def _make_agent(config, session, env):
    p = ParameterizedPolicy(session, env.observation_space.shape[0], env.action_space.n, NoLog(), config.lr_policy)
    b = ValueBaseline(session, env.observation_space.shape[0], NoLog(), config.lr_baseline)
    alg = Reinforce(p, config.gamma, b, config.num_trajectories)
    return BatchAgent(alg)

The factory function _make_agent creates the REINFORCE agent object. It uses a parameterized policy and baseline to learn and estimate proper actions. In this case, both parameterizations are straightforward artificial neural networks with no hidden layer. Both have the same input layer, but the output layer of the policy is a softmax function, whereas the baseline outputs a single linear value. The BatchAgent type records trajectories (states, actions, rewards) which are then used to optimize the policy and the baseline. The NoLog class is a Null-Object implementing the TensorBoard FileWriter interface.

def _run_episode(env, episode, agent, report):
    obs = env.reset()
    done, reward = False, 0
    while not done:
        if report.should_render(episode):
            env.render()
        obs, r, done, _ = env.step(agent.next_action(obs))
        agent.signal(r)
        reward += r

    agent.train()
    return reward

This function performs a run through a single episode of the environment. Observations of the environment are passed to the agent's next_action interface function. The resulting estimated actions are passed again to the environment, leading to the next observation and a reward signal. The agent is then trained at the end of the episode because we want to train it on whole trajectories. It also contains a call to env.render() to visualize some runs.

Running an Example

Running the REINFORCE agent example with default settings:

python example/reinforce.py

After a few 1000 episodes it should get very close to the highest achievable reward:

...
INFO:__main__:Episode 2800: reward=200.0; mean reward of last 100 episodes: 199.71
INFO:__main__:Episode 2900: reward=200.0; mean reward of last 100 episodes: 199.36
INFO:__main__:Episode 3000: reward=200.0; mean reward of last 100 episodes: 198.09

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

reinforcement-1.2.0.tar.gz (14.4 kB view details)

Uploaded Source

Built Distribution

reinforcement-1.2.0-py3-none-any.whl (23.6 kB view details)

Uploaded Python 3

File details

Details for the file reinforcement-1.2.0.tar.gz.

File metadata

  • Download URL: reinforcement-1.2.0.tar.gz
  • Upload date:
  • Size: 14.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/2.0.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.36.1 CPython/3.6.7

File hashes

Hashes for reinforcement-1.2.0.tar.gz
Algorithm Hash digest
SHA256 ed2c9ff0a8b6f902abaa3bb39f589d0a8c17995cca0a01bf7a77f57497fcdbc8
MD5 d5d627ca6caa55d0aa6c6b1d2f4a2698
BLAKE2b-256 2d009799ba20fb1276da5c5c3e1308f4ea2e33be95030868100c6dbc282dba44

See more details on using hashes here.

File details

Details for the file reinforcement-1.2.0-py3-none-any.whl.

File metadata

  • Download URL: reinforcement-1.2.0-py3-none-any.whl
  • Upload date:
  • Size: 23.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/2.0.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.36.1 CPython/3.6.7

File hashes

Hashes for reinforcement-1.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 f21e299b5467a5d73f7293d0c72ea75648671e3590bb284ff69699a96d41ecdc
MD5 0bde5b15d09e604eec05bb4d4669133a
BLAKE2b-256 8c2ddec9f01534257eb7edde9c51c44a62dfa329de5875c8aec965863ae0e2ec

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page