Skip to main content

A PyTorch Batch Sampler that buckets by input length and cuts to min size in batch

Project description

cut2min-bucket

A PyTorch Batch Sampler that buckets by input length and cuts to min size in batch

PyPi Version GitHub license

This package provides 2 utilities:

  1. cut2min_bucket.DatasetWrapper to eliminate padding and cut to min size in batch
  2. cut2min_bucket.BucketBatchSampler a batch sampler that buckets by input length.

In addition, we provide a Distributed Data Parallel version of the batch sampler: cut2min_bucket.DistributedBucketBatchSampler.

A detailed motivation for this package can be found on my blog.

Simple example:

import cut2min_bucket
import torch
import numpy as np

X = []
for _ in range(10000):
    X.append(torch.tensor(np.random.randn(torch.randint(size=(), low=2, high=1000),)))

seqlens = torch.tensor([len(x) for x in X])

X = torch.nn.utils.rnn.pad_sequence(X, batch_first=True)
y = (torch.rand(10000)>0.5).int()

dataset = torch.utils.data.TensorDataset(X, y)

dataset = cut2min_bucket.DatasetWrapper(
    dataset, seqlens,
    index_or_key=0
)

batch_sampler = cut2min_bucket.BucketBatchSampler(
    dataset,
    seqlens,
    batch_size=8,
    n_partitions=5
)

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_sampler=batch_sampler,
    collate_fn=dataset.collate_fn,
)

next(iter(dataloader))

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cut2min_bucket-0.1.0.tar.gz (7.4 kB view details)

Uploaded Source

Built Distribution

cut2min_bucket-0.1.0-py3-none-any.whl (6.4 kB view details)

Uploaded Python 3

File details

Details for the file cut2min_bucket-0.1.0.tar.gz.

File metadata

  • Download URL: cut2min_bucket-0.1.0.tar.gz
  • Upload date:
  • Size: 7.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.0 CPython/3.12.5

File hashes

Hashes for cut2min_bucket-0.1.0.tar.gz
Algorithm Hash digest
SHA256 b1b631230a7f0f4daf46e9591b4353b2a9473c83be2d62eac4b817e9ef7e5e38
MD5 82b438e8a6d19d10b4ee92796574f60d
BLAKE2b-256 1fb824a8a9260c474304d5df4c400053fbf705528eefbb7e46fc5bff46f78b9d

See more details on using hashes here.

File details

Details for the file cut2min_bucket-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for cut2min_bucket-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 ef343cc8b5e6b7611e0c4ba4e42b6229dba82292085eb34267cb8b89aeed18fb
MD5 1d00e60bebe4b43c26bab723eb840b31
BLAKE2b-256 745a408a544637686dd22cade3503b691e407892f8d97824d9b32ccb91ab3995

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page