Skip to main content

A library to build and train reinforcement learning agents in OpenAI Gym environments.

Project description

A library to build and train reinforcement learning agents in OpenAI Gym environments.

Build Status Documentation Status

Read full documentation here.

Getting Started

An agent has to implement the act() method which takes the current state as input and returns an action:

from train import Agent

class RandomAgent(Agent):

    def act(self, state):
        return self.env.action_space.sample()

Create an environment using OpenAI Gym:

import gym

env = gym.make('CartPole-v0')

Initialize your agent using the environment:

agent = RandomAgent(env=env)

Now you can start training your agent (in this example, the agent acts randomly always and doesn’t learn anything):

scores = agent.train(episodes=100)

You can also visualize how the training progresses but it will slow down the process:

scores = agent.train(episodes=100, render=True)

Once you are done with the training, you can test it:

scores = agent.test(episodes=10)

Alternatively, visualize how it performs:

scores = agent.test(episodes=10, render=True)

To learn more about how to build an agent that learns see Agent documentation.

See examples directory to see implementations of some algorithms (DQN, A3C, PPO etc.) in TensorFlow.

Installation

Requirements:

  • Python >= 3.6

Install from PyPI (recommended):

pip install train

Alternatively, install from source:

git clone https://github.com/marella/train.git
cd train
pip install -e .

To run examples and tests, install from source.

Other libraries such as Gym and TensorFlow should be installed separately.

Examples

To run examples, install TensorFlow and install dependencies:

pip install -e .[examples]

and run an example in examples directory:

cd examples
python PPO.py

Testing

To run tests, install dependencies:

pip install -e .[tests]

and run:

pytest tests

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

train-0.0.5.tar.gz (8.1 kB view details)

Uploaded Source

File details

Details for the file train-0.0.5.tar.gz.

File metadata

  • Download URL: train-0.0.5.tar.gz
  • Upload date:
  • Size: 8.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.11.0 pkginfo/1.4.2 requests/2.22.0 setuptools/42.0.2 requests-toolbelt/0.8.0 tqdm/4.20.0 CPython/3.6.4

File hashes

Hashes for train-0.0.5.tar.gz
Algorithm Hash digest
SHA256 441beaa4b792bdca301ac3c69bb8299256873e081a9f0de1e3782064a2f36cdf
MD5 fac4b208ce8d3ac2361794642096ec88
BLAKE2b-256 f7bd03ef37dfb2f0550f1fa43423bf8a2c833d1833ce1d90eb71dee05131eee2

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page