Skip to main content

Deep-Q learning with pytorch

Project description

Welcome to the torchagent repository. This repository contains the sources for the torchagent library.

What is it?

torchagent is a library that implements various reinforcement learning algorithms for PyTorch. You can use this library in combination with openAI Gym to implement reinforcement learning solutions.

Which algorithms are included?

Currently the following algorithms are implemented:

  • Deep Q Learning
  • Double Q Learning

Installation

You can install the library using the following command:

pip install torchagent

Usage

The following code shows a basic agent that uses Deep Q Learning.

from torchagent.memory import SequentialMemory
from torchagent.agents import DQNAgent

import torch
import torch.nn as nn
import torch.optim as optim

class PolicyNetwork(nn.Module):
    def __init__(self):
        self.linear = nn.Linear(210 * 160, 3)

    def forward(self, x):
        return self.linear(x)

policy_network = PolicyNetwork()
memory = SequentialMemory(20)
agent = DQNAgent(2, policy_network, nn.MSELoss(), optim.Adam(policy_network.parameters()), memory)

env = gym.make('Assault-v0')

for _ in range(50):
    state = env.reset()

    for t in count():
        action = agent.act(state)
        next_state, reward, done, _ = env.step(agent.act(state))

        agent.record(state, action, next_state, reward, done)
        agent.train()

        state = next_state

        if done:
            break

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Files for torchagent, version 0.2.4
Filename, size File type Python version Upload date Hashes
Filename, size torchagent-0.2.4.tar.gz (6.0 kB) File type Source Python version None Upload date Hashes View
Filename, size torchagent-0.2.4-py3-none-any.whl (6.6 kB) File type Wheel Python version py3 Upload date Hashes View

Supported by

AWS AWS Cloud computing Datadog Datadog Monitoring Facebook / Instagram Facebook / Instagram PSF Sponsor Fastly Fastly CDN Google Google Object Storage and Download Analytics Huawei Huawei PSF Sponsor Microsoft Microsoft PSF Sponsor NVIDIA NVIDIA PSF Sponsor Pingdom Pingdom Monitoring Salesforce Salesforce PSF Sponsor Sentry Sentry Error logging StatusPage StatusPage Status page