Transfusion in Pytorch
Project description
Transfusion - Pytorch (wip)
Pytorch implementation of Transfusion, "Predict the Next Token and Diffuse Images with One Multi-Modal Model", from MetaAI.
Once completed, will also extend this to flow matching, as well as audio, video, perhaps even policies.
Install
$ pip install transfusion-pytorch
Usage
One modality, say images
import torch
from transfusion_pytorch import Transfusion
model = Transfusion(
num_text_tokens = 256,
dim_latent = 192,
transformer = dict(
dim = 512,
depth = 8
)
)
text_ids = torch.randint(0, 256, (2, 1024))
modality_tokens = [[
torch.randn(6, 192),
torch.randn(4, 192)
], [
torch.randn(5, 192),
torch.randn(3, 192)
]]
modality_positions = [[
(2, 6),
(10, 4)
], [
(2, 5),
(10, 3)
]] # (offset, length)
loss, breakdown = model(
text_ids,
modality_tokens = modality_tokens,
modality_positions = modality_positions
)
loss.backward()
Multiple modalities
import torch
from transfusion_pytorch import Transfusion
model = Transfusion(
num_text_tokens = 256,
dim_latent = (384, 192),
transformer = dict(
dim = 512,
depth = 8
)
)
text_ids = torch.randint(0, 256, (2, 1024))
modality_tokens = [[
torch.randn(6, 384),
torch.randn(4, 192)
], [
torch.randn(5, 192),
torch.randn(3, 384)
]]
modality_positions = [[
(0, 2, 6),
(1, 10, 4)
], [
(1, 2, 5),
(0, 10, 3)
]] # (type, offset, length)
loss, breakdown = model(
text_ids,
modality_tokens = modality_tokens,
modality_positions = modality_positions
)
loss.backward()
Citations
@inproceedings{Zhou2024TransfusionPT,
title = {Transfusion: Predict the Next Token and Diffuse Images with One Multi-Modal Model},
author = {Chunting Zhou and Lili Yu and Arun Babu and Kushal Tirumala and Michihiro Yasunaga and Leonid Shamis and Jacob Kahn and Xuezhe Ma and Luke Zettlemoyer and Omer Levy},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:271909855}
}
@misc{Rubin2024,
author = {Ohad Rubin},
url = {https://medium.com/@ohadrubin/exploring-weight-decay-in-layer-normalization-challenges-and-a-reparameterization-solution-ad4d12c24950}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
transfusion_pytorch-0.0.12.tar.gz
(347.0 kB
view hashes)
Built Distribution
Close
Hashes for transfusion_pytorch-0.0.12.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3ba24cc90da5f4a7b97123618b06ce305d594a804e6ab5b21b38c65559f7261d |
|
MD5 | 6b4d4e5d96b260ef5a29e2e036e8ce27 |
|
BLAKE2b-256 | 50fcb55207f6c08a19b6a5e8fbdf69692a60bf880e14cf52b17a72fa50b491a7 |
Close
Hashes for transfusion_pytorch-0.0.12-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 691a74a617b1d393983a8bc091a5c4ee9688dd99958907eb896f45b4c5c3c232 |
|
MD5 | 9a8b6a02c9e634ca2c46f10daca6322b |
|
BLAKE2b-256 | efddd29fdfe57c3e2bff8d89ad2c7217ac758c441902b98e766d172374fbbfe0 |