MetaBatch: A micro-framework for efficient batching of tasks in PyTorch.
Project description
Introduction
MetaBatch is a micro-framework for meta-learning in PyTorch. It provides convenient Taskset
and
TaskLoader
classes for efficient batching and fast meta-training in a few-shot learning context.
Efficient batching
Training meta-learning models efficiently can be a challenge, especially when it comes to creating random tasks of a consistent shape in one batch. The task creation process can be time-consuming and typically requires all tasks to have the same amount of context and target points. This can be a bottleneck during training:
# Sample code for creating a batch of tasks with traditional approach
class MyTaskDataset(Dataset):
...
def __getitem__(self, idx):
task = self.task_data[idx]
return task
class Model(Module):
...
def forward(self, tasks):
ctx_batch = tasks['context']
tgt_batch = tasks['target']
...
# create dataset
task_data = [{'images': [...], 'label': 'dog'},
{'images': [...], 'label': 'cat'}, ...]
dataset = MyTaskDataset(task_data)
dataloader = DataLoader(dataset, batch_size=16, workers=8)
for batch in dataloader:
...
# Construct batch of random tasks in the training loop (bottleneck!)
n_context = random.randint(low=1, high=5)
n_target = random.randint(low=1, high=10)
tasks = {'context': [], 'target': []}
for task in batch:
context_images = sample_n_images(task['images'], n_context)
target_images = sample_n_images(task['images'], n_target)
tasks['context'].append(context_images)
tasks['target'].append(target_images)
model(tasks)
...
Multiprocessing
Wouldn't it be better to offload the task creation to the dataloader, so that it can be done in
parallel on multiple cores?
With MetaBatch, we simplify the process by allowing you to do just
that.
We provide a TaskSet
wrapper, where you can implement the __gettask__(self, index, n_context, n_target)__
method instead of PyTorch's __getitem(self, index)__
. Our TaskLoader
and
custom sampler take care of synchronizing n_context
and n_target
for each batch element
dispatched to all workers. With MetaBatch, the training bottleneck can be removed from the
above example:
# Sample code for creating a batch of tasks with MetaBatch
from metabatch import TaskSet, TaskLoader
class MyTaskSet(TaskSet):
...
def __gettask__(self, idx, n_context, n_target):
data = self.task_data[idx]
context_images = sample_n_images(data['images'], n_context)
target_images = sample_n_images(data['images'], n_target)
return {
'context': context_images
'target': target_images
}
class Model(Module):
...
def forward(self, tasks):
ctx_batch = tasks['context']
tgt_batch = tasks['target']
...
# create dataset
task_data = [{'images': [...], 'label': 'dog'},
{'images': [...], 'label': 'cat'}, ...]
dataset = MyTaskSet(task_data, min_pts=1, max_ctx_pts=5, max_tgt_pts=10)
dataloader = TaskLoader(dataset, batch_size=16, workers=8)
for batch in dataloader:
...
# Simply access the batch of constructed tasks (no bottleneck!)
model(batch)
...
Advantages
- MetaBatch allows for efficient task creation and batching during training, resulting in faster training times.
- Our approach provides more task diversity and allows for easy customization of task creation.
MetaBatch is a micro-framework for meta-learning in PyTorch that provides convenient tools for faster meta-training. It simplifies the task creation process and allows for efficient batching, making it a useful tool for researchers and engineers working on meta-learning projects.
License
MetaBatch is released under the MIT License. See the LICENSE file for more information.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for metabatch-0.0.2-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | b5040e01513bb0079214f50fd7529c980d9c6e3b4cb765d91f6ea6f304840d89 |
|
MD5 | ea322740c97fc93e0eaee2d8980a64d9 |
|
BLAKE2b-256 | 76fdd37032211513adf29fb8e7e9f84422c525a00fd613ed7ee80fd64b7f743c |