MM1 - Pytorch
Project description
MM1
PyTorch Implementation of the paper "MM1: Methods, Analysis & Insights from Multimodal LLM Pre-training".
img -> encoder -> connector -> llm -> tokens
install
pip3 install mm1-torch
usage
import torch
from mm1_torch.main import MM1
# Tensors
x = torch.randint(0, 100, (1, 512))
img = torch.randn(1, 3, 224, 224)
# Create a model
model = MM1(
dim=512,
depth=12,
heads=8,
dim_head=64,
dropout=0.1,
num_experts=4,
num_experts_per_tok=2,
)
# Forward
out = model(x, img)
print(out.shape) # torch.Size([2, 3, 512])
CAbstractor
import torch
from mm1_torch.main import CAbstractor
# Tensors
x = torch.randn(1, 3, 224, 224)
# Create a model
model = CAbstractor(
dim=512,
depth=12,
heads=8,
)
# Forward
out = model(x)
print(out.shape) # torch.Size([2, 3, 512])
License
MIT
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
mm1_torch-0.0.3.tar.gz
(7.1 kB
view hashes)
Built Distribution
Close
Hashes for mm1_torch-0.0.3-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 372d936fbf7ae3d391688b49a6393ad95a57aea6b64eff23c0aa4902f113f1a1 |
|
MD5 | 3676736b9d6417dc83a28fb8d529528e |
|
BLAKE2b-256 | 04d1a491533df83ef34dfa322f46da13a8223b511753eb42a860f7c9aab3d166 |