Skip to main content

A library for Meta-Learning and Few-Shot Learning with PyTorch

Project description

torchmetal

PyPI

A library for few-shot learning & meta-learning in PyTorch. torchmetal contains popular meta-learning benchmarks, fully compatible with both torchvision and PyTorch's DataLoader.

Features

  • A unified interface for both few-shot classification and regression problems, to allow easy benchmarking on multiple problems and reproducibility.
  • Helper functions for some popular problems, with default arguments from the literature.
  • An thin extension of PyTorch's Module, called MetaModule, that simplifies the creation of certain meta-learning models (e.g. gradient based meta-learning methods). See the MAML example for an example using MetaModule.

Datasets available

Installation

You can install torchmetal either using Python's package manager pip, or from source. To avoid any conflict with your existing Python setup, it is suggested to work in a virtual environment with virtualenv. To install virtualenv:

pip install --upgrade virtualenv
virtualenv venv
source venv/bin/activate

Using pip

This is the recommended way to install torchmetal:

pip install torchmetal

From source

You can also install torchmetal from source. This is recommended if you want to contribute to torchmetal.

git clone https://github.com/tristandeleu/pytorch-meta.git
cd pytorch-meta
python setup.py install

Example

Minimal example

This minimal example below shows how to create a dataloader for the 5-shot 5-way Omniglot dataset with torchmetal. The dataloader loads a batch of randomly generated tasks, and all the samples are concatenated into a single tensor. For more examples, check the examples folder.

from torchmetal.datasets.helpers import omniglot
from torchmetal.utils.data import BatchMetaDataLoader

dataset = omniglot("data", ways=5, shots=5, test_shots=15, meta_train=True, download=True)
dataloader = BatchMetaDataLoader(dataset, batch_size=16, num_workers=4)

for batch in dataloader:
    train_inputs, train_targets = batch["train"]
    print('Train inputs shape: {0}'.format(train_inputs.shape))    # (16, 25, 1, 28, 28)
    print('Train targets shape: {0}'.format(train_targets.shape))  # (16, 25)

    test_inputs, test_targets = batch["test"]
    print('Test inputs shape: {0}'.format(test_inputs.shape))      # (16, 75, 1, 28, 28)
    print('Test targets shape: {0}'.format(test_targets.shape))    # (16, 75)

Advanced example

Helper functions are only available for some of the datasets available. However, all of them are available through the unified interface provided by torchmetal. The variable dataset defined above is equivalent to the following

from torchmetal.datasets import Omniglot
from torchmetal.transforms import Categorical, ClassSplitter, Rotation
from torchvision.transforms import Compose, Resize, ToTensor
from torchmetal.utils.data import BatchMetaDataLoader

dataset = Omniglot("data",
                   # Number of ways
                   num_classes_per_task=5,
                   # Resize the images to 28x28 and converts them to PyTorch tensors (from Torchvision)
                   transform=Compose([Resize(28), ToTensor()]),
                   # Transform the labels to integers (e.g. ("Glagolitic/character01", "Sanskrit/character14", ...) to (0, 1, ...))
                   target_transform=Categorical(num_classes=5),
                   # Creates new virtual classes with rotated versions of the images (from Santoro et al., 2016)
                   class_augmentations=[Rotation([90, 180, 270])],
                   meta_train=True,
                   download=True)
dataset = ClassSplitter(dataset, shuffle=True, num_train_per_class=5, num_test_per_class=15)
dataloader = BatchMetaDataLoader(dataset, batch_size=16, num_workers=4)

Note that the dataloader, receiving the dataset, remains the same.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchmetal-0.1.0.tar.gz (131.8 kB view details)

Uploaded Source

Built Distribution

torchmetal-0.1.0-py3-none-any.whl (165.5 kB view details)

Uploaded Python 3

File details

Details for the file torchmetal-0.1.0.tar.gz.

File metadata

  • Download URL: torchmetal-0.1.0.tar.gz
  • Upload date:
  • Size: 131.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.1.6 CPython/3.7.6 Linux/5.4.0-7634-generic

File hashes

Hashes for torchmetal-0.1.0.tar.gz
Algorithm Hash digest
SHA256 98832a02b9ebaf7796acf9028a40af8925855540102d58a1b6efe6e88bfc9f86
MD5 994b212e227fef94fd4da197b7c12f39
BLAKE2b-256 889a1aac2f21d55dfc4e33152ef301bcd16e58b0ab7bcc232ad11bbe9316874f

See more details on using hashes here.

File details

Details for the file torchmetal-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: torchmetal-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 165.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.1.6 CPython/3.7.6 Linux/5.4.0-7634-generic

File hashes

Hashes for torchmetal-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 6dd71a01ccd89bcbed855beac1efccc209c28760a7c167963a05f9de0f2b47b6
MD5 0234065b70aea1309e104451a7609fc4
BLAKE2b-256 6616bfc842001abd5f0eb5bf090cfb761002b7e4428f6dc22cf57221c2dc5c5a

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page