Skip to main content

Diffusion models in PyTorch

Project description

Azula's banner

Azula - Diffusion models in PyTorch

Azula is a Python package that implements diffusion models in PyTorch. Its goal is to unify the different formalisms and notations of the generative diffusion models literature into a single, convenient and hackable interface.

In the Avatar cartoon, Azula is a powerful fire and lightning bender ⚡️

Installation

The azula package is available on PyPI, which means it is installable via pip.

pip install azula

Alternatively, if you need the latest features, you can install it from the repository.

pip install git+https://github.com/probabilists/azula

Getting started

In Azula's formalism, a diffusion model is the composition of three elements: a noise schedule, a denoiser and a sampler.

  • A noise schedule is a mapping from a time $t \in [0, 1]$ to the signal scale $\alpha_t$ and the noise scale $\sigma_t$ in a perturbation kernel $p(X_t \mid X) = \mathcal{N}(X_t \mid \alpha_t X, \sigma_t^2 I)$ from a "clean" random variable $X \sim p(X)$ to a "noisy" random variable $X_t$.

  • A denoiser is a neural network trained to predict $X$ given $X_t$.

  • A sampler defines a series of transition kernels $q(X_s \mid X_t)$ from $t$ to $s < t$ based on a noise schedule and a denoiser. Simulating these transitions from $t = 1$ to $0$ samples approximately from $p(X)$.

This formalism is closely followed by Azula's API.

from azula.denoise import KarrasDenoiser
from azula.noise import VPSchedule
from azula.sample import DDPMSampler

# Choose the variance preserving (VP) noise schedule
schedule = VPSchedule()

# Initialize a denoiser
denoiser = KarrasDenoiser(
    backbone=CustomNN(in_features=5, out_features=5),
    schedule=schedule,
)

# Train to predict x given x_t
optimizer = torch.optim.Adam(denoiser.parameters(), lr=1e-3)

for x in train_loader:
    t = torch.rand((batch_size,))

    loss = denoiser.loss(x, t)
    loss.backward()

    optimizer.step()
    optimizer.zero_grad()

# Generate 64 points in 1000 steps
sampler = DDPMSampler(denoiser.eval(), steps=1000)

x1 = sampler.init((64, 5))
x0 = sampler(x1)

Alternatively, Azula's plugin interface allows to load pre-trained models and use them with the same convenient interface.

from azula.plugins import adm
from azula.sample import DDIMSampler

# Download weights from openai/guided-diffusion
denoiser = adm.load_model("imagenet_256x256")
denoiser.to("cuda")

# Generate a batch of 4 images
sampler = DDIMSampler(denoiser, steps=64)

x1 = sampler.init((4, 3, 256, 256), device="cuda")
x0 = sampler(x1)

images = torch.clip((x0 + 1) / 2, min=0, max=1)

For more information, check out the documentation and tutorials at azula.readthedocs.io.

Contributing

If you have a question, an issue or would like to contribute, please read our contributing guidelines.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

azula-0.10.2.tar.gz (67.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

azula-0.10.2-py3-none-any.whl (92.1 kB view details)

Uploaded Python 3

File details

Details for the file azula-0.10.2.tar.gz.

File metadata

  • Download URL: azula-0.10.2.tar.gz
  • Upload date:
  • Size: 67.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.11.3 {"installer":{"name":"uv","version":"0.11.3","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for azula-0.10.2.tar.gz
Algorithm Hash digest
SHA256 1adb9ef1653e0866e6bede98e3e08286c5537ae98e3ae4dfb474688f470c27c7
MD5 dd2261f9699b170a9a0fc5a8131e7ef3
BLAKE2b-256 095c5f2d8fffe0512d9273fa2707b7f4682994257b5c604b874363723ee052b1

See more details on using hashes here.

File details

Details for the file azula-0.10.2-py3-none-any.whl.

File metadata

  • Download URL: azula-0.10.2-py3-none-any.whl
  • Upload date:
  • Size: 92.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.11.3 {"installer":{"name":"uv","version":"0.11.3","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for azula-0.10.2-py3-none-any.whl
Algorithm Hash digest
SHA256 d08d9dd45efa0ec3c65c3474b08e2778a9ee7aedd6dbdd36070c8f8b15998fdc
MD5 77fd1990b2428e40482a46d076d37284
BLAKE2b-256 24b8d0c665539c86f2a9dd1db1874c276c6cece7c70386951326719cae1cda76

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page