Utilities for efficiently iterating over mini-batches of PyTorch tensors
Project description
Pytorch batch iteration utilities
Utilities for iterating over mini-batches of tensors, while avoiding the overhead incurred by the significant DataLoader class when training light-weight models.
Installing:
pip install batch-iter
The library is based on this blog post. There you can find more explanation and motivation. Below are simple examples:
Plain iteration
from batch_iter import BatchIter
import torch
X = torch.arange(15).reshape(5, 3)
y = torch.linspace(-1, 1, 5)
for Xb, yb in BatchIter(X, y, batch_size=2, shuffle=True):
print('---')
print('features = ', Xb)
print('labels = ', yb)
The output is:
---
features = tensor([[12, 13, 14],
[ 6, 7, 8]])
labels = tensor([1., 0.])
---
features = tensor([[3, 4, 5],
[0, 1, 2]])
labels = tensor([-0.5000, -1.0000])
---
features = tensor([[ 9, 10, 11]])
labels = tensor([0.5000])
Grouped iteration
The library also supports iterating over groups of tensors, identified by an additional group-id tensor. This is useful, for example, when training a ranking model, and we would like to iterate over mini-batches consisting of full queries. Each query is a group.
For example:
from batch_iter import GroupBatchIter
import torch
X = torch.arange(15).reshape(5, 3)
y = torch.linspace(-1, 1, 5)
# first three samples are a group with id 1,
# the next two samples are another group with id 2.
group_id = torch.tensor([1, 1, 1, 2, 2])
for gb, Xb, yb in GroupBatchIter(group_id, X, y, batch_size=2, shuffle=True):
print('---')
print('group_id = ', gb)
print('features = ', Xb)
print('labels = ', yb)
The output is:
---
group_id = tensor([1, 1, 1, 2, 2])
features = tensor([[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11],
[12, 13, 14]])
labels = tensor([-1.0000, -0.5000, 0.0000, 0.5000, 1.0000])
The entire data-set is one mini-batch, since we chose a mini-batch of size two, meaning two groups.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file batch_iter-0.1.1.tar.gz.
File metadata
- Download URL: batch_iter-0.1.1.tar.gz
- Upload date:
- Size: 5.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.9.21
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3072ef5cbc8c6ee569de32ed1d05ce7ac58f531ffa6b467c6ea0f2081a602bf5
|
|
| MD5 |
164ff506bffa85b60da9a2af329b932a
|
|
| BLAKE2b-256 |
201f5f7fd21c5645f8bc6b4639c59acd20994f1ec523f6d4aeba03ca83bbd9be
|
File details
Details for the file batch_iter-0.1.1-py3-none-any.whl.
File metadata
- Download URL: batch_iter-0.1.1-py3-none-any.whl
- Upload date:
- Size: 5.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.9.21
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
563892321bed21dbbdcedfa77f4494fc8dff78fd96ac44a6628be680af9d3ee6
|
|
| MD5 |
0612afde405d975a64c14017f1022a42
|
|
| BLAKE2b-256 |
181f3f7238ad057767c80d7579f914c782ef471c066e9323fec2f3f9150eb3c9
|