Online Dynamic Batching (ODB) — a PyTorch DataLoader-side integration that dynamically groups sequences by length and adjusts batch sizes on-the-fly.
Project description
Online Dynamic Batching
Online Dynamic Batching (ODB) is a PyTorch DataLoader-side dynamic batching library for variable-length LLM and VLM training.
It waits until each sample has passed through the real input pipeline: tokenization, chat templates, image-token expansion, truncation, augmentation, and collation inputs. ODB then forms token-budgeted batches online. Short examples get larger batches, long examples get smaller batches, and your model, optimizer, attention kernels, and dataset format can stay where they are.
import odb
dataloader = odb.ODBDataLoader(
dataset,
token_budget=16384,
batch_size=1,
shuffle=True,
num_workers=4,
prefetch_factor=64,
collate_fn=collate_fn,
loss_scaling="exact",
join=True, # default; set join=False only when needed
)
for batch in dataloader:
info = odb.pop_step_info(batch, loss_scaling="exact")
loss = model(**batch).loss
loss = loss * info.loss_scale
loss.backward()
Why ODB
Modern training pipelines often do not know the true training length at dataset index time. A multimodal or instruction-tuning sample may change length after:
- applying a chat template;
- expanding images into vision tokens;
- truncating to a cutoff;
- adding stochastic augmentation;
- mixing multiple data sources with different processors.
Classic fixed-size batching wastes padding. Offline length caches can help, but they need a separate preprocessing pass and can go stale when the runtime input pipeline changes. ODB moves batching to the point where real length is already observable: the DataLoader/collate boundary.
What You Get
- DataLoader replacement path: use
ODBDataLoader(...)when you control DataLoader construction. - Existing DataLoader path: use
odb.apply(dataloader, ...)when a framework has already created the DataLoader. - DDP-ready dynamic batching: ODB aligns grouping across ranks with a small metadata exchange.
- Default join-mode protocol: strict identity-coverage termination for final
DDP training runs; set
join=Falseonly for constrained runtimes that cannot support drain-before-finish semantics. - Correct loss scaling:
odb.pop_step_info(...)returns the current all-rank sample count and the per-rank loss multiplier. - Trainer integrations: PyTorch loops, HuggingFace Trainer, LLaMA-Factory-style trainers, Accelerate loops, and Lightning modules.
- Production-shaped benchmark coverage: text, multimodal, LoRA/full FT, single-node, multi-node, oracle baselines, and high-variance production mixes.
Installation
From PyPI:
pip install online-dynamic-batching
# HuggingFace Trainer / LLaMA-Factory adapters
pip install "online-dynamic-batching[hf]"
# Accelerate or Lightning adapters
pip install "online-dynamic-batching[accelerate]"
pip install "online-dynamic-batching[lightning]"
From GitHub:
pip install "online-dynamic-batching @ git+https://github.com/online-dynamic-batching/online-dynamic-batching.git"
Local development:
git clone https://github.com/online-dynamic-batching/online-dynamic-batching.git
cd online-dynamic-batching
pip install -e ".[dev,all]"
pytest
Quick Start
Replace DataLoader Construction
Use this when you own the DataLoader code.
import odb
dataloader = odb.ODBDataLoader(
dataset,
token_budget=16384,
batch_size=1, # ODB forms the real batch dynamically
shuffle=True,
num_workers=4, # ODB requires worker prefetching
prefetch_factor=64,
collate_fn=collate_fn,
loss_scaling="exact", # "none", "approx", or "exact"
join=True, # default; set join=False only when needed
)
Patch An Existing DataLoader
Use this when a framework constructs the DataLoader for you.
from torch.utils.data import DataLoader
import odb
dataloader = DataLoader(
dataset,
batch_size=1,
shuffle=True,
num_workers=4,
prefetch_factor=64,
collate_fn=collate_fn,
)
handle = odb.apply(
dataloader,
token_budget=16384,
loss_scaling="exact",
join=True, # default; set join=False only when needed
)
Consume ODB Metadata Before Forward
ODB adds trainer-facing metadata to each yielded batch. Remove it before
model(**batch) and use it for correct progress/loss accounting.
for batch in dataloader:
info = odb.pop_step_info(batch, loss_scaling="exact")
loss = model(**batch).loss
loss = loss * info.loss_scale
loss.backward()
emitted_samples += info.all_samples_this_step
info.all_samples_this_step is the all-rank emitted sample count for the
current micro-step. info.loss_scale is the current-rank multiplier that makes
DDP gradient averaging match the intended global sample/token weighting.
Trainer Integration Modes
ODB supports three trainer integration styles. They are deliberately separate so framework authors and training-stack owners can choose the least invasive path.
| Mode | Best For | What ODB Handles |
|---|---|---|
| Manual contract | custom PyTorch loops | You call pop_step_info, scale loss, and update sample progress. |
| Configure existing trainer | existing HuggingFace-style trainer instances | configure_trainer(...) registers callbacks, scheduler/progress semantics, and optional compute-loss wrapping. |
| Native trainer/mixin | framework forks or new trainers | ODBTrainerMixin consumes metadata inside compute_loss; configure_trainer(...) handles runtime callbacks. |
HuggingFace Trainer
import odb
from odb.integrations.hf import configure_trainer
trainer = Trainer(model=model, args=training_args, train_dataset=dataset)
dataloader = trainer.get_train_dataloader()
handle = odb.apply(dataloader, token_budget=16384, loss_scaling="exact")
configure_trainer(
trainer,
dataloader=dataloader,
handle=handle,
sample_budget=len(dataset) * training_args.num_train_epochs,
max_steps_policy="overwrite",
)
trainer.train()
Native Trainer Class
from odb.integrations.hf import ODBTrainerMixin
class MyTrainer(ODBTrainerMixin, CustomTrainer):
pass
LLaMA-Factory-Style Trainers
from odb.integrations.llamafactory import configure_trainer
The LLaMA-Factory adapter resolves common training arguments such as
token_budget, join, loss_scaling, sample_budget, and max_steps before
delegating to the HuggingFace Trainer bridge.
See docs/integration-guides for PyTorch, HuggingFace Trainer, LLaMA-Factory, Accelerate, and Lightning integration details. The 0.1.1 release validation matrix is summarized in docs/validation.md.
Try It Without Private Data
Run a CPU/single-GPU synthetic benchmark that compares fixed-size batching and ODB on a long-tail sequence distribution:
python examples/synthetic_benchmark.py --device auto --num-samples 2048
For a copy-paste learning path, open examples/notebooks/odb_single_gpu_demo.ipynb.
How It Works
ODB changes batching without changing your model forward path:
- DataLoader workers produce fully processed single samples.
- ODB buffers the samples and observes their true runtime lengths.
- Samples with similar length are grouped under a token budget.
- DDP ranks exchange lightweight grouping metadata.
- Your original
collate_fncollates each dynamic group. - The trainer consumes
ODBStepInfofor progress and loss scaling.
The resulting step size varies in samples but is much more stable in tokens. That is the useful operating point for long-tail instruction and multimodal training.
API At A Glance
odb.ODBDataLoader(dataset, token_budget=..., **dataloader_kwargs)
odb.apply(dataloader, token_budget=..., loss_scaling="exact")
odb.pop_step_info(batch, loss_scaling="exact")
odb.integrations.hf.configure_trainer(...)
odb.integrations.hf.ODBTrainerMixin
odb.integrations.hf.ODBTrainer
odb.integrations.accelerate.configure_accelerator(...)
odb.integrations.lightning.configure_lightning_module(...)
Key Parameters
| Parameter | Meaning |
|---|---|
token_budget |
Target maximum total input length per dynamic group. Legacy name: max_input_length. |
loss_scaling |
"none", "approx", or "exact". Use "exact" for strict token-weighted DDP loss scaling. |
join |
Enables the ODB join-mode protocol; defaults to True. Legacy name: join_mode. |
buffer_size |
Number of prefetched single samples available to the online grouping window. |
max_patches |
Optional multimodal compute cap for image-heavy workloads. |
Benchmark Snapshot
Representative 8xH20 Qwen3-VL full fine-tuning results:
| Workload | Length CV | Standard | ODB | Speedup |
|---|---|---|---|---|
| UltraChat 200K, 8B Full FT | 0.48 | 5.77 sam/s | 10.23 sam/s | 1.77x |
| LLaVA 150K, 8B Full FT | 0.29 | 14.38 sam/s | 24.87 sam/s | 1.73x |
| ShareGPT4o 57K, 8B Full FT | 1.00 | 2.37 sam/s | 5.83 sam/s | 2.46x |
Quality is reported alongside throughput in the paper experiments. The intended claim is a better throughput-quality operating point under variable-length training, not identical optimizer-update geometry.
See docs/benchmarks.md for reporting policy and benchmark notes.
Integration Checklist
Use this as a quick audit before opening a PR in a training stack:
- DataLoader emits one fully processed sample at a time:
batch_size=1. - DataLoader uses worker prefetching:
num_workers > 0. - ODB is applied after the framework has selected sampler/shuffle behavior.
- Trainer removes ODB metadata before model forward.
- Trainer uses
info.loss_scalewhen DDP ranks can process different local sample/token counts. - Trainer progresses/stops by emitted samples when doing epoch-based training.
- Default
join=Trueis paired with DDP Join or the framework's equivalent uneven-input handling; usejoin=Falseonly when that runtime support is not available.
Project Layout
src/odb/ # core package
src/odb/integrations/ # trainer adapters
examples/ # minimal PyTorch/HF examples and synthetic benchmarks
docs/integration-guides/ # framework-specific integration notes
docs/benchmarks.md # benchmark reporting policy
agent-skills/ # Codex / Claude Code assisted integration skill
Build And Verify
python -m pip install -U build twine
python -m build
python -m twine check dist/*
python -m pip install dist/online_dynamic_batching-*.whl
python -c "import odb; print(odb.__version__)"
pytest
Engineering Roadmap
ODB's roadmap is focused on runtime capabilities: stronger distributed-training semantics, clearer trainer interfaces, additional batching policies, structured observability, and reproducible benchmarking. See ROADMAP.md.
Requirements
- Python 3.9+
- PyTorch 2.0+
- Optional:
transformers>=4.40for HuggingFace Trainer integration
Citation
If you find ODB useful, please cite the technical report:
@techreport{odb2025,
title = {Online Dynamic Batching: Adaptive Batch Sizing for Variable-Length Sequence Training},
year = {2025}
}
License
Apache-2.0
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file online_dynamic_batching-0.1.1.tar.gz.
File metadata
- Download URL: online_dynamic_batching-0.1.1.tar.gz
- Upload date:
- Size: 83.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
371a8e6c52e69cf7d1503c13560ee700de430e97db4fd9aff2be3ba4a0909f21
|
|
| MD5 |
cb0b955d3cc88670da11ac485b699abf
|
|
| BLAKE2b-256 |
9939207b74cb5cdf97854d2c16a3f72d2f5601b4b5c323bc305f240844346389
|
File details
Details for the file online_dynamic_batching-0.1.1-py3-none-any.whl.
File metadata
- Download URL: online_dynamic_batching-0.1.1-py3-none-any.whl
- Upload date:
- Size: 52.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4506d8407808f714086658baa1684527ec4a91156f305a96c7ced8fa4777098e
|
|
| MD5 |
966cbc69feac4ddebe8b199a5e202fea
|
|
| BLAKE2b-256 |
cc24cbc52fdd86bc337aa939be6d571e94c235b42c9ceca030f2fb196c77c823
|