Skip to main content

Argmax Model Optimization Toolkit for Diffusion Models.

Project description

DiffusionKit

Latest Python Version

Run Diffusion Models on Apple Silicon with Core ML and MLX

This repository comprises

  • diffusionkit, a Python package for converting PyTorch models to Core ML format and performing image generation with MLX in Python
  • DiffusionKit, a Swift package for on-device inference of diffusion models using Core ML and MLX

Installation

The following installation steps are required for:

  • MLX inference
  • PyTorch to Core ML model conversion

Python Environment Setup

conda create -n diffusionkit python=3.11 -y
conda activate diffusionkit
cd /path/to/diffusionkit/repo
pip install -e .

Hugging Face Hub Credentials

Click to expand

Stable Diffusion 3 requires users to accept the terms before downloading the checkpoint.

FLUX.1-dev also requires users to accept the terms before downloading the checkpoint.

Once you accept the terms, sign in with your Hugging Face hub READ token as below:

[!IMPORTANT] If using a fine-grained token, it is also necessary to edit permissions to allow Read access to contents of all public gated repos you can access

huggingface-cli login --token YOUR_HF_HUB_TOKEN

Converting Models from PyTorch to Core ML

Click to expand

Step 1: Follow the installation steps from the previous section

Step 2: Verify you've accepted the StabilityAI license terms and have allowed gated access on your HuggingFace token

Step 3: Prepare the denoise model (MMDiT) Core ML model files (.mlpackage)

python -m python.src.diffusionkit.tests.torch2coreml.test_mmdit --sd3-ckpt-path stabilityai/stable-diffusion-3-medium --model-version 2b -o <output-mlpackages-directory> --latent-size {64, 128}

Step 4: Prepare the VAE Decoder Core ML model files (.mlpackage)

python -m python.src.diffusionkit.tests.torch2coreml.test_vae --sd3-ckpt-path stabilityai/stable-diffusion-3-medium -o <output-mlpackages-directory> --latent-size {64, 128}

Note:

  • --sd3-ckpt-path can be a path any HuggingFace repo (e.g. stabilityai/stable-diffusion-3-medium) OR a path to a local sd3_medium.safetensors file

Image Generation with Python MLX

Click to expand

CLI

Most simple:

diffusionkit-cli --prompt "a photo of a cat" --output-path </path/to/output/image.png>

Some notable optional arguments for:

  • Reproduciblity of results, use --seed
  • image-to-image, use --image-path (path to input image) and --denoise (value between 0. and 1.)
  • Enabling T5 encoder in SD3, use --t5 (FLUX must use T5 regardless of this argument)
  • Different resolutions, use --height and --width
  • Using a local checkpoint, use --local-ckpt </path/to/ckpt.safetensors> (e.g. ~/models/stable-diffusion-3-medium/sd3_medium.safetensors).

Please refer to the help menu for all available arguments: diffusionkit-cli -h.

Note: When using FLUX.1-dev, verify you've accepted the FLUX.1-dev licence and have allowed gated access on your HuggingFace token

Code

For Stable Diffusion 3:

from diffusionkit.mlx import DiffusionPipeline
pipeline = DiffusionPipeline(
  shift=3.0,
  use_t5=False,
  model_version="argmaxinc/mlx-stable-diffusion-3-medium",
  low_memory_mode=True,
  a16=True,
  w16=True,
)

For FLUX:

from diffusionkit.mlx import FluxPipeline
pipeline = FluxPipeline(
  shift=1.0,
  model_version="argmaxinc/mlx-FLUX.1-schnell", # model_version="argmaxinc/mlx-FLUX.1-dev" for FLUX.1-dev
  low_memory_mode=True,
  a16=True,
  w16=True,
)

Finally, to generate the image, use the generate_image() function:

HEIGHT = 512
WIDTH = 512
NUM_STEPS = 4  #  4 for FLUX.1-schnell, 50 for SD3 and FLUX.1-dev
CFG_WEIGHT = 0. # for FLUX.1-schnell, 5. for SD3

image, _ = pipeline.generate_image(
  "a photo of a cat",
  cfg_weight=CFG_WEIGHT,
  num_steps=NUM_STEPS,
  latent_size=(HEIGHT // 8, WIDTH // 8),
)

Some notable optional arguments:

  • For image-to-image, use image_path (path to input image) and denoise (value between 0. and 1.) input variables.
  • For seed, use seed input variable.
  • For negative prompt, use negative_text input variable.

The generated image can be saved with:

image.save("path/to/save.png")

Image Generation with Swift

Click to expand

Core ML Swift

Apple Core ML Stable Diffusion is the initial Core ML backend for DiffusionKit. Stable Diffusion 3 support is upstreamed to that repository while we build the holistic Swift inference package.

MLX Swift

🚧

License

DiffusionKit is released under the MIT License. See LICENSE for more details.

Citation

If you use DiffusionKit for something cool or just find it useful, please drop us a note at info@takeargmax.com!

If you use DiffusionKit for academic work, here is the BibTeX:

@misc{diffusionkit-argmax,
   title = {DiffusionKit},
   author = {Argmax, Inc.},
   year = {2024},
   URL = {https://github.com/argmaxinc/DiffusionKit}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

diffusionkit-0.5.2.tar.gz (45.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

diffusionkit-0.5.2-py3-none-any.whl (53.3 kB view details)

Uploaded Python 3

File details

Details for the file diffusionkit-0.5.2.tar.gz.

File metadata

  • Download URL: diffusionkit-0.5.2.tar.gz
  • Upload date:
  • Size: 45.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.9.21

File hashes

Hashes for diffusionkit-0.5.2.tar.gz
Algorithm Hash digest
SHA256 0293195a3852c128771a8339450664efa254418583831ac4a24a13f1db72a4f6
MD5 74ce0e6bff146e6f9f379031310b15c0
BLAKE2b-256 578b71df5c4a3de012eab4e69e25ddac2a6161fb17870b7a880a5e6499c9c3c7

See more details on using hashes here.

File details

Details for the file diffusionkit-0.5.2-py3-none-any.whl.

File metadata

  • Download URL: diffusionkit-0.5.2-py3-none-any.whl
  • Upload date:
  • Size: 53.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.9.21

File hashes

Hashes for diffusionkit-0.5.2-py3-none-any.whl
Algorithm Hash digest
SHA256 199ec8e637d53be0b7ab92b62a4604f31e4d7c41da1bb1b29e0c4658932070c3
MD5 fec897a33d818d9193bd27a10f3aa8ca
BLAKE2b-256 8ad3585e48dc35263bd44ee74ab3a3ae9e63d3adc57c17f329d78addd9f9c33c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page