Diffusion meets sampling
Project description
fusions
Diffusion meets (nested) sampling
A miniminal implementation of generative diffusion models in JAX (Flax). Tuned for usage in building emulators for scientific models, particularly where MCMC sampling is tractable and used.
from fusions.cfm import CFM
from lsbi.model import LinearMixtureModel
from anesthetic import MCMCSamples
import matplotlib.pyplot as plt
import numpy as np
dims = 5
Model = LinearMixtureModel(
M=np.stack([np.eye(dims), -np.eye(dims)]),
C=np.eye(dims)*0.1,
)
data = Model.evidence().rvs(1)
diffusion = CFM(Model.prior())
# diffusion = CFM(dims)
diffusion.train(Model.posterior(data).rvs(1000))
a = MCMCSamples(Model.posterior(data).rvs(500)).plot_2d(np.arange(dims))
MCMCSamples(diffusion.rvs(500)).plot_2d(a)
plt.show()
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
fusions-0.4.0.tar.gz
(10.5 kB
view hashes)
Built Distribution
fusions-0.4.0-py3-none-any.whl
(12.5 kB
view hashes)