Skip to main content

Utilities for efficiently iterating over mini-batches of PyTorch tensors

Project description

Pytorch batch iteration utilities

Utilities for iterating over mini-batches of tensors, while avoiding the overhead incurred by the significant DataLoader class when training light-weight models.

Installing:

pip install batch-iter

The library is based on this blog post. There you can find more explanation and motivation. Below are simple examples:

Plain iteration

from batch_iter import BatchIter
import torch

X = torch.arange(15).reshape(5, 3)
y = torch.linspace(-1, 1, 5)

for Xb, yb in BatchIter(X, y, batch_size=2, shuffle=True):
  print('---')
  print('features = ', Xb)
  print('labels = ', yb)

The output is:

---
features =  tensor([[12, 13, 14],
        [ 6,  7,  8]])
labels =  tensor([1., 0.])
---
features =  tensor([[3, 4, 5],
        [0, 1, 2]])
labels =  tensor([-0.5000, -1.0000])
---
features =  tensor([[ 9, 10, 11]])
labels =  tensor([0.5000])

Grouped iteration

The library also supports iterating over groups of tensors, identified by an additional group-id tensor. This is useful, for example, when training a ranking model, and we would like to iterate over mini-batches consisting of full queries. Each query is a group.

For example:

from batch_iter import GroupBatchIter
import torch

X = torch.arange(15).reshape(5, 3)
y = torch.linspace(-1, 1, 5)
# first three samples are a group with id 1, 
# the next two samples are another group with id 2.
group_id = torch.tensor([1, 1, 1, 2, 2])

for gb, Xb, yb in GroupBatchIter(group_id, X, y, batch_size=2, shuffle=True):
  print('---')
  print('group_id = ', gb)
  print('features = ', Xb)
  print('labels = ', yb)

The output is:

---
group_id =  tensor([1, 1, 1, 2, 2])
features =  tensor([[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11],
        [12, 13, 14]])
labels =  tensor([-1.0000, -0.5000,  0.0000,  0.5000,  1.0000])

The entire data-set is one mini-batch, since we chose a mini-batch of size two, meaning two groups.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

batch_iter-0.1.1.tar.gz (5.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

batch_iter-0.1.1-py3-none-any.whl (5.8 kB view details)

Uploaded Python 3

File details

Details for the file batch_iter-0.1.1.tar.gz.

File metadata

  • Download URL: batch_iter-0.1.1.tar.gz
  • Upload date:
  • Size: 5.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.21

File hashes

Hashes for batch_iter-0.1.1.tar.gz
Algorithm Hash digest
SHA256 3072ef5cbc8c6ee569de32ed1d05ce7ac58f531ffa6b467c6ea0f2081a602bf5
MD5 164ff506bffa85b60da9a2af329b932a
BLAKE2b-256 201f5f7fd21c5645f8bc6b4639c59acd20994f1ec523f6d4aeba03ca83bbd9be

See more details on using hashes here.

File details

Details for the file batch_iter-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: batch_iter-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 5.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.21

File hashes

Hashes for batch_iter-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 563892321bed21dbbdcedfa77f4494fc8dff78fd96ac44a6628be680af9d3ee6
MD5 0612afde405d975a64c14017f1022a42
BLAKE2b-256 181f3f7238ad057767c80d7579f914c782ef471c066e9323fec2f3f9150eb3c9

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page