Skip to main content

Perceiver IO

Project description

Perceiver, Perceiver IO and Perceiver AR

This repository is a PyTorch implementation of Perceiver, Perceiver IO and Perceiver AR, with PyTorch Lightning interfaces for model training and Hugging Face 🤗 interfaces for inference.

Perceiver: General Perception with Iterative Attention (paper, video) Perceiver
Perceiver IO: A General Architecture for Structured Inputs & Outputs (paper, blog post) Perceiver IO
General-purpose, long-context autoregressive modeling with Perceiver AR (paper, blog post) Perceiver AR

Overview

Core of the perceiver-io library are backend models, lightweight PyTorch implementations of Perceiver, Perceiver IO and Perceiver AR. They can be wrapped into PyTorch Lightning modules for training (Lightning interface) and 🤗 modules for inference (Hugging Face interface). See library design for details.

library-design

The command line interface for training is implemented with Lightning CLI. Training datasets are 🤗 datasets wrapped into PyTorch Lightning data modules. For NLP tasks, perceiver-io supports all 🤗 fast tokenizers and the 🤗 Perceiver UTF-8 bytes tokenizer.

Documentation

Installation

Via pip

pip install perceiver-io[text,vision,audio]

From sources

Installation from sources requires a Miniconda and a Poetry (1.2.0 or higher) installation.

Create and activate the perceiver-io conda environment:

conda env create -f environment.yml
conda activate perceiver-io

Install main and test dependencies, including all extras:

# Without dependencies required for examples
poetry install --all-extras

If you want to run the examples locally, additionally use --with examples:

poetry install --all-extras --with examples

Docker image

docker pull ghcr.io/krasserm/perceiver-io:latest

See Docker image for details.

Getting started

Inference

Optical flow

Compute the optical flow between consecutive frames of an input video and write the rendered results to an output video:

from urllib.request import urlretrieve
from transformers import pipeline

from perceiver.data.vision import video_utils
from perceiver.model.vision import optical_flow  # register auto-classes and pipeline

urlretrieve(
    url="https://martin-krasser.com/perceiver/flow/sintel_clip_cave_dragon_fight.mp4",
    filename="sintel_clip_cave_dragon_fight.mp4",
)

# Create optical flow pipeline
optical_flow_pipeline = pipeline("optical-flow", model="krasserm/perceiver-io-optical-flow", device="cuda:0")

# load consecutive video frame pairs
frame_pairs = video_utils.read_video_frame_pairs("sintel_clip_cave_dragon_fight.mp4")

# create and render optical flow for all frame pairs
optical_flows = optical_flow_pipeline(frame_pairs, render=True, device="cuda:0")

# create video with rendered optical flows
video_utils.write_video("sintel_clip_cave_dragon_fight_output.mp4", optical_flows, fps=24)

Here is a side-by-side comparison of the input and output video:

optical-flow-sbs

Symbolic audio generation

Create audio sequences by generating symbolic (MIDI) audio data and converting the generated audio symbols into WAV output using fluidsynth (Note: fluidsynth must be installed in order for the following example to work):

from transformers import pipeline
from pretty_midi import PrettyMIDI
from perceiver.model.audio import symbolic  # auto-class registration

repo_id = "krasserm/perceiver-ar-sam-giant-midi"

prompt = PrettyMIDI("prompt.mid")
audio_generator = pipeline("symbolic-audio-generation", model=repo_id)

output = audio_generator(prompt, max_new_tokens=64, num_latents=1, do_sample=True, top_p=0.95, temperature=1.0, render=True)

with open("generated_audio.wav", "wb") as f:
    f.write(output["generated_audio_wav"])

Examples of generated audio sequences are available on the 🤗 hub.

See inference examples for more examples.

Training

Train a small Perceiver IO image classifier (907K parameters) on MNIST from the command line. The classifier cross-attends to individual pixels of input images with repeated cross-attention. See image classification training example for more details.

python -m perceiver.scripts.vision.image_classifier fit \
  --model.num_latents=32 \
  --model.num_latent_channels=128 \
  --model.encoder.num_frequency_bands=32 \
  --model.encoder.num_cross_attention_layers=2 \
  --model.encoder.num_self_attention_blocks=3 \
  --model.encoder.num_self_attention_layers_per_block=3 \
  --model.encoder.first_self_attention_block_shared=false \
  --model.encoder.dropout=0.1 \
  --model.encoder.init_scale=0.1 \
  --model.decoder.num_output_query_channels=128 \
  --model.decoder.dropout=0.1 \
  --model.decoder.init_scale=0.1 \
  --data=MNISTDataModule \
  --data.batch_size=64 \
  --optimizer=AdamW \
  --optimizer.lr=1e-3 \
  --lr_scheduler.warmup_steps=500 \
  --trainer.accelerator=gpu \
  --trainer.devices=1 \
  --trainer.max_epochs=30 \
  --trainer.logger=TensorBoardLogger \
  --trainer.logger.save_dir=logs \
  --trainer.logger.name=logs

Model construction describes how to implement model-specific command line interfaces with the Lightning CLI. Training checkpoints are written to the logs/img_clf/version_0/checkpoints directory. Assuming a checkpoint with filename epoch=025-val_loss=0.065.ckpt exists, it can be converted to a perceiver-io 🤗 model with

from perceiver.model.vision.image_classifier import convert_mnist_classifier_checkpoint

convert_mnist_classifier_checkpoint(
    save_dir="example/mnist-classifier",
    ckpt_url="logs/img_clf/version_0/checkpoints/epoch=025-val_loss=0.065.ckpt",
)

so that it can be used in a 🤗 image classification pipeline

from datasets import load_dataset
from transformers import pipeline

mnist_dataset = load_dataset("mnist", split="test")[:9]

images = mnist_dataset["image"]
labels = mnist_dataset["label"]

classifier = pipeline("image-classification", model="example/mnist-classifier")
predictions = [pred[0]["label"] for pred in classifier(images)]

print(f"Labels:      {labels}")
print(f"Predictions: {predictions}")
Labels:      [7, 2, 1, 0, 4, 1, 4, 9, 5]
Predictions: [7, 2, 1, 0, 4, 1, 4, 9, 5]

or loaded directly:

import torch
from transformers import AutoModelForImageClassification, AutoImageProcessor

model = AutoModelForImageClassification.from_pretrained("example/mnist-classifier")
processor = AutoImageProcessor.from_pretrained("example/mnist-classifier")

inputs = processor(images, return_tensors="pt")

with torch.no_grad():
    # use perceiver-io Hugging Face model
    output_1 = model(**inputs).logits

with torch.no_grad():
    # or use perceiver-io backend model directly  
    output_2 = model.backend_model(inputs.pixel_values)

print(f"Predictions: {output_1.argmax(dim=-1).numpy().tolist()}")
print(f"Predictions: {output_2.argmax(dim=-1).numpy().tolist()}")
Predictions: [7, 2, 1, 0, 4, 1, 4, 9, 5]
Predictions: [7, 2, 1, 0, 4, 1, 4, 9, 5]

See training examples for more examples.

Articles

Articles referencing this repository:

Other implementations

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

perceiver-io-0.11.0.tar.gz (15.4 MB view details)

Uploaded Source

Built Distribution

perceiver_io-0.11.0-py3-none-any.whl (87.6 kB view details)

Uploaded Python 3

File details

Details for the file perceiver-io-0.11.0.tar.gz.

File metadata

  • Download URL: perceiver-io-0.11.0.tar.gz
  • Upload date:
  • Size: 15.4 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.2.1 CPython/3.8.10 Linux/5.15.0-73-generic

File hashes

Hashes for perceiver-io-0.11.0.tar.gz
Algorithm Hash digest
SHA256 8ccef357d29e49d093b44417e9449d47fcf10c8a5ec2f4e17b4b06730842462a
MD5 6cf995abb61a9530b202cb3036890d49
BLAKE2b-256 04fb10a6c2e5c567269e2f5c09557fed8f5f571b13aa452e5e8186365e6598dd

See more details on using hashes here.

File details

Details for the file perceiver_io-0.11.0-py3-none-any.whl.

File metadata

  • Download URL: perceiver_io-0.11.0-py3-none-any.whl
  • Upload date:
  • Size: 87.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.2.1 CPython/3.8.10 Linux/5.15.0-73-generic

File hashes

Hashes for perceiver_io-0.11.0-py3-none-any.whl
Algorithm Hash digest
SHA256 9ab8398fb5b1120af7be73b80d982d71b65756f75b949236cc0326be45031bb6
MD5 a5513393d23c71b6c0c5631cf1f4ef0e
BLAKE2b-256 ed254b8956dc2190e3655b7e6a6dd72fb8b2ca5f5d9a38ca623d09aebd70af9a

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page