A PyTorch Batch Sampler that buckets by input length and cuts to min size in batch
Project description
This package provides 2 utilities:
cut2min_bucket.DatasetWrapper
to eliminate padding and cut to min size in batchcut2min_bucket.BucketBatchSampler
a batch sampler that buckets by input length.
In addition, we provide a Distributed Data Parallel version of the batch sampler: cut2min_bucket.DistributedBucketBatchSampler
.
A detailed motivation for this package can be found on my blog.
Simple example:
import cut2min_bucket
import torch
import numpy as np
X = []
for _ in range(10000):
X.append(torch.tensor(np.random.randn(torch.randint(size=(), low=2, high=1000),)))
seqlens = torch.tensor([len(x) for x in X])
X = torch.nn.utils.rnn.pad_sequence(X, batch_first=True)
y = (torch.rand(10000)>0.5).int()
dataset = torch.utils.data.TensorDataset(X, y)
dataset = cut2min_bucket.DatasetWrapper(
dataset, seqlens,
index_or_key=0
)
batch_sampler = cut2min_bucket.BucketBatchSampler(
dataset,
seqlens,
batch_size=8,
n_partitions=5
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_sampler=batch_sampler,
collate_fn=dataset.collate_fn,
)
next(iter(dataloader))
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
cut2min_bucket-0.1.0.tar.gz
(7.4 kB
view details)
Built Distribution
File details
Details for the file cut2min_bucket-0.1.0.tar.gz
.
File metadata
- Download URL: cut2min_bucket-0.1.0.tar.gz
- Upload date:
- Size: 7.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.0 CPython/3.12.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | b1b631230a7f0f4daf46e9591b4353b2a9473c83be2d62eac4b817e9ef7e5e38 |
|
MD5 | 82b438e8a6d19d10b4ee92796574f60d |
|
BLAKE2b-256 | 1fb824a8a9260c474304d5df4c400053fbf705528eefbb7e46fc5bff46f78b9d |
File details
Details for the file cut2min_bucket-0.1.0-py3-none-any.whl
.
File metadata
- Download URL: cut2min_bucket-0.1.0-py3-none-any.whl
- Upload date:
- Size: 6.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.0 CPython/3.12.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | ef343cc8b5e6b7611e0c4ba4e42b6229dba82292085eb34267cb8b89aeed18fb |
|
MD5 | 1d00e60bebe4b43c26bab723eb840b31 |
|
BLAKE2b-256 | 745a408a544637686dd22cade3503b691e407892f8d97824d9b32ccb91ab3995 |