Skip to main content

MetaBatch: A micro-framework for efficient batching of tasks in PyTorch.

Project description

Introduction

MetaBatch is a micro-framework for meta-learning in PyTorch. It provides convenient Taskset and TaskLoader classes for batch-aware online task creation for meta-learning.

Efficient batching

Training meta-learning models efficiently can be a challenge, especially when it comes to creating random tasks of a consistent shape in one batch. The task creation process can be time-consuming and typically requires all tasks in the batch to have the same amount of context and target points. This can be a bottleneck during training:

# Sample code for creating a batch of tasks with traditional approach
class MyTaskDataset(Dataset):
    ...
    def __getitem__(self, idx):
        task = self.task_data[idx]
        return task

class Model(Module):
    ...
    def forward(self, tasks):
        ctx_batch = tasks['context']
        tgt_batch = tasks['target']
        ...

# create dataset
task_data = [{'images': [...], 'label': 'dog'},
             {'images': [...], 'label': 'cat'}, ...]
dataset = MyTaskDataset(task_data)
dataloader = DataLoader(dataset, batch_size=16, workers=8)

for batch in dataloader:
    ...
    # Construct batch of random tasks in the training loop (bottleneck!)
    n_context = random.randint(low=1, high=5)
    n_target = random.randint(low=1, high=10)
    tasks = {'context': [], 'target': []}
    for task in batch:
        context_images = sample_n_images(task['images'], n_context)
        target_images = sample_n_images(task['images'], n_target)
        tasks['context'].append(context_images)
        tasks['target'].append(target_images)
    model(tasks)
    ...

Multiprocessing

Wouldn't it be better to offload the task creation to the dataloader, so that it can be done in parallel on multiple cores? With MetaBatch, we simplify the process by allowing you to do just that. We provide a TaskSet wrapper, where you can implement the __gettask__(self, index, n_context, n_target)__ method instead of PyTorch's __getitem(self, index)__. Our TaskLoader and custom sampler take care of synchronizing n_context and n_target for each batch element dispatched to all workers. With MetaBatch, the training bottleneck can be removed from the above example:

# Sample code for creating a batch of tasks with MetaBatch
from metabatch import TaskSet, TaskLoader

class MyTaskSet(TaskSet):
    ...
    def __gettask__(self, idx, n_context, n_target):
        data = self.task_data[idx]
        context_images = sample_n_images(data['images'], n_context)
        target_images = sample_n_images(data['images'], n_target)
        return {
            'context': context_images
            'target': target_images
        }

class Model(Module):
    ...
    def forward(self, tasks):
        ctx_batch = tasks['context']
        tgt_batch = tasks['target']
        ...

# create dataset
task_data = [{'images': [...], 'label': 'dog'},
             {'images': [...], 'label': 'cat'}, ...]
dataset = MyTaskSet(task_data, min_pts=1, max_ctx_pts=5, max_tgt_pts=10)
dataloader = TaskLoader(dataset, batch_size=16, workers=8)

for batch in dataloader:
    ...
    # Simply access the batch of constructed tasks (no bottleneck!)
    model(batch)
    ...

Installation & usage

Install it: pip install metabatch

Requirements:

  • pytorch

Look at the example above for an idea or how to use TaskLoader with TaskSet, or go through the examples in examples/ (TODO).

Advantages

  • MetaBatch allows for efficient task creation and batching during training, resulting in more task variations since you are no longer limited to precomputed tasks.
  • Reduces boilerplate needed to precompute and load tasks.

MetaBatch is a micro-framework for meta-learning in PyTorch that provides convenient tools for (potentially faster) meta-training. It simplifies the task creation process and allows for efficient batching, making it a useful tool for researchers and engineers working on meta-learning projects.

How much faster?

TODO: benchmark MAML and CNP examples with typical implementation and other repos.

License

MetaBatch is released under the MIT License. See the LICENSE file for more information.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

metabatch-0.9.0.tar.gz (19.2 kB view details)

Uploaded Source

Built Distribution

metabatch-0.9.0-py3-none-any.whl (7.1 kB view details)

Uploaded Python 3

File details

Details for the file metabatch-0.9.0.tar.gz.

File metadata

  • Download URL: metabatch-0.9.0.tar.gz
  • Upload date:
  • Size: 19.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.7.12

File hashes

Hashes for metabatch-0.9.0.tar.gz
Algorithm Hash digest
SHA256 2089fba07857d88d598a9c609e4595f61fde107c06c2a6efb6e4936d55df240b
MD5 2f45fc2ac7dbf1a5c5a097fa16d594b7
BLAKE2b-256 bed9b386d88469e79c553b74c423900576471e1e7fb464d6d4175d1828d08edb

See more details on using hashes here.

File details

Details for the file metabatch-0.9.0-py3-none-any.whl.

File metadata

  • Download URL: metabatch-0.9.0-py3-none-any.whl
  • Upload date:
  • Size: 7.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.7.12

File hashes

Hashes for metabatch-0.9.0-py3-none-any.whl
Algorithm Hash digest
SHA256 ea44f24605cf0c1bc71c3935c9bc49a9b635087001a4737b7bf5b524cbf3235e
MD5 9297d245c5f7b0ab5e0c3ec818242ef7
BLAKE2b-256 77ecf53d132521c8da1719acab940b3f4771c2de35acb1fed9933b39adde4841

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page