Dynamic-batch split replay and split training runtime for PyTorch models.
Project description
Ariadne
Ariadne is a dynamic-batch split replay runtime for PyTorch models. It traces real PyTorch forward execution into a symbolic split plan, generates efficient prefix/suffix segment callables, and supports boundary-based replay and split training.
Ariadne uses TorchLens-style tracing for broad PyTorch model compatibility, but it does not use node-by-node interpreted replay as the default runtime. Instead, it lowers traced graphs into generated prefix/suffix segment callables.
Why Ariadne Exists
Split inference and split training need a clean boundary between offline model preparation and lightweight runtime execution. Ariadne prepares a symbolic TracePlan, validates a declarative SplitSpec, chooses a valid frontier, and generates executable segments so runtime work stays simple:
- run the prefix segment
- package a boundary payload
- run the suffix segment
- optionally train the suffix and backpropagate boundary gradients into the prefix
Design Principles
- Trace observed PyTorch forward behavior while keeping default metadata light.
- Treat the first input tensor dimension as symbolic batch dimension
B. - Keep concrete batch size out of trace and runtime cache keys.
- Make split declarations explicit and verifiable.
- Use generated eager segments as the default execution path.
- Keep node-by-node interpretation only for
debug_interpretervalidation. - Use
torch.compileonly as an optional optimizer for generated segments.
Installation With uv
uv sync
Dependencies are declared only in pyproject.toml, and uv.lock is committed
for reproducible installs.
Development
uv sync
uv run pytest
uv run ruff check .
uv run mypy src
Useful demo commands:
uv run python examples/split_inference_demo.py
uv run python examples/split_training_demo.py
Optional real-model smoke checks install the integration extra and may download
YOLO weights. The current integration suite covers YOLOv8n, RF-DETR Nano,
timm resnet50, timm swin_tiny_patch4_window7_224, torchvision mobilenet_v3_large,
and torchvision deeplabv3_resnet50:
uv run --extra integration python examples/real_model_functional_test.py
uv run --extra integration python examples/real_model_timing.py --iterations 3 --warmup 1
ARIADNE_RUN_REAL_MODELS=1 uv run --extra integration pytest tests/integration -m integration
Basic Split Inference
import torch
from ariadne import SplitSpec, prepare_split
spec = SplitSpec(
boundary="after:layer3",
batch_symbol="B",
dynamic_batch=(2, 64),
trainable=True,
trace_batch_mode="batch_gt1",
)
runtime = prepare_split(
model,
example_inputs=(torch.randn(4, 5),),
split=spec,
mode="generated_eager",
)
boundary = runtime.run_prefix(torch.randn(8, 5))
output = runtime.run_suffix(boundary)
Basic Split Training
import torch.nn.functional as F
boundary = runtime.run_prefix(x_batch)
loss, boundary_grads = runtime.train_suffix(
boundary,
targets,
loss_fn=F.mse_loss,
optimizer=suffix_optimizer,
)
runtime.backward_prefix(
x_batch,
boundary_grads=boundary_grads,
optimizer=prefix_optimizer,
)
Dynamic Batch
By default, Ariadne treats the first dimension of the first tensor input as the
symbolic batch dimension B. If an example trace uses batch size 4, tensors like
(4, 256, 14, 14) are recorded as ("B", 256, 14, 14).
SplitSpec.trace_batch_mode makes the preparation strategy explicit:
batch_1: requiresexample_inputsbatch size 1 and adynamic_batchrange that includes 1. Ariadne performs prepare-time provenance probing and can prepare a batch>1 structural variant when singleton and non-singleton aten paths differ.batch_gt1: requiresexample_inputsbatch size greater than 1 and adynamic_batchrange that starts at 2 or greater. Ariadne stays in the non-singleton regime, derives affine batch shapes such as4*B, and avoids the extra singleton structural variant used bybatch_1.
For real YOLO and RF-DETR smoke tests, Ariadne uses batch_gt1 mode and verifies
cross-batch split replay plus split training on batch sizes 2 and 3.
At runtime, Ariadne materializes B from the actual input batch size, validates
that it is inside SplitSpec.dynamic_batch, and checks that non-batch dimensions
match the prepared boundary schema. The concrete batch size is intentionally not
part of RuntimeCacheKey.
Execution Modes
debug_interpreter: slow interpreter execution for validation and debugging.generated_eager: default generated prefix/suffix segment execution.compiled: appliestorch.compileto generated segments after preparation.
Example:
runtime = prepare_split(
model,
example_inputs=(x,),
split=spec,
mode="compiled",
compile_options={"backend": "inductor", "mode": "reduce-overhead", "dynamic": True},
)
Benchmarking
ariadne.benchmark includes utilities for measuring direct PyTorch forward,
debug/generated/compiled prefix+suffix execution, suffix-only replay, suffix
training, and prefix backward. Results include average latency, total latency,
batch size, split id, execution mode, and optional CUDA peak memory.
Current Limitations
- The default tracer uses
TorchDispatchModeruntime interception and records the observed forward path. - Alias, mutation, RNG, FLOP, and memory metadata are intentionally lightweight.
- Segment generation currently supports tensor boundaries from prepared observed paths, including a batch>1 structural variant when batch=1 tracing needs it.
- Dynamic non-batch dimensions are reserved for future SplitSpec extensions.
- Shape expressions cover direct and affine batch-derived dimensions; more complex non-affine shape arithmetic is still limited.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file ariadne_split-0.1.0.tar.gz.
File metadata
- Download URL: ariadne_split-0.1.0.tar.gz
- Upload date:
- Size: 213.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.10.10 {"installer":{"name":"uv","version":"0.10.10","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
57290863d31b790cbeaad8aee1264437d4f13a21f0c6dc48e2063677e501afb4
|
|
| MD5 |
45925bf06b26525daa0c49161dd8d766
|
|
| BLAKE2b-256 |
3ff8c86c7e2496b42d538ee1b6b2b596bc26295e04ec5650a92a30ad102421d1
|
File details
Details for the file ariadne_split-0.1.0-py3-none-any.whl.
File metadata
- Download URL: ariadne_split-0.1.0-py3-none-any.whl
- Upload date:
- Size: 36.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.10.10 {"installer":{"name":"uv","version":"0.10.10","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"24.04","id":"noble","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
72c8318e7c6bdb53cb4af2882763a8777c15f57da3c0a41095a412ce360ed3b8
|
|
| MD5 |
df6afbcb6dff2c7ccb6c29a73ac69e22
|
|
| BLAKE2b-256 |
39c0a74e4bf1ef7a1fd71ce5ecd9cb8fdb626eb930a3d5a3beb889c605db8ea7
|