Skip to main content

M2PT - Pytorch

Project description

Multi-Modality

Multi-Modal Pathway Transformer

Diagram

Implementation of M2PT in PyTorch from the paper: "Multimodal Pathway: Improve Transformers with Irrelevant Data from Other Modalities". PAPER LINK. This is really really cool because just by merging the projections of different multi-modal models together you can increase the performance of your base model. This is a small but effective technique that can be implemented in any model with a minor plug in.

Install

pip3 install -U m2pt

Usage

M2PT

A fully ready to train implementation of the M2PT model that can be merged with the linears from any multi-modal models, just plug it in! It takes in tokenized texts which are integers then embeds them and then passes -> them into the transformer blocks and then at the end projects them and applies a softmax

import torch
from torch import nn
from m2pt.main import M2PT

# Create an instance of the M2PT model class with the specified parameters
model = M2PT(
    dim=512,  # Dimension of the input and output tensors
    num_tokens=10000,
    depth=6,
    dim_head=64,  # Dimension of each attention head
    heads=8,  # Number of attention heads
    dropout=0.1,  # Dropout rate
    ff_mult=4,  # Multiplier for the dimension of the feed-forward network
    original_linear=nn.Linear(512, 512),  # Linear layer for the original input tensor
    auxiliar_linear=nn.Linear(512, 512),  # Linear layer for the auxiliary input tensor
    ffn_original_linear=nn.Linear,  # Linear layer for the original input tensor in the feed-forward network
    ffn_auxiliar_linear=nn.Linear,  # Linear layer for the auxiliary input tensor in the feed-forward network
    ffn_original_last_linear=nn.Linear,  # Last linear layer for the original input tensor in the feed-forward network
    ffn_aux_last_linear=nn.Linear,  # Last linear layer for the auxiliary input tensor in the feed-forward network
)

# Create a 3D tensor with shape B x S x D
x = torch.randint(0, 10000, (1, 512))

# Pass the input tensor through the model
out = model(x)

# Print the shape of the output tensor
print(out.shape)

MPTransformerBlock

  • Implementation of Figure 2 and the Multimodal Pathway Transformer with cross modal FFN, plug in and play your FFN

  • Re-Usable and Modular.

  • Combines linear projections from multiple models

import torch
from torch import nn
from m2pt import MPTransformerBlock

# Create an instance of the MPTransformerBlock class with the specified parameters
model = MPTransformerBlock(
    dim=512,  # Dimension of the input and output tensors
    dim_head=64,  # Dimension of each attention head
    heads=8,  # Number of attention heads
    dropout=0.1,  # Dropout rate
    ff_mult=4,  # Multiplier for the dimension of the feed-forward network
    original_linear=nn.Linear(512, 512),  # Linear layer for the original input tensor
    auxiliar_linear=nn.Linear(512, 512),  # Linear layer for the auxiliary input tensor
    ffn_original_linear=nn.Linear,  # Linear layer for the original input tensor in the feed-forward network
    ffn_auxiliar_linear=nn.Linear,  # Linear layer for the auxiliary input tensor in the feed-forward network
    ffn_original_last_linear=nn.Linear,  # Last linear layer for the original input tensor in the feed-forward network
    ffn_aux_last_linear=nn.Linear,  # Last linear layer for the auxiliary input tensor in the feed-forward network
)

# Create a 3D tensor with shape B x S x D
x = torch.randn(1, 512, 512)

# Pass the input tensor through the model
out = model(x)

# Print the shape of the output tensor
print(out.shape)

CrossModalReparameterization

  • Implementation of the Cross Modal Reparameterization from the paper in Figure 2 and section 3.2

  • It combines the linear methods of different multi-modal models and kinda merges them through addition and a constant value lambda or Cross Modal Scale

  • Modular & Re-usable: Simply plug in your linears from any models!

import torch

import torch.nn as nn

from transformers import BertModel, BertConfig, ViTModel, ViTConfig

from m2pt import CrossModalReparameterization

# Define a simple Transformer model for text
class TextTransformerModel(nn.Module):
    def __init__(self, bert_model_name='bert-base-uncased'):
        super(TextTransformerModel, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)

        # Assume we're reparameterizing the first linear layer of the classifier
        self.classifier = nn.Linear(self.bert.config.hidden_size, 2)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        logits = self.classifier(pooled_output)
        return logits

# Define a simple Transformer model for images (using ViT for example)
class ImageTransformerModel(nn.Module):
    def __init__(self, vit_model_name='google/vit-base-patch16-224'):
        super(ImageTransformerModel, self).__init__()
        self.vit = ViTModel.from_pretrained(vit_model_name)

        # Assume we're using the first linear layer of the classifier as the auxiliary layer
        self.classifier = nn.Linear(self.vit.config.hidden_size, 2)

    def forward(self, pixel_values):
        outputs = self.vit(pixel_values=pixel_values)
        pooled_output = outputs.pooler_output
        logits = self.classifier(pooled_output)
        return logits

# Example usage
# Initialize both models
text_model = TextTransformerModel()
image_model = ImageTransformerModel()

# Assume we want to reparameterize the classifier layer of the text model
# using the classifier layer of the image model
cross_modal_layer = CrossModalReparameterization(text_model.classifier, image_model.classifier)

# Replace the classifier in the text model with the cross-modal layer
text_model.classifier = cross_modal_layer

# Example input (batch_size, sequence_length)
input_ids = torch.randint(0, 1000, (8, 512))
attention_mask = torch.ones(8, 512)

# Forward pass through the reparameterized model
logits = text_model(input_ids, attention_mask)
print(logits)

# Train the text model as usual...

# After training, merge the parameters for inference
text_model.classifier.merge_parameters()

Citation

@misc{zhang2024multimodal,
    title={Multimodal Pathway: Improve Transformers with Irrelevant Data from Other Modalities}, 
    author={Yiyuan Zhang and Xiaohan Ding and Kaixiong Gong and Yixiao Ge and Ying Shan and Xiangyu Yue},
    year={2024},
    eprint={2401.14405},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

m2pt-0.0.6.tar.gz (8.2 kB view details)

Uploaded Source

Built Distribution

m2pt-0.0.6-py3-none-any.whl (7.3 kB view details)

Uploaded Python 3

File details

Details for the file m2pt-0.0.6.tar.gz.

File metadata

  • Download URL: m2pt-0.0.6.tar.gz
  • Upload date:
  • Size: 8.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.3.2 CPython/3.11.0 Darwin/22.4.0

File hashes

Hashes for m2pt-0.0.6.tar.gz
Algorithm Hash digest
SHA256 47bd85f290d4272267a9b222ece4074d30a076d6ceaabf08679d5bbb1ee623a3
MD5 27eb24f5270cb866989145017040205d
BLAKE2b-256 482f93a30aeee9330c4778d7e43142ef548b33d7e71d67e94f5d1fdda4c1d755

See more details on using hashes here.

File details

Details for the file m2pt-0.0.6-py3-none-any.whl.

File metadata

  • Download URL: m2pt-0.0.6-py3-none-any.whl
  • Upload date:
  • Size: 7.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.3.2 CPython/3.11.0 Darwin/22.4.0

File hashes

Hashes for m2pt-0.0.6-py3-none-any.whl
Algorithm Hash digest
SHA256 4ad0135d2dab80298e4d3737e00b626661cffc6765ca3722646072e416d6d780
MD5 6e9c3e617eddecf51aa55645984192aa
BLAKE2b-256 8d06548732b4c40ca6c17c1205dcdfcbe879977d14bd81dcd7dd7a4e073e9a9d

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page