Reinforcement Learning Library.
Project description
rllib
Reinforcement Learning Library
Installation
pip install pytorch-rllib
Usage
Implemented agents:
- CrossEntropy
- Value / Policy Iteration
- Q-Learning
- Expected Value SARSA
- Approximate Q-Learning
- DQN
- Rainbow
- REINFORCE
- A2C
import gym
import numpy as np
import torch
from rllib.qlearning import ApproximateQLearningAgent
from rllib.trainer import TrainerTorch as Trainer
from rllib.utils import set_global_seed
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# init environment
env = gym.make("CartPole-v0")
set_global_seed(seed=42, env=env)
n_actions = env.action_space.n
n_state = env.observation_space.shape[0]
# init torch model
model = torch.nn.Sequential()
model.add_module("layer1", torch.nn.Linear(n_state, 128))
model.add_module("relu1", torch.nn.ReLU())
model.add_module("layer2", torch.nn.Linear(128, 64))
model.add_module("relu2", torch.nn.ReLU())
model.add_module("values", torch.nn.Linear(64, n_actions))
model = model.to(device)
# init agent
agent = ApproximateQLearningAgent(
model=model,
alpha=0.5,
epsilon=0.5,
discount=0.99,
n_actions=n_actions,
)
# train
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
trainer = Trainer(env=env)
train_rewards = trainer.train(
agent=agent,
optimizer=optimizer,
n_epochs=20,
n_sessions=100,
)
# train results
print(f"Mean train reward: {np.mean(train_rewards[-10:])}") # reward: 120.318
# inference
inference_reward = trainer.play_session(
agent=agent,
t_max=10**4,
)
# inference results
print(f"Inference reward: {inference_reward}") # reward: 171.0
More examples you can find here.
Requirements
Python >= 3.7
Citation
If you use rllib in a scientific publication, we would appreciate references to the following BibTex entry:
@misc{dayyass2022rllib,
author = {El-Ayyass, Dani},
title = {Reinforcement Learning Library},
howpublished = {\url{https://github.com/dayyass/rllib}},
year = {2022}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
pytorch-rllib-0.1.2.tar.gz
(8.2 kB
view hashes)
Built Distribution
Close
Hashes for pytorch_rllib-0.1.2-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 93c9bf16ba4566f45d8377a98b96663bd6ce7fe8deabd980147ff484467512da |
|
MD5 | a11ee8c016d0cc15c2706faeb90c832b |
|
BLAKE2b-256 | e9a8b12dd6059112b75fec8fbb99d8bacf734412e0d5bab1edc2fae444c9313a |