Skip to main content

A fork of perceiver-pytorch that supports multiple modalities for the Perceiver architecture.

Project description

PyPI PyPI PyPI

Multi Modality Perceiver - Pytorch

Implementation of Perceiver, with support for multi-modality inputs. Fork of (lucidrains repo)[https://github.com/lucidrains/perceiver-pytorch] extended for multi-modality and support for text embedding splits chunking across layers. This repo also is closer to the Perceiver preprint because you can use GELU activation in feedforward, while Lucidrains' repo substitutes GEGLU instead. Set use_gelu to true in the MultiModalityPerceiver constructor.

MultiModalityPerceiver also provides means to customize pooling method. You can subclass perceiver_pytorch.multi_modality_perceiver.MultiModalityPerceiver and override the pool() method, or use the perceiver_pytorch.multi_modality_perceiver.MultiModalityPerceiverNoPooling implementation that returns the hidden representation without any pooling. This is useful if you need to train multitask models and want to experiment with, say, using the first 3 latent outputs to predict each a different task.

Install

To install the Perceiver implementation with multi-modality (also includes without multi-modality):

$ pip install perceiver-multi-modality-pytorch

Import with:

from perceiver_pytorch.modalities import modality_encoding
from perceiver_pytorch.multi_modality_perceiver import MultiModalityPerceiver, InputModality

See tests/test_multimodality_perceiver.py or

from perceiver_pytorch.modalities import InputModalityWithEmbedding
from perceiver_pytorch.multi_modality_with_text_perceiver import MultiModalityWithTextPerceiver

See tests/test_multimodality_with_text_perceiver.py

To install the Perceiver implementation, follow instructions at the (lucidrains repo)[https://github.com/lucidrains/perceiver-pytorch]:

Usage

import torch
from perceiver_pytorch import Perceiver

model = Perceiver(
    input_channels=3,  # number of channels for each token of the input
    input_axis=2,  # number of axis for input data (2 for images, 3 for video)
    num_freq_bands=6,  # number of freq bands, with original value (2 * K + 1)
    max_freq=10.,  # maximum frequency, hyperparameter depending on how fine the data is
    depth=6,  # depth of net
    num_latents=256,
    # number of latents, or induced set points, or centroids. different papers giving it different names
    latent_dim=512,  # latent dimension
    cross_heads=1,  # number of heads for cross attention. paper said 1
    latent_heads=8,  # number of heads for latent self attention, 8
    cross_dim_head=64,
    latent_dim_head=64,
    num_classes=1000,  # output number of classes
    attn_dropout=0.,
    ff_dropout=0.,
    weight_tie_layers=False  # whether to weight tie layers (optional, as indicated in the diagram)
)

img = torch.randn(1, 224, 224, 3)  # 1 imagenet image, pixelized

model(img)  # (1, 1000)

Multi-modality perceiver

An attractive feature of the perceiver architecture is that it can process multiple modalities of data in the same batch. This is not obvious from the perceiver forward signature shown above, but a relatively modest change can support processing video, images and audio with a single model, in one forward. This feature is demonstrated by the MultiModalityPerceiver, contributed by Fabien Campagne.

from perceiver_pytorch.multi_modality_perceiver import MultiModalityPerceiver, InputModality

image_inputs = torch.rand(size=(3, 260, 260, 3), requires_grad=True)
video_inputs = torch.rand(size=(3, 32, 260, 260, 3), requires_grad=True)
audio_inputs = torch.rand(size=(3, 44100, 1), requires_grad=True)

video_modality = InputModality(
    name='video',
    input_channels=3,  # number of channels for each token of the input
    input_axis=3,  # number of axes, 3 for video)
    num_freq_bands=6,  # number of freq bands, with original value (2 * K + 1)
    max_freq=4.,  # maximum frequency, hyperparameter depending on how fine the data is
)
image_modality = InputModality(
    name='image',
    input_channels=3,  # number of channels for each token of the input
    input_axis=2,  # number of axes, 2 for images
    num_freq_bands=6,  # number of freq bands, with original value (2 * K + 1)
    max_freq=4.,  # maximum frequency, hyperparameter depending on how fine the data is
)
audio_modality = InputModality(
    name='audio',
    input_channels=1,  # number of channels for mono audio
    input_axis=1,  # number of axes, 2 for images
    num_freq_bands=6,  # number of freq bands, with original value (2 * K + 1)
    max_freq=8.,  # maximum frequency, hyperparameter depending on how fine the data is
)
model = MultiModalityPerceiver(
    modalities=(video_modality, image_modality, audio_modality),
    depth=8,  # depth of net, combined with num_latent_blocks_per_layer to produce full Perceiver
    num_latents=12,
    # number of latents, or induced set points, or centroids. different papers giving it different names
    latent_dim=64,  # latent dimension
    cross_heads=1,  # number of heads for cross attention. paper said 1
    latent_heads=8,  # number of heads for latent self attention, 8
    cross_dim_head=64,
    latent_dim_head=64,
    num_classes=1000,  # output number of classes
    attn_dropout=0.,
    ff_dropout=0.,
    weight_tie_layers=True,
    num_latent_blocks_per_layer=6  # Note that this parameter is 1 in the original Lucidrain implementation
    # whether to weight tie layers (optional, as indicated in the diagram)
)
result = model({'image': image_inputs,
                'video': video_inputs,
                'audio': audio_inputs})

Text perceiver

While the Perceiver architecture described by [jaegle2021perceiver] could support text if text was embedded and each dimension of the embedding provided as a channel in the input, this introduces a mismatch between the text embedding dimension (typically large, 512/768 or more) and the number of channels used for video and images (typically 3 channels, one for red, green and blue), or audio (1 for mono or 2 for stereo channels). When training text embeddings from scratch, this creates an opportunity, because there should be no need for the perceiver to attend to the entire text embedding in each layer. If we split the text embedding into as many chunks as there are layers in a perceiver, we reduce how much we need to pad other modalities, and introduce a structure to the learned embeddings, were parts of the text embedding can specialize according to the needs of each layer. The perceiver implementation provided in this repo can be used to explore the question of whether splitting text embeddings across layers is beneficial (you would compare the performance of MultiModalityWithTextPerceiver with that of MultiModalityPerceiver).

Citations

@misc{jaegle2021perceiver,
    title   = {Perceiver: General Perception with Iterative Attention},
    author  = {Andrew Jaegle and Felix Gimeno and Andrew Brock and Andrew Zisserman and Oriol Vinyals and Joao Carreira},
    year    = {2021},
    eprint  = {2103.03206},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{campagne2021textperceiver,
    title   = {Adapting Perceiver for learning with text modalities},
    author  = {Fabien Campagne},
    year    = {2021},
    eprint  = {unpublished results},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

perceiver-multi-modality-pytorch-1.2.0.tar.gz (14.0 kB view details)

Uploaded Source

Built Distribution

File details

Details for the file perceiver-multi-modality-pytorch-1.2.0.tar.gz.

File metadata

File hashes

Hashes for perceiver-multi-modality-pytorch-1.2.0.tar.gz
Algorithm Hash digest
SHA256 9b091b7ec8909025cef88b7a445168cc7bd1bcd32d2f1a28d9bb3aaa9b8e0472
MD5 a364ce7f1255194920925252640be1cf
BLAKE2b-256 1eb6977b9d89b04474849d57dafe8d02d47668903ff144d2994b480fcab90be8

See more details on using hashes here.

File details

Details for the file perceiver_multi_modality_pytorch-1.2.0-py3-none-any.whl.

File metadata

File hashes

Hashes for perceiver_multi_modality_pytorch-1.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 ff4313c071e3efbc87023b8f301463680157c0b3582547abf2e6902fb5a3d37f
MD5 3efb9ee906fcb550632ef5b55e484055
BLAKE2b-256 d5cc24ce6a1cbff9d9524e8a62fb9cff0df0f96a757f948928133c5cccfe72a0

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page