Dynamic-batch split replay and split training runtime for PyTorch models.
Project description
Ariadne
Ariadne is a dynamic-batch split replay runtime for PyTorch models. It traces real PyTorch forward execution into a symbolic split plan, generates efficient prefix/suffix segment callables, and supports boundary-based replay and split training.
Ariadne uses TorchLens-style tracing for broad PyTorch model compatibility, but it does not use node-by-node interpreted replay as the default runtime. Instead, it lowers traced graphs into generated prefix/suffix segment callables.
Why Ariadne Exists
Split inference and split training need a clean boundary between offline model preparation and lightweight runtime execution. Ariadne prepares a symbolic TracePlan, validates a declarative SplitSpec, chooses a valid frontier, and generates executable segments so runtime work stays simple:
- run the prefix segment
- package a boundary payload
- run the suffix segment
- optionally train the suffix and backpropagate boundary gradients into the prefix
Design Principles
- Trace observed PyTorch forward behavior while keeping default metadata light.
- Treat the first input tensor dimension as symbolic batch dimension
B. - Keep concrete batch size out of trace and runtime cache keys.
- Make split declarations explicit and verifiable.
- Use generated eager segments as the default execution path.
- Keep node-by-node interpretation only for
debug_interpretervalidation. - Use
torch.compileonly as an optional optimizer for generated segments.
Installation
Quick Start (Users)
Install Ariadne from PyPI with uv:
uv add ariadne-split
If you are not using a uv project, pip install ariadne-split works too.
After installation, you can import and use Ariadne:
from ariadne import SplitSpec, prepare_split
import torch
# See usage examples in "Basic Split Inference" and "Basic Split Training" sections below
Optional Dependencies
For real-model integration tests (YOLOv8n, RF-DETR, timm models, torchvision), install with the integration extra:
uv add "ariadne-split[integration]"
For operation-level TracePlan and split-candidate visualization, install with the
visualization extra:
uv add "ariadne-split[visualization]"
The visualization extra installs the Python graphviz package. Rendering SVG,
PDF, or PNG files also requires the Graphviz system executable on PATH. DOT
export does not require the system executable.
Development Setup (Contributors)
If you're developing Ariadne or want to run the full test suite:
uv sync
Dependencies are declared only in pyproject.toml, and uv.lock is committed
for reproducible installs.
Development & Testing
uv sync
uv run pytest
uv run ruff check .
uv run mypy src
Useful demo commands:
uv run python examples/split_inference_demo.py
uv run python examples/split_training_demo.py
uv run --extra integration --extra visualization python examples/visualize_trace_demo.py --model resnet18
uv run --extra integration python examples/percent_split_real_model_test.py --models all
Optional real-model smoke checks install the integration extra and may download
YOLO weights. The current integration suite covers YOLOv8n, RF-DETR Nano,
timm resnet50, timm swin_tiny_patch4_window7_224, torchvision mobilenet_v3_large,
and torchvision deeplabv3_resnet50:
uv run --extra integration python examples/real_model_functional_test.py
uv run --extra integration python examples/real_model_timing.py --iterations 3 --warmup 1
ARIADNE_RUN_REAL_MODELS=1 uv run --extra integration pytest tests/integration -m integration
Visualization also has optional real-model smoke tests using torchvision
resnet18 and vgg11:
ARIADNE_RUN_REAL_MODELS=1 uv run --extra integration --extra visualization pytest tests/integration/test_visualization_real_models.py -m integration
Basic Split Inference
import torch
from ariadne import SplitSpec, prepare_split
spec = SplitSpec(
boundary="after:layer3",
batch_symbol="B",
dynamic_batch=(2, 64),
trainable=True,
trace_batch_mode="batch_gt1",
)
runtime = prepare_split(
model,
example_inputs=(torch.randn(4, 5),),
split=spec,
mode="generated_eager",
)
boundary = runtime.run_prefix(torch.randn(8, 5))
output = runtime.run_suffix(boundary)
SplitSpec.boundary also accepts percentage positions such as "percent:50"
or "50%". Ariadne maps the percentage onto the nearest valid traced frontier
candidate, and if trainable=True, it chooses the nearest candidate whose
suffix still has trainable parameters.
Split frontiers are selected from verified operation-output groups. Ariadne
keeps multi-output ATen/module calls together, carries all live values needed by
the suffix as explicit boundary or passthrough values, and validates each
candidate with replay before auto or percentage selection. When a user names a
specific boundary, Ariadne still validates that boundary and raises a clear
reason if it cannot be replayed safely.
Boundary payloads always use the structured v2 protocol. Tensors are stored by
label, while the boundary carries a serializable value tree for None,
booleans, numbers, strings, bytes, slice, nested list/tuple/dict, and
supported sequence values such as batch-polymorphic split, chunk, or
unbind outputs.
Basic Split Training
import torch.nn.functional as F
boundary = runtime.run_training_prefix(x_batch)
loss, boundary_grads = runtime.train_suffix(
boundary,
targets,
loss_fn=F.mse_loss,
optimizer=suffix_optimizer,
)
runtime.backward_prefix(
boundary,
boundary_grads=boundary_grads,
optimizer=prefix_optimizer,
)
For split training, use run_training_prefix() so the boundary keeps the
original prefix autograd graph. Then backward_prefix(boundary, ...) applies
the suffix boundary gradients directly to that graph without recomputing prefix
operations. This matters for RNG-sensitive training operations such as
nn.Dropout, where recomputing prefix would sample a different random mask.
The lightweight run_prefix() path is still available for split replay and
inference-style boundary generation.
Visualization
Ariadne can export offline operation-level views of a captured TracePlan and
the selected split candidate. Visualization reads metadata already stored in
TracePlan, TraceNode, and SplitCandidate; it does not re-trace the model,
does not run generated segments, and does not store real activations or tensors.
The most direct path is through a prepared runtime:
runtime.visualize(view="trace", outpath="trace_graph", fileformat="svg")
runtime.visualize(view="split", outpath="split_graph", fileformat="svg")
For tests, notebooks, and debugging environments where the Graphviz system binary is unavailable, request DOT source instead:
trace_dot = runtime.visualize(view="trace", return_dot=True)
split_dot = runtime.visualize(view="split", return_dot=True)
The public export helpers can also be used directly:
from ariadne.visualization import (
export_split_candidates_table,
export_split_dot,
export_trace_dot,
)
trace_dot = export_trace_dot(runtime.trace_plan)
split_dot = export_split_dot(runtime.trace_plan, runtime.candidate)
candidates = export_split_candidates_table(runtime.trace_plan)
Visualizations default to a module-structure view inspired by TorchLens: traced
operations are folded back into nn.Module paths, and deep module hierarchies
are collapsed to a readable nesting depth by default. For example, ResNet-style
blocks render as layer1.0 BasicBlock clusters containing conv, bn,
relu, residual add, and downsample nodes, instead of a raw ATen operator
stream or a single opaque block. Labels omit ATen targets, trace node indices,
mutation/debug markers, dtype, raw byte counts, and buffers by default. Node
rows stay compact: module/type on the first line, symbolic shape plus activation
memory in MB on the second line, and shortened parameter rows such as
params: weight(64x3x7x7), bias(x64) or params: 74.0K for explicitly
collapsed modules. Trainable parameters use parentheses and frozen parameters
use square brackets.
When you need low-level debugging, request the operation view explicitly:
runtime.visualize(
view="trace",
view_detail="operation",
show_operation_targets=True,
show_debug_markers=True,
show_node_indices=True,
outpath="trace_ops",
)
Use max_module_depth to collapse deep model hierarchies into coarser module
nodes, similar to TorchLens nesting-depth controls. Pass None to expand the
nested module clusters:
runtime.visualize(view="split", max_module_depth=2) # collapse BasicBlock nodes
runtime.visualize(view="trace", max_module_depth=None)
Split visualizations mark prefix, suffix, boundary, and passthrough nodes and
include lightweight cost information such as boundary_bytes, prefix/suffix
node counts, and whether the suffix is trainable.
Dynamic Batch
By default, Ariadne treats the first dimension of the first tensor input as the
symbolic batch dimension B. If an example trace uses batch size 4, tensors like
(4, 256, 14, 14) are recorded as ("B", 256, 14, 14).
SplitSpec.trace_batch_mode makes the preparation strategy explicit:
batch_1: requiresexample_inputsbatch size 1 and adynamic_batchrange that includes 1. Ariadne performs prepare-time provenance probing and can prepare a batch>1 structural variant when singleton and non-singleton aten paths differ.batch_gt1: requiresexample_inputsbatch size greater than 1 and adynamic_batchrange that contains the runtime batch sizes you want to allow. Ariadne traces in the non-singleton regime, derives affine batch shapes such as4*B, and validates any singleton runtime batch in the range as an explicit safe-frontier condition.
For real YOLO and RF-DETR smoke tests, Ariadne uses batch_gt1 mode and verifies
cross-batch split replay plus split training on batch sizes 2 and 3.
At runtime, Ariadne materializes B from the actual input batch size, validates
that it is inside SplitSpec.dynamic_batch, and checks that non-batch dimensions
match the prepared boundary schema. The concrete batch size is intentionally not
part of RuntimeCacheKey.
Dynamic batch support covers tensor shapes, affine batch-derived dimensions, and
batch-derived Python integer arguments. It also supports whole-sequence
consumption of batch-polymorphic split, chunk, and unbind results. When a
validated split frontier must transport such a sequence, Ariadne uses
BoundaryPayload's structured value tree so the suffix can rebuild the Python
structure before replay or split training. If dynamic Python control flow
changes the traced op count,
indexes or reorders an unsupported dynamic sequence pattern, or captures an
unserializable external object, Ariadne marks that candidate unsafe. auto and
percentage splits skip unsafe candidates; an explicit boundary reports the
rejection reason during preparation instead of failing later inside replay.
For smoke checks outside the test suite, use validate_dynamic_batches() to run
prepared prefix/suffix replay across selected batch sizes:
from ariadne import validate_dynamic_batches
validate_dynamic_batches(runtime, example_inputs=(x,), batch_sizes=(1, 4, 20))
Execution Modes
Ariadne lets users choose between non-compiled generated execution and
torch.compile optimized execution at preparation time:
debug_interpreter: slow interpreter execution for validation and debugging.generated_eager: default generated prefix/suffix segment execution. Use this for short-lived scripts, CPU-only environments, correctness checks, or when the expected number of calls is too small to pay back compile cost.compiled: appliestorch.compileto generated segments after preparation. Use this for long-lived GPU services or repeated same-model, same-shape workloads where startup warmup can happen outside the online request path.
Split retain training uses prepare_split(...) in either mode:
runtime = prepare_split(
model,
example_inputs=(x,),
split=spec,
mode="generated_eager", # or "compiled"
)
Inference-only split replay can use prepare_split_replay(...). Compiled replay
uses segment-level compilation, so run_prefix() still returns a lightweight
ReplayBoundary and run_suffix(boundary) consumes that explicit intermediate
feature object:
from ariadne import prepare_split_replay
replay_runtime = prepare_split_replay(
model,
example_inputs=(x,),
split=spec,
mode="compiled", # or "generated_eager"
)
replay_runtime.warmup(x) # trigger compile before measuring/serving
boundary = replay_runtime.run_prefix(x)
output = replay_runtime.run_suffix(boundary)
The default compiled replay options target low-overhead GPU inference with Inductor. Custom options can still be passed when a deployment needs a different backend or stricter control:
runtime = prepare_split(
model,
example_inputs=(x,),
split=spec,
mode="compiled",
compile_options={"backend": "inductor", "mode": "reduce-overhead", "dynamic": True},
)
Choosing Eager or Compiled
torch.compile changes where time is spent: the steady-state calls can be
faster, but the first compiled call pays a cold-start cost. Ariadne benchmarks
therefore report both steady-state latency and compile overhead.
- Choose
generated_eagerwhen the process handles only a few batches, when startup latency matters more than steady-state throughput, or when the target CPU/GPU toolchain does not compile reliably. - Choose
compiledwhen the runtime is reused for many calls, especially on CUDA GPUs. For replay runtimes, callwarmup(...)during service startup; for split retain, run one representative training round before measuring or serving latency-sensitive traffic. - Benchmark the target machine before setting a global default. CPU
torch.compilemay work in some environments, but it is not automatically faster and may require a working native compiler stack.
Measure steady-state replay optimization:
uv run --extra integration python examples/replay_optimization_timing.py \
--models resnet50 mobilenet --batches 4 32 256 \
--iterations 50 --warmup 20 --backend torch_compile
Measure compile overhead and break-even iterations:
uv run --extra integration python examples/compile_overhead_timing.py \
--models resnet50 mobilenet --modes replay retain --batches 4 32 256 \
--iterations 20 --warmup 5 --require-cuda
Benchmarking
ariadne.benchmark includes utilities for measuring direct PyTorch forward,
debug/generated/compiled prefix+suffix execution, suffix-only replay, suffix
training, and prefix backward. Results include average latency, total latency,
batch size, split id, execution mode, and optional CUDA peak memory.
Current Limitations
- The default tracer uses
TorchDispatchModeruntime interception and records the observed forward path. - Alias, mutation, RNG, FLOP, and memory metadata are intentionally lightweight.
- Segment generation supports verified structured boundaries from prepared observed paths, including a batch>1 structural variant when batch=1 tracing needs it.
- Whole dynamic
split/chunk/unbindsequences can replay inside a segment or cross a split boundary through the structured payload. Unsupported element-wise dynamic sequence use is rejected at preparation time. - Dynamic non-batch dimensions are reserved for future SplitSpec extensions.
- Shape expressions cover direct and affine batch-derived dimensions; more complex non-affine shape arithmetic is still limited.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file ariadne_split-0.1.2.tar.gz.
File metadata
- Download URL: ariadne_split-0.1.2.tar.gz
- Upload date:
- Size: 264.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.18
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
fc89855020d012c281ebbd408a5e9c3f4b60b9948377a350602679628b8c5d04
|
|
| MD5 |
bc0d6533c89202ebc1f8123bea3e9a3b
|
|
| BLAKE2b-256 |
778866b2ad2b50db386b3168e2fd3a1176b7da3a51e9608e604975e75d6757bb
|
File details
Details for the file ariadne_split-0.1.2-py3-none-any.whl.
File metadata
- Download URL: ariadne_split-0.1.2-py3-none-any.whl
- Upload date:
- Size: 73.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.18
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a552dc0792c35c14d249b2e9dee7227dc3e10472fdc375284f9c1871954f2c47
|
|
| MD5 |
85de71ab0fb974c7638dc88da11dc520
|
|
| BLAKE2b-256 |
554ad36499de1e087cca0b972f20f170f71c7af5e158d398877214990e03ca37
|