Skip to main content

PyTorch Meta-Learning Framework for Researchers

Project description


Build Status

learn2learn is a PyTorch library for meta-learning implementations.

The goal of meta-learning is to enable agents to learn how to learn. That is, we would like our agents to become better learners as they solve more and more tasks. For example, the animation below shows an agent that learns to run after a only one parameter update.

Features

learn2learn provides high- and low-level utilities for meta-learning. The high-level utilities allow arbitrary users to take advantage of exisiting meta-learning algorithms. The low-level utilities enable researchers to develop new and better meta-learning algorithms.

Some features of learn2learn include:

  • Modular API: implement your own training loops with our low-level utilities.
  • Provides various meta-learning algorithms (e.g. MAML, FOMAML, MetaSGD, ProtoNets, DiCE)
  • Task generator with unified API, compatible with torchvision, torchtext, torchaudio, and cherry.
  • Provides standardized meta-learning tasks for vision (Omniglot, mini-ImageNet), reinforcement learning (Particles, Mujoco), and even text (news classification).
  • 100% compatible with PyTorch -- use your own modules, datasets, or libraries!

Installation

pip install learn2learn

API Demo

The following is an example of using the high-level MAML implementation on MNIST. For more algorithms and lower-level utilities, please refer to the documentation or the examples.

import learn2learn as l2l

mnist = torchvision.datasets.MNIST(root="/tmp/mnist", train=True)

mnist = l2l.data.MetaDataset(mnist)
task_generator = l2l.data.TaskGenerator(mnist,
                                        ways=3,
                                        classes=[0, 1, 4, 6, 8, 9],
                                        tasks=10)
model = Net()
maml = l2l.algorithms.MAML(model, lr=1e-3, first_order=False)
opt = optim.Adam(maml.parameters(), lr=4e-3)

for iteration in range(num_iterations):
    learner = maml.clone()  # Creates a clone of model
    adaptation_task = task_generator.sample(shots=1)

    # Fast adapt
    for step in range(adaptation_steps):
        error = compute_loss(adaptation_task)
        learner.adapt(error)

    # Compute evaluation loss
    evaluation_task = task_generator.sample(shots=1,
                                            task=adaptation_task.sampled_task)
    evaluation_error = compute_loss(evaluation_task)

    # Meta-update the model parameters
    opt.zero_grad()
    evaluation_error.backward()
    opt.step()

Acknowledgements

  1. The RL environments are adapted from Tristan Deleu's implementations and from the ProMP repository. Both shared with permission, under the MIT License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

learn2learn-0.0.4.tar.gz (24.0 kB view details)

Uploaded Source

File details

Details for the file learn2learn-0.0.4.tar.gz.

File metadata

  • Download URL: learn2learn-0.0.4.tar.gz
  • Upload date:
  • Size: 24.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.14.0 pkginfo/1.4.2 requests/2.19.1 setuptools/40.2.0 requests-toolbelt/0.9.1 tqdm/4.26.0 CPython/3.7.1

File hashes

Hashes for learn2learn-0.0.4.tar.gz
Algorithm Hash digest
SHA256 2848943968deeb39fd747991b5c3a7e84edb0486f908ee6aa877be4d9b472f4c
MD5 c594b859e11fd87ef631c8ff785d0f41
BLAKE2b-256 84197b49dfb8958443ae97f2a714d5e7fbe3bcb4509d9234af04e4c10eb9aa20

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page