SpectraX: a JAX-only neural-network library with a PyTorch-shaped eager surface and Flax-NNX-style graph/state underneath.
Project description
True MPMD pipeline parallelism for JAX — with an eager module API.
Quick Start | Installation | MPMD Runtime | Benchmark | Examples | Docs
SpectraX is a JAX-native neural network library built around true MPMD
pipeline parallelism. Each physical rank compiles and runs its own XLA
program — no shared shard_map HLO, no SPMD-same-shape constraint.
Heterogeneous stages (embed → blocks → head), multiple schedules
(GPipe, 1F1B, ZeroBubble, Interleaved, DualPipe), and a unified
spx.run() entry point that dispatches to SPMD or MPMD from the same
training script.
The module API is eager and debuggable — subclass Module, override
forward, call model(x) — but every Module is a JAX pytree, so
jax.jit, jax.grad, and jax.tree.map work directly.
Why SpectraX?
| Capability | SpectraX | JAX + manual | other JAX frameworks |
|---|---|---|---|
| True MPMD | built-in (sxcall, sxjit) |
hand-rolled | SPMD-only |
| Heterogeneous stages | native (different class/shape per rank) | fragile | not supported |
| Pipeline schedules | 9 schedules (GPipe→DualPipeV) | hand-rolled | limited |
| Unified runtime | spx.run(mesh) → SPMD or MPMD |
separate code paths | separate APIs |
| Eager modules | model(x) + pytree-native |
functional only | functional or ref-tracking |
| Dispatch overhead | ~150 µs | ~N/A | ~300–2000 µs |
Dispatch overhead measured on a tiny 2-layer CPU transformer. See benchmark.
Quick Start
pip install spectrax-lib
Single-device eager training
import jax.numpy as jnp
import spectrax as spx
from spectrax import nn
class MLP(spx.Module):
def __init__(self, d, h, o, *, rngs):
super().__init__()
self.fc1 = nn.Linear(d, h, rngs=rngs)
self.fc2 = nn.Linear(h, o, rngs=rngs)
def forward(self, x):
return self.fc2(nn.gelu(self.fc1(x)))
model = MLP(16, 64, 4, rngs=spx.Rngs(0))
@spx.jit
@spx.value_and_grad
def loss_fn(m, x, y):
return ((m(x) - y) ** 2).mean()
loss, grads = loss_fn(model, jnp.ones((8, 16)), jnp.zeros((8, 4)))
Marker-based MPMD — split a function into per-rank programs
from spectrax.runtime.mpmd import sxjit, sxstage_iter
from spectrax.pipeline import Std1F1B
# Define a multi-stage forward with explicit stage boundaries
@sxjit(schedule=Std1F1B(microbatches=8))
def pipeline_fwd(model, x):
x = model.embed(x) # stage 0
x = sxstage_iter(x) # boundary — split here
x = model.blocks[0](x) # stage 1
x = sxstage_iter(x) # boundary — split here
x = model.blocks[1](x) # stage 2
x = sxstage_iter(x) # boundary — split here
return model.head(x) # stage 3
# sxjit traces the function, splits the jaxpr at sxstage_iter markers,
# and compiles one XLA executable per stage/rank. Each rank only sees
# its own sub-graph — true MPMD, not SPMD.
output = pipeline_fwd(model, x)
MPMD pipeline training — one call, multiple devices
from spectrax.sharding import logical_axis_rules
# Create a 4-stage pipeline mesh
mesh = spx.create_mesh(axis_dims=(2, 1, -1, 1, 1, 1), mpmd_axis="pp")
with logical_axis_rules(FSDP_TP_RULES), mesh:
# MPMD: each rank gets its own compiled stage
loss, grads = spx.run(
model, inputs=x, targets=y,
mesh=mesh, mode="train", loss_fn=ce_loss, microbatches=8,
)
# Drop mpmd_axis → same code runs under pure SPMD pjit
# Same model, same script, different mesh — no code changes.
Deferred initialization — infer shapes at runtime
model = nn.Sequential(
nn.Linear(None, 256, rngs=rngs), # in_features inferred from first call
nn.ReLU(),
nn.Linear(256, 10, rngs=rngs),
)
y = model(jnp.zeros((8, 128))) # weight shapes resolved here
MPMD Runtime
SpectraX implements true MPMD: each physical rank compiles and executes its own distinct JAX program. This is not SPMD-with-barriers — stages can have different classes, different parameter shapes, and different I/O shapes.
spx.run — unified entry point
spx.run routes to SPMD (pjit) or MPMD (sxcall) based on the mesh:
spx.run(
model,
inputs=x, # microbatched along leading axis
targets=y,
mesh=mesh, # SpxMesh — mpmd_axis decides the path
mode="train", # "forward" | "train"
loss_fn=ce_loss,
microbatches=8,
)
| Mesh type | What happens |
|---|---|
mpmd_axis=None |
Pure SPMD — pjit with FSDP/TP via logical axis rules |
mpmd_axis="pp" |
True MPMD — auto-split into stages, per-rank compilation |
Lower-level primitives
For full control, drop below spx.run:
| Primitive | Purpose |
|---|---|
sxcall |
Execute a PipelineSequential under a schedule — Python dispatch loop over pre-built per-rank callables |
sxjit |
Decorator: trace a function, split at sxstage_iter markers, compile one XLA executable per stage/rank |
sxgrad / sxvalue_and_grad |
Schedule-faithful gradients of an sxjit function |
treduce |
Schedule-driven microbatch reduction primitive — binds a body + schedule into the traced jaxpr |
sxstage_iter |
Marker primitive for stage boundaries inside sxjit |
Supported schedules
| Schedule | Type | Key trait |
|---|---|---|
GPipe |
Flat | All forwards, then all backwards. Simple, high memory. |
Std1F1B |
Flat | Standard 1-forward-1-backward. Peak memory O(n_stages). |
ZeroBubbleH1 |
Flat | Splits BWD into input-grad + weight-grad; weight-grad fills bubble slots. |
InterleavedH1 |
Virtual | Each rank owns v non-contiguous stages. Bubble shrinks by v. |
DualPipeV |
Virtual | V-shaped bidirectional pipeline (DeepSeek-style). |
KimiK2 |
Virtual | Interleaved with extra warmup (Moonshot K2 design). |
Heterogeneous stages
No same-shape constraint. Stages can be completely different:
model = PipelineSequential(
EmbedStage(vocab, d, rngs=rngs), # (B, S) int → (B, S, d)
BlockStage(d, rngs=rngs), # (B, S, d) → (B, S, d)
BlockStage(d, rngs=rngs), # (B, S, d) → (B, S, d)
HeadStage(d, vocab, rngs=rngs), # (B, S, d) → (B, S, vocab)
)
Auto-splitting is available for homogeneous stacks: spx.run detects
model.blocks: ModuleList and slices it evenly across pipeline stages.
Installation
pip install spectrax-lib
# Optional extras
pip install "spectrax-lib[contrib]" # optax integration
pip install "spectrax-lib[cuda]" # CUDA jaxlib
pip install "spectrax-lib[tpu]" # TPU jaxlib
From source:
uv sync --extra dev --extra test --extra contrib
Requires Python 3.11+ and JAX >= 0.9.2.
Features
MPMD Pipeline Parallelism
True multi-program multi-data execution with per-rank compilation, schedule-faithful dispatch, and heterogeneous stage support.
Module-aware JAX transforms
| Transform | What it does |
|---|---|
spx.jit |
mutable= selector declares writable collections |
spx.grad / spx.value_and_grad |
wrt= selector picks the differentiated subset |
spx.vmap |
module states passed with in_axes=None automatically |
spx.scan / spx.remat_scan |
module-aware loops, optionally checkpointed |
spx.remat |
gradient checkpointing |
spx.cond / spx.switch / spx.while_loop / spx.fori_loop |
control flow over module state |
Selector DSL
One composable predicate for every "subset of the model" API:
spx.grad(loss, wrt="parameters") # by collection name
spx.grad(loss, wrt=nn.LoraParameter) # by Variable class
spx.grad(loss, wrt=spx.path_contains("attn")) # by path glob
sel = spx.select().at_instances_of(nn.Linear).of_type(spx.Parameter) - spx.path_contains("head")
trainable, frozen = sel.partition_state(model, state)
Sharding & SPMD
Annotate variables with logical axis names and let the mesh decide:
w = spx.Parameter(
jnp.zeros((256, 256)),
sharding=spx.Sharding(("data", "model")),
axis_names=("in", "out"),
)
LoRA Fine-Tuning
base = nn.Linear(768, 768, rngs=spx.Rngs(0))
model = nn.wrap_lora(base, rank=8, alpha=16, rngs=spx.Rngs(1))
@spx.jit
@spx.grad(wrt="lora")
def step(m, x, y):
return ((m(x) - y) ** 2).mean()
FP8 Training
Delayed-scaling per-tensor FP8 with rolling amax history:
@spx.jit(mutable="fp8_meta")
def step(m, x, y):
def loss(m, x, y):
return ((m(x) - y) ** 2).mean()
return spx.grad(loss)(m, x, y)
Explicit graph / state seam
gdef, state = spx.export(model) # immutable GraphDef + State dict
model2 = spx.bind(gdef, state) # reconstruct, skips __init__
spx.update(model, state) # in-place state patch
clone = spx.clone(model) # deep-copy via export+bind
Inspection
spx.inspect.summary(model, jnp.zeros((1, 128)))
spx.inspect.count_parameters(model)
spx.inspect.count_bytes(model)
spx.inspect.tabulate(model, example_input)
Examples
See examples/ for runnable scripts:
| Folder | Topic |
|---|---|
01_basics/ |
Modules, training loops, export/bind, optimizers |
02_implementation_guide/ |
Llama 3, Qwen 2, GPT-2, ViT, custom transformer |
03_transformations/ |
jit, grad, vmap, remat, scan, fori_loop |
04_surgery/ |
Selectors, LoRA injection, FP8, freezing, swapping |
05_shardings/ |
FSDP, TP, hybrid sharding, logical axis rules |
06_spmd_scheduled/ |
Pipeline runtime with all schedules |
07_mpmd/ |
Real MPMD pipeline via spx.run — train, forward, decode, 3-D mesh |
python -m examples.01_basics.02_training_loop
python -m examples.02_implementation_guide.01_llama3
python -m examples.07_mpmd.01_train_homogeneous
Design
- True MPMD first —
sxcallcompiles one XLA program per physical rank. Stages can differ in class, shape, and parameters. No SPMD-same-shape constraint. - Unified runtime —
spx.rundispatches on the mesh. Same model, same script; change the mesh and you change the parallelism strategy. - Schedule-faithful execution — the dispatch loop walks the schedule grid exactly as planned. No hidden reordering, no implicit fusing.
- Modules are JAX pytrees — flatten/unflatten through
export/bind.jax.jit,jax.tree.map,jax.value_and_gradwork directly. - State lives in
Variablecells —Parameter,Buffer, custom subclasses. Each tags its collection (parameters,buffers,lora,fp8_meta, ...). - One filter DSL everywhere —
Selectorservesgrad(wrt=...),jit(mutable=...),Optimizer(wrt=...),freeze(...),iter_variables(select=...).
Benchmark
python -m benchmarks.bench --cases all --device cpu
python -m benchmarks.bench --cases large --device tpu # 1.21B transformer
python -m benchmarks.llama_transforms_3way --preset small --device tpu --plots
The CPU dispatch benchmark writes benchmarks/results/latest.{json,md} with
per-case spectrax_ms / nnx_ms ratios. On a tiny CPU dispatch-bound benchmark
(2-layer / d=64 / batch-2 transformer), SpectraX runs at 1.83× the speed of
flax.nnx; on d=48 it hits 2.0×. On compute-bound workloads (TPU 8B) the
Python gap shrinks but stays positive.
The plots below are from the TPU small three-way Llama transform benchmark,
which compares raw JAX, Flax NNX, and SpectraX across the transform surface used
by real training code: jit, grad, value_and_grad, jvp, vjp, vmap,
control-flow primitives, scan, and remat. These are runtime ratios against
raw JAX, so 1.0 means parity, lower is faster, and higher is slower.
Compile latency is tracked separately because first-call XLA costs can dominate short benchmark cases and tell a different story from steady-state runtime.
Testing
pytest tests/ -q
pytest tests/test_conformance.py
Status
v0.0.1 — alpha. API may still move; pin the version if you depend on
behavioural stability.
License
Apache-2.0.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file spectrax_lib-0.0.2.tar.gz.
File metadata
- Download URL: spectrax_lib-0.0.2.tar.gz
- Upload date:
- Size: 274.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.11.1 {"installer":{"name":"uv","version":"0.11.1","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"22.04","id":"jammy","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f061c9558b5bc7fa02b9c64456a52cce7b2eb0918fbbe7c5d48ac6343add04c1
|
|
| MD5 |
ca11d9d057927be90fc06951a0b9a6e5
|
|
| BLAKE2b-256 |
592ccd54bd14de56bf6979aec6d2e105464069ebe72d8d16644f5248c66003db
|
File details
Details for the file spectrax_lib-0.0.2-py3-none-any.whl.
File metadata
- Download URL: spectrax_lib-0.0.2-py3-none-any.whl
- Upload date:
- Size: 365.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.11.1 {"installer":{"name":"uv","version":"0.11.1","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"22.04","id":"jammy","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
18e260660acca8e1e8be6656d3f563af0007960cfec514fe19ce2a204af89ac8
|
|
| MD5 |
93fbd41433dcc4b0290e726175289b3c
|
|
| BLAKE2b-256 |
01239d209bd88c838f9c1c3f9da71a208deb83ff81930a7bc66e20784a232fdf
|