A torch implementation of the Recurrent Inference Machine
Project description
This is an implementation of a Recurrent Inference Machine (see Putzky & Welling (2017)) alongside some standard neural network architectures for the type of problem RIM can solve.
Installation
To install the package, you can use pip:
pip install torch-rim
Usage
from torch_rim import RIM, Hourglass, Unet
from torch.func import vmap
# B is the batch size
# C is the input channels
# dimensions are the spatial dimensions (e.g. [28, 28] for MNIST)
# Create a score_fn, e.g. a Gaussian likelihood score function
@vmap
def score_fn(x, y, A, Sigma): # must respect the signature (x, y, *args)
# A is a linear forward model, Sigma is the noise covariance
return (y - A @ x) @ Sigma.inverse() @ A
# ... or a Gaussian energy function (unnormalized log probability)
@vmap
def energy_fn(x, y, F, Sigma):
# F is a general forward model
return (y - F(x)) @ Sigma.inverse() @ (y - F(x))
# Create a RIM instance with the Hourglass neural network back-bone and the score function
net = Hourglass(C, dimensions=len(dimensions))
rim = RIM(dimensions, net, score_fn=score_fn)
# ... or with the energy function
rim = RIM(dimensions, net, energy_fn=energy_fn)
# Train the rim, and save its weight in checkpoints_directory
rim.fit(dataset, epochs=100, learning_rate=1e-4, checkpoints_directory=checkpoints_directory)
# Make a prediction on an observation y
x_hat = rim.predict(y, A, Sigma) # of with the signature (y, F, Sigma) with the energy_fn
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torch_rim-0.2.3.tar.gz
(17.3 kB
view hashes)
Built Distribution
torch_rim-0.2.3-py3-none-any.whl
(14.4 kB
view hashes)
Close
Hashes for torch_rim-0.2.3-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 427b96d8ca51f8d0a78fc30b9a8c626ef6563c242c51aea45a6e0dd7b3746739 |
|
MD5 | 70efb5a415470ac3be33f7f68ed6dc0b |
|
BLAKE2b-256 | c03e31a46eac24dfe4081ddcd11169617d812c5367c180489c7f2bb28be52654 |