Skip to main content

A PyTorch Batch Sampler that buckets by input length and cuts to min size in batch

Project description

cut2min-bucket

Motivation and documentation WIP.

A PyTorch Batch Sampler that buckets by input length and cuts to min size in batch

This allows to use "vanilla" flash attention, which I have found to be vastly superior.

Simple example:

import cut2min_bucket
import torch
import numpy as np

X = []
for _ in range(10000):
    X.append(torch.tensor(np.random.randn(torch.randint(size=(), low=2, high=1000),)))

seqlens = torch.tensor([len(x) for x in X])

X = torch.nn.utils.rnn.pad_sequence(X, batch_first=True)
y = (torch.rand(10000)>0.5).int()

dataset = torch.utils.data.TensorDataset(X, y)

dataset = cut2min_bucket.DatasetWrapper(
    dataset, seqlens,
    index_or_key=0
)

batch_sampler = cut2min_bucket.BucketBatchSampler(
    dataset,
    seqlens,
    batch_size=8,
    n_partitions=5
)

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_sampler=batch_sampler,
    collate_fn=dataset.collate_fn,
)

next(iter(dataloader))

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cut2min_bucket-0.0.1.tar.gz (7.2 kB view details)

Uploaded Source

Built Distribution

cut2min_bucket-0.0.1-py3-none-any.whl (6.2 kB view details)

Uploaded Python 3

File details

Details for the file cut2min_bucket-0.0.1.tar.gz.

File metadata

  • Download URL: cut2min_bucket-0.0.1.tar.gz
  • Upload date:
  • Size: 7.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.0 CPython/3.12.5

File hashes

Hashes for cut2min_bucket-0.0.1.tar.gz
Algorithm Hash digest
SHA256 2621ca6066bbe4cad3f35b301ca2904fa74d29ffa639d2cefd2fa43c5d326ec4
MD5 9858c309176ce2b2f387d7629084a056
BLAKE2b-256 b3bc5cdb78eb79cdac5dcf7d2701633eb7498d68fc8ed3d2e710fa2e7e849093

See more details on using hashes here.

File details

Details for the file cut2min_bucket-0.0.1-py3-none-any.whl.

File metadata

File hashes

Hashes for cut2min_bucket-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 81b0296301c92774fc25653c706ad2424b19485dae15c61cf95edd450158c385
MD5 fdbe6f54606af6341d9fb17b30591fba
BLAKE2b-256 c5882c9234356e83598dbf111d771652b92dff126ee6fd7ddc85a273a28832ee

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page