A library for Meta-Learning and Few-Shot Learning with PyTorch
Project description
torchmetal
A library for few-shot learning & meta-learning in PyTorch.
torchmetal contains popular meta-learning benchmarks, fully compatible with
both torchvision
and PyTorch's DataLoader
.
Features
- A unified interface for both few-shot classification and regression problems, to allow easy benchmarking on multiple problems and reproducibility.
- Helper functions for some popular problems, with default arguments from the literature.
- An thin extension of PyTorch's
Module
, calledMetaModule
, that simplifies the creation of certain meta-learning models (e.g. gradient based meta-learning methods). See the MAML example for an example usingMetaModule
.
Datasets available
- Few-shot regression (toy problems):
- Sine waves (Finn et al., 2017)
- Harmonic functions (Lacoste et al., 2018)
- Sinusoid & lines (Finn et al., 2018)
- Few-shot classification (image classification):
- Omniglot (Lake et al., 2015, 2019)
- Mini-ImageNet (Vinyals et al., 2016, Ravi et al., 2017)
- Tiered-ImageNet (Ren et al., 2018)
- CIFAR-FS (Bertinetto et al., 2018)
- Fewshot-CIFAR100 (Oreshkin et al., 2018)
- Caltech-UCSD Birds (Hilliard et al., 2019, Wah et al., 2019)
- Double MNIST (Sun, 2019)
- Triple MNIST (Sun, 2019)
- Few-shot segmentation (semantic segmentation):
- Pascal5i 1-way Setup
Installation
You can install torchmetal either using Python's package manager pip, or from source. To avoid any conflict with your existing Python setup, it is suggested to work in a virtual environment with virtualenv
. To install virtualenv
:
pip install --upgrade virtualenv
virtualenv venv
source venv/bin/activate
Using pip
This is the recommended way to install torchmetal:
pip install torchmetal
From source
You can also install torchmetal from source. This is recommended if you want to contribute to torchmetal.
git clone https://github.com/tristandeleu/pytorch-meta.git
cd pytorch-meta
python setup.py install
Example
Minimal example
This minimal example below shows how to create a dataloader for the 5-shot 5-way Omniglot dataset with torchmetal. The dataloader loads a batch of randomly generated tasks, and all the samples are concatenated into a single tensor. For more examples, check the examples folder.
from torchmetal.datasets.helpers import omniglot
from torchmetal.utils.data import BatchMetaDataLoader
dataset = omniglot("data", ways=5, shots=5, test_shots=15, meta_train=True, download=True)
dataloader = BatchMetaDataLoader(dataset, batch_size=16, num_workers=4)
for batch in dataloader:
train_inputs, train_targets = batch["train"]
print('Train inputs shape: {0}'.format(train_inputs.shape)) # (16, 25, 1, 28, 28)
print('Train targets shape: {0}'.format(train_targets.shape)) # (16, 25)
test_inputs, test_targets = batch["test"]
print('Test inputs shape: {0}'.format(test_inputs.shape)) # (16, 75, 1, 28, 28)
print('Test targets shape: {0}'.format(test_targets.shape)) # (16, 75)
Advanced example
Helper functions are only available for some of the datasets available. However, all of them are available through the unified interface provided by torchmetal. The variable dataset
defined above is equivalent to the following
from torchmetal.datasets import Omniglot
from torchmetal.transforms import Categorical, ClassSplitter, Rotation
from torchvision.transforms import Compose, Resize, ToTensor
from torchmetal.utils.data import BatchMetaDataLoader
dataset = Omniglot("data",
# Number of ways
num_classes_per_task=5,
# Resize the images to 28x28 and converts them to PyTorch tensors (from Torchvision)
transform=Compose([Resize(28), ToTensor()]),
# Transform the labels to integers (e.g. ("Glagolitic/character01", "Sanskrit/character14", ...) to (0, 1, ...))
target_transform=Categorical(num_classes=5),
# Creates new virtual classes with rotated versions of the images (from Santoro et al., 2016)
class_augmentations=[Rotation([90, 180, 270])],
meta_train=True,
download=True)
dataset = ClassSplitter(dataset, shuffle=True, num_train_per_class=5, num_test_per_class=15)
dataloader = BatchMetaDataLoader(dataset, batch_size=16, num_workers=4)
Note that the dataloader, receiving the dataset, remains the same.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file torchmetal-0.1.0.tar.gz
.
File metadata
- Download URL: torchmetal-0.1.0.tar.gz
- Upload date:
- Size: 131.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.1.6 CPython/3.7.6 Linux/5.4.0-7634-generic
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 98832a02b9ebaf7796acf9028a40af8925855540102d58a1b6efe6e88bfc9f86 |
|
MD5 | 994b212e227fef94fd4da197b7c12f39 |
|
BLAKE2b-256 | 889a1aac2f21d55dfc4e33152ef301bcd16e58b0ab7bcc232ad11bbe9316874f |
File details
Details for the file torchmetal-0.1.0-py3-none-any.whl
.
File metadata
- Download URL: torchmetal-0.1.0-py3-none-any.whl
- Upload date:
- Size: 165.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.1.6 CPython/3.7.6 Linux/5.4.0-7634-generic
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6dd71a01ccd89bcbed855beac1efccc209c28760a7c167963a05f9de0f2b47b6 |
|
MD5 | 0234065b70aea1309e104451a7609fc4 |
|
BLAKE2b-256 | 6616bfc842001abd5f0eb5bf090cfb761002b7e4428f6dc22cf57221c2dc5c5a |