Skip to main content

GraB'n Go: Optimal Permutation-based SGD Data Sampler for PyTorch

Project description

grabngo is an efficient PyTorch-based sampler that supports GraB-style example ordering by Online Gradient Balancing. GraB algorithm takes O(d) extra memory and O(1) extra time compared with Random Reshuffling.

Proposed in the paper GraB: Finding Provably Better Data Permutations than Random Reshuffling, GraB (Gradient Balancing) is a data permutation algorithm that greedily choose data orderings depending on per-sample gradients to further speed up convergence of neural network training empirically. Recent paper Tighter Lower Bounds for Shuffling SGD: Random Permutations and Beyond shows that GraB provably achieves optimal convergence rate among arbitrary data permutations on SGD. Observation shows that not only does GraB allow fast minimization of the empirical risk, but also lets the model generalize better on multiple deep learning tasks.

Supported GraB Algorithms

  • Mean Balance (Vanilla GraB, default)
  • Pair Balance
  • Recursive Balance
  • Recursive Pair Balance
  • Random Reshuffling (RR)
  • Various experimental balance algorithms that doesn't provably outperform Mean Balance

In terms of balancing, all of the above algorithm supports

  • Deterministic Balancing (default)
  • Probabilistic Balancing

Per-sample gradients, PyTorch 2, and Functional programming

GraB algorithm requires per-sample gradients while solving the herding problem. In general, it's hard to implement it in the vanilla PyTorch Automatic Differentiation (AD) framework because the C++ kernel average the per-sample gradients within a batch before it is passed to the next layer.

PyTorch 2 integrates Functorch that supports efficient computation of Per-sample Gradients. Alas, it requires a Functional programming style of coding and requires the model to be pure functions, disallowing layers including randomness (Dropout) or storing inter-batch statistics (BathNorm).

Example Usage

To train a PyTorch model in a functional programming style using per-sample gradients, one is likely to write a script like

import torch
import torchopt
from torch.func import (
    grad, grad_and_value, vmap, functional_call
)
from functools import partial

from grabsampler import GraBSampler

# Initiate model, loss function, and dataset
model = ...
loss_fn = ...
dataset = ...

# Transform model into functional programming
# https://pytorch.org/docs/master/func.migrating.html#functorch-make-functional
# https://pytorch.org/docs/stable/generated/torch.func.functional_call.html
params = dict(model.named_parameters())
buffers = dict(model.named_buffers())

# initiate optimizer, using torchopt package
optimizer = torchopt.sgd(...)
opt_state = optimizer.init(params)  # init optimizer

###############################################################################
# Initiate GraB sampler and dataloader
sampler = GraBSampler(dataset, params)  # <- add this init of GraB sampler
dataloader = torch.utils.data.DataLoader(dataset, sampler=sampler)


###############################################################################


# pure function
def compute_loss(model, loss_fn, params, buffers, inputs, targets):
    prediction = functional_call(model, (params, buffers), (inputs,))

    return loss_fn(prediction, targets)


# Compute per sample gradients and loss
ft_compute_sample_grad_and_loss = vmap(
    grad_and_value(partial(compute_loss, model, loss_fn)),
    in_dims=(None, None, 0, 0)
)  # the only argument of compute_loss is batched along the first axis

for epoch in range(...):
    for _, (x, y) in enumerate(dataloader):
        ft_per_sample_grads, batch_loss = ft_compute_sample_grad_and_loss(
            params, buffers, x, y
        )

        #######################################################################
        sampler.step(ft_per_sample_grads)  # <- step compute GraB algorithm
        #######################################################################

        # The following is equivalent to
        # optimizer.zero_grad()
        # loss.backward()
        # optimizer.step()
        grads = {k: g.mean(dim=0) for k, g in ft_per_sample_grads.items()}
        updates, opt_state = optimizer.update(
            grads, opt_state, params=params
        )  # get updates
        params = torchopt.apply_updates(
            params, updates
        )  # update model parameters

Experiment Training Scripts

How does grabngo work?

The reordering of data permutation happens at the beginning of each training epoch, whenever an iterator of the dataloader is created, e.g. for _ in enumerate(dataloader): internally calls __iter__() of the sampler and updates the data ordering.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

grabngo-0.1.0.tar.gz (21.8 kB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

grabngo-0.1.0-py3-none-any.whl (29.6 kB view details)

Uploaded Python 3

grabngo-0.1.0-py2.py3-none-any.whl (3.0 kB view details)

Uploaded Python 2Python 3

File details

Details for the file grabngo-0.1.0.tar.gz.

File metadata

  • Download URL: grabngo-0.1.0.tar.gz
  • Upload date:
  • Size: 21.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.12

File hashes

Hashes for grabngo-0.1.0.tar.gz
Algorithm Hash digest
SHA256 eb78411d63388e6ba08250e537b6f6a3b606a1e47fb540104af3c2eebe28b4ac
MD5 7998ee6ecd213f093bd78fa326131727
BLAKE2b-256 11e28ccd3f61edcc5a5248b6b1a8ef825e3adc791d7f5af8ece04a136a11eeb1

See more details on using hashes here.

File details

Details for the file grabngo-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: grabngo-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 29.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.12

File hashes

Hashes for grabngo-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 4d966dd92acb59c4f06e533e4c34ed11a27365aa76fa3e4888c9c7a4c48ca683
MD5 9d4519c88c7f0da4087a667460056c11
BLAKE2b-256 1a18fb3777a0e731bb58368aa6b625b85ec14764f2f1a4e80d939e915d182e2c

See more details on using hashes here.

File details

Details for the file grabngo-0.1.0-py2.py3-none-any.whl.

File metadata

  • Download URL: grabngo-0.1.0-py2.py3-none-any.whl
  • Upload date:
  • Size: 3.0 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.12

File hashes

Hashes for grabngo-0.1.0-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 77fa04dbda6b35532f47ce663135b3f10f920302c84142cc24e3f3b43c4771bf
MD5 86ea631599c63c369b82156974b486b8
BLAKE2b-256 be934c998d85030127f433eccd914415d3ab5e42b6411e9832bc79eeaf18208d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page