Split learning framework for PyTorch — pipeline-parallel training across heterogeneous devices.
Project description
TorchSlicer
Split learning for PyTorch — partition a model across multiple devices or machines and train it end-to-end without any single node holding the full model.
import torchslicer as ts
sliced = ts.slice(model, n=4)
sliced.train(loader, optimizer_cfg, criterion_cfg, epochs=10)
What it does
TorchSlicer splits an nn.Module into sequential partitions, assigns each to a device or remote worker, and runs forward and backward passes across them in a pipeline. It handles:
- Gradient flow across partition boundaries
- Mixed-precision training (
bfloat16autocast) - GPipe micro-batch pipelining
- Fault tolerance with automatic recovery from worker checkpoints
- Run logging to structured JSONL for offline analysis
Installation
# Core (local training, no networking)
pip install torchslicer
# With gRPC distributed support
pip install 'torchslicer[grpc]'
# All extras (gRPC + OpenTelemetry monitoring + LoRA/PEFT)
pip install 'torchslicer[grpc,monitor,peft]'
Requires Python ≥ 3.10 and PyTorch (not declared as a dependency — install the build that matches your hardware).
Execution modes
| Mode | Use case |
|---|---|
| Local | Single process, all partitions on one machine. Fast iteration. |
| Centralized | Coordinator + gRPC workers. Each worker runs on a separate host/GPU. |
| P2P | No coordinator. Driver picks up peers directly. |
Transport layer is configurable: gRPC (default) or raw TCP for lower serialization overhead (~2.4× faster on plain split, ~1.6× faster with GPipe).
Quick start — local
import torchslicer as ts
import torchvision
model = torchvision.models.resnet18()
sliced = ts.slice(model, n=2)
optimizer_cfg = {"name": "SGD", "params": {"lr": 0.01, "momentum": 0.9}}
criterion_cfg = {"name": "CrossEntropyLoss", "params": {}}
sliced.train(train_loader, optimizer_cfg, criterion_cfg, epochs=5, verbose=True)
Quick start — distributed (coordinator + workers)
# coordinator
from torchslicer import RunConfig
from torchslicer.executors.distributed import DistributedExecutor
from torchslicer.discovery import CoordinatorDiscovery
import torchslicer as ts
cfg = RunConfig.load("experiments/resnet18_4gpu.yaml")
executor = DistributedExecutor(
discovery=CoordinatorDiscovery(run_id=cfg.run_id),
coordinator_addr="0.0.0.0:50054",
run_config=cfg,
)
sliced = ts.slice(model, n=cfg.discovery.n_workers, executor=executor)
sliced.train(loader, run_config=cfg)
Workers are started separately — see examples/train/ for full entrypoints and the Docker Compose setup in the repository.
Custom architectures
For models torch.fx cannot trace (HuggingFace LLMs, MoE, etc.), supply a pack function:
def pack_qwen(model):
return [
ts.SimpleEmbedStage(model.model.embed_tokens),
*[ts.BlockStage(layer) for layer in model.model.layers],
ts.CausalLMHeadStage(model.model.norm, model.lm_head),
]
sliced = ts.slice(model, n=4, pack=pack_qwen)
Built-in stage types: BlockStage, GPT2EmbedStage, SimpleEmbedStage, CausalLMHeadStage, AuxInputStage, MoEBlockStage.
Configuration
All parameters are controlled via RunConfig with a 4-level priority: Python kwargs > YAML file > environment variables > defaults.
from torchslicer import RunConfig
cfg = RunConfig.load("experiments/resnet18_4gpu.yaml") # YAML + env overrides
cfg.to_yaml("runs/my_run/resolved_config.yaml") # save for reproducibility
Monitoring
When torchslicer[monitor] is installed, runs emit OpenTelemetry traces. Point the exporter at any OTLP-compatible backend by setting:
OTEL_EXPORTER_OTLP_ENDPOINT=http://your-backend:4317
Compatible with Arize Phoenix, Jaeger, Grafana Tempo, Honeycomb, and any other OTLP collector. No library changes needed per backend.
Run metrics are also written to runs/<run_id>/metrics.jsonl for offline analysis.
License
GPL-3.0-or-later. See LICENSE.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torchslicer-0.2.0.tar.gz.
File metadata
- Download URL: torchslicer-0.2.0.tar.gz
- Upload date:
- Size: 72.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
36bb6e42167d1eda600555cab006f33b0376a14d3d3138ddb30c258406812343
|
|
| MD5 |
56e087674caa7e550d05ca75688cc02b
|
|
| BLAKE2b-256 |
5beb5e9d9606ce7c162ea2d5a421e4b283588ddbf8b737c546e2741b7e0427ae
|
Provenance
The following attestation bundles were made for torchslicer-0.2.0.tar.gz:
Publisher:
publish.yml on MarcoGarofalo94/TorchSlicer
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
torchslicer-0.2.0.tar.gz -
Subject digest:
36bb6e42167d1eda600555cab006f33b0376a14d3d3138ddb30c258406812343 - Sigstore transparency entry: 1191484027
- Sigstore integration time:
-
Permalink:
MarcoGarofalo94/TorchSlicer@6150c514d7149af45b0cfa7cc8e83cdfa21b2ceb -
Branch / Tag:
refs/tags/v0.2.0 - Owner: https://github.com/MarcoGarofalo94
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@6150c514d7149af45b0cfa7cc8e83cdfa21b2ceb -
Trigger Event:
push
-
Statement type:
File details
Details for the file torchslicer-0.2.0-py3-none-any.whl.
File metadata
- Download URL: torchslicer-0.2.0-py3-none-any.whl
- Upload date:
- Size: 84.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
764751c43afcc20f3862b025449f738faae49259dfcf2dbce0c976d9aa7ee6a8
|
|
| MD5 |
765e0dfdbe702df9d1c148752af98316
|
|
| BLAKE2b-256 |
c1ccbfa10b56fe83442d7274f2fa43ae98c3435bf863752c59faff86a7187f86
|
Provenance
The following attestation bundles were made for torchslicer-0.2.0-py3-none-any.whl:
Publisher:
publish.yml on MarcoGarofalo94/TorchSlicer
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
torchslicer-0.2.0-py3-none-any.whl -
Subject digest:
764751c43afcc20f3862b025449f738faae49259dfcf2dbce0c976d9aa7ee6a8 - Sigstore transparency entry: 1191484029
- Sigstore integration time:
-
Permalink:
MarcoGarofalo94/TorchSlicer@6150c514d7149af45b0cfa7cc8e83cdfa21b2ceb -
Branch / Tag:
refs/tags/v0.2.0 - Owner: https://github.com/MarcoGarofalo94
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@6150c514d7149af45b0cfa7cc8e83cdfa21b2ceb -
Trigger Event:
push
-
Statement type: