Dataloaders for meta-learning in Pytorch
Project description
torchmeta
A collection of extensions and data-loaders for few-shot learning & meta-learning in PyTorch. The package contains popular meta-learning benchmarks, fully compatible with both torchvision
and PyTorch's DataLoader
.
Example
This minimal example below shows how to create a dataloader for the 5-shot 5-way Omniglot dataset with torchmeta
. The dataloader loads a batch of randomly generated tasks. For more examples, check the examples folder.
from torchmeta.datasets import Omniglot
from torchmeta.transforms import Categorical, ClassSplitter
from torchvision.transforms import Resize, ToTensor, Compose
from torchmeta.utils.data import BatchMetaDataLoader
dataset = Omniglot('data', num_classes_per_task=5,
transform=Compose([Resize(28), ToTensor()]),
target_transform=Categorical(num_classes=5),
meta_train=True, download=True)
dataset = ClassSplitter(dataset, num_train_per_class=5, num_test_per_class=15)
dataloader = BatchMetaDataLoader(dataset, batch_size=16, num_workers=4)
for batch in dataloader:
train_inputs, train_targets = batch['train']
print('Train inputs shape: {0}'.format(train_inputs.shape))
print('Train targets shape: {0}'.format(train_targets.shape))
# Train inputs shape: torch.Size([16, 25, 1, 28, 28])
# Train targets shape: torch.Size([16, 25])
test_inputs, test_targets = batch['test']
print('Test inputs shape: {0}'.format(test_inputs.shape))
print('Test targets shape: {0}'.format(test_targets.shape))
# Test inputs shape: torch.Size([16, 75, 1, 28, 28])
# Test targets shape: torch.Size([16, 75])
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torchmeta-1.1.0rc1.tar.gz
(104.3 kB
view hashes)
Built Distribution
Close
Hashes for torchmeta-1.1.0rc1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 711b8cf23d00c262ab7abb7f69669c903203d1031a738c3510e01344d576e6ad |
|
MD5 | 162e3d6772093cc43743e226155ebdc0 |
|
BLAKE2b-256 | d3f46394a464a48d8f657533f34dfa2f8217edd6a92f55fc274981dfdb3f2dde |