Skip to main content

Graph convolutional memory for reinforcement learning

Project description

Graph Convolution Memory for Reinforcement Learning

Description

Graph convolutional memory (GCM) is graph-structured memory that may be applied to reinforcement learning to solve POMDPs, replacing LSTMs or attention mechanisms. GCM allows you to embed your domain knowledge in the form of connections in a knowledge graph. See the full paper for further details.

Installation

gcm is installed using pip. The dependencies must be installed manually, as they target your specific architecture (with or without CUDA).

Conda install

First install torch >= 1.8.0 and torch-geometric dependencies, then gcm

conda install torch
conda install pytorch-geometric -c rusty1s -c conda-forge
pip install gcm

Pip install

Please follow the torch-geometric install guide, then

pip install gcm

Quickstart

Below is a quick example of how to use GCM in a basic RL problem:

import torch
import torch_geometric
from gcm.gcm import DenseGCM
from gcm.edge_selectors.temporal import TemporalBackedge


# Define the GNN used in GCM. The following is the one used in the paper
# Make sure you define the first layer to match your observation space
obs_size = 8
our_gnn = torch_geometric.nn.Sequential(
    "x, adj, weights, B, N",
    [
        (torch_geometric.nn.DenseGraphConv(obs_size, 32), "x, adj -> x"),
        (torch.nn.Tanh()),
        (torch_geometric.nn.DenseGraphConv(32, 32), "x, adj -> x"),
        (torch.nn.Tanh()),
    ],
)
# graph_size denotes the maximum number of observations in the graph, after which
# the oldest observations will be overwritten
gcm = DenseGCM(our_gnn, edge_selectors=TemporalBackedge([1]), graph_size=128)

# Create initial state
edges = torch.zeros(
    (1, 128, 128), dtype=torch.float
)
nodes = torch.zeros((1, 128, obs_size))
weights = torch.zeros(
    (1, 128, 128), dtype=torch.float
)
num_nodes = torch.tensor([0], dtype=torch.long)
# Our memory state
m_t = [nodes, edges, weights, num_nodes]

for t in train_timestep:
   state, m_t = gcm(obs[t], m_t)
   # Do what you will with the state
   # likely you want to use it to get action/value estimate
   action_logits = logits(state)
   state_value = vf(state)

See gcm.edge_selectors for different kinds of priors suitable to your specific problem. Do not be afraid to implement your own!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

graph-conv-memory-0.0.1.tar.gz (9.2 kB view details)

Uploaded Source

Built Distribution

graph_conv_memory-0.0.1-py3-none-any.whl (10.5 kB view details)

Uploaded Python 3

File details

Details for the file graph-conv-memory-0.0.1.tar.gz.

File metadata

  • Download URL: graph-conv-memory-0.0.1.tar.gz
  • Upload date:
  • Size: 9.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.6.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.1 CPython/3.6.9

File hashes

Hashes for graph-conv-memory-0.0.1.tar.gz
Algorithm Hash digest
SHA256 be5828b940d9b66f99d9e35c3c033a3582a1f00a9130d439a7a55029ea3fd66b
MD5 f519cdb2268c3cf5a16df499eb4544f8
BLAKE2b-256 73cb5200a78b21e245c1cfae82e5ed7120eb6ed30227fd4cb5aeefb41d44ba6b

See more details on using hashes here.

File details

Details for the file graph_conv_memory-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: graph_conv_memory-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 10.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.6.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.1 CPython/3.6.9

File hashes

Hashes for graph_conv_memory-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 cc75bd75b3d7c928e5d1a2d602ce8ad0c00912a03b0d4b2330a13501fcf9ad92
MD5 c09e36dbd7396fddf646bf24ceeeb58b
BLAKE2b-256 e1f8e7cbef516145aac19a496d68dccffc5b7fdf3f37d9a87e7b6c8086b3a8a6

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page