Skip to main content

A research-oriented package for diffusion-based generative modeling with modular components

Project description

Diffuse

Denoising Process

A Python package designed for research in diffusion-based generative modeling with modular components that can be easily swapped and combined for experimentation.

Quick Start

Unconditional Generation

import jax
from diffuse import Flow, Predictor, Denoiser
from diffuse.integrators import EulerIntegrator
from diffuse.timer import VpTimer

# Define flow matching model
flow = Flow(tf=1.0)

# Create predictor with your neural network
predictor = Predictor(
    model=flow,
    network=network_fn,
    prediction_type="velocity",  # or "noise", "sample"
)

# Setup timer and integrator
timer = VpTimer(n_steps=100, eps=0.001, tf=1.0)
integrator = EulerIntegrator(model=flow, timer=timer)

# Create denoiser
denoiser = Denoiser(
    integrator=integrator,
    model=flow,
    predictor=predictor,
    x0_shape=(height, width, channels),
)

# Generate samples
key = jax.random.PRNGKey(42)
state, trajectory = denoiser.generate(
    rng_key=key,
    n_steps=100,
    n_particles=10,
    keep_history=False,
)

# Single denoising step
next_state = denoiser.step(rng_key, state)  # x_t -> x_{t-1}

Conditional Generation with DPS

import jax
from diffuse import Flow, Predictor
from diffuse.integrators import EulerIntegrator
from diffuse.denoisers import DPSDenoiser
from diffuse.timer import VpTimer

# Define flow matching model
flow = Flow(tf=1.0)

# Create predictor
predictor = Predictor(
    model=flow,
    network=network_fn,
    prediction_type="velocity",
)

# Setup timer and integrator
timer = VpTimer(n_steps=100, eps=0.001, tf=1.0)
integrator = EulerIntegrator(model=flow, timer=timer)

# DPS denoiser for conditional sampling
dps = DPSDenoiser(
    integrator=integrator,
    model=flow,
    predictor=predictor,
    forward_model=forward_model,
    x0_shape=(height, width, channels),
)

# Generate conditional samples
key = jax.random.PRNGKey(42)
state, trajectory = dps.generate(
    rng_key=key,
    measurement_state=measurement_state,
    n_steps=100,
    n_particles=10,
)

# Single conditional step
next_state = dps.step(rng_key, state, measurement_state)  # x_t -> x_{t-1}

Features

  • Flow Matching & Diffusion: Support for both flow-based models and SDE-based diffusion processes
  • Flexible Prediction Types: Velocity, noise, and sample prediction for different model architectures
  • Timer-aware Integration: Advanced timing schemes (VpTimer, HeunTimer, FluxTimer) for improved sampling
  • Multiple Integrators: EulerIntegrator, DDIMIntegrator, DPM++, Heun, and more
  • Conditional Sampling: DPS (Diffusion Posterior Sampling) for inverse problems and conditioning
  • Modular Design: Mix and match models, predictors, denoisers, integrators, and timers
  • JAX-Powered: Efficient computation with automatic differentiation and JIT compilation
  • Research-Focused: Built for experimentation with new diffusion and flow matching techniques
  • Examples: MNIST, Gaussian mixtures, text-to-image generation, and more

Installation

pip install diffuse-jax

or with uv

uv add diffuse-jax

Examples

See the examples/ directory for implementations including:

  • MNIST digit generation
  • Gaussian mixture modeling
  • Conditional sampling demonstrations

Citation

If you use Diffuse in your research, please cite the library:

@software{diffuse2024,
  title = {Diffuse: A modular diffusion model library},
  author = {Iollo, J., Oudoumanessah G.},
  year = {2025},
  url = {https://github.com/jcopo/diffuse}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

diffuse_jax-0.1.1.tar.gz (41.0 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

diffuse_jax-0.1.1-py3-none-any.whl (101.9 kB view details)

Uploaded Python 3

File details

Details for the file diffuse_jax-0.1.1.tar.gz.

File metadata

  • Download URL: diffuse_jax-0.1.1.tar.gz
  • Upload date:
  • Size: 41.0 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.11

File hashes

Hashes for diffuse_jax-0.1.1.tar.gz
Algorithm Hash digest
SHA256 5954e6f12ab35e9f7d57a03bba472e88fdf0a23fe95d65c4ce9cd9a649c5bb13
MD5 408f4e3f0cf1dd6b4d49e8a436c6e37e
BLAKE2b-256 49585f96a6e5e9011e7c8348b09fc6fa4657d439dd42495f66c42b573a037ff0

See more details on using hashes here.

File details

Details for the file diffuse_jax-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: diffuse_jax-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 101.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.11

File hashes

Hashes for diffuse_jax-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 66b151732c88e21f91e98da13ac23d2fb573847faf023400e4290765ea06dd85
MD5 80e5481a7be7445497dbf023e3f0a1fc
BLAKE2b-256 987326fbdc5a210cd184fa07ae6f9c58f112b5c55fd3537d9f2c5426dfaa3cb7

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page