Perceiver IO
Project description
Perceiver IO
This library is a PyTorch and PyTorch Lightning implementation of
- Perceiver IO: A General Architecture for Structured Inputs & Outputs and
- Perceiver: General Perception with Iterative Attention
An introduction to the model interfaces provided by the library is given in Interfaces. Further implementation details are described in Architecture. The codebase was designed for easy extension to new tasks and datasets. The integration with PyTorch Lightning supports model training at any scale. The command line interface is implemented with the Lightning CLI.
Datasets used for model training are ๐ค Datasets wrapped into PyTorch Lightning data modules (see data package). Datasets are automatically downloaded, preprocessed and cached when their corresponding Lightning data module is loaded during training. For larger datasets, however, it is recommended to do this prior to training as described here.
For NLP tasks, this library also supports ๐ค fast tokenizers (see this list for an overview which ๐ค tokenizers are fast tokenizers). It additionally provides special support for the ๐ค UTF-8 bytes Perceiver tokenizer so that it can be used here for masked language modeling with whole word masking.
Overview
Installation
Via pip
pip install perceiver-io[image,text]
From sources
Installation from sources requires a Miniconda and a Poetry (1.2.0b2 or higher) installation.
conda env create -f environment.yml
conda activate perceiver-io
poetry install --all-extras
Pretrained models
Parameters of pretrained models can be imported from the ๐ค Hub as described in the following subsections.
Language model
Perceiver IO language model (UTF-8 bytes tokenization, vocabulary size of 262, 201M parameters) specified in Section 4 (Table 1) and Appendix F (Table 11) of the Perceiver IO paper. See Interfaces for further details.
from transformers import AutoConfig
from perceiver.model.text.language import convert_config, LanguageModel, LitLanguageModel
# Import and convert language model configuration from Huggingface Hub
config = convert_config(AutoConfig.from_pretrained("deepmind/language-perceiver"))
# Construct a PyTorch model and load pretrained parameters
model = LanguageModel(config)
# Alternatively, construct a PyTorch Lightning module and load pretrained parameters
lit_model = LitLanguageModel.create(config)
On the command line, the pretrained model can be loaded with the --model.params=deepmind/language-perceiver
option.
python -m perceiver.scripts.text.lm fit \
--model.params=deepmind/language-perceiver \
...
Image classifier
The Perceiver IO image classifier (config A, 2D Fourier features, 48.8M parameters) specified in Appendix A of the Perceiver IO paper.
from transformers import AutoConfig
from perceiver.model.image.classifier import convert_config, ImageClassifier, LitImageClassifier
# Import and convert language model configuration from Huggingface Hub
config = convert_config(AutoConfig.from_pretrained("deepmind/vision-perceiver-fourier"))
# Construct a PyTorch model and load pretrained parameters
model = ImageClassifier(config)
# Alternatively, construct a PyTorch Lightning module and load pretrained parameters
lit_model = LitImageClassifier.create(config)
On the command line, the pretrained model can be loaded with the --model.params=deepmind/vision-perceiver-fourier
option.
python -m perceiver.scripts.image.classifier fit \
--model.params=deepmind/vision-perceiver-fourier \
...
Training examples
Here are some command line examples how to train Perceiver IO models with this library. If a model must be initialized
with parameters from a previous run, it references a checkpoint from that run with the --model.params
option. You can
download these checkpoints here (2.3 GB) or create your own
checkpoints by running the examples yourself. Training results are used in Inference examples
These examples were tested on a machine with 4x RTX 3080ti GPUs (12 GB memory each). You'll need to adjust some settings (batch size, ...) for running them on a different hardware configuration. Furthermore, I didn't really tune these examples, so you'll likely get better results with a bit of experimentation.
Dataset preprocessing
Although data modules automatically download and preprocess datasets if needed, it is recommended to preprocess at least IMDb and WikiText prior to training:
python -m perceiver.scripts.text.preproc imdb \
--tokenizer=deepmind/language-perceiver \
--max_seq_len=2048 \
--add_special_tokens=true
python -m perceiver.scripts.text.preproc wikitext \
--tokenizer=bert-base-uncased \
--max_seq_len=128 \
--add_special_tokens=true \
--filter_empty=true \
--filter_headers=true
Language model fine-tuning
Fine-tune a pretrained deepmind/language-perceiver
model with masked language modeling and whole word masking on
the IMDb dataset (unsupervised split). It prepares the language model for a better performance on IMDb sentiment
classification. The tokenizer is a UTF-8 bytes tokenizer and the model attends to the
raw bytes of the input. Word masking is done dynamically at data loading time i.e. each epoch has a different set
of words masked.
python -m perceiver.scripts.text.lm fit \
--model.params=deepmind/language-perceiver \
--model.activation_checkpointing=true \
--data=ImdbDataModule \
--data.target_task=mlm \
--data.tokenizer=deepmind/language-perceiver \
--data.add_special_tokens=true \
--data.max_seq_len=2048 \
--data.batch_size=32 \
--optimizer=AdamW \
--optimizer.lr=2e-5 \
--optimizer.weight_decay=0.01 \
--lr_scheduler.warmup_steps=1000 \
--trainer.max_steps=5200 \
--trainer.accelerator=gpu \
--trainer.precision=16 \
--trainer.devices=2 \
--trainer.strategy=ddp_sharded \
--trainer.log_every_n_steps=20 \
--trainer.logger.save_dir=logs \
--trainer.logger=TensorBoardLogger \
--trainer.logger.name=mlm
Sentiment classification
Train a text classification model on the IMDb dataset (train split). The encoder of the classifier is the fine-tuned
language model encoder from the previous run (--model.encoder.params=...
), the decoder
is a randomly initialized classification decoder (see TextClassifier
and LitTextClassifier
in classifier.py).
First, only the decoder is trained, the encoder is frozen (--model.encoder.freeze=true
)
python -m perceiver.scripts.text.classifier fit \
--model.encoder.params="logs/mlm/version_0/checkpoints/epoch=009-val_loss=1.174.ckpt" \
--model.encoder.freeze=true \
--model.encoder.dropout=0.0 \
--model.decoder.dropout=0.1 \
--data=ImdbDataModule \
--data.target_task=clf \
--data.tokenizer=deepmind/language-perceiver \
--data.add_special_tokens=true \
--data.max_seq_len=2048 \
--data.batch_size=64 \
--optimizer=AdamW \
--optimizer.lr=1e-3 \
--optimizer.weight_decay=0.01 \
--trainer.accelerator=gpu \
--trainer.precision=16 \
--trainer.devices=4 \
--trainer.max_epochs=12 \
--trainer.logger=TensorBoardLogger \
--trainer.logger.save_dir=logs \
--trainer.logger.name=txt_clf_dec
Then, we unfreeze the encoder, initialize the classifier parameters with a checkpoint from the first classifier training
(--model.params=...
) and fine-tune the encoder and decoder together on the IMDb training set for further 4 epochs.
python -m perceiver.scripts.text.classifier fit \
--model.params="logs/txt_clf_dec/version_1/checkpoints/epoch=010-val_loss=0.212.ckpt" \
--model.activation_checkpointing=true \
--model.encoder.freeze=false \
--model.encoder.dropout=0.1 \
--model.decoder.dropout=0.1 \
--data=ImdbDataModule \
--data.target_task=clf \
--data.tokenizer=deepmind/language-perceiver \
--data.add_special_tokens=true \
--data.max_seq_len=2048 \
--data.batch_size=16 \
--data.num_workers=3 \
--optimizer=AdamW \
--optimizer.lr=5e-6 \
--optimizer.weight_decay=0.01 \
--trainer.accelerator=gpu \
--trainer.precision=16 \
--trainer.devices=4 \
--trainer.max_epochs=4 \
--trainer.logger=TensorBoardLogger \
--trainer.logger.save_dir=logs \
--trainer.logger.name=txt_clf_all
The validation accuracy of these two runs can be obtained with
python -m perceiver.scripts.text.classifier validate \
--config=logs/txt_clf_dec/version_1/config.yaml \
--model.encoder.params=null \
--ckpt_path="logs/txt_clf_dec/version_1/checkpoints/epoch=010-val_loss=0.212.ckpt"
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
Validate metric DataLoader 0
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
val_acc 0.9162399768829346
val_loss 0.21216852962970734
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
and
python -m perceiver.scripts.text.classifier validate \
--config=logs/txt_clf_all/version_0/config.yaml \
--model.params=null \
--ckpt_path="logs/txt_clf_all/version_0/checkpoints/epoch=002-val_loss=0.156.ckpt"
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
Validate metric DataLoader 0
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
val_acc 0.9444400072097778
val_loss 0.15595446527004242
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
When training only the classification decoder, the validation accuracy is 91.6%. Fine-tuning encoder and decoder on the classification task further increases validation accuracy to 94.4%.
Language model pretraining
Pretrain a smaller language model (45.2M parameters) with masked language modeling and whole word masking on the
Wikitext-103 dataset. This is a toy example for demonstrating how to use a custom model configuration/architecture
and another ๐ค tokenizer (bert-base-uncased
, a SentencePiece tokenizer with a vocabulary of size of 30,522). To
speed up training, --data.max_seq_len=128
and --model.num_latents=64
is used (a quarter of the default values).
python -m perceiver.scripts.text.lm fit \
--model.activation_checkpointing=true \
--model.num_latents=64 \
--model.num_latent_channels=768 \
--model.encoder.num_input_channels=512 \
--model.encoder.num_cross_attention_v_channels=768 \
--model.encoder.num_self_attention_v_channels=768 \
--model.encoder.num_self_attention_layers_per_block=6 \
--model.encoder.cross_attention_widening_factor=2 \
--model.encoder.self_attention_widening_factor=2 \
--model.encoder.dropout=0.0 \
--model.decoder.num_cross_attention_v_channels=512 \
--model.decoder.cross_attention_widening_factor=2 \
--model.decoder.dropout=0.0 \
--data=WikiTextDataModule \
--data.tokenizer=bert-base-uncased \
--data.add_special_tokens=true \
--data.filter_empty=true \
--data.filter_headers=true \
--data.max_seq_len=128 \
--data.batch_size=128 \
--optimizer=AdamW \
--optimizer.lr=1e-4 \
--optimizer.weight_decay=0.01 \
--lr_scheduler.warmup_steps=1000 \
--trainer.max_steps=25000 \
--trainer.accelerator=gpu \
--trainer.precision=16 \
--trainer.devices=4 \
--trainer.strategy=ddp_sharded \
--trainer.accumulate_grad_batches=2 \
--trainer.val_check_interval=0.5 \
--trainer.log_every_n_steps=20 \
--trainer.logger.save_dir=logs \
--trainer.logger=TensorBoardLogger \
--trainer.logger.name=mlm_pre
Image classification
Train a tiny image classifier (805K parameters) on the MNIST dataset. The model attends to individual pixels of the input image and uses Fourier position encodings. This is another toy example that demonstrates how to use a custom model configuration compared to the defaults in classifier.py.
python -m perceiver.scripts.image.classifier fit \
--model.num_latents=32 \
--model.num_latent_channels=128 \
--model.encoder.num_frequency_bands=32 \
--model.encoder.num_self_attention_layers_per_block=3 \
--model.encoder.num_self_attention_blocks=3 \
--model.encoder.first_self_attention_block_shared=false \
--model.encoder.dropout=0.0 \
--model.encoder.init_scale=0.1 \
--model.decoder.num_output_query_channels=128 \
--model.decoder.dropout=0.0 \
--model.decoder.init_scale=0.1 \
--data=MNISTDataModule \
--data.batch_size=128 \
--optimizer=AdamW \
--optimizer.lr=1e-3 \
--optimizer.weight_decay=0.01 \
--trainer.accelerator=gpu \
--trainer.devices=2 \
--trainer.max_epochs=20 \
--trainer.logger=TensorBoardLogger \
--trainer.logger.save_dir=logs \
--trainer.logger.name=img_clf
The validation accuracy is 98.1%:
python -m perceiver.scripts.image.classifier validate \
--config=logs/img_clf/version_0/config.yaml \
--ckpt_path="logs/img_clf/version_0/checkpoints/epoch=015-val_loss=0.068.ckpt"
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
Validate metric DataLoader 0
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
val_acc 0.9807000160217285
val_loss 0.06775263696908951
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for perceiver_io-0.5.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8f4e0b529d8c3988c1cf72d9b42144a3ed6a59e62d8d8ce7497f4b21ddfb39d9 |
|
MD5 | 264cd8c2e3ef0cd02836607a7ae0fdb2 |
|
BLAKE2b-256 | 1c5b5000638341b290ac23dcf2b839ff2be23150cca9b495d07a027f4e789db3 |