M2PT - Pytorch
Project description
Multi-Modal Pathway Transformer
Implementation of M2PT in PyTorch from the paper: "Multimodal Pathway: Improve Transformers with Irrelevant Data from Other Modalities". PAPER LINK. This is really really cool because just by merging the projections of different multi-modal models together you can increase the performance of your base model. This is a small but effective technique that can be implemented in any model with a minor plug in.
Install
pip3 install -U m2pt
Usage
M2PT
A fully ready to train implementation of the M2PT model that can be merged with the linears from any multi-modal models, just plug it in! It takes in tokenized texts which are integers then embeds them and then passes -> them into the transformer blocks and then at the end projects them and applies a softmax
import torch
from torch import nn
from m2pt.main import M2PT
# Create an instance of the M2PT model class with the specified parameters
model = M2PT(
dim=512, # Dimension of the input and output tensors
num_tokens=10000,
depth=6,
dim_head=64, # Dimension of each attention head
heads=8, # Number of attention heads
dropout=0.1, # Dropout rate
ff_mult=4, # Multiplier for the dimension of the feed-forward network
original_linear=nn.Linear(512, 512), # Linear layer for the original input tensor
auxiliar_linear=nn.Linear(512, 512), # Linear layer for the auxiliary input tensor
ffn_original_linear=nn.Linear, # Linear layer for the original input tensor in the feed-forward network
ffn_auxiliar_linear=nn.Linear, # Linear layer for the auxiliary input tensor in the feed-forward network
ffn_original_last_linear=nn.Linear, # Last linear layer for the original input tensor in the feed-forward network
ffn_aux_last_linear=nn.Linear, # Last linear layer for the auxiliary input tensor in the feed-forward network
)
# Create a 3D tensor with shape B x S x D
x = torch.randint(0, 10000, (1, 512))
# Pass the input tensor through the model
out = model(x)
# Print the shape of the output tensor
print(out.shape)
MPTransformerBlock
-
Implementation of Figure 2 and the Multimodal Pathway Transformer with cross modal FFN, plug in and play your FFN
-
Re-Usable and Modular.
-
Combines linear projections from multiple models
import torch
from torch import nn
from m2pt import MPTransformerBlock
# Create an instance of the MPTransformerBlock class with the specified parameters
model = MPTransformerBlock(
dim=512, # Dimension of the input and output tensors
dim_head=64, # Dimension of each attention head
heads=8, # Number of attention heads
dropout=0.1, # Dropout rate
ff_mult=4, # Multiplier for the dimension of the feed-forward network
original_linear=nn.Linear(512, 512), # Linear layer for the original input tensor
auxiliar_linear=nn.Linear(512, 512), # Linear layer for the auxiliary input tensor
ffn_original_linear=nn.Linear, # Linear layer for the original input tensor in the feed-forward network
ffn_auxiliar_linear=nn.Linear, # Linear layer for the auxiliary input tensor in the feed-forward network
ffn_original_last_linear=nn.Linear, # Last linear layer for the original input tensor in the feed-forward network
ffn_aux_last_linear=nn.Linear, # Last linear layer for the auxiliary input tensor in the feed-forward network
)
# Create a 3D tensor with shape B x S x D
x = torch.randn(1, 512, 512)
# Pass the input tensor through the model
out = model(x)
# Print the shape of the output tensor
print(out.shape)
CrossModalReparameterization
-
Implementation of the Cross Modal Reparameterization from the paper in Figure 2 and section 3.2
-
It combines the linear methods of different multi-modal models and kinda merges them through addition and a constant value lambda or Cross Modal Scale
-
Modular & Re-usable: Simply plug in your linears from any models!
import torch
import torch.nn as nn
from transformers import BertModel, BertConfig, ViTModel, ViTConfig
from m2pt import CrossModalReparameterization
# Define a simple Transformer model for text
class TextTransformerModel(nn.Module):
def __init__(self, bert_model_name='bert-base-uncased'):
super(TextTransformerModel, self).__init__()
self.bert = BertModel.from_pretrained(bert_model_name)
# Assume we're reparameterizing the first linear layer of the classifier
self.classifier = nn.Linear(self.bert.config.hidden_size, 2)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.pooler_output
logits = self.classifier(pooled_output)
return logits
# Define a simple Transformer model for images (using ViT for example)
class ImageTransformerModel(nn.Module):
def __init__(self, vit_model_name='google/vit-base-patch16-224'):
super(ImageTransformerModel, self).__init__()
self.vit = ViTModel.from_pretrained(vit_model_name)
# Assume we're using the first linear layer of the classifier as the auxiliary layer
self.classifier = nn.Linear(self.vit.config.hidden_size, 2)
def forward(self, pixel_values):
outputs = self.vit(pixel_values=pixel_values)
pooled_output = outputs.pooler_output
logits = self.classifier(pooled_output)
return logits
# Example usage
# Initialize both models
text_model = TextTransformerModel()
image_model = ImageTransformerModel()
# Assume we want to reparameterize the classifier layer of the text model
# using the classifier layer of the image model
cross_modal_layer = CrossModalReparameterization(text_model.classifier, image_model.classifier)
# Replace the classifier in the text model with the cross-modal layer
text_model.classifier = cross_modal_layer
# Example input (batch_size, sequence_length)
input_ids = torch.randint(0, 1000, (8, 512))
attention_mask = torch.ones(8, 512)
# Forward pass through the reparameterized model
logits = text_model(input_ids, attention_mask)
print(logits)
# Train the text model as usual...
# After training, merge the parameters for inference
text_model.classifier.merge_parameters()
Citation
@misc{zhang2024multimodal,
title={Multimodal Pathway: Improve Transformers with Irrelevant Data from Other Modalities},
author={Yiyuan Zhang and Xiaohan Ding and Kaixiong Gong and Yixiao Ge and Ying Shan and Xiangyu Yue},
year={2024},
eprint={2401.14405},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
License
MIT
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file m2pt-0.0.6.tar.gz
.
File metadata
- Download URL: m2pt-0.0.6.tar.gz
- Upload date:
- Size: 8.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.3.2 CPython/3.11.0 Darwin/22.4.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 47bd85f290d4272267a9b222ece4074d30a076d6ceaabf08679d5bbb1ee623a3 |
|
MD5 | 27eb24f5270cb866989145017040205d |
|
BLAKE2b-256 | 482f93a30aeee9330c4778d7e43142ef548b33d7e71d67e94f5d1fdda4c1d755 |
File details
Details for the file m2pt-0.0.6-py3-none-any.whl
.
File metadata
- Download URL: m2pt-0.0.6-py3-none-any.whl
- Upload date:
- Size: 7.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.3.2 CPython/3.11.0 Darwin/22.4.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 4ad0135d2dab80298e4d3737e00b626661cffc6765ca3722646072e416d6d780 |
|
MD5 | 6e9c3e617eddecf51aa55645984192aa |
|
BLAKE2b-256 | 8d06548732b4c40ca6c17c1205dcdfcbe879977d14bd81dcd7dd7a4e073e9a9d |