Skip to main content

Hackable Diffusion is a library for custom diffusion models.

Project description

Hackable diffusion

Hackable Diffusion is a modular toolbox written in Jax to experiment and educate around Diffusion modeling.

Philosophy

The core philosophy of this library is hackability. It is designed from the ground up to be modular, composable, and easy to modify, enabling rapid experimentation with new research ideas. Key principles include:

  • Composition over Configuration: Build models and training loops by composing small, well-defined Python objects.
  • Clear Separation of Concerns: The codebase is organized into logical sub-libraries for architecture, corruption, inference, loss, and sampling.
  • Native Multimodality: The library has first-class support for handling multimodal data (e.g., images and text) through a consistent "Nested" component pattern that applies different diffusion parameters to different parts of the data.

Tutorials

The notebooks/ directory contains several tutorials to get you started:

  • 2d_training.ipynb: A minimal example on a 2D toy dataset.
  • mnist.ipynb: Standard image diffusion on MNIST.
  • mnist_discrete.ipynb: An example of discrete diffusion.
  • mnist_multimodal.ipynb: A showcase of the multimodal capabilities, generating images and labels jointly.

Training configs

The kdiff/configs/ directory contains example configurations for training:

  • mnist_unet.py: Standard diffusion training configuration on MNIST.

To run a config locally, create a small launcher script (e.g. train.py):

import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import multiprocessing
from kauldron import konfig

def main():
    import importlib.util
    spec = importlib.util.spec_from_file_location(
        "config", "kdiff/configs/mnist_unet.py"
    )
    config_module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(config_module)

    cfg = config_module.get_config()
    cfg.workdir = "/tmp/mnist_workdir"
    trainer = konfig.resolve(cfg)
    trainer.train()

if __name__ == "__main__":
    multiprocessing.set_start_method("spawn", force=True)
    main()

Note: XLA_PYTHON_CLIENT_PREALLOCATE=false must be set before importing JAX to prevent GPU memory preallocation conflicts with data loading workers. The if __name__ == "__main__" guard is required for multiprocessing compatibility.

Installation

To install the necessary dependencies, you can use pip with the provided pyproject.toml file:

pip install -e .

To install development dependencies (for running tests), use:

pip install -e .[dev]

This will install libraries such as JAX, Flax, and other utilities required to run the code.

Disclaimer

Copyright 2025 Google LLC
All software is licensed under the Apache License, Version 2.0 (Apache 2.0); you may not use this file except in compliance with the Apache 2.0 license. You may obtain a copy of the Apache 2.0 license at: https://www.apache.org/licenses/LICENSE-2.0 All other materials are licensed under the Creative Commons Attribution 4.0 International License (CC-BY). You may obtain a copy of the CC-BY license at: https://creativecommons.org/licenses/by/4.0/legalcode Unless required by applicable law or agreed to in writing, all software and materials distributed here under the Apache 2.0 or CC-BY licenses are distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the licenses for the specific language governing permissions and limitations under those licenses. This is not an official Google product.

Citing Hackable Diffusion

If Hackable Diffusion was helpful for a publication, please cite this repository: (authors are included in the alphabetical order by the last name)

@software{hackable_diffusion2026github,
  author = {Crepy, Clement and De Bortoli, Valentin and Galashov, Alexandre and Greff, Klaus and Korshunova, Ira},
  title = {{Hackable Diffusion}: A modular toolbox written in Jax to experiment and educate around Diffusion modeling.},
  url = {https://github.com/google/hackable_diffusion},
  version = {1.0.1},
  year = {2026},
  note = {Authors listed in alphabetical order by the last name},
}

This is not an officially supported Google product.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

hackable_diffusion-1.0.1.tar.gz (3.8 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

hackable_diffusion-1.0.1-py3-none-any.whl (3.9 MB view details)

Uploaded Python 3

File details

Details for the file hackable_diffusion-1.0.1.tar.gz.

File metadata

  • Download URL: hackable_diffusion-1.0.1.tar.gz
  • Upload date:
  • Size: 3.8 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.13

File hashes

Hashes for hackable_diffusion-1.0.1.tar.gz
Algorithm Hash digest
SHA256 900910626fb68f60c683740df36765ee5a375efb4d3e279401e27404076f037e
MD5 4242513f8f1aacdb8243eac2446cec36
BLAKE2b-256 81fb219affe0638c0f8476d8e451a838e6dcf377bc930189b82ed82cc0b05aa1

See more details on using hashes here.

File details

Details for the file hackable_diffusion-1.0.1-py3-none-any.whl.

File metadata

File hashes

Hashes for hackable_diffusion-1.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 f79c0f10b6c16978ac038becb66c92f7d03750032f8f95aef3d92a8bd697a16c
MD5 2d34ff9be6ba9fbddbc8d5676da7079a
BLAKE2b-256 dbc60a1fc498fd7a2743d3cc2b0adcc464c310c73098471a24a1321469fd6366

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page