Skip to main content

A standalone implementation of a single block of Multimodal Diffusion Transformer (MMDiT) originally proposed in Scaling Rectified Flow Transformers for High-Resolution Image Synthesis (https://arxiv.org/abs/2403.03206) in PyTorch with minimal dependencies.

Project description

MMDiT-PyTorch

MMDiT-PyTorch is a lightweight and standalone PyTorch implementation of a single block from the Multimodal Diffusion Transformer (MMDiT), originally proposed in Scaling Rectified Flow Transformers for High-Resolution Image Synthesis.

MMDiT Architecture

This project focuses on simplicity and minimal dependencies to allow easy understanding and extensibility for research and experimentation.


🔍 Overview

MMDiT introduces a scalable and efficient Transformer-based architecture tailored for high-resolution image synthesis through rectified flows. This repository implements a single MMDiT block for educational and experimental purposes.

  • 📦 Single-block MMDiT in PyTorch
  • 🧠 Minimal and readable implementation
  • 🛠️ No training framework dependency

📦 Installation

Make sure you have Python 3.12+

Using pip

pip install mmdit-pytorch

From the source

git clone https://github.com/KennyStryker/mmdit-pytorch.git
cd mmdit-pytorch
poetry install

🚀 Usage

Make sure you have Python 3.12+ and Poetry installed.

import torch
from mmdit import MMDiTBlock

# Set embedding dimensions for each modality
dim_txt = 768         # Dimension of text embeddings
dim_img = 512         # Dimension of image embeddings
dim_timestep = 256    # Dimension of timestep embeddings (e.g., for conditioning)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Initialize the multimodal transformer block
mmdit_block = MMDiTBlock(
    dim_txt=dim_txt,
    dim_img=dim_img,
    dim_timestep=dim_timestep,
    qk_rmsnorm=True  # Use RMSNorm on query/key in attention (optional setting)
).to(device)

# Generate random embeddings for demonstration
txt_emb = torch.randn(1, 512, dim_txt).to(device)
img_emb = torch.randn(1, 1024, dim_img).to(device)
time_emb = torch.randn(1, dim_timestep).to(device)

# Forward pass through the multimodal transformer block
txt_out, img_out = mmdit_block(txt_emb, img_emb, time_emb)

print(f"Text output shape: {txt_out.shape}")
print(f"Image output shape: {img_out.shape}")

Citations

@article{arXiv,
    title   = {Scaling Rectified Flow Transformers for High-Resolution Image Synthesis},
    author  = {Patrick Esser, Sumith Kulal, Andreas Blattmann, Rahim Entezari, Jonas Müller, Harry Saini, Yam Levi, Dominik Lorenz, Axel Sauer, Frederic Boesel, Dustin Podell, Tim Dockhorn, Zion English, Kyle Lacey, Alex Goodwin, Yannik Marek, Robin Rombach},
    url     = {https://arxiv.org/abs/2403.03206}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mmdit_pytorch-0.2.7.tar.gz (5.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mmdit_pytorch-0.2.7-py3-none-any.whl (6.6 kB view details)

Uploaded Python 3

File details

Details for the file mmdit_pytorch-0.2.7.tar.gz.

File metadata

  • Download URL: mmdit_pytorch-0.2.7.tar.gz
  • Upload date:
  • Size: 5.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.0 CPython/3.12.9 Linux/5.15.167.4-microsoft-standard-WSL2

File hashes

Hashes for mmdit_pytorch-0.2.7.tar.gz
Algorithm Hash digest
SHA256 72a2c0272946b4a0fe671bd52d5b8d6c40b0f2106bfeb868ff79cd1a8c441e24
MD5 55d95fe26306c53b12bd3e2d842621e4
BLAKE2b-256 d8e8659e30565c23e492d9a85ab37066a8c70de15624e024ca674aaaa38178c9

See more details on using hashes here.

File details

Details for the file mmdit_pytorch-0.2.7-py3-none-any.whl.

File metadata

  • Download URL: mmdit_pytorch-0.2.7-py3-none-any.whl
  • Upload date:
  • Size: 6.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.0 CPython/3.12.9 Linux/5.15.167.4-microsoft-standard-WSL2

File hashes

Hashes for mmdit_pytorch-0.2.7-py3-none-any.whl
Algorithm Hash digest
SHA256 bddd21c405947ecfbd2db5a110d61784c79b9e931019c354a7b12a42c7ae4900
MD5 b3568b98f74a9bc91842ae762a7da88d
BLAKE2b-256 6dca32d856fcb7602f443f315b4869db59c15b83094df25a9e6c59bc0984cc7f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page