Skip to main content

Transfusion in Pytorch

Project description

Transfusion - Pytorch (wip)

Pytorch implementation of Transfusion, "Predict the Next Token and Diffuse Images with One Multi-Modal Model", from MetaAI.

Once completed, will also extend this to flow matching, as well as audio, video, perhaps even policies.

Install

$ pip install transfusion-pytorch

Usage

One modality, say images

import torch
from transfusion_pytorch import Transfusion

model = Transfusion(
    num_text_tokens = 256,
    dim_latent = 192,
    transformer = dict(
        dim = 512,
        depth = 8
    )
)

text_ids = torch.randint(0, 256, (2, 1024))

modality_tokens = [[
    torch.randn(6, 192),
    torch.randn(4, 192)
], [
    torch.randn(5, 192),
    torch.randn(3, 192)
]]

modality_positions = [[
    (2, 6),
    (10, 4)
], [
    (2, 5),
    (10, 3)
]] # (offset, length)

loss, breakdown = model(
    text_ids,
    modality_tokens = modality_tokens,
    modality_positions = modality_positions
)

loss.backward()

Multiple modalities

import torch
from transfusion_pytorch import Transfusion

model = Transfusion(
    num_text_tokens = 256,
    dim_latent = (384, 192),
    transformer = dict(
        dim = 512,
        depth = 8
    )
)

text_ids = torch.randint(0, 256, (2, 1024))

modality_tokens = [[
    torch.randn(6, 384),
    torch.randn(4, 192)
], [
    torch.randn(5, 192),
    torch.randn(3, 384)
]]

modality_positions = [[
    (0, 2, 6),
    (1, 10, 4)
], [
    (1, 2, 5),
    (0, 10, 3)
]] # (type, offset, length)

loss, breakdown = model(
    text_ids,
    modality_tokens = modality_tokens,
    modality_positions = modality_positions
)

loss.backward()

Citations

@inproceedings{Zhou2024TransfusionPT,
    title  = {Transfusion: Predict the Next Token and Diffuse Images with One Multi-Modal Model},
    author = {Chunting Zhou and Lili Yu and Arun Babu and Kushal Tirumala and Michihiro Yasunaga and Leonid Shamis and Jacob Kahn and Xuezhe Ma and Luke Zettlemoyer and Omer Levy},
    year   = {2024},
    url    = {https://api.semanticscholar.org/CorpusID:271909855}
}
@misc{Rubin2024,
    author  = {Ohad Rubin},
    url     = {https://medium.com/@ohadrubin/exploring-weight-decay-in-layer-normalization-challenges-and-a-reparameterization-solution-ad4d12c24950}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

transfusion_pytorch-0.0.5.tar.gz (346.2 kB view hashes)

Uploaded Source

Built Distribution

transfusion_pytorch-0.0.5-py3-none-any.whl (10.1 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page