Equiareal batch sampler
Project description
Equiareal batch sampler
Standard practice in Deep Learning is to train models on batches of data, keeping the number of samples in a batch ("batch size") constant. However, what you really need is a constant memory footprint of a batch, to i.e. coordinate it with the memory of your GPU. If your samples have different sizes (common in text, time series data), constant batch size will lead to a highly variable memory footprint. This package provides a batch sampler that keeps constant "batch area": sum of lengths of samples. Batch area mostly corresponds to its memory footprint, although padding will increase it slightly.
Installation
pip install equibatch
Usage
from equibatch import EquiarealBatchSampler
data = [
'London',
'Birmingham',
'Glasgow',
'Llanfairpwllgwyngyllgogerychwyrndrobwllllantysiliogogogoch',
'Liverpool',
'Bristol',
'Manchester'
]
batch_sampler = EquiarealBatchSampler(
sampler = range(len(data)), # the order in which the dataset is traversed
len_checker = lambda ix: len(data[ix]), # definition of "footprint of a sample"
max_size = 10, # maximum number of samples in a batch
max_footprint = 60 # maximum cumulative footprint of a batch
)
for batch in batch_sampler:
sample = [data[ix] for ix in batch]
print(sample)
This will print
['London', 'Birmingham', 'Glasgow']
['Llanfairpwllgwyngyllgogerychwyrndrobwllllantysiliogogogoch']
['Liverpool', 'Bristol', 'Manchester']
Pytorch support
If you have Pytorch installed, EquiarealBatchSampler
will be a subclass of torch.utils.data.sampler.BatchSampler
and you can use it in your torch.utils.data.DataLoader
like this:
from torch.utils.data import DataLoader
dataloader = DataLoader(
dataset = data,
batch_sampler = batch_sampler
)
for input in dataloader:
output = model(input)
loss = loss_fn(output)
loss.backward()
optimizer.step()
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for equibatch-0.1.2-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2dce660e4f0c43da2e159ad364b43044c5fb3f839e377203c1e9681e962940f5 |
|
MD5 | 5492a615a36bb84bc44ff6cb3e2c8f5b |
|
BLAKE2b-256 | bb1be9cd6f1cb65adfc4d3dd436f121110948b90ad7c1dee17abedbd014ea666 |