Skip to main content

M2PT - Pytorch

Project description

Multi-Modality

Multi-Modal Pathway Transformer

Diagram

Implementation of M2PT in PyTorch from the paper: "Multimodal Pathway: Improve Transformers with Irrelevant Data from Other Modalities". PAPER LINK

Install

pip3 install -U m2pt

Usage

MPTransformerBlock

  • Implementation of Figure 2 and the Multimodal Pathway Transformer

  • Re-Usable and Modular.

  • Combines linear projections from multiple models

import torch
from torch import nn
from m2pt import MPTransformerBlock

# Create an instance of the MPTransformerBlock class with the specified parameters
model = MPTransformerBlock(
    dim=512,  # Dimension of the input and output tensors
    dim_head=64,  # Dimension of each attention head
    heads=8,  # Number of attention heads
    dropout=0.1,  # Dropout rate
    ff_mult=4,  # Multiplier for the dimension of the feed-forward network
    original_linear=nn.Linear(512, 512),  # Linear layer for the original input tensor
    auxiliar_linear=nn.Linear(512, 512),  # Linear layer for the auxiliary input tensor
    ffn_original_linear=nn.Linear,  # Linear layer for the original input tensor in the feed-forward network
    ffn_auxiliar_linear=nn.Linear,  # Linear layer for the auxiliary input tensor in the feed-forward network
    ffn_original_last_linear=nn.Linear,  # Last linear layer for the original input tensor in the feed-forward network
    ffn_aux_last_linear=nn.Linear,  # Last linear layer for the auxiliary input tensor in the feed-forward network
)

# Create a 3D tensor with shape B x S x D
x = torch.randn(1, 512, 512)

# Pass the input tensor through the model
out = model(x)

# Print the shape of the output tensor
print(out.shape)

CrossModalReparameterization

  • Implementation of the Cross Modal Reparameterization from the paper in Figure 2 and section 3.2

  • It combines the linear methods of different multi-modal models and kinda merges them through addition and a constant value lambda or Cross Modal Scale

  • Modular & Re-usable: Simply plug in your linears from any models!

import torch

import torch.nn as nn

from transformers import BertModel, BertConfig, ViTModel, ViTConfig

from m2pt import CrossModalReparameterization

# Define a simple Transformer model for text
class TextTransformerModel(nn.Module):
    def __init__(self, bert_model_name='bert-base-uncased'):
        super(TextTransformerModel, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)

        # Assume we're reparameterizing the first linear layer of the classifier
        self.classifier = nn.Linear(self.bert.config.hidden_size, 2)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        logits = self.classifier(pooled_output)
        return logits

# Define a simple Transformer model for images (using ViT for example)
class ImageTransformerModel(nn.Module):
    def __init__(self, vit_model_name='google/vit-base-patch16-224'):
        super(ImageTransformerModel, self).__init__()
        self.vit = ViTModel.from_pretrained(vit_model_name)

        # Assume we're using the first linear layer of the classifier as the auxiliary layer
        self.classifier = nn.Linear(self.vit.config.hidden_size, 2)

    def forward(self, pixel_values):
        outputs = self.vit(pixel_values=pixel_values)
        pooled_output = outputs.pooler_output
        logits = self.classifier(pooled_output)
        return logits

# Example usage
# Initialize both models
text_model = TextTransformerModel()
image_model = ImageTransformerModel()

# Assume we want to reparameterize the classifier layer of the text model
# using the classifier layer of the image model
cross_modal_layer = CrossModalReparameterization(text_model.classifier, image_model.classifier)

# Replace the classifier in the text model with the cross-modal layer
text_model.classifier = cross_modal_layer

# Example input (batch_size, sequence_length)
input_ids = torch.randint(0, 1000, (8, 512))
attention_mask = torch.ones(8, 512)

# Forward pass through the reparameterized model
logits = text_model(input_ids, attention_mask)
print(logits)

# Train the text model as usual...

# After training, merge the parameters for inference
text_model.classifier.merge_parameters()

Citation

@misc{zhang2024multimodal,
    title={Multimodal Pathway: Improve Transformers with Irrelevant Data from Other Modalities}, 
    author={Yiyuan Zhang and Xiaohan Ding and Kaixiong Gong and Yixiao Ge and Ying Shan and Xiangyu Yue},
    year={2024},
    eprint={2401.14405},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

m2pt-0.0.5.tar.gz (7.0 kB view details)

Uploaded Source

Built Distribution

m2pt-0.0.5-py3-none-any.whl (6.6 kB view details)

Uploaded Python 3

File details

Details for the file m2pt-0.0.5.tar.gz.

File metadata

  • Download URL: m2pt-0.0.5.tar.gz
  • Upload date:
  • Size: 7.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.3.2 CPython/3.11.0 Darwin/22.4.0

File hashes

Hashes for m2pt-0.0.5.tar.gz
Algorithm Hash digest
SHA256 ac7b3fc2c39bb1b7f75d421107a895b2f9585c3b5c0e27317f9ea722c04a65b5
MD5 d431b8c60ad395108f37780459fa9ad5
BLAKE2b-256 5df0319121a47988b843a7ae5643252e3140d108961092920a285c6b4fc5d3ea

See more details on using hashes here.

File details

Details for the file m2pt-0.0.5-py3-none-any.whl.

File metadata

  • Download URL: m2pt-0.0.5-py3-none-any.whl
  • Upload date:
  • Size: 6.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.3.2 CPython/3.11.0 Darwin/22.4.0

File hashes

Hashes for m2pt-0.0.5-py3-none-any.whl
Algorithm Hash digest
SHA256 91730dc7cb49255e801ba6e8574f16653e87cce974e26df4aeac335a5fd95d5e
MD5 be37800430f07e4afe6c1dad635820af
BLAKE2b-256 7ce5968e53d4ce7104520018c1ecf9af89e81739912461ad8820c448bdb8e9be

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page