A research-oriented package for diffusion-based generative modeling with modular components
Project description
Diffuse
A Python package designed for research in diffusion-based generative modeling with modular components that can be easily swapped and combined for experimentation.
Quick Start
Unconditional Generation
import jax
from diffuse import Flow, Predictor, Denoiser
from diffuse.integrators import EulerIntegrator
from diffuse.timer import VpTimer
# Define flow matching model
flow = Flow(tf=1.0)
# Create predictor with your neural network
predictor = Predictor(
model=flow,
network=network_fn,
prediction_type="velocity", # or "noise", "sample"
)
# Setup timer and integrator
timer = VpTimer(n_steps=100, eps=0.001, tf=1.0)
integrator = EulerIntegrator(model=flow, timer=timer)
# Create denoiser
denoiser = Denoiser(
integrator=integrator,
model=flow,
predictor=predictor,
x0_shape=(height, width, channels),
)
# Generate samples
key = jax.random.PRNGKey(42)
state, trajectory = denoiser.generate(
rng_key=key,
n_steps=100,
n_particles=10,
keep_history=False,
)
# Single denoising step
next_state = denoiser.step(rng_key, state) # x_t -> x_{t-1}
Conditional Generation with DPS
import jax
from diffuse import Flow, Predictor
from diffuse.integrators import EulerIntegrator
from diffuse.denoisers import DPSDenoiser
from diffuse.timer import VpTimer
# Define flow matching model
flow = Flow(tf=1.0)
# Create predictor
predictor = Predictor(
model=flow,
network=network_fn,
prediction_type="velocity",
)
# Setup timer and integrator
timer = VpTimer(n_steps=100, eps=0.001, tf=1.0)
integrator = EulerIntegrator(model=flow, timer=timer)
# DPS denoiser for conditional sampling
dps = DPSDenoiser(
integrator=integrator,
model=flow,
predictor=predictor,
forward_model=forward_model,
x0_shape=(height, width, channels),
)
# Generate conditional samples
key = jax.random.PRNGKey(42)
state, trajectory = dps.generate(
rng_key=key,
measurement_state=measurement_state,
n_steps=100,
n_particles=10,
)
# Single conditional step
next_state = dps.step(rng_key, state, measurement_state) # x_t -> x_{t-1}
Features
- Flow Matching & Diffusion: Support for both flow-based models and SDE-based diffusion processes
- Flexible Prediction Types: Velocity, noise, and sample prediction for different model architectures
- Timer-aware Integration: Advanced timing schemes (VpTimer, HeunTimer, FluxTimer) for improved sampling
- Multiple Integrators: EulerIntegrator, DDIMIntegrator, DPM++, Heun, and more
- Conditional Sampling: DPS (Diffusion Posterior Sampling) for inverse problems and conditioning
- Modular Design: Mix and match models, predictors, denoisers, integrators, and timers
- JAX-Powered: Efficient computation with automatic differentiation and JIT compilation
- Research-Focused: Built for experimentation with new diffusion and flow matching techniques
- Examples: MNIST, Gaussian mixtures, text-to-image generation, and more
Installation
pip install diffuse-jax
or with uv
uv add diffuse-jax
Examples
See the examples/ directory for implementations including:
- MNIST digit generation
- Gaussian mixture modeling
- Conditional sampling demonstrations
Citation
If you use Diffuse in your research, please cite the library:
@software{diffuse2024,
title = {Diffuse: A modular diffusion model library},
author = {Iollo, J., Oudoumanessah G.},
year = {2025},
url = {https://github.com/jcopo/diffuse}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file diffuse_jax-0.1.1.tar.gz.
File metadata
- Download URL: diffuse_jax-0.1.1.tar.gz
- Upload date:
- Size: 41.0 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
5954e6f12ab35e9f7d57a03bba472e88fdf0a23fe95d65c4ce9cd9a649c5bb13
|
|
| MD5 |
408f4e3f0cf1dd6b4d49e8a436c6e37e
|
|
| BLAKE2b-256 |
49585f96a6e5e9011e7c8348b09fc6fa4657d439dd42495f66c42b573a037ff0
|
File details
Details for the file diffuse_jax-0.1.1-py3-none-any.whl.
File metadata
- Download URL: diffuse_jax-0.1.1-py3-none-any.whl
- Upload date:
- Size: 101.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
66b151732c88e21f91e98da13ac23d2fb573847faf023400e4290765ea06dd85
|
|
| MD5 |
80e5481a7be7445497dbf023e3f0a1fc
|
|
| BLAKE2b-256 |
987326fbdc5a210cd184fa07ae6f9c58f112b5c55fd3537d9f2c5426dfaa3cb7
|