PyTorch Meta-Learning Framework for Researchers
Project description
learn2learn is a PyTorch library for meta-learning implementations. It was developed during the first PyTorch Hackathon. Edit: L2L was lucky to win the hackathon!
Note learn2learn is under active development and many things are breaking.
Installation
pip install learn2learn
API Demo
import learn2learn as l2l
mnist = torchvision.datasets.MNIST(root="/tmp/mnist", train=True)
task_generator = l2l.data.TaskGenerator(mnist, ways=3)
model = Net()
maml = l2l.MAML(model, lr=1e-3, first_order=False)
opt = optim.Adam(maml.parameters(), lr=4e-3)
for iteration in range(num_iterations):
learner = maml.new() # Creates a clone of model
task = task_generator.sample(shots=1)
# Fast adapt
for step in range(adaptation_steps):
error = compute_loss(task)
learner.adapt(error)
# Compute validation loss
valid_task = task_generator.sample(shots=1, classes_to_sample=task.sampled_classes)
valid_error = compute_loss(valid_task)
# Take the meta-learning step
opt.zero_grad()
valid_error.backward()
opt.step()
# Changelog
The following changelog is mostly for the hackathon period.
August 12, 2019
- Basic implementation of MAML, FOMAML, Meta-SGD.
- TaskGenerator code for classification tasks.
- Environments for RL.
- Small scale examples of MAML-A2C and MAML-PPO.
Acknowledgements
- The RL environments are copied from: https://github.com/tristandeleu/pytorch-maml-rl
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
learn2learn-0.0.2.tar.gz
(20.1 kB
view details)
File details
Details for the file learn2learn-0.0.2.tar.gz.
File metadata
- Download URL: learn2learn-0.0.2.tar.gz
- Upload date:
- Size: 20.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/1.13.0 pkginfo/1.4.2 requests/2.19.1 setuptools/40.2.0 requests-toolbelt/0.9.1 tqdm/4.26.0 CPython/3.7.1
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6baa9795580c47731b602b0387c1448be24443353eb403dcbacc5b2af7fd99f1
|
|
| MD5 |
0b1ab2d4f2089bb727ef9a2effe2cb8e
|
|
| BLAKE2b-256 |
9242ad0c64fd5dda78bdcf476ac97fc49dca84248156e2190c03c807b30f58a0
|