Perceiver IO
Project description
Perceiver, Perceiver IO and Perceiver AR
This repository is a PyTorch implementation of Perceiver, Perceiver IO and Perceiver AR, with PyTorch Lightning interfaces for model training and Hugging Face 🤗 interfaces for inference.
Perceiver: General Perception with Iterative Attention (paper, video) | |
Perceiver IO: A General Architecture for Structured Inputs & Outputs (paper, blog post) | |
General-purpose, long-context autoregressive modeling with Perceiver AR (paper, blog post) |
Overview
Core of the perceiver-io
library are backend models, lightweight PyTorch implementations of Perceiver,
Perceiver IO and Perceiver AR. They can be wrapped into PyTorch Lightning
modules for training (Lightning interface) and 🤗 modules for inference (Hugging Face interface). See
library design for details.
The command line interface for training is implemented with Lightning CLI.
Training datasets are 🤗 datasets wrapped into PyTorch Lightning data modules.
For NLP tasks, perceiver-io
supports all 🤗 fast tokenizers
and the 🤗 Perceiver UTF-8 bytes tokenizer.
Documentation
- Installation
- Getting started
- Library design
- Pretrained models
- Training examples
- Inference examples
- Model construction
- Building blocks
Installation
Via pip
pip install perceiver-io[text,vision,audio]
From sources
Installation from sources requires a Miniconda and a Poetry (1.2.0 or higher) installation.
Create and activate the perceiver-io
conda environment:
conda env create -f environment.yml
conda activate perceiver-io
Install main and test dependencies, including all extras:
# Without dependencies required for examples
poetry install --all-extras
If you want to run the examples locally, additionally use --with examples
:
poetry install --all-extras --with examples
Docker image
docker pull ghcr.io/krasserm/perceiver-io:latest
See Docker image for details.
Getting started
Inference
Optical flow
Compute the optical flow between consecutive frames of an input video and write the rendered results to an output video:
from urllib.request import urlretrieve
from transformers import pipeline
from perceiver.data.vision import video_utils
from perceiver.model.vision import optical_flow # register auto-classes and pipeline
urlretrieve(
url="https://martin-krasser.com/perceiver/flow/sintel_clip_cave_dragon_fight.mp4",
filename="sintel_clip_cave_dragon_fight.mp4",
)
# Create optical flow pipeline
optical_flow_pipeline = pipeline("optical-flow", model="krasserm/perceiver-io-optical-flow", device="cuda:0")
# load consecutive video frame pairs
frame_pairs = video_utils.read_video_frame_pairs("sintel_clip_cave_dragon_fight.mp4")
# create and render optical flow for all frame pairs
optical_flows = optical_flow_pipeline(frame_pairs, render=True, device="cuda:0")
# create video with rendered optical flows
video_utils.write_video("sintel_clip_cave_dragon_fight_output.mp4", optical_flows, fps=24)
Here is a side-by-side comparison of the input and output video:
Symbolic audio generation
Create audio sequences by generating symbolic (MIDI) audio data and converting the generated audio symbols into WAV output using fluidsynth (Note: fluidsynth must be installed in order for the following example to work):
from transformers import pipeline
from pretty_midi import PrettyMIDI
from perceiver.model.audio import symbolic # auto-class registration
repo_id = "krasserm/perceiver-ar-sam-giant-midi"
prompt = PrettyMIDI("prompt.mid")
audio_generator = pipeline("symbolic-audio-generation", model=repo_id)
output = audio_generator(prompt, max_new_tokens=64, num_latents=1, do_sample=True, top_p=0.95, temperature=1.0, render=True)
with open("generated_audio.wav", "wb") as f:
f.write(output["generated_audio_wav"])
Examples of generated audio sequences are available on the 🤗 hub.
See inference examples for more examples.
Training
Train a small Perceiver IO image classifier (907K parameters) on MNIST from the command line. The classifier cross-attends to individual pixels of input images with repeated cross-attention. See image classification training example for more details.
python -m perceiver.scripts.vision.image_classifier fit \
--model.num_latents=32 \
--model.num_latent_channels=128 \
--model.encoder.num_frequency_bands=32 \
--model.encoder.num_cross_attention_layers=2 \
--model.encoder.num_self_attention_blocks=3 \
--model.encoder.num_self_attention_layers_per_block=3 \
--model.encoder.first_self_attention_block_shared=false \
--model.encoder.dropout=0.1 \
--model.encoder.init_scale=0.1 \
--model.decoder.num_output_query_channels=128 \
--model.decoder.dropout=0.1 \
--model.decoder.init_scale=0.1 \
--data=MNISTDataModule \
--data.batch_size=64 \
--optimizer=AdamW \
--optimizer.lr=1e-3 \
--lr_scheduler.warmup_steps=500 \
--trainer.accelerator=gpu \
--trainer.devices=1 \
--trainer.max_epochs=30 \
--trainer.logger=TensorBoardLogger \
--trainer.logger.save_dir=logs \
--trainer.logger.name=logs
Model construction describes how to implement model-specific command line interfaces
with the Lightning CLI. Training checkpoints are written to the logs/img_clf/version_0/checkpoints
directory. Assuming
a checkpoint with filename epoch=025-val_loss=0.065.ckpt
exists, it can be converted to a perceiver-io
🤗 model with
from perceiver.model.vision.image_classifier import convert_mnist_classifier_checkpoint
convert_mnist_classifier_checkpoint(
save_dir="example/mnist-classifier",
ckpt_url="logs/img_clf/version_0/checkpoints/epoch=025-val_loss=0.065.ckpt",
)
so that it can be used in a 🤗 image classification pipeline
from datasets import load_dataset
from transformers import pipeline
mnist_dataset = load_dataset("mnist", split="test")[:9]
images = mnist_dataset["image"]
labels = mnist_dataset["label"]
classifier = pipeline("image-classification", model="example/mnist-classifier")
predictions = [pred[0]["label"] for pred in classifier(images)]
print(f"Labels: {labels}")
print(f"Predictions: {predictions}")
Labels: [7, 2, 1, 0, 4, 1, 4, 9, 5]
Predictions: [7, 2, 1, 0, 4, 1, 4, 9, 5]
or loaded directly:
import torch
from transformers import AutoModelForImageClassification, AutoImageProcessor
model = AutoModelForImageClassification.from_pretrained("example/mnist-classifier")
processor = AutoImageProcessor.from_pretrained("example/mnist-classifier")
inputs = processor(images, return_tensors="pt")
with torch.no_grad():
# use perceiver-io Hugging Face model
output_1 = model(**inputs).logits
with torch.no_grad():
# or use perceiver-io backend model directly
output_2 = model.backend_model(inputs.pixel_values)
print(f"Predictions: {output_1.argmax(dim=-1).numpy().tolist()}")
print(f"Predictions: {output_2.argmax(dim=-1).numpy().tolist()}")
Predictions: [7, 2, 1, 0, 4, 1, 4, 9, 5]
Predictions: [7, 2, 1, 0, 4, 1, 4, 9, 5]
See training examples for more examples.
Articles
Articles referencing this repository:
- Training compute-optimal Perceiver AR language models
- A gentle introduction to Rotary Position Embedding
Other implementations
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file perceiver-io-0.11.0.tar.gz
.
File metadata
- Download URL: perceiver-io-0.11.0.tar.gz
- Upload date:
- Size: 15.4 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.2.1 CPython/3.8.10 Linux/5.15.0-73-generic
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8ccef357d29e49d093b44417e9449d47fcf10c8a5ec2f4e17b4b06730842462a |
|
MD5 | 6cf995abb61a9530b202cb3036890d49 |
|
BLAKE2b-256 | 04fb10a6c2e5c567269e2f5c09557fed8f5f571b13aa452e5e8186365e6598dd |
File details
Details for the file perceiver_io-0.11.0-py3-none-any.whl
.
File metadata
- Download URL: perceiver_io-0.11.0-py3-none-any.whl
- Upload date:
- Size: 87.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.2.1 CPython/3.8.10 Linux/5.15.0-73-generic
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 9ab8398fb5b1120af7be73b80d982d71b65756f75b949236cc0326be45031bb6 |
|
MD5 | a5513393d23c71b6c0c5631cf1f4ef0e |
|
BLAKE2b-256 | ed254b8956dc2190e3655b7e6a6dd72fb8b2ca5f5d9a38ca623d09aebd70af9a |