Skip to main content

Perceiver IO

Project description

Perceiver IO

This project is a PyTorch implementation of

and supports training of Perceiver and Perceiver IO models with Pytorch Lightning at any scale.

Installation

From sources

conda env create -f environment.yml
conda activate perceiver-io
poetry install

Via pip

pip install perceiver-io

When installing via pip make sure you have a CUDA toolkit installed as well (see also environment.yml).

Architecture

The following sections describe how the conceptual architectures of Perceiver IO and Perceiver can be mapped to the implementation provided by this project.

Perceiver IO

architecture-perceiver-io

Names of components shown in the implementation architecture are class names in the PyTorch model API (see also model.py). Task-specific input and output adapters are subclasses of InputAdapter and OuptutAdapter, respectively (see also adapter.py). Array dimensions (M, C), (N, D), (O, F) and (O, E) have the following names in code and/or on the command line:

Array dimension Configuration parameter name
M Input-specific name (e.g. max_seq_len for text input, ...)
C num_input_channels (property of InputAdapter)
N num_latents
D num_latent_channels
O Output-specific name (e.g. num_output_queries for classification output, ...)
E Output-specific name (e.g. num_classes for classification output, ...)
F num_output_query_channels (property of OutputAdapter)

The number of layers in a SelfAttentionBlock can be specified with num_self_attention_layers_per_block and the number of blocks with num_self_attention_blocks (L in the conceptual architecture). Self-attention blocks share their weights.

Perceiver

Perceiver IO does not use repeated encoder cross-attention as described the Perceiver IO paper:

We omit the repeated encoder cross-attends used in Perceiver as we found these to lead to relatively small performance improvements but to significantly slow down training ...

This may be the case for the very large datasets used in the Perceiver IO paper but I found that repeated encoder cross-attention actually gives much better training results for smaller datasets. Therefore, the implementation provided by this project supports repeated encoder cross-attention.

architecture-perceiver

The number of repeated cross-attentions can be specified with num_cross_attention_layers (P) which must be less than or equal num_self_attention_blocks (L). Cross-attention layers 2 - P and self-attention blocks 2 - L always share their weights. Sharing the weights with the first cross-attention layer can be controlled with first_cross_attention_layer_shared, sharing the weights with the first self-attention block can be controlled with first_self_attention_block_shared. The default values of these hyperparameters are consistent with the Perceiver IO architecture (1 cross-attention layer, L self-attention blocks with weight sharing).

Model API

PyTorch model API

The PyTorch model API is based on generic encoder and decoder classes (PerceiverEncoder and PerceiverDecoder) and task-specific input and output adapter classes. These are defined in model.py and adapter.py, respectively. The following snippet shows how they can be used to create an ImageNet classifier as specified in Appendix A of the paper (Perceiver IO, config A, with 2D Fourier Features, 48.4M parameters):

from perceiver.model import (
    PerceiverIO,
    PerceiverEncoder,
    PerceiverDecoder,
    ImageInputAdapter,
    ClassificationOutputAdapter,
)

# Fourier-encodes pixel positions and flatten along spatial dimensions
input_adapter = ImageInputAdapter(
    image_shape=(224, 224, 3),  # M = 224 * 224
    num_frequency_bands=64,
)

# Projects generic Perceiver decoder output to specified number of classes
output_adapter = ClassificationOutputAdapter(
    num_classes=1000,
    num_output_query_channels=1024,  # F
)  

# Generic Perceiver encoder
encoder = PerceiverEncoder(
    input_adapter=input_adapter,
    num_latents=512,  # N
    num_latent_channels=1024,  # D
    num_cross_attention_qk_channels=input_adapter.num_input_channels,  # C
    num_cross_attention_heads=1,
    num_self_attention_heads=8,
    num_self_attention_layers_per_block=6,
    num_self_attention_blocks=8,
    dropout=0.0,
)

# Generic Perceiver decoder
decoder = PerceiverDecoder(
    output_adapter=output_adapter,
    num_latent_channels=1024,  # D
    num_cross_attention_heads=1,
    dropout=0.0,
)

# Perceiver IO image classifier
model = PerceiverIO(encoder, decoder)

PyTorch Lightning model API

Models created with the PyTorch model API are wrapped in task-specific LightningModules (e.g. LitImageClassifier) so that they can be trained with the PyTorch Lightning Trainer. They are defined in lightning.py. Part of this API are also task-specific configuration classes defined in config.py.

A task-specific encoder configuration class (e.g. ImageEncoderConfig) covers the configuration of the generic encoder and its task-specific input adapter. A task-specific decoder configuration object (e.g. ClassificationDecoderConfig) covers the configuration of the generic decoder and its task-specific output adapter.

The same model as in the previous section, wrapped in a LitImageClassifier, can be created with:

from perceiver.model.config import ImageEncoderConfig, ClassificationDecoderConfig
from perceiver.model.lightning import LitImageClassifier

encoder_cfg = ImageEncoderConfig(
    image_shape=(224, 224, 3),
    num_frequency_bands=64,
    num_cross_attention_heads=1,
    num_self_attention_heads=8,
    num_self_attention_layers_per_block=6,
    num_self_attention_blocks=8,
    dropout=0.0,
)
decoder_cfg = ClassificationDecoderConfig(
    num_classes=1000,
    num_output_query_channels=1024,
    num_cross_attention_heads=1,
    dropout=0.0,
)

lit_model = LitImageClassifier(encoder_cfg, decoder_cfg, num_latents=512, num_latent_channels=1024)

# Wrapped PyTorch model
model = lit_model.model

PyTorch Lightning model CLI

The PyTorch Lightning model API is primarily designed for command-line binding via the Lightning CLI. For example, when implementing a command line interface for LitImageClassifier with LightningCLI in a file named classifier.py

# File classifier.py

from pytorch_lightning.utilities.cli import LightningCLI
from perceiver.model.lightning import LitImageClassifier

if __name__ == "__main__":
    LightningCLI(model_class=LitImageClassifier)

the same classifier as before can be created with the following command line options:

python classifier.py fit \
  --model.num_latents=512 \
  --model.num_latent_channels=1024 \
  --model.encoder.image_shape=[224,224,3] \
  --model.encoder.num_frequency_bands=64 \
  --model.encoder.num_cross_attention_heads=1 \
  --model.encoder.num_self_attention_heads=8 \
  --model.encoder.num_self_attention_layers_per_block=6 \
  --model.encoder.num_self_attention_blocks=8 \
  --model.encoder.dropout=0.0 \
  --model.decoder.num_classes=1000 \
  --model.decoder.num_output_query_channels=1024 \
  --model.decoder.num_cross_attention_heads=1 \
  --model.decoder.dropout=0.0 \
  ...

Task-specific training scripts can set default values so that command lines are usually much shorter (see img_clf.py for an example of a training script and section Image classification for a usage example).

Training examples

In the following subsections, Perceiver IO models are trained on a rather small scale (and on small datasets). In particular, hyperparameters are set such that parallel training on two NVIDIA GTX 1080 GPUs (8 GB memory each) works quite well. I didn't really tune model architectures and other hyperparameters yet, so you'll probably get better results with a bit of experimentation. Support for more datasets and tasks as well as instructions for training on larger scale will come soon.

Masked language modeling

Pretrain a Perceiver IO model on masked language modeling (MLM) with text from the IMDB training set. The pretrained encoder is then used for training a sentiment classification model. Predictions of masked tokens are logged to Tensorboard.

python -m perceiver.scripts.mlm fit \
  --model.num_latents=64 \
  --model.num_latent_channels=64 \
  --model.encoder.num_input_channels=64 \
  --model.encoder.num_cross_attention_layers=3 \
  --model.encoder.num_self_attention_layers_per_block=6 \
  --model.encoder.num_self_attention_blocks=3 \
  --model.encoder.dropout=0.0 \
  --model.decoder.num_output_query_channels=64 \
  --model.decoder.dropout=0.0 \
  --data=ImdbDataModule \
  --data.max_seq_len=512 \
  --data.batch_size=64 \
  --optimizer.lr=3e-3 \
  --optimizer.weight_decay=0.0 \
  --lr_scheduler.pct_start=0.1 \
  --trainer.accelerator=gpu \
  --trainer.devices=-1 \
  --trainer.max_steps=50000 \
  --trainer.check_val_every_n_epoch=5 \
  --trainer.logger=TensorBoardLogger \
  --trainer.logger.save_dir=logs \
  --trainer.logger.name=mlm

For saving GPU memory and scaling model training, activation checkpointing can be enabled with --model.activation_checkpointing=true (disabled by default).

Sentiment classification

Train a classification decoder using a frozen encoder from masked language modeling. If you ran MLM yourself you'll need to modify the --model.mlm_ckpt argument accordingly, otherwise download checkpoints from here and extract them in the root directory of this project.

python -m perceiver.scripts.seq_clf fit \
  --model.mlm_ckpt='logs/mlm/version_0/checkpoints/epoch=254-val_loss=4.527.ckpt' \
  --model.num_latents=64 \
  --model.num_latent_channels=64 \
  --model.encoder.num_input_channels=64 \
  --model.encoder.num_cross_attention_layers=3 \
  --model.encoder.num_self_attention_layers_per_block=6 \
  --model.encoder.num_self_attention_blocks=3 \
  --model.encoder.dropout=0.0 \
  --model.encoder.freeze=true \
  --model.decoder.num_output_query_channels=64 \
  --model.decoder.dropout=0.0 \
  --data=ImdbDataModule \
  --data.max_seq_len=512 \
  --data.batch_size=128 \
  --optimizer=AdamW \
  --optimizer.lr=1e-3 \
  --optimizer.weight_decay=0.01 \
  --trainer.accelerator=gpu \
  --trainer.devices=-1 \
  --trainer.max_epochs=30 \
  --trainer.logger=TensorBoardLogger \
  --trainer.logger.save_dir=logs \
  --trainer.logger.name=seq_clf

Unfreeze the encoder and jointly fine-tune it together with the decoder that has been trained in the previous step. If you ran the previous step yourself you'll need to modify the --model.clf_ckpt argument accordingly, otherwise download checkpoints from here.

python -m perceiver.scripts.seq_clf fit \
  --model.clf_ckpt='logs/seq_clf/version_0/checkpoints/epoch=009-val_loss=0.343.ckpt' \
  --model.num_latents=64 \
  --model.num_latent_channels=64 \
  --model.encoder.num_input_channels=64 \
  --model.encoder.num_cross_attention_layers=3 \
  --model.encoder.num_self_attention_layers_per_block=6 \
  --model.encoder.num_self_attention_blocks=3 \
  --model.encoder.dropout=0.1 \
  --model.decoder.num_output_query_channels=64 \
  --model.decoder.dropout=0.1 \
  --data=ImdbDataModule \
  --data.max_seq_len=512 \
  --data.batch_size=128 \
  --optimizer=AdamW \
  --optimizer.lr=1e-4 \
  --optimizer.weight_decay=0.01 \
  --trainer.accelerator=gpu \
  --trainer.devices=-1 \
  --trainer.max_epochs=40 \
  --trainer.logger=TensorBoardLogger \
  --trainer.logger.save_dir=logs \
  --trainer.logger.name=seq_clf

Image classification

Classify MNIST images.

python -m perceiver.scripts.img_clf fit \
  --model.num_latents=32 \
  --model.num_latent_channels=128 \
  --model.encoder.num_self_attention_layers_per_block=3 \
  --model.encoder.num_self_attention_blocks=3 \
  --model.encoder.dropout=0.0 \
  --model.decoder.dropout=0.0 \
  --data=MnistDataModule \
  --data.batch_size=128 \
  --optimizer=AdamW \
  --optimizer.lr=1e-3 \
  --optimizer.weight_decay=0.01 \
  --trainer.accelerator=gpu \
  --trainer.devices=-1 \
  --trainer.max_epochs=20 \
  --trainer.logger=TensorBoardLogger \
  --trainer.logger.save_dir=logs \
  --trainer.logger.name=img_clf

Inference examples

Development environment

Update the project dependencies in the conda environment:

invoke install

Install the pre-commit hooks:

invoke precommit-install

Run code quality checks:

invoke cc

Run tests:

invoke test

The structure of this project is based on the Python Project Template.

Citations

@misc{jaegle2021perceiver,
    title   = {Perceiver: General Perception with Iterative Attention},
    author  = {Andrew Jaegle and Felix Gimeno and Andrew Brock and Andrew Zisserman and Oriol Vinyals and Joao Carreira},
    year    = {2021},
    eprint  = {2103.03206},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{jaegle2021perceiver,
    title   = {Perceiver IO: A General Architecture for Structured Inputs & Outputs},
    author  = {Andrew Jaegle and Sebastian Borgeaud and Jean-Baptiste Alayrac and Carl Doersch and Catalin Ionescu and David Ding and Skanda Koppula and Andrew Brock and Evan Shelhamer and Olivier Hénaff and Matthew M. Botvinick and Andrew Zisserman and Oriol Vinyals and João Carreira},
    year    = {2021},
    eprint  = {2107.14795},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

perceiver-io-0.2.1.tar.gz (482.5 kB view hashes)

Uploaded Source

Built Distribution

perceiver_io-0.2.1-py3-none-any.whl (28.3 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page