Skip to main content

MMCA - Pytorch

Project description

Multi-Modality

Multi-Modal Causal Attention

The open source community's implementation of the all-new Multi-Modal Causal Attention from "DeepSpeed-VisualChat: Multi-Round Multi-Image Interleave Chat via Multi-Modal Causal Attention"

Paper Link

Appreciation

  • Lucidrains
  • Agorians

Install

pip install mmca

Usage

import torch 
from mmca.main import MultiModalCausalAttention


attn = MultiModalCausalAttention(dim=512, heads=8)

x = torch.randn(1, 10, 512)
y = torch.randn(1, 20, 512)

x, y = attn(x, y)

print(x)
print(y)

Architecture

Algorithmic pseudocode

Input: Visual tokens V, Textual tokens T
Output: Updated Textual tokens T'

1: procedure MMCA(V, T)
2:     for each visual token v in V do
3:         v' = self_attention(v)  // Visual tokens only attend to themselves
4:     end for
5:     for each textual token t in T do
6:         t' = attention(t, T_previous) + attention(t, V)  // Textual tokens attend to all their previous tokens AND image tokens
7:     end for
8:     return T'
9: end procedure

Multi-Modal Causal Attention: A Study

MMCA is a novel attention mechanism designed to handle multi-modal data, i.e., data that comes from different sources or formats, such as text and images. It is an extension of the causal attention mechanism, which is commonly used in transformer models for tasks like language modeling.

Causal Attention


Before diving into MMCA, let's first understand the concept of causal attention. In the context of transformers, attention is a measure of how much a model should focus on different parts of the input when producing a particular part of the output.

Causal attention, also known as autoregressive or self-attention, is a type of attention where a token can only attend to previous tokens in the sequence. This is in contrast to other types of attention where a token can attend to all other tokens in the sequence.

The causal attention mechanism can be visualized as follows:

Token1 -> |------|
Token2 -> |------|------|
Token3 -> |------|------|------|
Token4 -> |------|------|------|------|

Each token can attend to itself and all the tokens before it, but not the ones after it.


Multi-Modal Causal Attention

In a multi-modal setting, we often deal with different types of data simultaneously. For instance, in an image captioning task, the model has to process both image features and textual data. This is where MMCA comes into play.

MMCA extends the concept of causal attention to handle multi-modal data. The key idea behind MMCA is as follows:

  1. For visual tokens, they only attend to themselves, as visual tokens are encoded by the visual encoder.
  2. For textual tokens, they attend to all their previous tokens. However, they have two separate attention weight matrices for their previous textual tokens and image tokens.

This can be visualized as follows:

Visual Tokens:
V1 -> |------|
V2 -> |------|
V3 -> |------|

Textual Tokens:
T1 -> |------|------|------|------|
T2 -> |------|------|------|------|------|
T3 -> |------|------|------|------|------|------|

Here, V1V2, and V3 are visual tokens, and T1T2, and T3 are textual tokens. Each visual token only attends to itself, while each textual token attends to all previous textual and visual tokens.


Mathematical Formulation

Let's now delve into the mathematical formulation of MMCA. The attention mechanism in transformers is typically computed using the dot product of query Q and key K matrices, followed by a softmax operation. In MMCA, we have two separate attention weight matrices for textual and visual tokens.

Let Q_T and K_T be the query and key matrices for textual tokens, and Q_V and K_V be the query and key matrices for visual tokens. The attention weights for textual tokens attending to previous textual tokens (A_TT) and visual tokens (A_TV) can be computed as follows:

A_TT = softmax(Q_T * K_T^T)
A_TV = softmax(Q_T * K_V^T)

The updated textual token representations can then be computed by applying these attention weights to the value V matrices:

T' = A_TT * V_T + A_TV * V_V

Here, V_T and V_V are the value matrices for textual and visual tokens, respectively.

Conclusion

Multi-Modal Causal Attention is a powerful attention mechanism that extends the concept of causal attention to handle multi-modal data. It allows a model to process different types of data simultaneously and in a more efficient manner. By having separate attention weight matrices for different types of tokens, MMCA allows the model to focus on the most relevant parts of the input for each type of token, leading to improved performance on multi-modal tasks.


Todo

  • implement flash attention from zeta as the main attn

License

MIT


Citations

@misc{2309.14327,
Author = {Zhewei Yao and Xiaoxia Wu and Conglong Li and Minjia Zhang and Heyang Qi and Olatunji Ruwase and Ammar Ahmad Awan and Samyam Rajbhandari and Yuxiong He},
Title = {DeepSpeed-VisualChat: Multi-Round Multi-Image Interleave Chat via Multi-Modal Causal Attention},
Year = {2023},
Eprint = {arXiv:2309.14327},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mmca-0.0.4.tar.gz (5.6 kB view details)

Uploaded Source

Built Distribution

mmca-0.0.4-py3-none-any.whl (5.1 kB view details)

Uploaded Python 3

File details

Details for the file mmca-0.0.4.tar.gz.

File metadata

  • Download URL: mmca-0.0.4.tar.gz
  • Upload date:
  • Size: 5.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.3.2 CPython/3.11.0 Darwin/22.4.0

File hashes

Hashes for mmca-0.0.4.tar.gz
Algorithm Hash digest
SHA256 0b21d1d1b5d81aa781fc86df973dc2c2aca2ef6c15d633be547017d64a60426b
MD5 1490c6376010965f5e7bcd4e8f006ac3
BLAKE2b-256 857de590e3b68976c27102da784e9d5dc70c1ff6d28e7ba1f559ddd8d2b96984

See more details on using hashes here.

File details

Details for the file mmca-0.0.4-py3-none-any.whl.

File metadata

  • Download URL: mmca-0.0.4-py3-none-any.whl
  • Upload date:
  • Size: 5.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.3.2 CPython/3.11.0 Darwin/22.4.0

File hashes

Hashes for mmca-0.0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 962bc0681b506e0176a114d481a01b4b49c9ac475e972d1024a72cc6877d27f4
MD5 57d31cac8e4ea255324de450f90ed77b
BLAKE2b-256 d11c086d27b1b494f6f410f738ed245e57e181b1ddb920ad78cbcfcf1cbb9fda

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page