Skip to main content

PyTorch Meta-Learning Framework for Researchers

Project description


Build Status

learn2learn is a software library for meta-learning research.

learn2learn builds on top of PyTorch to accelerate two aspects of the meta-learning research cycle:

  • fast prototyping, essential in letting researchers quickly try new ideas, and
  • correct reproducibility, ensuring that these ideas are evaluated fairly.

learn2learn provides low-level utilities and unified interface to create new algorithms and domains, together with high-quality implementations of existing algorithms and standardized benchmarks. It retains compatibility with torchvision, torchaudio, torchtext, cherry, and any other PyTorch-based library you might be using.

Overview

  • learn2learn.data: TaskDataset and transforms to create few-shot tasks from any PyTorch dataset.
  • learn2learn.vision: Models, datasets, and benchmarks for computer vision and few-shot learning.
  • learn2learn.gym: Environment and utilities for meta-reinforcement learning.
  • learn2learn.algorithms: High-level wrappers for existing meta-learning algorithms.
  • learn2learn.optim: Utilities and algorithms for differentiable optimization and meta-descent.

Resources

Installation

pip install learn2learn

Snippets & Examples

The following snippets provide a sneak peek at the functionalities of learn2learn.

High-level Wrappers

Few-Shot Learning with MAML

For more algorithms (ProtoNets, ANIL, Meta-SGD, Reptile, Meta-Curvature, KFO) refer to the examples folder. Most of them can be implemented with with the GBML wrapper. (documentation).

maml = l2l.algorithms.MAML(model, lr=0.1)
opt = torch.optim.SGD(maml.parameters(), lr=0.001)
for iteration in range(10):
    opt.zero_grad()
    task_model = maml.clone()  # torch.clone() for nn.Modules
    adaptation_loss = compute_loss(task_model)
    task_model.adapt(adaptation_loss)  # computes gradient, update task_model in-place
    evaluation_loss = compute_loss(task_model)
    evaluation_loss.backward()  # gradients w.r.t. maml.parameters()
    opt.step()

Meta-Descent with Hypergradient

Learn any kind of optimization algorithm with the LearnableOptimizer. (example and documentation)

linear = nn.Linear(784, 10)
transform = l2l.optim.ModuleTransform(l2l.nn.Scale)
metaopt = l2l.optim.LearnableOptimizer(linear, transform, lr=0.01)  # metaopt has .step()
opt = torch.optim.SGD(metaopt.parameters(), lr=0.001)  # metaopt also has .parameters()

metaopt.zero_grad()
opt.zero_grad()
error = loss(linear(X), y)
error.backward()
opt.step()  # update metaopt
metaopt.step()  # update linear

Learning Domains

Custom Few-Shot Dataset

Many standardized datasets (Omniglot, mini-/tiered-ImageNet, FC100, CIFAR-FS) are readily available in learn2learn.vision.datasets. (documentation)

dataset = l2l.data.MetaDataset(MyDataset())  # any PyTorch dataset
transforms = [  # Easy to define your own transform
    l2l.data.transforms.NWays(dataset, n=5),
    l2l.data.transforms.KShots(dataset, k=1),
    l2l.data.transforms.LoadData(dataset),
]
taskset = TaskDataset(dataset, transforms, num_tasks=20000)
for task in taskset:
    X, y = task
    # Meta-train on the task

Environments and Utilities for Meta-RL

Parallelize your own meta-environments with AsyncVectorEnv, or use the standardized ones. (documentation)

def make_env():
    env = l2l.gym.HalfCheetahForwardBackwardEnv()
    env = cherry.envs.ActionSpaceScaler(env)
    return env

env = l2l.gym.AsyncVectorEnv([make_env for _ in range(16)])  # uses 16 threads
for task_config in env.sample_tasks(20):
    env.set_task(task)  # all threads receive the same task
    state = env.reset()  # use standard Gym API
    action = my_policy(env)
    env.step(action)

Low-Level Utilities

Differentiable Optimization

Learn and differentiate through updates of PyTorch Modules. (documentation)

model = MyModel()
transform = l2l.optim.KroneckerTransform(l2l.nn.KroneckerLinear)
learned_update = l2l.optim.ParameterUpdate(  # learnable update function
        model.parameters(), transform)
clone = l2l.clone_module(model)  # torch.clone() for nn.Modules
error = loss(clone(X), y)
updates = learned_update(  # similar API as torch.autograd.grad
    error,
    clone.parameters(),
    create_graph=True,
)
l2l.update_module(clone, updates=updates)
loss(clone(X), y).backward()  # Gradients w.r.t model.parameters() and learned_update.parameters()

Changelog

A human-readable changelog is available in the CHANGELOG.md file.

Citation

To cite the learn2learn repository in your academic publications, please use the following reference.

Sebastien M.R. Arnold, Praateek Mahajan, Debajyoti Datta, Ian Bunner. "learn2learn". https://github.com/learnables/learn2learn, 2019.

You can also use the following Bibtex entry.

@misc{learn2learn2019,
    author       = {Sebastien M.R. Arnold, Praateek Mahajan, Debajyoti Datta, Ian Bunner},
    title        = {learn2learn},
    month        = sep,
    year         = 2019,
    url          = {https://github.com/learnables/learn2learn}
    }

Acknowledgements & Friends

  1. The RL environments are adapted from Tristan Deleu's implementations and from the ProMP repository. Both shared with permission, under the MIT License.
  2. TorchMeta is similar library, with a focus on datasets for supervised meta-learning.
  3. higher is a PyTorch library that enables differentiating through optimization inner-loops. While they monkey-patch nn.Module to be stateless, learn2learn retains the stateful PyTorch look-and-feel. For more information, refer to their ArXiv paper.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

learn2learn-0.1.2.tar.gz (321.1 kB view details)

Uploaded Source

File details

Details for the file learn2learn-0.1.2.tar.gz.

File metadata

  • Download URL: learn2learn-0.1.2.tar.gz
  • Upload date:
  • Size: 321.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.4.0.post20200518 requests-toolbelt/0.9.1 tqdm/4.46.0 CPython/3.7.3

File hashes

Hashes for learn2learn-0.1.2.tar.gz
Algorithm Hash digest
SHA256 b57d85893f703abd521bd2821aec4f99db86aa4db462893fc9119910d3261521
MD5 99ae0cec01257424f400c85359c37e9f
BLAKE2b-256 da9c6ac5cf155baee6279a1fe4dcad8ed59131c7c33c5a805523609ec2d85810

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page