Skip to main content

MMDiT

Project description

MMDiT

Implementation of a single layer of the MMDiT, proposed by Esser et al. in Stable Diffusion 3, in Pytorch

Besides a straight reproduction, will also generalize to > 2 modalities, as I can envision an MMDiT for images, audio, and text.

Will also offer an improvised variant of self attention that adaptively selects the weights to use through learned gating. This idea came from adaptive convolutions applied by Kang et al. for GigaGAN.

Install

$ pip install mmdit

Usage

import torch
from mmdit import MMDiTBlock

# define mm dit block

block = MMDiTBlock(
    dim_joint_attn = 512,
    dim_cond = 256,
    dim_text = 768,
    dim_image = 512,
    qk_rmsnorm = True
)

# mock inputs

time_cond = torch.randn(2, 256)

text_tokens = torch.randn(2, 512, 768)
text_mask = torch.ones((2, 512)).bool()

image_tokens = torch.randn(2, 1024, 512)

# single block forward

text_tokens_next, image_tokens_next = block(
    time_cond = time_cond,
    text_tokens = text_tokens,
    text_mask = text_mask,
    image_tokens = image_tokens
)

A generalized version can be used as so

import torch
from mmdit.mmdit_generalized_pytorch import MMDiT

mmdit = MMDiT(
    depth = 2, 
    dim_modalities = (768, 512, 384),
    dim_joint_attn = 512,
    dim_cond = 256,
    qk_rmsnorm = True
)

# mock inputs

time_cond = torch.randn(2, 256)

text_tokens = torch.randn(2, 512, 768)
text_mask = torch.ones((2, 512)).bool()

video_tokens = torch.randn(2, 1024, 512)

audio_tokens = torch.randn(2, 256, 384)

# forward

text_tokens, video_tokens, audio_tokens = mmdit(
    modality_tokens = (text_tokens, video_tokens, audio_tokens),
    modality_masks = (text_mask, None, None),
    time_cond = time_cond,
)

Citations

@article{Esser2024ScalingRF,
    title   = {Scaling Rectified Flow Transformers for High-Resolution Image Synthesis},
    author  = {Patrick Esser and Sumith Kulal and A. Blattmann and Rahim Entezari and Jonas Muller and Harry Saini and Yam Levi and Dominik Lorenz and Axel Sauer and Frederic Boesel and Dustin Podell and Tim Dockhorn and Zion English and Kyle Lacey and Alex Goodwin and Yannik Marek and Robin Rombach},
    journal = {ArXiv},
    year    = {2024},
    volume  = {abs/2403.03206},
    url     = {https://api.semanticscholar.org/CorpusID:268247980}
}
@inproceedings{Darcet2023VisionTN,
    title   = {Vision Transformers Need Registers},
    author  = {Timoth'ee Darcet and Maxime Oquab and Julien Mairal and Piotr Bojanowski},
    year    = {2023},
    url     = {https://api.semanticscholar.org/CorpusID:263134283}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mmdit-0.1.4.tar.gz (148.4 kB view details)

Uploaded Source

Built Distribution

mmdit-0.1.4-py3-none-any.whl (9.9 kB view details)

Uploaded Python 3

File details

Details for the file mmdit-0.1.4.tar.gz.

File metadata

  • Download URL: mmdit-0.1.4.tar.gz
  • Upload date:
  • Size: 148.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.19

File hashes

Hashes for mmdit-0.1.4.tar.gz
Algorithm Hash digest
SHA256 3b5359a36ff7dd0ab2c6428bd5eeeba037e075f779d2b498328f9e6e53024992
MD5 c0863ab57525e4480c641da0049b63db
BLAKE2b-256 265230e3b0d5584a800d0e8eb6de847dc6db538072640743427977330d2119f1

See more details on using hashes here.

File details

Details for the file mmdit-0.1.4-py3-none-any.whl.

File metadata

  • Download URL: mmdit-0.1.4-py3-none-any.whl
  • Upload date:
  • Size: 9.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.19

File hashes

Hashes for mmdit-0.1.4-py3-none-any.whl
Algorithm Hash digest
SHA256 2ff932c454bedec45303cd8b178c20e322aa5d6390999d8b3cf516c42c30e14e
MD5 311cff27d1cd45eed651d615d882a0db
BLAKE2b-256 dcb7a4a4910033c7f48713a3694657b0f00b9330bdbe746bf5e9eaa26165b64f

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page