Perceiver IO
Project description
Perceiver, Perceiver IO and Perceiver AR
This repository is a PyTorch implementation of Perceiver, Perceiver IO and Perceiver AR, with PyTorch Lightning interfaces for model training and Hugging Face 🤗 interfaces for inference.
Perceiver: General Perception with Iterative Attention (paper, video) | |
Perceiver IO: A General Architecture for Structured Inputs & Outputs (paper, blog post) | |
General-purpose, long-context autoregressive modeling with Perceiver AR (paper, blog post) |
Overview
Core of the perceiver-io
library are backend models, lightweight PyTorch implementations of Perceiver,
Perceiver IO and Perceiver AR. They can be wrapped into PyTorch Lightning
modules for training (Lightning interface) and 🤗 modules for inference (Hugging Face interface). See
library design for details.
The command line interface for training is implemented with Lightning CLI.
Training datasets are 🤗 datasets wrapped into PyTorch Lightning data modules.
For NLP tasks, perceiver-io
supports all 🤗 fast tokenizers
and the 🤗 Perceiver UTF-8 bytes tokenizer.
Documentation
- Installation
- Getting started
- Library design
- Pretrained models
- Training examples
- Inference examples
- Model construction
- Building blocks
Installation
Via pip
pip install perceiver-io[text,vision]
From sources
Installation from sources requires a Miniconda and a Poetry (1.2.0 or higher) installation.
Create and activate the perceiver-io
conda environment:
conda env create -f environment.yml
conda activate perceiver-io
Install main and test dependencies, including all extras:
# Without dependencies required for examples
poetry install --all-extras
If you want to run the examples locally, additionally use --with examples
:
poetry install --all-extras --with examples
Docker image
docker pull ghcr.io/krasserm/perceiver-io:latest
See Docker image for details.
Getting started
Inference
Compute the optical flow between consecutive frames of an input video and write the rendered results to an output video:
from urllib.request import urlretrieve
from transformers import pipeline
from perceiver.data.vision import video_utils
from perceiver.model.vision import optical_flow # register auto-classes and pipeline
urlretrieve(
url="https://martin-krasser.com/perceiver/flow/sintel_clip_cave_dragon_fight.mp4",
filename="sintel_clip_cave_dragon_fight.mp4",
)
# Create optical flow pipeline
optical_flow_pipeline = pipeline("optical-flow", model="krasserm/perceiver-io-optical-flow", device="cuda:0")
# load consecutive video frame pairs
frame_pairs = video_utils.read_video_frame_pairs("sintel_clip_cave_dragon_fight.mp4")
# create and render optical flow for all frame pairs
optical_flows = optical_flow_pipeline(frame_pairs, render=True, device="cuda:0")
# create video with rendered optical flows
video_utils.write_video("sintel_clip_cave_dragon_fight_output.mp4", optical_flows, fps=24)
Here is a side-by-side comparison of the input and output video:
See inference examples for more examples.
Training
Train a small Perceiver IO image classifier (907K parameters) on MNIST from the command line. The classifier cross-attends to individual pixels of input images with repeated cross-attention. See image classification training example for more details.
python -m perceiver.scripts.vision.image_classifier fit \
--model.num_latents=32 \
--model.num_latent_channels=128 \
--model.encoder.num_frequency_bands=32 \
--model.encoder.num_cross_attention_layers=2 \
--model.encoder.num_self_attention_blocks=3 \
--model.encoder.num_self_attention_layers_per_block=3 \
--model.encoder.first_self_attention_block_shared=false \
--model.encoder.dropout=0.1 \
--model.encoder.init_scale=0.1 \
--model.decoder.num_output_query_channels=128 \
--model.decoder.dropout=0.1 \
--model.decoder.init_scale=0.1 \
--data=MNISTDataModule \
--data.batch_size=64 \
--optimizer=AdamW \
--optimizer.lr=1e-3 \
--lr_scheduler.warmup_steps=500 \
--trainer.accelerator=gpu \
--trainer.devices=1 \
--trainer.max_epochs=30 \
--trainer.logger=TensorBoardLogger \
--trainer.logger.save_dir=logs \
--trainer.logger.name=logs
Model construction describes how to implement model-specific command line interfaces
with the Lightning CLI. Training checkpoints are written to the logs/img_clf/version_0/checkpoints
directory. Assuming
a checkpoint with filename epoch=025-val_loss=0.065.ckpt
exists, it can be converted to a perceiver-io
🤗 model with
from perceiver.model.vision.image_classifier import convert_mnist_classifier_checkpoint
convert_mnist_classifier_checkpoint(
save_dir="example/mnist-classifier",
ckpt_url="logs/img_clf/version_0/checkpoints/epoch=025-val_loss=0.065.ckpt",
)
so that it can be used in a 🤗 image classification pipeline
from datasets import load_dataset
from transformers import pipeline
mnist_dataset = load_dataset("mnist", split="test")[:9]
images = mnist_dataset["image"]
labels = mnist_dataset["label"]
classifier = pipeline("image-classification", model="example/mnist-classifier")
predictions = [int(pred[0]["label"]) for pred in classifier(images)]
print(f"Labels: {labels}")
print(f"Predictions: {predictions}")
Labels: [7, 2, 1, 0, 4, 1, 4, 9, 5]
Predictions: [7, 2, 1, 0, 4, 1, 4, 9, 5]
or loaded directly:
import torch
from transformers import AutoModelForImageClassification, AutoImageProcessor
model = AutoModelForImageClassification.from_pretrained("example/mnist-classifier")
processor = AutoImageProcessor.from_pretrained("example/mnist-classifier")
inputs = processor(images, return_tensors="pt")
with torch.no_grad():
# use perceiver-io Hugging Face model
output_1 = model(**inputs).logits
with torch.no_grad():
# or use perceiver-io backend model directly
output_2 = model.backend_model(inputs.pixel_values)
print(f"Predictions: {output_1.argmax(dim=-1).numpy().tolist()}")
print(f"Predictions: {output_2.argmax(dim=-1).numpy().tolist()}")
Predictions: [7, 2, 1, 0, 4, 1, 4, 9, 5]
Predictions: [7, 2, 1, 0, 4, 1, 4, 9, 5]
See training examples for more examples.
Articles
Articles referencing this repository:
- Training compute-optimal Perceiver AR language models
- A gentle introduction to Rotary Position Embedding
Other implementations
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for perceiver_io-0.9.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 58b419787ee618c9d2219a51f4d929b287004d1a894b6bf1e83c286ededd7d6c |
|
MD5 | 13ac607be7744cd7deb0a993b76df2b3 |
|
BLAKE2b-256 | e54fa7661830b61547a3d15343df1d4024b39177e4d6e64213f639239edcdbb7 |