Skip to main content

Perceiver IO

Project description

Perceiver IO

This repository is a PyTorch and PyTorch Lightning implementation of

The codebase is designed for easy extension to new tasks and datasets. If you are a researcher or practitioner working on new Perceiver IO models and use cases, you might find this repository useful. The integration with PyTorch Lightning supports model training at any scale. On the other hand, if you are mainly interested in using or fine-tuning models from the Perceiver IO paper you may want to take a look at the 🤗 Perceiver IO implementation.

Overview

The following figure maps Perceiver IO and Perceiver concepts to the core modules of the implementation (see Architecture for details).

architecture

Interfaces are defined on three levels:

  • PyTorch model API: defines generic PerceiverEncoder and PerceiverDecoder classes and task-specific InputAdapter and OutputAdapter subclasses from which PyTorch models can be constructed.
  • PyTorch Lightning model API: defines wrappers for PyTorch models to support training with the PyTorch Lightning Trainer.
  • PyTorch Lightning model CLI: binds the PyTorch Lightning model API to the command line via the Lightning CLI.

Interface usage examples are available for two models from the Perceiver IO paper:

Model Parameters
Language model (Perceiver IO Base, SentencePiece tokenization) 223M construction training
Image classifier (Perceiver IO config A, 2D Fourier Features) 48.4M construction training

Training of smaller models is shown in section Training examples, their usage in section Inference examples.

Installation

Via pip

pip install perceiver-io[image,text]

From sources

Conda + Poetry

conda env create -f environment.yml
conda activate perceiver-io
poetry install --all-extras

This requires a Poetry installation, version 1.2.0b2 or higher. If running poetry fails with a KeyringError, refer to the keyring documentation how to disable keyring usage.

Docker image

A perceiver-io Docker image can be built with:

docker build -t perceiver-io .

Training of Perceiver IO models with this image is described here.

Training examples

This section uses rather small Perceiver IO models so that the following training examples can be run on limited hardware resources. Training automatically scales to more than one GPU and was tested on 4 RTX 3080 GPUs. For GPUs with less memory you may need to reduce the --data.batch_size or turn on activation checkpointing for some of the examples.

Datasets used for model training are 🤗 Datasets wrapped into PyTorch Lightning data modules (see data package). Datasets are automatically downloaded, preprocessed and cached when their corresponding Lightning data module is loaded during training. Manual dataset preprocessing is described here.

An archive with training checkpoints can be downloaded here and should be extracted in project's root directory to be compatible with the example command lines below. It also contains Tensorboard logs and config files.

I didn't really tune hyperparameters, so you'll likely get better results with a bit of experimentation (see also training tips).

Masked language modeling

This section trains a very small language model (2.9M parameters) on masked language modeling with whole word masking. It is first pretrained on WikiText-103 and then adapted to the IMDb dataset. The encoder of the trained language model is then used for sentiment classification.

The tokenizer is a customized BERT tokenizer (tokenizers/bert-base-uncased-10k-bookcorpus-ext), trained on BookCorpus with a vocabulary size of 10,000. You can also use any other 🤗 fast tokenizer from the 🤗 Hub with the --data.tokenizer option (see Tokenizers for details).

The training script is mlm.py. It implements the command line interface and defines training defaults (see also trainer.yaml for further defaults). Pretraining on WikiText-103 can be started with:

python -m perceiver.scripts.text.mlm fit \
  --model.num_latents=128 \
  --model.num_latent_channels=128 \
  --model.encoder.num_input_channels=128 \
  --model.encoder.num_cross_attention_layers=3 \
  --model.encoder.num_self_attention_layers_per_block=6 \
  --model.encoder.num_self_attention_blocks=3 \
  --model.encoder.first_self_attention_block_shared=false \
  --model.encoder.dropout=0.1 \
  --model.decoder.dropout=0.1 \
  --data=WikiTextDataModule \
  --data.tokenizer=tokenizers/bert-base-uncased-10k-bookcorpus-ext \
  --data.max_seq_len=512 \
  --data.batch_size=64 \
  --optimizer=AdamW \
  --optimizer.lr=1e-3 \
  --optimizer.weight_decay=0.01 \
  --lr_scheduler.warmup_steps=5000 \
  --trainer.accelerator=gpu \
  --trainer.devices=-1 \
  --trainer.max_steps=50000 \
  --trainer.check_val_every_n_epoch=5 \
  --trainer.logger=TensorBoardLogger \
  --trainer.logger.save_dir=logs \
  --trainer.logger.name=mlm
Model parameters Validation loss Mask prediction samples
Total params:      2.9M
Frozen params: 0M
Trainable params: 2.9M
val-loss-1 mask-pred-1

Starting from the best pretraining checkpoint, the language model is then adapted to the IMDb dataset for further 15,000 steps. If you ran pretraining yourself, you'll need to modify the --model.ckpt value accordingly, otherwise the checkpoint from the downloaded archive is used.

python -m perceiver.scripts.text.mlm fit \
  --model.ckpt="logs/mlm/version_0/checkpoints/epoch=044-val_loss=3.917.ckpt" \
  --model.num_latents=128 \
  --model.num_latent_channels=128 \
  --model.encoder.num_input_channels=128 \
  --model.encoder.num_cross_attention_layers=3 \
  --model.encoder.num_self_attention_layers_per_block=6 \
  --model.encoder.num_self_attention_blocks=3 \
  --model.encoder.first_self_attention_block_shared=false \
  --model.encoder.dropout=0.1 \
  --model.decoder.dropout=0.1 \
  --data=ImdbDataModule \
  --data.tokenizer=tokenizers/bert-base-uncased-10k-bookcorpus-ext \
  --data.max_seq_len=512 \
  --data.batch_size=64 \
  --optimizer=AdamW \
  --optimizer.lr=1e-3 \
  --optimizer.weight_decay=0.01 \
  --lr_scheduler.warmup_steps=1000 \
  --trainer.accelerator=gpu \
  --trainer.devices=-1 \
  --trainer.max_steps=15000 \
  --trainer.check_val_every_n_epoch=3 \
  --trainer.logger=TensorBoardLogger \
  --trainer.logger.save_dir=logs \
  --trainer.logger.name=mlm
Model parameters Validation loss Mask prediction samples
Total params:      2.9M
Frozen params: 0M
Trainable params: 2.9M
val-loss-2 mask-pred-2

After adaption to IMDb, mask prediction samples are obviously more related to movie reviews compared to pretraining on WikiText-103 only. Prediction samples are screenshots from Tensorboard logs.

Sentiment classification

This section trains a Perceiver IO text classifier on IMDb reviews. The encoder is initialized with weights from masked language modeling (--model.mlm_ckpt option), the decoder is a randomly initialized classification decoder. In a first step, only the decoder is trained and the encoder is frozen. The training script is classifier.py.

python -m perceiver.scripts.text.classifier fit \
  --model.mlm_ckpt="logs/mlm/version_1/checkpoints/epoch=113-val_loss=3.904.ckpt" \
  --model.num_latents=128 \
  --model.num_latent_channels=128 \
  --model.encoder.num_input_channels=128 \
  --model.encoder.num_cross_attention_layers=3 \
  --model.encoder.num_self_attention_layers_per_block=6 \
  --model.encoder.num_self_attention_blocks=3 \
  --model.encoder.first_self_attention_block_shared=false \
  --model.encoder.dropout=0.1 \
  --model.encoder.freeze=true \
  --model.decoder.num_output_query_channels=128 \
  --model.decoder.dropout=0.1 \
  --data=ImdbDataModule \
  --data.tokenizer=tokenizers/bert-base-uncased-10k-bookcorpus-ext \
  --data.target_task=clf \
  --data.max_seq_len=512 \
  --data.batch_size=256 \
  --optimizer=AdamW \
  --optimizer.lr=1e-4 \
  --optimizer.weight_decay=0.01 \
  --trainer.accelerator=gpu \
  --trainer.devices=-1 \
  --trainer.max_epochs=30 \
  --trainer.log_every_n_steps=10 \
  --trainer.logger=TensorBoardLogger \
  --trainer.logger.save_dir=logs \
  --trainer.logger.name=clf
Model parameters Validation accuracy
Total params:      2.9M
Frozen params: 2.8M
Trainable params: 100K
val-acc-1

The small classification decoder (100K parameters) can be trained to a validation accuracy of 88% when using an encoder that has been adapted to the IMDb dataset (red line). When using an encoder that has been pretrained on WikiText-103 only, the validation accuracy saturates at 78% (pink line). Unfreezing the encoder and fine-tuning it jointly with the classification decoder further improves validation accuracy:

python -m perceiver.scripts.text.classifier fit \
  --model.clf_ckpt="logs/clf/version_0/checkpoints/epoch=028-val_loss=0.301.ckpt" \
  --model.num_latents=128 \
  --model.num_latent_channels=128 \
  --model.encoder.num_input_channels=128 \
  --model.encoder.num_cross_attention_layers=3 \
  --model.encoder.num_self_attention_layers_per_block=6 \
  --model.encoder.num_self_attention_blocks=3 \
  --model.encoder.first_self_attention_block_shared=false \
  --model.encoder.dropout=0.1 \
  --model.decoder.num_output_query_channels=128 \
  --model.decoder.dropout=0.1 \
  --data=ImdbDataModule \
  --data.tokenizer=tokenizers/bert-base-uncased-10k-bookcorpus-ext \
  --data.target_task=clf \
  --data.max_seq_len=512 \
  --data.batch_size=256 \
  --optimizer=AdamW \
  --optimizer.lr=1e-5 \
  --optimizer.weight_decay=0.01 \
  --trainer.accelerator=gpu \
  --trainer.devices=-1 \
  --trainer.max_epochs=30 \
  --trainer.log_every_n_steps=10 \
  --trainer.logger=TensorBoardLogger \
  --trainer.logger.save_dir=logs \
  --trainer.logger.name=clf
Model parameters Validation accuracy
Total params:      2.9M
Frozen params: 0M
Trainable params: 2.9M
val-acc-2

Image classification

This section trains a tiny Perceiver IO image classifier (805K parameters) on MNIST digits. The model attends to each pixel in input images and does not use convolutional layers. In contrast to other examples only a single cross-attention layer is used. The training script is classifier.py.

python -m perceiver.scripts.image.classifier fit \
  --model.num_latents=32 \
  --model.num_latent_channels=128 \
  --model.encoder.num_frequency_bands=32 \
  --model.encoder.num_cross_attention_layers=1 \
  --model.encoder.num_self_attention_layers_per_block=3 \
  --model.encoder.num_self_attention_blocks=3 \
  --model.encoder.first_self_attention_block_shared=false \
  --model.encoder.dropout=0.0 \
  --model.encoder.init_scale=0.1 \
  --model.decoder.num_output_query_channels=128 \
  --model.decoder.dropout=0.0 \
  --model.decoder.init_scale=0.1 \
  --data=MNISTDataModule \
  --data.batch_size=128 \
  --optimizer=AdamW \
  --optimizer.lr=1e-3 \
  --optimizer.weight_decay=0.01 \
  --trainer.accelerator=gpu \
  --trainer.devices=-1 \
  --trainer.max_epochs=20 \
  --trainer.logger=TensorBoardLogger \
  --trainer.logger.save_dir=logs \
  --trainer.logger.name=exp
Model parameters Validation accuracy
Total params:      805K
Frozen params: 0K
Trainable params: 805K
val-acc-3

Inference examples

  • Sentiment classification
    Open In Colab
  • Image classification
    Open In Colab

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

perceiver-io-0.4.0.tar.gz (955.6 kB view hashes)

Uploaded Source

Built Distribution

perceiver_io-0.4.0-py3-none-any.whl (39.6 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page