MM1 - Pytorch
Project description
MM1
PyTorch Implementation of the paper "MM1: Methods, Analysis & Insights from Multimodal LLM Pre-training".
img -> encoder -> connector -> llm -> tokens
install
pip3 install mm1-torch
usage
import torch
from mm1_torch.main import MM1
# Tensors
x = torch.randint(0, 100, (1, 512))
img = torch.randn(1, 3, 224, 224)
# Create a model
model = MM1(
dim=512,
depth=12,
heads=8,
dim_head=64,
dropout=0.1,
num_experts=4,
num_experts_per_tok=2,
encoder_dim=512,
encoder_depth=12,
encoder_heads=8,
)
# Forward
out = model(x, img)
print(out.shape) # torch.Size([2, 3, 512])
print(out)
CAbstractor
import torch
from mm1_torch.main import CAbstractor
# Tensors
x = torch.randn(1, 100, 512)
# Create a model
model = CAbstractor(
dim=512,
depth=12,
heads=8,
)
# Forward
out = model(x)
print(out.shape)
License
MIT
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
mm1_torch-0.0.4.tar.gz
(7.8 kB
view hashes)
Built Distribution
Close
Hashes for mm1_torch-0.0.4-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | cb8b0096f6bc5f9ba90da5a58eca6cc6f91381730e484b052a97f7b930e3c69a |
|
MD5 | 5cacc5c379aab2f6b6ba4ab5df1445be |
|
BLAKE2b-256 | 2ca41345390e13cb323a4bede23cea08670365dd43e88e5bc119d0493669ac97 |