SpectraX: a JAX-only neural-network library with a True MPMD pipeline parallelism for JAX — with an eager module API.
Project description
True MPMD pipeline parallelism for JAX — with an eager module API.
Quick Start · Installation · MPMD Runtime · Sharding · Benchmark · Examples · Docs
SpectraX is a JAX-native neural-network library built around true MPMD
pipeline parallelism. Each physical rank compiles and runs its own XLA
program — no shared shard_map HLO, no SPMD-same-shape constraint.
Heterogeneous stages (embed → blocks → head), nine pipeline schedules
(GPipe, 1F1B, ZeroBubble, Interleaved, DualPipeV, …), and a unified
spx.run()/spx.jit() entry point that dispatches to SPMD or MPMD from the same
training script.
The module API is eager and debuggable — subclass Module, override
forward, call model(x) — but every Module is a JAX pytree, so
jax.jit, jax.grad, and jax.tree.map work directly.
Used in production. EasyDeL is built on SpectraX: 77 model families (Llama, DeepSeek-V3, Qwen3, Gemma, GPT-OSS, Mamba, Whisper, …) with shared sharding, KV-cache, and pipeline plumbing. SpectraX is the JAX-only NN core; EasyDeL is the model zoo + training/serving stack.
Why SpectraX?
| Capability | SpectraX | JAX + manual | other JAX frameworks |
|---|---|---|---|
| True MPMD | built-in (sxcall, sxjit, spx.run) |
hand-rolled | SPMD-only |
| Heterogeneous stages | native (different class/shape per rank) | fragile | not supported |
| Pipeline schedules | 9 schedules (GPipe → DualPipeV) | hand-rolled | limited |
| Unified runtime | spx.run(mesh) → SPMD or MPMD |
separate code paths | separate APIs |
| Eager modules | model(x) + pytree-native |
functional only | functional or ref-tracking |
| Symbolic sharding | BATCH, EMBED, TP, FSDP, … |
string-typed P |
manual PartitionSpec |
| Dispatch overhead | ~150 µs | ~N/A | ~300–2000 µs |
Dispatch overhead measured on a tiny 2-layer CPU transformer. See Benchmark.
Quick Start
pip install spectrax-lib
1 · Single-device eager training
import jax.numpy as jnp
import spectrax as spx
from spectrax import nn
class MLP(spx.Module):
def __init__(self, d, h, o, *, rngs):
super().__init__()
self.fc1 = nn.Linear(d, h, rngs=rngs)
self.fc2 = nn.Linear(h, o, rngs=rngs)
def forward(self, x):
return self.fc2(nn.gelu(self.fc1(x)))
model = MLP(16, 64, 4, rngs=spx.Rngs(0))
@spx.jit
@spx.value_and_grad
def loss_fn(m, x, y):
return ((m(x) - y) ** 2).mean()
loss, grads = loss_fn(model, jnp.ones((8, 16)), jnp.zeros((8, 4)))
2 · Marker-based MPMD — split a function into per-rank programs
from spectrax.runtime.mpmd import sxjit, sxstage_iter
from spectrax.runtime.schedules import Std1F1B
from jax.sharding import PartitionSpec as P
mesh = spx.create_mesh(axis_dims=(4, 1, -1, 1, 1, 1), mpmd_axis="pp")
@sxjit(mesh=mesh, schedule=Std1F1B(microbatches=8))
def pipeline_fwd(model, x):
x = model.embed(x) # stage 0
x = sxstage_iter(x, sharding=P("fsdp", None, "tp")) # boundary — declares the
# activation sharding the
# next stage receives
x = model.blocks[0](x) # stage 1
x = sxstage_iter(x, sharding=P("fsdp", None, "tp")) # same contract again
x = model.blocks[1](x) # stage 2
x = sxstage_iter(x) # plain identity boundary
return model.head(x) # stage 3
sxjit traces, splits the jaxpr at sxstage_iter markers, and
compiles one XLA executable per stage / rank. Each rank only
sees its own sub-graph — true MPMD, not SPMD.
Each marker is functionally the identity but accepts two optional keywords:
| Arg | Purpose |
|---|---|
stage= |
Integer hint for the stage index (validation / debugging only — the cluster splitter partitions purely by marker order). |
sharding= |
PartitionSpec declaring the activation layout that flows across this boundary. The compiler honors it during cross-rank transport. |
This is the same pattern EasyDeL uses to keep activation sharding
stable across stage boundaries — see
easydel.infra.base_module._maybe_emit_stage_boundary.
3 · One-call MPMD training — spx.run
from spectrax.sharding import logical_axis_rules
mesh = spx.create_mesh(axis_dims=(2, 1, -1, 1, 1, 1), mpmd_axis="pp")
with logical_axis_rules(FSDP_TP_RULES), mesh:
loss, grads = spx.run(
model,
inputs=ids,
targets=labels,
mesh=mesh,
mode="train",
loss_fn=cross_entropy,
microbatches=8,
)
Drop mpmd_axis="pp" and the same code path runs under pure SPMD —
same model, same script, no rewrites.
4 · Deferred initialization — infer shapes at runtime
model = nn.Sequential(
nn.Linear(None, 256, rngs=rngs), # in_features inferred from first call
nn.ReLU(),
nn.Linear(256, 10, rngs=rngs),
)
y = model(jnp.zeros((8, 128))) # weight shapes resolved here
MPMD Runtime
SpectraX implements true MPMD: each physical rank compiles and executes its own distinct JAX program. This is not SPMD-with-barriers — stages can have different classes, different parameter shapes, and different I/O shapes.
spx.run — the unified entry point
spx.run routes to SPMD (pjit) or MPMD (sxcall) based on the mesh:
spx.run(
model,
inputs=x, # microbatched along leading axis
targets=y,
mesh=mesh, # SpxMesh — mpmd_axis decides the path
mode="train", # "forward" | "train"
loss_fn=ce_loss,
microbatches=8,
schedule=Std1F1B(...), # optional; defaults to GPipe
)
| Mesh type | What happens |
|---|---|
mpmd_axis=None |
Pure SPMD — pjit with FSDP/TP via logical axis rules |
mpmd_axis="pp" |
True MPMD — auto-split into stages, per-rank compilation |
For mode="train", inputs/targets accept arrays, tuples, or dicts;
SpectraX threads them into model.forward(...) and loss_fn(...).
Lower-level primitives
| Primitive | Purpose |
|---|---|
sxcall |
Execute a PipelineSequential under a schedule — Python dispatch loop over per-rank compiled callables. |
sxjit |
Decorator: trace → split at sxstage_iter markers → compile one XLA executable per stage / rank. |
sxgrad / sxvalue_and_grad |
Schedule-faithful gradients of an sxjit function. |
treduce |
Schedule-driven microbatch reduction — binds a body + schedule into the traced jaxpr. |
sxstage_iter |
Marker primitive for stage boundaries inside sxjit. |
spx.assign_stage |
Context manager that stamps subsequently-created variables with a (current, total) stage tag. |
Supported schedules
| Schedule | Type | Key trait |
|---|---|---|
GPipe |
Flat | All forwards, then all backwards. Simple, high memory. |
Std1F1B |
Flat | Standard 1-forward-1-backward. Peak memory O(n_stages). |
Eager1F1B |
Flat | 1F1B variant that aggressively starts the backward pipe. |
ZeroBubbleH1 |
Flat | Splits BWD into input-grad + weight-grad; weight-grad fills bubble slots. |
InterleavedH1 |
Virtual | Each rank owns v non-contiguous stages. Bubble shrinks by v. |
InterleavedGPipe |
Virtual | GPipe analog of InterleavedH1. |
Interleaved1F1BPlusOne |
Virtual | Interleaved 1F1B with one extra warmup microbatch. |
DualPipeV |
Virtual | V-shaped bidirectional pipeline (DeepSeek-style). |
KimiK2 |
Virtual | Interleaved with extra warmup (Moonshot K2 design). |
Choosing & using a schedule
All schedules share the same constructor shape: microbatches= is
required; virtual-stage schedules also take virtual_stages=. Every
schedule object exposes bubble_ratio(n_stages),
peak_activations(n_stages), and total_steps(n_stages) so you can
compare them analytically before launching:
from spectrax.runtime.schedules import (
GPipe, Std1F1B, ZeroBubbleH1, InterleavedH1, DualPipeV,
)
n = 4 # pipeline depth
m = 16 # microbatches
for sc in [
GPipe(m),
Std1F1B(m),
ZeroBubbleH1(m),
InterleavedH1(m, virtual_stages=2),
DualPipeV(m),
]:
print(f"{type(sc).__name__:14s} bubble={sc.bubble_ratio(n):.3f} "
f"steps={sc.total_steps(n)} peak_acts={sc.peak_activations(n)}")
Pass the chosen schedule to spx.run, spx.jit, sxcall, or sxjit:
# spx.run — picks SPMD or MPMD by mesh
loss, grads = spx.run(
model, inputs=x, targets=y, mesh=mesh,
mode="train", loss_fn=ce_loss,
microbatches=m, schedule=Std1F1B(m),
)
# sxcall — explicit MPMD on a PipelineSequential
loss, grads = sxcall(
model, (x, y),
mesh=mpmd_mesh, schedule=ZeroBubbleH1(m), loss_fn=ce_loss,
)
# sxjit — marker-based MPMD compile
@sxjit(mesh=mesh, schedule=InterleavedH1(m, virtual_stages=2))
def step(model, x):
...
Two optional fusion flags on spx.run flatten extra schedule overhead
when applicable:
spx.run(
model, ..., schedule=Std1F1B(m),
fuse_1f1b=True, # collapse the 1F1B steady-state into a single fused stage call
fuse_zb=True, # pair ZeroBubble's BWD_I + BWD_W when the schedule allows it
)
fuse_1f1b only applies to 1F1B-family schedules; fuse_zb only to
ZeroBubbleH1. Leave them at None (default) to let spx.run pick.
Rough rule of thumb:
| Goal | Pick |
|---|---|
| Simplest, lowest dispatch | GPipe |
Steady-state memory O(n_stages) |
Std1F1B (+ fuse_1f1b=True) |
| Fill the 1F1B bubble | ZeroBubbleH1 (+ fuse_zb=True) |
| Shrink bubble further at extra transport | InterleavedH1(virtual_stages=v) |
| DeepSeek-style V-shape | DualPipeV |
| Long-context / Moonshot-K2-style | KimiK2 |
Heterogeneous stages
No same-shape constraint. Stages can be completely different:
model = PipelineSequential(
EmbedStage(vocab, d, rngs=rngs), # (B, S) int → (B, S, d)
BlockStage(d, rngs=rngs), # (B, S, d) → (B, S, d)
BlockStage(d, rngs=rngs), # (B, S, d) → (B, S, d)
HeadStage(d, vocab, rngs=rngs), # (B, S, d) → (B, S, vocab)
)
spx.run also auto-splits homogeneous stacks: it detects a
ModuleList named blocks and slices it evenly across pipeline
stages — no pipeline-aware bookkeeping in user code.
Sharding: the symbolic-axis system
SpectraX separates what a tensor axis means (BATCH, EMBED, TP, FSDP)
from which mesh axis it lands on. Layer authors write symbolic tokens;
the runtime picks the actual PartitionSpec based on the active mesh
and runtime mode (training vs. autoregressive decode).
from spectrax import PartitionAxis, PartitionManager, common_types as ct
paxis = PartitionAxis() # defaults: dp/fsdp/tp/sp/ep
with PartitionManager(paxis):
h = ct.apply_logical_sharding(h, dynamic_axes=ct.HiddenStateSharding)
# h is now constrained to (BATCH, QUERY_LENGTH, EMBED)
# → resolved at runtime to whatever mesh the user is running on
Pre-baked symbolic shapes ship for the common cases:
HiddenStateSharding, AttnQSharding, AttnKVSharding, RowWise,
ColumnWise, Replicated, plus an Expert* family for MoE.
This is the same axis system EasyDeL uses to ship 77 model families
without per-model sharding code — see
easydel.axis
for an example of registering a domain-specific axis token (ATTN_DP)
that shadows DP only for KV-cache placement.
Module-aware JAX transforms
| Transform | What it does |
|---|---|
spx.jit |
mutable= + mesh= + schedule= — one entry, SPMD/MPMD |
spx.grad / spx.value_and_grad |
wrt= selector picks the differentiated subset |
spx.vmap |
module states passed with in_axes=None automatically |
spx.scan / spx.remat_scan |
module-aware loops, optionally checkpointed |
spx.remat |
gradient checkpointing (function- and class-aware) |
spx.cond / spx.switch / spx.while_loop / spx.fori_loop |
module-aware control flow |
spx.eval_shape |
shape-inference without running ops |
spx.jvp / spx.vjp |
forward / reverse-mode lifts that respect mutable= |
spx.jit in depth
spx.jit accepts every keyword jax.jit does, plus three SpectraX-only
ones: mutable=, mesh=, and schedule=. The mesh decides where the
call goes:
@spx.jit # plain jax.jit — no mesh
def fwd(model, x): return model(x)
@spx.jit(mutable="batch_stats") # writes to "batch_stats" survive the call
def train_bn(model, x):
return ((model(x)) ** 2).mean()
@spx.jit(mesh=spmd_mesh) # SPMD pjit — uses spmd_mesh
def fwd(model, x): return model(x)
@spx.jit(mesh=mpmd_mesh, schedule=Std1F1B(microbatches=8))
def step(model, x): # MPMD — internally calls sxjit
x = model.embed(x)
x = sxstage_iter(x)
x = model.blocks[0](x)
...
| Keyword | Meaning |
|---|---|
mutable= |
Selector / collection-name sugar: which collections may be written back after the call. Anything else throws IllegalMutationError. Default () (read-only). |
mesh= |
SpxMesh. If mesh.is_mpmd, the call is forwarded to sxjit (true MPMD). Otherwise plain jax.jit. |
schedule= |
Schedule instance, only meaningful when mesh is MPMD. Forwarded to sxjit. |
static_argnums / static_argnames |
Same as jax.jit. Static Module args are recommended (graph baked into the trace). |
donate_argnums |
Same as jax.jit. Useful for handing off optimizer state in-place under MPMD. |
in_shardings / out_shardings |
Same as jax.jit. Forwarded verbatim under SPMD; under MPMD they are resolved against per-stage sub-meshes. |
keep_unused, device, backend, inline, compiler_options |
Forwarded to jax.jit. Rejected with a clear error when mesh is MPMD (no analog in sxjit). |
When mutable= is set, the compiled function takes the form
(states, stripped_args, stripped_kwargs) internally — static_argnums
and donate_argnums index into that 3-tuple, so prefer the
*_argnames variants whenever possible.
The cache is keyed on the input modules' GraphDef snapshot. Mutating
parameter values does not invalidate the cache; structural changes
(adding/removing/replacing a Module or Variable attribute) do.
Selector DSL
One composable predicate for every "subset of the model" API:
spx.grad(loss, wrt="parameters") # by collection name
spx.grad(loss, wrt=nn.LoraParameter) # by Variable class
spx.grad(loss, wrt=spx.path_contains("attn")) # by path glob
# Compose with set algebra
sel = (
spx.select()
.at_instances_of(nn.Linear)
.of_type(spx.Parameter)
- spx.path_contains("head")
)
trainable, frozen = sel.partition_state(model, state)
The same DSL backs grad(wrt=...), jit(mutable=...),
Optimizer(wrt=...), freeze(...), and iter_variables(select=...).
Features
LoRA fine-tuning
base = nn.Linear(768, 768, rngs=spx.Rngs(0))
model = nn.wrap_lora(base, rank=8, alpha=16, rngs=spx.Rngs(1))
@spx.jit
@spx.grad(wrt="lora")
def step(m, x, y):
return ((m(x) - y) ** 2).mean()
FP8 training (delayed scaling, rolling amax history)
@spx.jit(mutable="fp8_meta")
def step(m, x, y):
def loss(m, x, y):
return ((m(x) - y) ** 2).mean()
return spx.grad(loss)(m, x, y)
Explicit graph / state seam
gdef, state = spx.export(model) # immutable GraphDef + State dict
model2 = spx.bind(gdef, state) # reconstruct, skips __init__
spx.update(model, state) # in-place state patch
clone = spx.clone(model) # deep-copy via export+bind
Inspection
spx.inspect.summary(model, jnp.zeros((1, 128)))
spx.inspect.count_parameters(model)
spx.inspect.count_bytes(model)
spx.inspect.tabulate(model, example_input)
Built-in layers
Linear, Conv1d/2d/3d + ConvTranspose*, MultiheadAttention,
LayerNorm, RMSNorm, BatchNorm, InstanceNorm, GroupNorm, Dropout,
Embed, MLPBlock, MoE primitives, RNN/GRU/LSTM cells, FP8 path,
LoRA path, plus containers (Sequential, ModuleList, ModuleDict,
StackedModuleList for repeated transformer blocks under lax.scan).
Installation
pip install spectrax-lib
# Optional extras
pip install "spectrax-lib[contrib]" # optax integration
pip install "spectrax-lib[cuda]" # CUDA jaxlib
pip install "spectrax-lib[tpu]" # TPU jaxlib
From source:
git clone https://github.com/erfanzar/spectrax
cd spectrax
uv sync --extra dev --extra test --extra contrib
Requires Python 3.11+ and JAX ≥ 0.9.2.
Examples
Seven topic folders under examples/ progress from
single-Module forward passes to multi-device MPMD pipeline training:
| Folder | Topic |
|---|---|
01_basics/ |
Modules, training loops, export/bind, optimizers |
02_implementation_guide/ |
Llama 3, Qwen 2, GPT-2, ViT, custom transformer |
03_transformations/ |
jit, grad, vmap, remat, scan, fori_loop |
04_surgery/ |
Selectors, LoRA injection, FP8, freezing, swapping |
05_shardings/ |
FSDP, TP, hybrid sharding, logical axis rules |
06_spmd_scheduled/ |
Pipeline runtime with all schedules |
07_mpmd/ |
Real MPMD pipeline via spx.run — train, forward, decode, 3-D mesh |
python -m examples.01_basics.02_training_loop
python -m examples.02_implementation_guide.01_llama3
python -m examples.07_mpmd.01_train_homogeneous
Most examples run on CPU with small configs; sharding and pipeline examples benefit from multi-device TPU / GPU but fall back cleanly to one device.
Used by
- EasyDeL — production
training/serving framework for LLMs, multimodal models, and vision
models. SpectraX is the NN core; EasyDeL adds the model zoo
(Llama, Qwen3, DeepSeek-V3, Gemma, GPT-OSS, Mamba, Whisper, …),
trainers, KV cache, and inference engine. EasyDeL leans on
spx.Module,spx.Rngs,spx.Parameter,spx.assign_stage,spx.sxstage_iter, thePartitionAxisregistry, and thecommon_typessymbolic-axis tokens.
If your project uses SpectraX, open a PR to add it here.
Design
- True MPMD first.
sxcallcompiles one XLA program per physical rank. Stages can differ in class, shape, and parameters. No SPMD-same-shape constraint. - Unified runtime.
spx.rundispatches on the mesh. Same model, same script; change the mesh and you change the parallelism strategy. - Schedule-faithful execution. The dispatch loop walks the schedule grid exactly as planned. No hidden reordering, no implicit fusing.
- Modules are JAX pytrees. Flatten/unflatten via
export/bind.jax.jit,jax.tree.map,jax.value_and_gradwork directly. - State lives in
Variablecells.Parameter,Buffer, and user subclasses; each tags its collection (parameters,buffers,lora,fp8_meta, …). - One filter DSL everywhere.
Selectorservesgrad(wrt=...),jit(mutable=...),Optimizer(wrt=...),freeze(...),iter_variables(select=...). - Symbolic sharding tokens. Layer code says
BATCH,EMBED,TP;PartitionManagerresolves the active mesh at runtime. Decode-mode overrides flow through the same tokens — no per-mode forking in layer code.
Benchmark
python -m benchmarks.bench --cases all --device cpu
python -m benchmarks.bench --cases large --device tpu # 1.21B transformer
python -m benchmarks.llama_transforms_3way --preset small --device tpu --plots
The CPU dispatch benchmark writes
benchmarks/results/latest.{json,md} with per-case spectrax_ms / nnx_ms ratios. On a tiny CPU dispatch-bound benchmark (2-layer / d=64
/ batch-2 transformer), SpectraX runs at 1.83x the speed of
flax.nnx; on d=48 it hits 2.0x. On compute-bound workloads (TPU
8B) the Python gap shrinks but stays positive.
The plots below are from the TPU small three-way Llama transform
benchmark, comparing raw JAX, Flax NNX, and SpectraX across the
transform surface used by real training code: jit, grad,
value_and_grad, jvp, vjp, vmap, control flow, scan, and
remat. Ratios are runtime against raw JAX, so 1.0 is parity,
lower is faster, higher is slower.
Compile latency is tracked separately because first-call XLA costs can dominate short benchmark cases and tell a different story from steady-state runtime.
Testing
pytest tests/ -q
pytest tests/test_conformance.py
Status
v0.0.3 — alpha. API may still move; pin the version if you depend
on behavioural stability. See CHANGELOG.md for the
release log.
License
This project is licensed under AGPL-3.0-or-later.
If you use EasyDeL or SpecTrax in research, infrastructure, benchmarks, or derivative systems, please provide attribution and cite the project.
Citation
@software{easydel,
author = {Erfan Zare Chavoshi},
title = {EasyDeL},
year = {2023},
url = {<https://github.com/erfanzar/EasyDeL}>
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file spectrax_lib-0.0.4.tar.gz.
File metadata
- Download URL: spectrax_lib-0.0.4.tar.gz
- Upload date:
- Size: 441.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.11.1 {"installer":{"name":"uv","version":"0.11.1","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"22.04","id":"jammy","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1e7288ada7bc1d70b2dd5485263e2fdd6526708ba8ad51500df81bf9a4008ea9
|
|
| MD5 |
6f181aa0649e300421fc93f44b64af7f
|
|
| BLAKE2b-256 |
561be9219a657028ce4fd2b2eb7d6fe40f4a09a217259fb65e8884b1d1617663
|
File details
Details for the file spectrax_lib-0.0.4-py3-none-any.whl.
File metadata
- Download URL: spectrax_lib-0.0.4-py3-none-any.whl
- Upload date:
- Size: 527.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.11.1 {"installer":{"name":"uv","version":"0.11.1","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"22.04","id":"jammy","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0a2bf18bcdf2dbcec999d164174d5f74b740f7e522cfe61d5eedd8b53e2f5804
|
|
| MD5 |
afd6b5aa22c834847bbaa95f62cb2a33
|
|
| BLAKE2b-256 |
c5b1cd6d7895b2700c03f3c7dd7f07bed838839de8980d39320c65dfef9aee54
|