A torch implementation of the Recurrent Inference Machine
Project description
This is an implementation of a Recurrent Inference Machine (see Putzky & Welling (2017)) alongside some standard neural network architectures for the type of problem RIM can solve.
Installation
To install the package, you can use pip:
pip install torch-rim
Usage
from torch_rim import RIM, Hourglass, Unet
from torch.func import vmap
# B is the batch size
# C is the input channels
# dimensions are the spatial dimensions (e.g. [28, 28] for MNIST)
# Create a score_fn, e.g. a Gaussian likelihood score function
@vmap
def score_fn(x, y, A, Sigma): # must respect the signature (x, y, *args)
# A is a linear forward model, Sigma is the noise covariance
return (y - A @ x) @ Sigma.inverse() @ A
# ... or a Gaussian energy function (unnormalized log probability)
@vmap
def energy_fn(x, y, F, Sigma):
# F is a general forward model
return (y - F(x)) @ Sigma.inverse() @ (y - F(x))
# Create a RIM instance with the Hourglass neural network back-bone and the score function
net = Hourglass(C, dimensions=len(dimensions))
rim = RIM(dimensions, net, score_fn=score_fn)
# ... or with the energy function
rim = RIM(dimensions, net, energy_fn=energy_fn)
# Train the rim, and save its weight in checkpoints_directory
rim.fit(dataset, epochs=100, learning_rate=1e-4, checkpoints_directory=checkpoints_directory)
# Make a prediction on an observation y
x_hat = rim.predict(y, A, Sigma) # of with the signature (y, F, Sigma) with the energy_fn
Background
A RIM is a gradient-based meta-learning algorithm. It is trained not as a feed-forward neural network, but rather as an optimisation algorithm. More specifically, the RIM is given a problem instance specified by a likelihood score \(\nabla_\mathbf{x} \log p(y \mid x)\), or more generally a posterior score function \(\nabla_{\mathbf{x} \mid \mathbf{y}} \equiv \nabla_{\mathbf{x}} \log p(x \mid y)\), and an observation \(y\) to condition said posterior.
The RIM uses this information to perform a learned gradient ascent algorithm on the posterior. This procedure will produce a MAP estimate of the parameters of interests \(\mathbf{x}\) when the RIM is trained.
for \(t \in [0, T]\).
In the last equation, \(\mathbf{g}_\theta\) is a neural network that act as the gradient in the gradient ascent algorithm. The second equation represent an hidden state much like modern optimisation algorithm like ADAM (Kingma & Ba (2014)) or RMSProp (Hinton (2011)) that uses hidden state to aggregate information about the trajectory of the particle during optimisation. In this case, \(\mathbf{h}\) is the hidden state of a Gated Recurrent Unit (Chung et al. (2014)).
The RIM is trained using an outer loop optimisation contructed with labels \(\mathbf{x}\) (the parameters of interests) and a simulator \(F\):
Equipped with a dataset of problem instances and their solutions \(\mathbf{x}\)
we can train the RIM to make it’s gradient ascent trajectories as efficient as possible by minimizing a weighted mean squared loss
\(\mathbf{w}_t\) weighs each MSE on the parameter trajectory.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for torch_rim-0.2.5-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 9267db76f6390b653348298ba91d36d6c3087e3665f8a5e79285d5a7c245d1e3 |
|
MD5 | e90d13b7d863eb116f134314a9b6565c |
|
BLAKE2b-256 | 668b00063c4cd4ccdc302af9bf2bf2842d70f74e881d6fe2c0733533d6f19816 |