Graph convolutional memory for reinforcement learning
Project description
Graph Convolutional Memory for Reinforcement Learning
Description
Graph convolutional memory (GCM) is graph-structured memory that may be applied to reinforcement learning to solve POMDPs, replacing LSTMs or attention mechanisms. GCM allows you to embed your domain knowledge in the form of connections in a knowledge graph. See the full paper for further details. This repo contains the GCM library implementation for use in your projects. To replicate the experiments from the paper, please see this repository instead.
If you use GCM, please cite the paper!
@article{morad2021graph,
title={Graph Convolutional Memory for Deep Reinforcement Learning},
author={Morad, Steven D and Liwicki, Stephan and Prorok, Amanda},
journal={arXiv preprint arXiv:2106.14117},
year={2021}
}
Installation
GCM is installed using pip
. The dependencies must be installed manually, as they target your specific architecture (with or without CUDA).
Conda install
First install torch >= 1.8.0
and torch-geometric
dependencies, then gcm
:
conda install torch
conda install pytorch-geometric -c rusty1s -c conda-forge
pip install graph-conv-memory
Pip install
Please follow the torch-geometric install guide, then
pip install graph-conv-memory
Quickstart
Below is a quick example of how to use GCM in a basic RL problem:
import torch
import torch_geometric
from gcm.gcm import DenseGCM
from gcm.edge_selectors.temporal import TemporalBackedge
# graph_size denotes the maximum number of observations in the graph, after which
# the oldest observations will be overwritten with newer observations. Reduce this number to
# reduce memory usage.
graph_size = 128
# Define the GNN used in GCM. The following is the one used in the paper
# Make sure you define the first layer to match your observation space
obs_size = 8
our_gnn = torch_geometric.nn.Sequential(
"x, adj, weights, B, N",
[
(torch_geometric.nn.DenseGraphConv(obs_size, 32), "x, adj -> x"),
(torch.nn.Tanh()),
(torch_geometric.nn.DenseGraphConv(32, 32), "x, adj -> x"),
(torch.nn.Tanh()),
],
)
# Create the GCM using our GNN and edge selection criteria. TemporalBackedge([1]) will link observation o_t to o_{t-1}.
# See `gcm.edge_selectors` for different kinds of priors suitable for your specific problem. Do not be afraid to implement your own!
gcm = DenseGCM(our_gnn, edge_selectors=TemporalBackedge([1]), graph_size=graph_size)
# If the hidden state m_t is None, GCM will initialize one for you
# only do this at the beginning, as GCM must track and update the hidden
# state to function correctly
m_t = None
for t in train_timestep:
# Obs at timestep t should be a tensor of shape (batch_size, obs_size)
# obs = my_env.step()
belief, m_t = gcm(obs, m_t)
# GCM provides a belief state -- a combination of all past observational data relevant to the problem
# What you likely want to do is put this state through actor and critic networks to obtain
# action and value estimates
action_logits = logits_nn(belief)
state_value = vf_nn(belief)
We provide a few edge selectors, which we briefly detail here:
gcm.edge_selectors.temporal.TemporalBackedge
# Connections to the past. Give it [1,2,4] to connect each
# observation t to t-1, t-2, and t-4.
gcm.edge_selectors.dense.DenseEdge
# Connections to all past observations
# observation t is connected to t-1, t-2, ... 0
gcm.edge_selectors.distance.EuclideanEdge
# Connections to observations within some max_distance
# e.g. if l2_norm(o_t, o_k) < max_distance, create an edge
gcm.edge_selectors.distance.CosineEdge
# Like euclidean edge, but using cosine similarity instead
gcm.edge_selectors.distance.SpatialEdge
# Euclidean distance, but only compares slices from the observation
# this is useful if you have an 'x' and 'y' dimension in your observation
# and only want to connect nearby entries
#
# You can also implement the identity priors using this by setting
# max_distance to something like 1e-6
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for graph_conv_memory-0.0.5-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6e451699ea83a0464fc9e6470d898bb23b32b5c7c1a741f876efde0e3ca3b65a |
|
MD5 | 7a8e9aa51d3a620cd04eb1f0da50e334 |
|
BLAKE2b-256 | 363e9d971e163c3c6d56d512b894e3d55df2462caa626a7a807a1ccaf21d3b14 |