A plotter for reinforcement learning
Project description
rl-plotter
This is a simple tool which can plot learning curves easily for reinforcement learning.
Installation
from PIP
pip install rl_plotter
from source
python3 setup.py install
Examples
First, add a logger in your code (for example: DQN):
from rl_plotter.logger import Logger
def train(name):
dqn = DQN()
logger = Logger(name, env_name='PongNoFrameskip-v4', use_tensorboard=False)
while True:
s = env.reset()
while True:
total_step = logger.add_step()
a = dqn.select_action(s, EPSILON)
s_, r, done, info = env.step(a)
dqn.store_transition(s, a, r, s_)
episode_reward += r
if dqn.replay_memory.memory_counter > REPLAY_MEMORY_SIZE:
loss = dqn.learn()
logger.add_loss(loss.cpu().item())
if done:
break
s = s_
logger.add_episode()
logger.add_reward(episode_reward, freq=10)
logger.finish()
After the training or when you are training your agent, you can plot the learning curves in this way:
python -m rl_plotter.plotter
for help use:
python -m rl_plotter.plotter --help
The learning curves looks like this:
And you can custom the style of your curves by modifiying `rl_plotter.plotter`To Do
- reinforcement learning plot tools
- timestamp features
- history experiment data plot tools
- ~~basic data plot tools锛坕ncluding ML-Loss plot锛墌~
-
dynamic plot tools
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
rl_plotter-1.0.3.tar.gz
(11.4 kB
view hashes)
Built Distribution
rl_plotter-1.0.3-py3-none-any.whl
(12.4 kB
view hashes)
Close
Hashes for rl_plotter-1.0.3-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 9369bd5dcfe26f0a4a7cccd3116dec03534c599fd918df03489ca71f93ec810f |
|
MD5 | a4cec12e8dfd980221dfe314aa2707ff |
|
BLAKE2b-256 | 6e79d55f6c2d6aa61d4f6043fb32055f128926b06d522fcf34b46f6748ca228c |