Skip to main content

Graph convolutional memory for reinforcement learning

Project description

Graph Convolution Memory for Reinforcement Learning

Description

Graph convolutional memory (GCM) is graph-structured memory that may be applied to reinforcement learning to solve POMDPs, replacing LSTMs or attention mechanisms. GCM allows you to embed your domain knowledge in the form of connections in a knowledge graph. See the full paper for further details. This repo contains the GCM library implementation for use in your projects. To replicate the experiments from the paper, please see this repository instead

Installation

GCM is installed using pip. The dependencies must be installed manually, as they target your specific architecture (with or without CUDA).

Conda install

First install torch >= 1.8.0 and torch-geometric dependencies, then gcm:

conda install torch
conda install pytorch-geometric -c rusty1s -c conda-forge
pip install graph-conv-memory

Pip install

Please follow the torch-geometric install guide, then

pip install graph-conv-memory

Quickstart

Below is a quick example of how to use GCM in a basic RL problem:

import torch
import torch_geometric
from gcm.gcm import DenseGCM
from gcm.edge_selectors.temporal import TemporalBackedge


# graph_size denotes the maximum number of observations in the graph, after which
# the oldest observations will be overwritten with newer observations. Reduce this number to
# reduce memory usage.
graph_size = 128
# Define the GNN used in GCM. The following is the one used in the paper
# Make sure you define the first layer to match your observation space
obs_size = 8
our_gnn = torch_geometric.nn.Sequential(
    "x, adj, weights, B, N",
    [
        (torch_geometric.nn.DenseGraphConv(obs_size, 32), "x, adj -> x"),
        (torch.nn.Tanh()),
        (torch_geometric.nn.DenseGraphConv(32, 32), "x, adj -> x"),
        (torch.nn.Tanh()),
    ],
)
# Create the GCM using our GNN and edge selection criteria. TemporalBackedge([1]) will link observation o_t to o_{t-1}.
# See `gcm.edge_selectors` for different kinds of priors suitable for your specific problem. Do not be afraid to implement your own!
gcm = DenseGCM(our_gnn, edge_selectors=TemporalBackedge([1]), graph_size=graph_size)

# Create initial state
# Shape: (batch_size, graph_size, graph_size)
edges = torch.zeros(
    (1, graph_size, graph_size), dtype=torch.float
)
# Shape: (batch_size, graph_size, obs_size)
nodes = torch.zeros((1, graph_size, obs_size))
# Shape: (batch_size, graph_size, graph_size)
weights = torch.zeros(
    (1, graph_size, graph_size), dtype=torch.float
)
# Shape: (batch_size)
num_nodes = torch.tensor([0], dtype=torch.long)
# Our memory state (m_t in the paper)
m_t = [nodes, edges, weights, num_nodes]

for t in train_timestep:
   # Obs at timestep t should be of shape (batch_size, obs_size)
   belief, m_t = gcm(obs[t], m_t)
   # GCM provides a belief state -- a combination of all past observational data relevant to the problem
   # What you likely want to do is put this state through actor and critic networks to obtain
   # action and value estimates
   action_logits = logits_nn(belief)
   state_value = vf_nn(belief)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

graph-conv-memory-0.0.2.tar.gz (10.6 kB view details)

Uploaded Source

Built Distribution

graph_conv_memory-0.0.2-py3-none-any.whl (10.9 kB view details)

Uploaded Python 3

File details

Details for the file graph-conv-memory-0.0.2.tar.gz.

File metadata

  • Download URL: graph-conv-memory-0.0.2.tar.gz
  • Upload date:
  • Size: 10.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.6.1 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.55.0 CPython/3.9.1

File hashes

Hashes for graph-conv-memory-0.0.2.tar.gz
Algorithm Hash digest
SHA256 2be93654779e2158586d346af4792aa0f081ee3e4b7f444de206a96dfc93b699
MD5 2331efea3a1ec9088cd90f8557c2a551
BLAKE2b-256 14ed1d9a4ad3f61970f7a6459ca8bf7a9a680a5ae55183e7219a4402ad5be64c

See more details on using hashes here.

File details

Details for the file graph_conv_memory-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: graph_conv_memory-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 10.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.6.1 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.55.0 CPython/3.9.1

File hashes

Hashes for graph_conv_memory-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 b418924083be763530123a5c09b5f226ffd3ceda2cee7636e46a15180ce1c7e0
MD5 d4fae6290dfd9a3ddee49d688192ace5
BLAKE2b-256 f7dcd59e7c9f15f4e079cd7dc63dcb8546725c192a7e28e0fe5f4bd9e5c0686e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page