Skip to main content

MMDiT

Project description

MMDiT

Implementation of a single layer of the MMDiT, proposed by Esser et al. in Stable Diffusion 3, in Pytorch

Besides a straight reproduction, will also generalize to > 2 modalities, as I can envision an MMDiT for images, audio, and text.

Will also offer an improvised variant of self attention that adaptively selects the weights to use through learned gating. This idea came from adaptive convolutions applied by Kang et al. for GigaGAN.

Install

$ pip install mmdit

Usage

import torch
from mmdit import MMDiTBlock

# define mm dit block

block = MMDiTBlock(
    dim_cond = 256,
    dim_text = 768,
    dim_image = 512,
    qk_rmsnorm = True
)

# mock inputs

time_cond = torch.randn(2, 256)

text_tokens = torch.randn(2, 512, 768)
text_mask = torch.ones((2, 512)).bool()

image_tokens = torch.randn(2, 1024, 512)

# single block forward

text_tokens_next, image_tokens_next = block(
    time_cond = time_cond,
    text_tokens = text_tokens,
    text_mask = text_mask,
    image_tokens = image_tokens
)

A generalized version can be used as so

import torch
from mmdit.mmdit_generalized_pytorch import MMDiT

mmdit = MMDiT(
    depth = 2, 
    dim_modalities = (768, 512, 384),
    dim_cond = 256,
    qk_rmsnorm = True
)

# mock inputs

time_cond = torch.randn(2, 256)

text_tokens = torch.randn(2, 512, 768)
text_mask = torch.ones((2, 512)).bool()

video_tokens = torch.randn(2, 1024, 512)

audio_tokens = torch.randn(2, 256, 384)

# forward

text_tokens, video_tokens, audio_tokens = mmdit(
    modality_tokens = (text_tokens, video_tokens, audio_tokens),
    modality_masks = (text_mask, None, None),
    time_cond = time_cond,
)

Citations

@article{Esser2024ScalingRF,
    title   = {Scaling Rectified Flow Transformers for High-Resolution Image Synthesis},
    author  = {Patrick Esser and Sumith Kulal and A. Blattmann and Rahim Entezari and Jonas Muller and Harry Saini and Yam Levi and Dominik Lorenz and Axel Sauer and Frederic Boesel and Dustin Podell and Tim Dockhorn and Zion English and Kyle Lacey and Alex Goodwin and Yannik Marek and Robin Rombach},
    journal = {ArXiv},
    year    = {2024},
    volume  = {abs/2403.03206},
    url     = {https://api.semanticscholar.org/CorpusID:268247980}
}
@inproceedings{Darcet2023VisionTN,
    title   = {Vision Transformers Need Registers},
    author  = {Timoth'ee Darcet and Maxime Oquab and Julien Mairal and Piotr Bojanowski},
    year    = {2023},
    url     = {https://api.semanticscholar.org/CorpusID:263134283}
}
@article{Zhu2024HyperConnections,
    title   = {Hyper-Connections},
    author  = {Defa Zhu and Hongzhi Huang and Zihao Huang and Yutao Zeng and Yunyao Mao and Banggu Wu and Qiyang Min and Xun Zhou},
    journal = {ArXiv},
    year    = {2024},
    volume  = {abs/2409.19606},
    url     = {https://api.semanticscholar.org/CorpusID:272987528}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mmdit-0.3.0.tar.gz (149.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mmdit-0.3.0-py3-none-any.whl (10.5 kB view details)

Uploaded Python 3

File details

Details for the file mmdit-0.3.0.tar.gz.

File metadata

  • Download URL: mmdit-0.3.0.tar.gz
  • Upload date:
  • Size: 149.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.25

File hashes

Hashes for mmdit-0.3.0.tar.gz
Algorithm Hash digest
SHA256 977101788fda71d25407e1fd73d5f3858147a9cda50d997d5b82892decec5016
MD5 99f7e9c95d21af1b46b3fa6203d67f73
BLAKE2b-256 d55d6f22c46093f1ef675144aaef8971ab16cac98b39357160e1e1330add0d80

See more details on using hashes here.

File details

Details for the file mmdit-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: mmdit-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 10.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.25

File hashes

Hashes for mmdit-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 0059b18d9739af7e78edc585a57a9dedc893309e83b5c4c9e9c0c29e520f1959
MD5 815fb2ff0881b4b3b08a6d3efaf5755b
BLAKE2b-256 73f9f6a628053afa0423acb1f6293cf80249bb864d67cb4fe86b759def28ff7a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page