Skip to main content

Easy no-frills implementations of common abstractions for diffusion models.

Project description

DiffusionLab Logo

Documentationpip install diffusionlabllms.txt

TestsLinting and Formatting

What is DiffusionLab?

TL;DR: DiffusionLab is a laboratory for quickly and easily experimenting with diffusion models.

  • DiffusionLab IS:
    • A lightweight and flexible set of JAX APIs for smaller-scale diffusion model training and sampling.
    • An implementation of the mathematical foundations of diffusion models.
  • DiffusionLab IS NOT:
    • A replacement for HuggingFace Diffusers.
    • A codebase for SoTA diffusion model training or inference.

Example

The following code compares three sample sets:

  • One drawn from the ground truth distribution, which is a Gaussian mixture model;
  • One sampled using DDIM with the ground-truth denoiser for the Gaussian mixture model;
  • One sampled using DDIM with the ground-truth denoiser for the empirical distribution of the first sample set.
import jax 
from jax import numpy as jnp, vmap
from diffusionlab.dynamics import VariancePreservingProcess
from diffusionlab.schedulers import UniformScheduler
from diffusionlab.samplers import DDMSampler
from diffusionlab.distributions.gmm.gmm import GMM
from diffusionlab.distributions.empirical import EmpiricalDistribution
from diffusionlab.vector_fields import VectorFieldType 

key = jax.random.key(1)

dim = 10
num_samples_ground_truth = 100
num_samples_ddim = 50

num_components = 3
priors = jnp.ones(num_components) / num_components
key, subkey = jax.random.split(key)
means = jax.random.normal(subkey, (num_components, dim))
key, subkey = jax.random.split(key)
cov_factors = jax.random.normal(subkey, (num_components, dim, dim))
covs = jax.vmap(lambda A: A @ A.T)(cov_factors)

gmm = GMM(means, covs, priors)

key, subkey = jax.random.split(key)
X_ground_truth, y_ground_truth = gmm.sample(key, num_samples_ground_truth)

num_steps = 100
t_min = 0.001 
t_max = 0.999

diffusion_process = VariancePreservingProcess()
scheduler = UniformScheduler()
ts = scheduler.get_ts(t_min=t_min, t_max=t_max, num_steps=num_steps)

key, subkey = jax.random.split(key)
X_noise = jax.random.normal(subkey, (num_samples_ddim, dim))

zs = jax.random.normal(key, (num_samples_ddim, num_steps, dim))

ground_truth_sampler = DDMSampler(diffusion_process, lambda x, t: gmm.x0(x, t, diffusion_process), VectorFieldType.X0, use_stochastic_sampler=False)
X_ddim_ground_truth = jax.vmap(lambda x_init, z: ground_truth_sampler.sample(x_init, z, ts))(X_noise, zs)

empirical_distribution = EmpiricalDistribution([(X_ground_truth, y_ground_truth)])
empirical_sampler = DDMSampler(diffusion_process, lambda x, t: empirical_distribution.x0(x, t, diffusion_process), VectorFieldType.X0, use_stochastic_sampler=False)
X_ddim_empirical = jax.vmap(lambda x_init, z: empirical_sampler.sample(x_init, z, ts))(X_noise, zs)

min_distance_to_gt_empirical = vmap(lambda x: jnp.min(vmap(lambda x_gt: jnp.linalg.norm(x - x_gt))(X_ground_truth)))(X_ddim_empirical)
min_distance_to_gt_ground_truth = vmap(lambda x: jnp.min(vmap(lambda x_gt: jnp.linalg.norm(x - x_gt))(X_ground_truth)))(X_ddim_ground_truth)

print(f"Min distance to ground truth samples from DDIM samples using empirical denoiser: {min_distance_to_gt_empirical}")
print(f"Min distance to ground truth samples from DDIM samples using ground truth denoiser: {min_distance_to_gt_ground_truth}")

Note on Frameworks

DiffusionLab versions < 3.0 use a PyTorch backbone. Here is a permalink to the GitHub pages and llms.txt for the old version.

DiffusionLab versions >= 3.0 use a JAX backbone.

Citation Information

You can use the following Bibtex:

@Misc{pai25diffusionlab,
    author = {Pai, Druv},
    title = {DiffusionLab},
    howpublished = {\url{https://github.com/DruvPai/DiffusionLab}},
    year = {2025}
}

Many thanks!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

diffusionlab-3.0.4.tar.gz (190.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

diffusionlab-3.0.4-py3-none-any.whl (31.3 kB view details)

Uploaded Python 3

File details

Details for the file diffusionlab-3.0.4.tar.gz.

File metadata

  • Download URL: diffusionlab-3.0.4.tar.gz
  • Upload date:
  • Size: 190.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.6.14

File hashes

Hashes for diffusionlab-3.0.4.tar.gz
Algorithm Hash digest
SHA256 45245f317815c54009aa28ce860b845896329cd46d9c5eaca4f271f939c36ee2
MD5 1bcfc038f4e1268c577139ca867259c6
BLAKE2b-256 50b9410ae95e7de80a20e59902950b7e535ba9437f9ef81fb3a6ae717ac95a11

See more details on using hashes here.

File details

Details for the file diffusionlab-3.0.4-py3-none-any.whl.

File metadata

File hashes

Hashes for diffusionlab-3.0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 95276c88fbf09d7526467cab70e6ddbe640f58238ce7a912fd2d152df8931dcf
MD5 acaa12b414dc936fcf4a50b843f0479f
BLAKE2b-256 d332d9d165c1a9fd459f9ef1a1b60103ad87b6d4a94b259bb5fa25eade6371cd

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page