Perceiver IO
Project description
Perceiver IO
This repository is a PyTorch and PyTorch Lightning implementation of
- Perceiver IO: A General Architecture for Structured Inputs & Outputs and
- Perceiver: General Perception with Iterative Attention
The codebase is designed for easy extension to new tasks and datasets. If you are a researcher or practitioner working on new Perceiver IO models and use cases, you might find this repository useful. The integration with PyTorch Lightning supports model training at any scale. On the other hand, if you are mainly interested in using or fine-tuning models from the Perceiver IO paper you may want to take a look at the 🤗 Perceiver IO implementation.
Overview
The following figure maps Perceiver IO and Perceiver concepts to the core modules of the implementation (see Architecture for details).
Interfaces are defined on three levels:
- PyTorch model API: defines generic
PerceiverEncoder
andPerceiverDecoder
classes and task-specificInputAdapter
andOutputAdapter
subclasses from which PyTorch models can be constructed. - PyTorch Lightning model API: defines wrappers for PyTorch models to support training with the PyTorch Lightning Trainer.
- PyTorch Lightning model CLI: binds the PyTorch Lightning model API to the command line via the Lightning CLI.
Interface usage examples are available for two models from the Perceiver IO paper:
Model | Parameters | ||
---|---|---|---|
Language model (Perceiver IO Base, SentencePiece tokenization) | 223M | construction | training |
Image classifier (Perceiver IO config A, 2D Fourier Features) | 48.4M | construction | training |
Training of smaller models is shown in section Training examples, their usage in section Inference examples.
Installation
Via pip
pip install perceiver-io[image,text]
From sources
Conda + Poetry
conda env create -f environment.yml
conda activate perceiver-io
poetry install --all-extras
This requires a Poetry installation, version 1.2.0b2 or higher.
If running poetry
fails with a KeyringError
, refer to the keyring documentation
how to disable keyring usage.
Docker image
A perceiver-io
Docker image can be built with:
docker build -t perceiver-io .
Training of Perceiver IO models with this image is described here.
Training examples
This section uses rather small Perceiver IO models so that the following training examples can be run on limited hardware
resources. Training automatically scales to more than one GPU and was tested on 4 RTX 3080 GPUs. For GPUs with less memory
you may need to reduce the --data.batch_size
or turn on activation checkpointing for some
of the examples.
Datasets used for model training are 🤗 Datasets wrapped into PyTorch Lightning data modules (see data package). Datasets are automatically downloaded, preprocessed and cached when their corresponding Lightning data module is loaded during training. Manual dataset preprocessing is described here.
An archive with training checkpoints can be downloaded here and should be extracted in project's root directory to be compatible with the example command lines below. It also contains Tensorboard logs and config files.
I didn't really tune hyperparameters, so you'll likely get better results with a bit of experimentation (see also training tips).
Masked language modeling
This section trains a very small language model (2.9M parameters) on masked language modeling with whole word masking. It is first pretrained on WikiText-103 and then adapted to the IMDb dataset. The encoder of the trained language model is then used for sentiment classification.
The tokenizer is a customized BERT tokenizer (tokenizers/bert-base-uncased-10k-bookcorpus-ext
), trained on BookCorpus
with a vocabulary size of 10,000. You can also use any other 🤗 fast tokenizer
from the 🤗 Hub with the --data.tokenizer
option (see Tokenizers for details).
The training script is mlm.py. It implements the command line interface and defines training defaults (see also trainer.yaml for further defaults). Pretraining on WikiText-103 can be started with:
python -m perceiver.scripts.text.mlm fit \
--model.num_latents=128 \
--model.num_latent_channels=128 \
--model.encoder.num_input_channels=128 \
--model.encoder.num_cross_attention_layers=3 \
--model.encoder.num_self_attention_layers_per_block=6 \
--model.encoder.num_self_attention_blocks=3 \
--model.encoder.first_self_attention_block_shared=false \
--model.encoder.dropout=0.1 \
--model.decoder.dropout=0.1 \
--data=WikiTextDataModule \
--data.tokenizer=tokenizers/bert-base-uncased-10k-bookcorpus-ext \
--data.max_seq_len=512 \
--data.batch_size=64 \
--optimizer=AdamW \
--optimizer.lr=1e-3 \
--optimizer.weight_decay=0.01 \
--lr_scheduler.warmup_steps=5000 \
--trainer.accelerator=gpu \
--trainer.devices=-1 \
--trainer.max_steps=50000 \
--trainer.check_val_every_n_epoch=5 \
--trainer.logger=TensorBoardLogger \
--trainer.logger.save_dir=logs \
--trainer.logger.name=mlm
Model parameters | Validation loss | Mask prediction samples |
---|---|---|
Total params: 2.9M |
Starting from the best pretraining checkpoint, the language model is then adapted to the IMDb dataset for further 15,000
steps. If you ran pretraining yourself, you'll need to modify the --model.ckpt
value accordingly, otherwise the checkpoint
from the downloaded archive is used.
python -m perceiver.scripts.text.mlm fit \
--model.ckpt="logs/mlm/version_0/checkpoints/epoch=044-val_loss=3.917.ckpt" \
--model.num_latents=128 \
--model.num_latent_channels=128 \
--model.encoder.num_input_channels=128 \
--model.encoder.num_cross_attention_layers=3 \
--model.encoder.num_self_attention_layers_per_block=6 \
--model.encoder.num_self_attention_blocks=3 \
--model.encoder.first_self_attention_block_shared=false \
--model.encoder.dropout=0.1 \
--model.decoder.dropout=0.1 \
--data=ImdbDataModule \
--data.tokenizer=tokenizers/bert-base-uncased-10k-bookcorpus-ext \
--data.max_seq_len=512 \
--data.batch_size=64 \
--optimizer=AdamW \
--optimizer.lr=1e-3 \
--optimizer.weight_decay=0.01 \
--lr_scheduler.warmup_steps=1000 \
--trainer.accelerator=gpu \
--trainer.devices=-1 \
--trainer.max_steps=15000 \
--trainer.check_val_every_n_epoch=3 \
--trainer.logger=TensorBoardLogger \
--trainer.logger.save_dir=logs \
--trainer.logger.name=mlm
Model parameters | Validation loss | Mask prediction samples |
---|---|---|
Total params: 2.9M |
After adaption to IMDb, mask prediction samples are obviously more related to movie reviews compared to pretraining on WikiText-103 only. Prediction samples are screenshots from Tensorboard logs.
Sentiment classification
This section trains a Perceiver IO text classifier on IMDb reviews. The encoder is initialized with weights from
masked language modeling (--model.mlm_ckpt
option), the decoder is a randomly initialized
classification decoder. In a first step, only the decoder is trained and the encoder is frozen. The training script is
classifier.py.
python -m perceiver.scripts.text.classifier fit \
--model.mlm_ckpt="logs/mlm/version_1/checkpoints/epoch=113-val_loss=3.904.ckpt" \
--model.num_latents=128 \
--model.num_latent_channels=128 \
--model.encoder.num_input_channels=128 \
--model.encoder.num_cross_attention_layers=3 \
--model.encoder.num_self_attention_layers_per_block=6 \
--model.encoder.num_self_attention_blocks=3 \
--model.encoder.first_self_attention_block_shared=false \
--model.encoder.dropout=0.1 \
--model.encoder.freeze=true \
--model.decoder.num_output_query_channels=128 \
--model.decoder.dropout=0.1 \
--data=ImdbDataModule \
--data.tokenizer=tokenizers/bert-base-uncased-10k-bookcorpus-ext \
--data.target_task=clf \
--data.max_seq_len=512 \
--data.batch_size=256 \
--optimizer=AdamW \
--optimizer.lr=1e-4 \
--optimizer.weight_decay=0.01 \
--trainer.accelerator=gpu \
--trainer.devices=-1 \
--trainer.max_epochs=30 \
--trainer.log_every_n_steps=10 \
--trainer.logger=TensorBoardLogger \
--trainer.logger.save_dir=logs \
--trainer.logger.name=clf
Model parameters | Validation accuracy |
---|---|
Total params: 2.9M |
The small classification decoder (100K parameters) can be trained to a validation accuracy of 88% when using an encoder that has been adapted to the IMDb dataset (red line). When using an encoder that has been pretrained on WikiText-103 only, the validation accuracy saturates at 78% (pink line). Unfreezing the encoder and fine-tuning it jointly with the classification decoder further improves validation accuracy:
python -m perceiver.scripts.text.classifier fit \
--model.clf_ckpt="logs/clf/version_0/checkpoints/epoch=028-val_loss=0.301.ckpt" \
--model.num_latents=128 \
--model.num_latent_channels=128 \
--model.encoder.num_input_channels=128 \
--model.encoder.num_cross_attention_layers=3 \
--model.encoder.num_self_attention_layers_per_block=6 \
--model.encoder.num_self_attention_blocks=3 \
--model.encoder.first_self_attention_block_shared=false \
--model.encoder.dropout=0.1 \
--model.decoder.num_output_query_channels=128 \
--model.decoder.dropout=0.1 \
--data=ImdbDataModule \
--data.tokenizer=tokenizers/bert-base-uncased-10k-bookcorpus-ext \
--data.target_task=clf \
--data.max_seq_len=512 \
--data.batch_size=256 \
--optimizer=AdamW \
--optimizer.lr=1e-5 \
--optimizer.weight_decay=0.01 \
--trainer.accelerator=gpu \
--trainer.devices=-1 \
--trainer.max_epochs=30 \
--trainer.log_every_n_steps=10 \
--trainer.logger=TensorBoardLogger \
--trainer.logger.save_dir=logs \
--trainer.logger.name=clf
Model parameters | Validation accuracy |
---|---|
Total params: 2.9M |
Image classification
This section trains a tiny Perceiver IO image classifier (805K parameters) on MNIST digits. The model attends to each pixel in input images and does not use convolutional layers. In contrast to other examples only a single cross-attention layer is used. The training script is classifier.py.
python -m perceiver.scripts.image.classifier fit \
--model.num_latents=32 \
--model.num_latent_channels=128 \
--model.encoder.num_frequency_bands=32 \
--model.encoder.num_cross_attention_layers=1 \
--model.encoder.num_self_attention_layers_per_block=3 \
--model.encoder.num_self_attention_blocks=3 \
--model.encoder.first_self_attention_block_shared=false \
--model.encoder.dropout=0.0 \
--model.encoder.init_scale=0.1 \
--model.decoder.num_output_query_channels=128 \
--model.decoder.dropout=0.0 \
--model.decoder.init_scale=0.1 \
--data=MNISTDataModule \
--data.batch_size=128 \
--optimizer=AdamW \
--optimizer.lr=1e-3 \
--optimizer.weight_decay=0.01 \
--trainer.accelerator=gpu \
--trainer.devices=-1 \
--trainer.max_epochs=20 \
--trainer.logger=TensorBoardLogger \
--trainer.logger.save_dir=logs \
--trainer.logger.name=exp
Model parameters | Validation accuracy |
---|---|
Total params: 805K |
Inference examples
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for perceiver_io-0.4.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | eb95aa802aa23e76a846026d603c55970fa7522f33aaa80670e4e32d49da6127 |
|
MD5 | 1d734e4586a237f8b97879759669c707 |
|
BLAKE2b-256 | b88604546ff63e71fc54d699c9ab3533a13c47f1d2149175700fb7e5b34d5d7e |