Skip to main content

Online Dynamic Batching (ODB) — a PyTorch DataLoader-side integration that dynamically groups sequences by length and adjusts batch sizes on-the-fly.

Project description

Online Dynamic Batching

CI Python License Status

Online Dynamic Batching (ODB) is a PyTorch DataLoader-side dynamic batching library for variable-length LLM and VLM training.

It waits until each sample has passed through the real input pipeline: tokenization, chat templates, image-token expansion, truncation, augmentation, and collation inputs. ODB then forms token-budgeted batches online. Short examples get larger batches, long examples get smaller batches, and your model, optimizer, attention kernels, and dataset format can stay where they are.

ODB online grouping animation

import odb

dataloader = odb.ODBDataLoader(
    dataset,
    token_budget=16384,
    batch_size=1,
    shuffle=True,
    num_workers=4,
    prefetch_factor=64,
    collate_fn=collate_fn,
    loss_scaling="exact",
    join=True,                  # default; set join=False only when needed
)

for batch in dataloader:
    info = odb.pop_step_info(batch, loss_scaling="exact")
    loss = model(**batch).loss
    loss = loss * info.loss_scale
    loss.backward()

Why ODB

Modern training pipelines often do not know the true training length at dataset index time. A multimodal or instruction-tuning sample may change length after:

  • applying a chat template;
  • expanding images into vision tokens;
  • truncating to a cutoff;
  • adding stochastic augmentation;
  • mixing multiple data sources with different processors.

Classic fixed-size batching wastes padding. Offline length caches can help, but they need a separate preprocessing pass and can go stale when the runtime input pipeline changes. ODB moves batching to the point where real length is already observable: the DataLoader/collate boundary.

What You Get

  • DataLoader replacement path: use ODBDataLoader(...) when you control DataLoader construction.
  • Existing DataLoader path: use odb.apply(dataloader, ...) when a framework has already created the DataLoader.
  • DDP-ready dynamic batching: ODB aligns grouping across ranks with a small metadata exchange.
  • Default join-mode protocol: strict identity-coverage termination for final DDP training runs; set join=False only for constrained runtimes that cannot support drain-before-finish semantics.
  • Correct loss scaling: odb.pop_step_info(...) returns the current all-rank sample count and the per-rank loss multiplier.
  • Trainer integrations: PyTorch loops, HuggingFace Trainer, LLaMA-Factory-style trainers, Accelerate loops, and Lightning modules.
  • Production-shaped benchmark coverage: text, multimodal, LoRA/full FT, single-node, multi-node, oracle baselines, and high-variance production mixes.

Installation

From PyPI:

pip install online-dynamic-batching

# HuggingFace Trainer / LLaMA-Factory adapters
pip install "online-dynamic-batching[hf]"

# Accelerate or Lightning adapters
pip install "online-dynamic-batching[accelerate]"
pip install "online-dynamic-batching[lightning]"

From GitHub:

pip install "online-dynamic-batching @ git+https://github.com/online-dynamic-batching/online-dynamic-batching.git"

Local development:

git clone https://github.com/online-dynamic-batching/online-dynamic-batching.git
cd online-dynamic-batching
pip install -e ".[dev,all]"
pytest

Quick Start

Replace DataLoader Construction

Use this when you own the DataLoader code.

import odb

dataloader = odb.ODBDataLoader(
    dataset,
    token_budget=16384,
    batch_size=1,              # ODB forms the real batch dynamically
    shuffle=True,
    num_workers=4,             # ODB requires worker prefetching
    prefetch_factor=64,
    collate_fn=collate_fn,
    loss_scaling="exact",      # "none", "approx", or "exact"
    join=True,                  # default; set join=False only when needed
)

Patch An Existing DataLoader

Use this when a framework constructs the DataLoader for you.

from torch.utils.data import DataLoader
import odb

dataloader = DataLoader(
    dataset,
    batch_size=1,
    shuffle=True,
    num_workers=4,
    prefetch_factor=64,
    collate_fn=collate_fn,
)

handle = odb.apply(
    dataloader,
    token_budget=16384,
    loss_scaling="exact",
    join=True,                  # default; set join=False only when needed
)

Consume ODB Metadata Before Forward

ODB adds trainer-facing metadata to each yielded batch. Remove it before model(**batch) and use it for correct progress/loss accounting.

for batch in dataloader:
    info = odb.pop_step_info(batch, loss_scaling="exact")

    loss = model(**batch).loss
    loss = loss * info.loss_scale
    loss.backward()

    emitted_samples += info.all_samples_this_step

info.all_samples_this_step is the all-rank emitted sample count for the current micro-step. info.loss_scale is the current-rank multiplier that makes DDP gradient averaging match the intended global sample/token weighting.

Trainer Integration Modes

ODB supports three trainer integration styles. They are deliberately separate so framework authors and training-stack owners can choose the least invasive path.

Mode Best For What ODB Handles
Manual contract custom PyTorch loops You call pop_step_info, scale loss, and update sample progress.
Configure existing trainer existing HuggingFace-style trainer instances configure_trainer(...) registers callbacks, scheduler/progress semantics, and optional compute-loss wrapping.
Native trainer/mixin framework forks or new trainers ODBTrainerMixin consumes metadata inside compute_loss; configure_trainer(...) handles runtime callbacks.

HuggingFace Trainer

import odb
from odb.integrations.hf import configure_trainer

trainer = Trainer(model=model, args=training_args, train_dataset=dataset)
dataloader = trainer.get_train_dataloader()

handle = odb.apply(dataloader, token_budget=16384, loss_scaling="exact")
configure_trainer(
    trainer,
    dataloader=dataloader,
    handle=handle,
    sample_budget=len(dataset) * training_args.num_train_epochs,
    max_steps_policy="overwrite",
)

trainer.train()

Native Trainer Class

from odb.integrations.hf import ODBTrainerMixin

class MyTrainer(ODBTrainerMixin, CustomTrainer):
    pass

LLaMA-Factory-Style Trainers

from odb.integrations.llamafactory import configure_trainer

The LLaMA-Factory adapter resolves common training arguments such as token_budget, join, loss_scaling, sample_budget, and max_steps before delegating to the HuggingFace Trainer bridge.

See docs/integration-guides for PyTorch, HuggingFace Trainer, LLaMA-Factory, Accelerate, and Lightning integration details. The 0.1.0 release validation matrix is summarized in docs/validation.md.

Try It Without Private Data

Run a CPU/single-GPU synthetic benchmark that compares fixed-size batching and ODB on a long-tail sequence distribution:

python examples/synthetic_benchmark.py --device auto --num-samples 2048

For a copy-paste learning path, open examples/notebooks/odb_single_gpu_demo.ipynb.

How It Works

ODB changes batching without changing your model forward path:

  1. DataLoader workers produce fully processed single samples.
  2. ODB buffers the samples and observes their true runtime lengths.
  3. Samples with similar length are grouped under a token budget.
  4. DDP ranks exchange lightweight grouping metadata.
  5. Your original collate_fn collates each dynamic group.
  6. The trainer consumes ODBStepInfo for progress and loss scaling.

The resulting step size varies in samples but is much more stable in tokens. That is the useful operating point for long-tail instruction and multimodal training.

API At A Glance

odb.ODBDataLoader(dataset, token_budget=..., **dataloader_kwargs)
odb.apply(dataloader, token_budget=..., loss_scaling="exact")
odb.pop_step_info(batch, loss_scaling="exact")
odb.integrations.hf.configure_trainer(...)
odb.integrations.hf.ODBTrainerMixin
odb.integrations.hf.ODBTrainer
odb.integrations.accelerate.configure_accelerator(...)
odb.integrations.lightning.configure_lightning_module(...)

Key Parameters

Parameter Meaning
token_budget Target maximum total input length per dynamic group. Legacy name: max_input_length.
loss_scaling "none", "approx", or "exact". Use "exact" for strict token-weighted DDP loss scaling.
join Enables the ODB join-mode protocol; defaults to True. Legacy name: join_mode.
buffer_size Number of prefetched single samples available to the online grouping window.
max_patches Optional multimodal compute cap for image-heavy workloads.

Benchmark Snapshot

Representative 8xH20 Qwen3-VL full fine-tuning results:

Workload Length CV Standard ODB Speedup
UltraChat 200K, 8B Full FT 0.48 5.77 sam/s 10.23 sam/s 1.77x
LLaVA 150K, 8B Full FT 0.29 14.38 sam/s 24.87 sam/s 1.73x
ShareGPT4o 57K, 8B Full FT 1.00 2.37 sam/s 5.83 sam/s 2.46x

Quality is reported alongside throughput in the paper experiments. The intended claim is a better throughput-quality operating point under variable-length training, not identical optimizer-update geometry.

See docs/benchmarks.md for reporting policy and benchmark notes.

Integration Checklist

Use this as a quick audit before opening a PR in a training stack:

  • DataLoader emits one fully processed sample at a time: batch_size=1.
  • DataLoader uses worker prefetching: num_workers > 0.
  • ODB is applied after the framework has selected sampler/shuffle behavior.
  • Trainer removes ODB metadata before model forward.
  • Trainer uses info.loss_scale when DDP ranks can process different local sample/token counts.
  • Trainer progresses/stops by emitted samples when doing epoch-based training.
  • Default join=True is paired with DDP Join or the framework's equivalent uneven-input handling; use join=False only when that runtime support is not available.

Project Layout

src/odb/                     # core package
src/odb/integrations/        # trainer adapters
examples/                    # minimal PyTorch/HF examples and synthetic benchmarks
docs/integration-guides/     # framework-specific integration notes
docs/benchmarks.md           # benchmark reporting policy
agent-skills/                # Codex / Claude Code assisted integration skill

Build And Verify

python -m pip install -U build twine
python -m build
python -m twine check dist/*
python -m pip install dist/online_dynamic_batching-*.whl
python -c "import odb; print(odb.__version__)"
pytest

Engineering Roadmap

ODB's roadmap is focused on runtime capabilities: stronger distributed-training semantics, clearer trainer interfaces, additional batching policies, structured observability, and reproducible benchmarking. See ROADMAP.md.

Requirements

  • Python 3.9+
  • PyTorch 2.0+
  • Optional: transformers>=4.40 for HuggingFace Trainer integration

Citation

If you find ODB useful, please cite the technical report:

@techreport{odb2025,
  title = {Online Dynamic Batching: Adaptive Batch Sizing for Variable-Length Sequence Training},
  year = {2025}
}

License

Apache-2.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

online_dynamic_batching-0.1.0.tar.gz (83.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

online_dynamic_batching-0.1.0-py3-none-any.whl (52.2 kB view details)

Uploaded Python 3

File details

Details for the file online_dynamic_batching-0.1.0.tar.gz.

File metadata

  • Download URL: online_dynamic_batching-0.1.0.tar.gz
  • Upload date:
  • Size: 83.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.12

File hashes

Hashes for online_dynamic_batching-0.1.0.tar.gz
Algorithm Hash digest
SHA256 36a6a5ced54d2d7ff9da71dda8a5fa88cbc528dda4d5221a45e2caafb159234b
MD5 bcc7146a3285a080638564dbf6a9934d
BLAKE2b-256 aec49e3c01611bdafe165eb1e162af175b1724d33648dc2afe00e2d4cf748b27

See more details on using hashes here.

File details

Details for the file online_dynamic_batching-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for online_dynamic_batching-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 b475eacfe451878d1983651b400a8f9125c163a54ab62be73e4289ec574e824d
MD5 1923ed5cd36dff78fe5fa290746b4097
BLAKE2b-256 857d179ce6e750aaddc507ac75383f7cee4d2f4f98fc49444118928a0b39b63b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page