A torch implementation of the Recurrent Inference Machine
Project description
This is an implementation of a Recurrent Inference Machine (see Putzky & Welling (2017)) alongside some standard neural network architectures for the type of problem RIM can solve.
Installation
To install the package, you can use pip:
pip install torch-rim
Usage
from torch_rim import RIM, Hourglass, Unet
from torch.func import vmap
# B is the batch size
# C is the input channels
# dimensions are the spatial dimensions (e.g. [28, 28] for MNIST)
# Create a score_fn, e.g. a Gaussian likelihood score function
@vmap
def score_fn(x, y, A, Sigma): # must respect the signature (x, y, *args)
# A is a linear forward model, Sigma is the noise covariance
return (y - A @ x) @ Sigma.inverse() @ A
# ... or a Gaussian energy function (unnormalized log probability)
@vmap
def energy_fn(x, y, F, Sigma):
# F is a general forward model
return (y - F(x)) @ Sigma.inverse() @ (y - F(x))
# Create a RIM instance with the Hourglass neural network back-bone and the score function
net = Hourglass(C, dimensions=len(dimensions))
rim = RIM(dimensions, net, score_fn=score_fn)
# ... or with the energy function
rim = RIM(dimensions, net, energy_fn=energy_fn)
# Train the rim, and save its weight in checkpoints_directory
rim.fit(dataset, epochs=100, learning_rate=1e-4, checkpoints_directory=checkpoints_directory)
# Make a prediction on an observation y
x_hat = rim.predict(y, A, Sigma) # of with the signature (y, F, Sigma) with the energy_fn
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
No source distribution files available for this release.See tutorial on generating distribution archives.
Built Distribution
torch_rim-0.2.2-py3-none-any.whl
(14.4 kB
view hashes)
Close
Hashes for torch_rim-0.2.2-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3e7e6c33c5842bffc3d58f1d52acc09597474a1ddac24331ef0da7020bca1afb |
|
MD5 | 07e82df150cfba7943d88c1712a8a3db |
|
BLAKE2b-256 | 4499ddeb592f8981a553e60bcd56569e3d99b4e1aa3911a693bafaf9eff17e14 |