A PyTorch Batch Sampler that buckets by input length and cuts to min size in batch
Project description
cut2min-bucket
Motivation and documentation WIP.
A PyTorch Batch Sampler that buckets by input length and cuts to min size in batch
This allows to use "vanilla" flash attention, which I have found to be vastly superior.
Simple example:
import cut2min_bucket
import torch
import numpy as np
X = []
for _ in range(10000):
X.append(torch.tensor(np.random.randn(torch.randint(size=(), low=2, high=1000),)))
seqlens = torch.tensor([len(x) for x in X])
X = torch.nn.utils.rnn.pad_sequence(X, batch_first=True)
y = (torch.rand(10000)>0.5).int()
dataset = torch.utils.data.TensorDataset(X, y)
dataset = cut2min_bucket.DatasetWrapper(
dataset, seqlens,
index_or_key=0
)
batch_sampler = cut2min_bucket.BucketBatchSampler(
dataset,
seqlens,
batch_size=8,
n_partitions=5
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_sampler=batch_sampler,
collate_fn=dataset.collate_fn,
)
next(iter(dataloader))
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
cut2min_bucket-0.0.1.tar.gz
(7.2 kB
view details)
Built Distribution
File details
Details for the file cut2min_bucket-0.0.1.tar.gz
.
File metadata
- Download URL: cut2min_bucket-0.0.1.tar.gz
- Upload date:
- Size: 7.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.0 CPython/3.12.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2621ca6066bbe4cad3f35b301ca2904fa74d29ffa639d2cefd2fa43c5d326ec4 |
|
MD5 | 9858c309176ce2b2f387d7629084a056 |
|
BLAKE2b-256 | b3bc5cdb78eb79cdac5dcf7d2701633eb7498d68fc8ed3d2e710fa2e7e849093 |
File details
Details for the file cut2min_bucket-0.0.1-py3-none-any.whl
.
File metadata
- Download URL: cut2min_bucket-0.0.1-py3-none-any.whl
- Upload date:
- Size: 6.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.0 CPython/3.12.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 81b0296301c92774fc25653c706ad2424b19485dae15c61cf95edd450158c385 |
|
MD5 | fdbe6f54606af6341d9fb17b30591fba |
|
BLAKE2b-256 | c5882c9234356e83598dbf111d771652b92dff126ee6fd7ddc85a273a28832ee |