Skip to main content

AGNES - Flexible Reinforcement Learning Framework with PyTorch

Project description

AGNES - Flexible Reinforcement Learning Framework with PyTorch

Status: This framework is under active development and bugs may occur.

Build status Upload Python Package

Results

MuJoCo

(Current results) The ending average is 5326.2 MuJoCo "Ant-v2" training with 1M steps. Single runner with PPO algorithm, MLP NN and 32 number of envs. The curve is an average of 3 runs.

You can get the Tensorboard log file by clicking the image above(You will be redirected to the destination GitHub folder). The default config for the MuJoCo environment was used. Plotted by examples/plot.py

Atari

(Old results)

Peaking at 861.8 at the end. The ending average is 854.8. Atari "BreakoutNoFrameskip-v4" with frame stack training with 10M steps. DistributedMPI runner with PPO algorithm, LSTMCNN and 16 number of envs.

You can get the Tensorboard log file by clicking the image above(You will be redirected to the destination GitHub folder). The default config for the Atari environment was used.

LSTMCNN agent plays Breakout

Grad-cam technique was used for sampled action chosen by trained LSTMCNN(previous point).

LSTMCNN agent plays Breakout

Runners

Single

One worker and trainer. agnes.make_vec_env can also be used here.

import agnes
import time


if __name__ == '__main__':
    env = agnes.make_env("InvertedDoublePendulum-v2")
    runner = agnes.Single(env, agnes.PPO, agnes.MLP)
    runner.log(agnes.log, agnes.TensorboardLogger(".logs/"), agnes.CsvLogger(".logs/"))
    runner.run()

agnes.log - object of StandardLogger class that outputs parameters to console. agnes.TensorboardLogger - class for writing logs in Tensorboard file. agnes.CsvLogger - class for writing logs in csv file. (required for plotting)

DistributedMPI

Unlike in Single runner, in DistributedMPI runner due to async executing, weights are delayed by one rollout but this has no effect on learning because weights are delayed only by one update as it is in Single runner. So all parameters like probabilities ratio stay the same.

Runs with

mpiexec -n 3 python -m mpi4py script_name.py
or
mpirun -n 3 python -m mpi4py script_name.py

This command will run 2 workers and 1 trainer.

# script_name.py
import agnes


if __name__ == '__main__':
    env = agnes.make_vec_env("BreakoutNoFrameskip-v4")
    runner = agnes.DistributedMPI(env, agnes.PPO, agnes.CNN)
    runner.run()

Algorithms

A2C

Sync version of Advantage Actor Critic is implemented in this framework and can be used simply:

import agnes


if __name__ == '__main__':
    runner = agnes.Single(env, agnes.A2C, agnes.MLP)
    runner.run()

PPO

Proximal Policy Optimization is implemented in this framework and can be used simply:

import agnes


if __name__ == '__main__':
    runner = agnes.Single(env, agnes.PPO, agnes.MLP)
    runner.run()

Neural Network Architectures

Multi Layer Perceptron

Can be used with both continuous and discrete action spaces.

...
runner = agnes.Single(env, agnes.PPO, agnes.MLP)
...

Convolutional Neural Network

Can be used only with discrete action spaces.

...
runner = agnes.Single(env, agnes.PPO, agnes.CNN)
...

Recurrent Neural Network

Can be used with both continuous and discrete action spaces.

...
runner = agnes.Single(env, agnes.PPO, agnes.RNN)
...

Convolutional Recurrent Neural Network

Can be used only with discrete action spaces.

...
runner = agnes.Single(env, agnes.PPO, agnes.RNNCNN)
...

Convolutional Neural Network with last LSTM layer

Can be used only with discrete action spaces.

...
runner = agnes.Single(env, agnes.PPO, agnes.LSTMCNN)
...

Make environment

  • make_vec_env(env, envs_num=ncpu, config=None)**

    Parameters:

    • env(str or function) is id of gym environment or function, that returns initialized environment
    • envs_num(int) is a number of environments to initialize, by default is a number of logical cores on the CPU
    • config(dict) is a dictionary with parameters for Monitor and for initializing environment, by default is None(uses default config)

    Returns:

    • dict of
      1. "env"(VecEnv object)
      2. "env_type"(str)
      3. "env_num"(int) is a number of envs in VecEnv object
      4. "env_name"(str) is the name of envs in VecEnv object(Id in gym or class name)

    The whole tuple should be put in a runner.

  • make_env(env, config=None) is an alias of make_vec_env without envs_num argument that will be setted to 1.

Notice: Some plot functions and environment wrappers were taken from OpenAI Baselines(2017).

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

agnes-0.0.7.4.tar.gz (41.7 kB view details)

Uploaded Source

Built Distribution

agnes-0.0.7.4-py3-none-any.whl (55.9 kB view details)

Uploaded Python 3

File details

Details for the file agnes-0.0.7.4.tar.gz.

File metadata

  • Download URL: agnes-0.0.7.4.tar.gz
  • Upload date:
  • Size: 41.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.46.0 CPython/3.8.3

File hashes

Hashes for agnes-0.0.7.4.tar.gz
Algorithm Hash digest
SHA256 bbbeabd6bef6a05b2d49833031938ad8274846d797b0b65ac2486272ea5d7fb2
MD5 524ea49d27423d1aa08c637a1266e96c
BLAKE2b-256 2546797a5733a81f243a2d89b1a9617117b84f0c2344d33a35647ebf17cdf6f5

See more details on using hashes here.

File details

Details for the file agnes-0.0.7.4-py3-none-any.whl.

File metadata

  • Download URL: agnes-0.0.7.4-py3-none-any.whl
  • Upload date:
  • Size: 55.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.46.0 CPython/3.8.3

File hashes

Hashes for agnes-0.0.7.4-py3-none-any.whl
Algorithm Hash digest
SHA256 287e11ccce48f5921b33785981699beb30df3d0481eaa83af406a4e01ebc1126
MD5 a98f2502dbd26ce4777a03902393a6f0
BLAKE2b-256 ecf0be6e55f3dea2cb58feef5e224e0e28b61bf5ab1bac1d5b551e3469a0c5de

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page