Skip to main content

A torch implementation of the Recurrent Inference Machine

Project description

https://badge.fury.io/py/torch-rim.svg https://codecov.io/gh/AlexandreAdam/torch_rim/branch/master/graph/badge.svg

This is an implementation of a Recurrent Inference Machine (see Putzky & Welling (2017)) alongside some standard neural network architectures for the type of problem RIM can solve.

Installation

To install the package, you can use pip:

pip install torch-rim

Usage

from rim import RIM, Hourglass, Unet
from torch.func import vmap

# B is the batch size
# C is the input channels
# dimensions are the spatial dimensions (e.g. [28, 28] for MNIST)

# Create a score_fn, e.g. a Gaussian likelihood score function
@vmap
def score_fn(x, y, A, Sigma): # must respect the signature (x, y, *args)
    # A is a linear forward model, Sigma is the noise covariance
    return (y - A @ x) @ Sigma.inverse() @ A

# ... or a Gaussian energy function (unnormalized log probability)
@vmap
def energy_fn(x, y, F, Sigma):
    # F is a general forward model
    return (y - F(x)) @ Sigma.inverse() @ (y - F(x))

# Create a RIM instance with the Hourglass neural network back-bone and the score function
net = Hourglass(C, dimensions=len(dimensions))
rim = RIM(dimensions, net, score_fn=score_fn)

# ... or with the energy function
rim = RIM(dimensions, net, energy_fn=energy_fn)

# Train the rim, and save its weight in checkpoints_directory
rim.fit(dataset, epochs=100, learning_rate=1e-4, checkpoints_directory=checkpoints_directory)

# Make a prediction on an observation y
x_hat = rim.predict(y, A, Sigma) # of with the signature (y, F, Sigma) with the energy_fn

Background

A RIM is a gradient-based meta-learning algorithm. It is trained not as a feed-forward neural network, but rather as an optimisation algorithm. More specifically, the RIM is given a problem instance specified by a likelihood score \(\nabla_\mathbf{x} \log p(y \mid x)\), or more generally a posterior score function \(\nabla_{\mathbf{x} \mid \mathbf{y}} \equiv \nabla_{\mathbf{x}} \log p(x \mid y)\), and an observation \(y\) to condition said posterior.

The RIM uses this information to perform a learned gradient ascent algorithm on the posterior. This procedure will produce a MAP estimate of the parameters of interests \(\mathbf{x}\) when the RIM is trained.

\begin{align*} \hat{\mathbf{x}}_{t+1} &= \hat{\mathbf{x}}_t + \mathbf{g}_\theta (\hat{\mathbf{x}}_t,\, \mathbf{y},\, \nabla_{\hat{\mathbf{x}}_t \mid \mathbf{y}},\, \mathbf{h}_t)\\ \mathbf{h}_{t+1} &= \mathbf{g}_\theta(\hat{\mathbf{x}}_t,\, \mathbf{y},\, \nabla_{\hat{\mathbf{x}}_t \mid \mathbf{y}},\, \mathbf{h}_t) \end{align*}

for \(t \in [0, T]\).

In the last equation, \(\mathbf{g}_\theta\) is a neural network that act as the gradient in the gradient ascent algorithm. The second equation represent an hidden state much like modern optimisation algorithm like ADAM (Kingma & Ba (2014)) or RMSProp (Hinton (2011)) that uses hidden state to aggregate information about the trajectory of the particle during optimisation. In this case, \(\mathbf{h}\) is the hidden state of a Gated Recurrent Unit (Chung et al. (2014)).

The RIM is trained using an outer loop optimisation contructed with labels \(\mathbf{x}\) (the parameters of interests) and a simulator \(F\):

\begin{equation*} \mathbf{y} = F(\mathbf{x}) + \boldsymbol{\eta} \end{equation*}

Equipped with a dataset of problem instances and their solutions \(\mathbf{x}\)

\begin{equation*} \mathcal{D} = \big\{\mathbf{y}^{(i)},\, \mathbf{x}^{(i)},\, \nabla_{\mathbf{x} \mid \mathbf{y}}^{(i)}\big\}_{i=1}^N \end{equation*}

we can train the RIM to make it’s gradient ascent trajectories as efficient as possible by minimizing a weighted mean squared loss

\begin{equation*} \mathcal{L}_\theta = \mathbb{E}_\mathcal{D} \left[ \sum_{t=1}^T\lVert \mathbf{w}_t(\hat{\mathbf{x}}_t - \mathbf{x})\rVert^2_2 \right] \end{equation*}

\(\mathbf{w}_t\) weighs each MSE on the parameter trajectory.

Citations

If you find this work useful, please consider citing:

@article{Putzky2017,
    author = {{Putzky}, Patrick and {Welling}, Max},
    journal = {arXiv e-prints},
    archivePrefix = {arXiv},
    eprint = {1706.04008},
    primaryClass = {Neural and Evolutionary Computing},
    title = "{Recurrent Inference Machines for Solving Inverse Problems}",
    url = {http://arxiv.org/abs/1706.04008},
    year = 2017
}
@article{Adam2023,
     author = {{Adam}, Alexandre and {Perreault-Levasseur}, Laurence and {Hezaveh}, Yashar and {Welling}, Max},
     title = "{Pixelated Reconstruction of Foreground Density and Background Surface Brightness in Gravitational Lensing Systems Using Recurrent Inference Machines}",
     keywords = {Convolutional neural networks, Astronomical simulations, Nonparametric inference},
     journal = {\apj},
     year = 2023,
     month = jul,
     volume = {951},
     number = {1},
     eid = {6},
     pages = {6},
     doi = {10.3847/1538-4357/accf84},
     archivePrefix = {arXiv},
     eprint = {2301.04168},
     primaryClass = {astro-ph.IM},
     adsurl = {https://ui.adsabs.harvard.edu/abs/2023ApJ...951....6A},
     adsnote = {Provided by the SAO/NASA Astrophysics Data System}
}
@article{Morningstar2019,
     author = {{Morningstar}, Warren R. and {Perreault Levasseur}, Laurence and {Hezaveh}, Yashar D. and {Blandford}, Roger and {Marshall}, Phil and {Putzky}, Patrick and {Rueter}, Thomas D. and {Wechsler}, Risa and {Welling}, Max},
     title = "{Data-driven Reconstruction of Gravitationally Lensed Galaxies Using Recurrent Inference Machines}",
     journal = {\apj},
     keywords = {Convolutional neural networks, Strong gravitational lensing, Neural networks, Astrophysics - Instrumentation and Methods for Astrophysics, Astrophysics - Cosmology and Nongalactic Astrophysics, Astrophysics - Astrophysics of Galaxies},
     year = 2019,
     month = sep,
     volume = {883},
     number = {1},
     eid = {14},
     pages = {14},
     doi = {10.3847/1538-4357/ab35d7},
     archivePrefix = {arXiv},
     eprint = {1901.01359},
     primaryClass = {astro-ph.IM},
     adsurl = {https://ui.adsabs.harvard.edu/abs/2019ApJ...883...14M},
     adsnote = {Provided by the SAO/NASA Astrophysics Data System}
}
@article{Morningstar2018,
     author = {{Morningstar}, Warren R. and {Hezaveh}, Yashar D. and {Perreault Levasseur}, Laurence and {Blandford}, Roger D. and {Marshall}, Philip J. and {Putzky}, Patrick and {Wechsler}, Risa H.},
     title = "{Analyzing interferometric observations of strong gravitational lenses with recurrent and convolutional neural networks}",
     journal = {arXiv e-prints},
     keywords = {Astrophysics - Instrumentation and Methods for Astrophysics, Astrophysics - Cosmology and Nongalactic Astrophysics},
     year = 2018,
     month = jul,
     eid = {arXiv:1808.00011},
     pages = {arXiv:1808.00011},
     doi = {10.48550/arXiv.1808.00011},
     archivePrefix = {arXiv},
     eprint = {1808.00011},
     primaryClass = {astro-ph.IM},
     adsurl = {https://ui.adsabs.harvard.edu/abs/2018arXiv180800011M},
     adsnote = {Provided by the SAO/NASA Astrophysics Data System}
}
@article{Lonning2019,
    author = {L{\o}nning, Kai and Putzky, Patrick and Sonke, Jan Jakob and Reneman, Liesbeth and Caan, Matthan W.A. and Welling, Max},
    doi = {10.1016/j.media.2019.01.005},
    issn = {13618423},
    journal = {Medical Image Analysis},
    keywords = {Deep learning,Inverse problems,MRI,Reconstruction},
    month = {apr},
    pages = {64--78},
    pmid = {30703579},
    publisher = {Elsevier},
    title = {{Recurrent inference machines for reconstructing heterogeneous MRI data}},
    volume = {53},
    year = {2019}
}
@article{Modi2021,
       author = {{Modi}, Chirag and {Lanusse}, Fran{\c{c}}ois and {Seljak}, Uro{\v{s}} and {Spergel}, David N. and {Perreault-Levasseur}, Laurence},
        title = "{CosmicRIM : Reconstructing Early Universe by Combining Differentiable Simulations with Recurrent Inference Machines}",
      journal = {arXiv e-prints},
archivePrefix = {arXiv},
       eprint = {2104.12864},
 primaryClass = {astro-ph.CO},
     keywords = {Astrophysics - Cosmology and Nongalactic Astrophysics},
         year = 2021,
        month = apr,
        eid = {arXiv:2104.12864},
      pages = {arXiv:2104.12864},
      adsurl = {https://ui.adsabs.harvard.edu/abs/2021arXiv210412864M},
      adsnote = {Provided by the SAO/NASA Astrophysics Data System}
}

License

This package is licensed under the MIT License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_rim-0.2.9.tar.gz (20.8 kB view details)

Uploaded Source

Built Distribution

torch_rim-0.2.9-py3-none-any.whl (16.9 kB view details)

Uploaded Python 3

File details

Details for the file torch_rim-0.2.9.tar.gz.

File metadata

  • Download URL: torch_rim-0.2.9.tar.gz
  • Upload date:
  • Size: 20.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.16

File hashes

Hashes for torch_rim-0.2.9.tar.gz
Algorithm Hash digest
SHA256 be72a6eb4c52ec55c86881e9f3ee6ead808163109f3226da9eabe333a5b290de
MD5 f0931c75b6063c69d12279e49a99ab2f
BLAKE2b-256 e9b33b0a8b3d77481df51154e95ca240b49b5c10084d98ae81a3a783b47d1156

See more details on using hashes here.

File details

Details for the file torch_rim-0.2.9-py3-none-any.whl.

File metadata

  • Download URL: torch_rim-0.2.9-py3-none-any.whl
  • Upload date:
  • Size: 16.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.16

File hashes

Hashes for torch_rim-0.2.9-py3-none-any.whl
Algorithm Hash digest
SHA256 558e6b3d2151a0007200481187f6684866a1b5ec45b3b42f7deb47af7f212544
MD5 758751cfd3d9bcef03f484f44b79431a
BLAKE2b-256 0135f01fd7fbbf7642fddc6af323a0f0ce882574689c5260798f508766852284

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page