Skip to main content

Reinforcement learning extensions for PyTorch

Project description

Welcome to the torchagent repository. This repository contains the sources for the torchagent library.

What is it?

torchagent is a library that implements various reinforcement learning algorithms for PyTorch. You can use this library in combination with openAI Gym to implement reinforcement learning solutions.

Which algorithms are included?

Currently the following algorithms are implemented:

  • Deep Q Learning
  • Double Q Learning

Installation

You can install the library using the following command:

pip install torchagent

Usage

The following code shows a basic agent that uses Deep Q Learning.

from torchagent.memory import SequentialMemory
from torchagent.agents import DQNAgent

import torch
import torch.nn as nn
import torch.optim as optim

class PolicyNetwork(nn.Module):
    def __init__(self):
        self.linear = nn.Linear(210 * 160, 3)

    def forward(self, x):
        return self.linear(x)

policy_network = PolicyNetwork()
memory = SequentialMemory(20)
agent = DQNAgent(2, policy_network, nn.MSELoss(), optim.Adam(policy_network.parameters()), memory)

env = gym.make('Assault-v0')

for _ in range(50):
    state = env.reset()

    for t in count():
        action = agent.act(state)
        next_state, reward, done, _ = env.step(agent.act(state))

        agent.record(state, action, next_state, reward, done)
        agent.train()

        state = next_state

        if done:
            break

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Files for torchagent, version 0.2.3
Filename, size File type Python version Upload date Hashes
Filename, size torchagent-0.2.3-py2.py3-none-any.whl (8.4 kB) File type Wheel Python version py2.py3 Upload date Hashes View
Filename, size torchagent-0.2.3.tar.gz (6.1 kB) File type Source Python version None Upload date Hashes View

Supported by

Pingdom Pingdom Monitoring Google Google Object Storage and Download Analytics Sentry Sentry Error logging AWS AWS Cloud computing DataDog DataDog Monitoring Fastly Fastly CDN DigiCert DigiCert EV certificate StatusPage StatusPage Status page