Skip to main content

Probabilistic programming framework for signal inference algorithms that operate regardless of the underlying grids and their resolutions

Project description

NIFTy - Numerical Information Field Theory

pipeline status coverage report

NIFTy project homepage: ift.pages.mpcdf.de/nifty | Found a bug? github.com/nifty-ppl/nifty/issues | Need help? github.com/nifty-ppl/nifty/discussions

Summary

Description

NIFTy, "Numerical Information Field Theory", is a Bayesian imaging library. It is designed to infer the million to billion dimensional posterior distribution in the image space from noisy input data. At the core of NIFTy lies a set of powerful Gaussian Process (GP) models and accurate Variational Inference (VI) algorithms.

Gaussian Processes

One standard tool from the NIFTy toolbox are the structured GP models. These models usually rely on the harmonic domain being easily accessible, e.g. for pixels spaced on a regular Cartesian grid, the natural choice to represent a stationary kernel is the Fourier domain. An example, initializing a non-parameteric GP prior for a $128 \times 128$ space with unit volume is shown in the following.

from nifty8 import re as jft

dims = (128, 128)
cfm = jft.CorrelatedFieldMaker("cf")
cfm.set_amplitude_total_offset(offset_mean=2, offset_std=(1e-1, 3e-2))
cfm.add_fluctuations(  # Axis over which the kernle is defined
  dims,
  distances=tuple(1.0 / d for d in dims),
  fluctuations=(1.0, 5e-1),
  loglogavgslope=(-3.0, 2e-1),
  flexibility=(1e0, 2e-1),
  asperity=(5e-1, 5e-2),
  prefix="ax1",
  non_parametric_kind="power",
)
correlated_field = cfm.finalize()  # forward model for a GP prior

Not all problems are well described by regularly spaced pixels. For more complicated pixel spacings, NIFTy features Iterative Charted Refinement, a GP model for arbitrarily deformed spaces. This model exploits nearest neighbor relations on various coarsening of the discretized modeled space and runs very efficiently on GPUs. For one dimensional problems with arbitrarily spaced pixel, NIFTy also implements multiple flavors of Gauss-Markov processes.

Building Up Complex Models

Models are rarely just a GP prior. Commonly, a model contains at least several non-linearities that transform the GP prior or combine it with other random variables. For building more complex models, NIFTy provides a Model class that offers a somewhat familiar object-oriented design yet is fully JAX compatible and functional under the hood. The following code showcases such a model that builds up a slightly more involved model using the objects from the previous example.

from jax import numpy as jnp


class Forward(jft.Model):
  def __init__(self, correlated_field):
    self._cf = correlated_field
    # Track a method with which a random input for the model. This is not
    # strictly required but is usually handy when building deep models.
    super().__init__(init=correlated_field.init)

  def __call__(self, x):
    # NOTE, any kind of masking of the output, non-linear and linear
    # transformation could be carried out here. Models can also combined and
    # nested in any way and form.
    return jnp.exp(self._cf(x))


forward = Forward(correlated_field)

data = jnp.load("data.npy")
lh = jft.Poissonian(data).amend(forward)

All GP models in NIFTy as well as all likelihoods are models and their attributes are exposed to JAX, meaning JAX understands what it means if a computation involves self or other models. In other words, correlated_field, forward, and lh from the code snippets shown here are all so-called pytrees in JAX and, e.g., the following is valid code jax.jit(lambda l, x: l(x))(lh, x0) with x0 some arbitrarily chosen valid input to lh. Inspired by equinox, individual attributes of the class can be marked as non-static or static via dataclass.field(metadata=dict(static=...)) for the purpose of compiling. Depending on the value, JAX will either treat the attribute as unknown placeholder or as known concrete attribute and potentially inline it during compiles. This mechanism is extensively used in likelihoods to avoid inlining large constants such as the data and avoiding expensive re-compiles whenever possible.

Variational Inference

NIFTy is built for models with millions to billions of degrees of freedom. To probe the posterior efficiently and accurately, NIFTy relies on VI. At the core of the VI methods lie an alternating procedure in which we switch between optimizing the Kullback–Leibler divergence for a specific shape of the variational posterior and updating the shape of the variational posterior.

A typical minimization with NIFTy is shown in the following. It retrieves six independent, antithetically mirrored samples from the approximate posterior via 25 iterations of alternating between optimization and sample adaption. The final result is stored in the samples variable. A convenient one-shot wrapper for the below is jft.optimize_kl. By virtue of all modeling tools in NIFTy being written in JAX, it is also possible to combine NIFTy tools with blackjax or any other posterior sampler in the JAX ecosystem.

from jax import random

key = random.PRNGKey(42)
key, sk = random.split(key, 2)
# NIFTy is agnostic w.r.t. the type of input it gets as long as it supports core
# arithmetic properties. Tell NIFTy to treat our parameter dictionary as a
# vector.
samples = jft.Samples(pos=jft.Vector(lh.init(sk)), samples=None, keys=None)

delta = 1e-4
absdelta = delta * jft.size(samples.pos)

opt_vi = jft.OptimizeVI(lh, n_total_iterations=25)
opt_vi_st = opt_vi.init_state(
  key,
  # Typically on the order of 2-12
  n_samples=lambda i: 1 if i < 2 else (2 if i < 4 else 6),
  # Arguments for the conjugate gradient method used to drawing samples from
  # an implicit covariance matrix
  draw_linear_kwargs=dict(
    cg_name="SL", cg_kwargs=dict(absdelta=absdelta / 10.0, maxiter=100)
  ),
  # Arguements for the minimizer in the nonlinear updating of the samples
  nonlinearly_update_kwargs=dict(
    minimize_kwargs=dict(
      name="SN", xtol=delta, cg_kwargs=dict(name=None), maxiter=5
    )
  ),
  # Arguments for the minimizer of the KL-divergence cost potential
  kl_kwargs=dict(minimize_kwargs=dict(name="M", xtol=delta, maxiter=35)),
  sample_mode=lambda i: "nonlinear_resample" if i < 3 else "nonlinear_update",
)
for i in range(opt_vi.n_total_iterations):
  print(f"Iteration {i+1:04d}")
  # Continuously updates the samples of the approximate posterior distribution
  samples, opt_vi_st = opt_vi.update(samples, opt_vi_st)
  print(opt_vi.get_status_message(samples, opt_vi_st))

Installation

If you only want to use NIFTy in your projects, but not change its source code, the easiest way to install NIFTy is via pip:

pip install --user 'nifty8[re]'

The line above installs the optional JAX backend termed NIFTy.re in addition to the numpy-based NIFTy.

If you might want to adapt the NIFTy source code, we suggest installing NIFTy as editable python package using the following commands:

git clone -b NIFTy_8 https://gitlab.mpcdf.mpg.de/ift/nifty.git
cd nifty
pip install --user --editable '.[re]'

First Steps

For a quick start, you can browse through the informal introduction or dive into NIFTy by running one of the demonstrations, e.g.:

python demos/0_intro.py

Contributing

Contributions are very welcome! Feel free to reach out early on in the development process e.g. by opening a draft PR or filing an issue, we are happy to help in the development and provide feedback along the way. Please open an issue first if you think your PR changes current code substantially. Please format your code according to the existing style used in the file or with black for new files. To advertise your changes, please update the public documentation and the ChangeLog if your PR affects the public API. Please add appropriate tests to your PR.

Building the Documentation

NIFTy's documentation is generated via Sphinx and is available online at ift.pages.mpcdf.de/nifty.

To build the documentation locally, run:

sudo apt-get install dvipng jupyter-nbconvert texlive-latex-base texlive-latex-extra
pip install --user sphinx jupytext pydata-sphinx-theme myst-parser
cd <nifty_directory>
bash docs/generate.sh

To view the documentation, open docs/build/index.html in your browser.

Note: Make sure that you reinstall nifty after each change since sphinx imports nifty from the Python path.

Run the tests

To run the tests, install all optional requirements 'nifty8[all]' and afterwards run pytest (and create a coverage report) via

pytest --cov=nifty8 test

If you are writing your own tests, it is often sufficient to just install the optional test dependencies 'nifty8[test]'. However, to run the full test suit including tests of optional functionality, it is assumed that all optional dependencies are installed.

Licensing terms

Most of NIFTy is licensed under the terms of the GPLv3 license with NIFTy.re being a notable exception. NIFTy.re is licensed under GPL-2.0+ OR BSD-2-Clause. All of NIFTy is distributed without any warranty.

Citing NIFTy

To cite the probabilistic programming framework NIFTy, please use the citation provided below. In addition to citing NIFTy itself, please consider crediting the Gaussian process models you used and the inference machinery. See the corresponding entry on citing NIFTy in the documentation for further details.

@article{niftyre,
  title     = {Re-Envisioning Numerical Information Field Theory (NIFTy.re): A Library for Gaussian Processes and Variational Inference},
  author    = {Gordian Edenhofer and Philipp Frank and Jakob Roth and Reimar H. Leike and Massin Guerdi and Lukas I. Scheel-Platz and Matteo Guardiani and Vincent Eberle and Margret Westerkamp and Torsten A. Enßlin},
  year      = {2024},
  journal   = {Journal of Open Source Software},
  publisher = {The Open Journal},
  volume    = {9},
  number    = {98},
  pages     = {6593},
  doi       = {10.21105/joss.06593},
  url       = {https://doi.org/10.21105/joss.06593},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nifty8-8.5.2.tar.gz (1.2 MB view hashes)

Uploaded Source

Built Distribution

nifty8-8.5.2-py3-none-any.whl (386.9 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page