Skip to main content

A JAX-only neural-network library with a PyTorch-shaped eager surface and Flax-NNX-style graph/state underneath.

Project description

spectrux

True MPMD pipeline parallelism for JAX — with an eager module API.

Quick Start | Installation | MPMD Runtime | Examples | Docs


Spectrux is a JAX-native neural network library built around true MPMD pipeline parallelism. Each physical rank compiles and runs its own XLA program — no shared shard_map HLO, no SPMD-same-shape constraint. Heterogeneous stages (embed → blocks → head), multiple schedules (GPipe, 1F1B, ZeroBubble, Interleaved, DualPipe), and a unified spx.run() entry point that dispatches to SPMD or MPMD from the same training script.

The module API is eager and debuggable — subclass Module, override forward, call model(x) — but every Module is a JAX pytree, so jax.jit, jax.grad, and jax.tree.map work directly.


Why spectrux?

Capability spectrux JAX + manual other JAX frameworks
True MPMD built-in (sxcall, sxjit) hand-rolled SPMD-only
Heterogeneous stages native (different class/shape per rank) fragile not supported
Pipeline schedules 9 schedules (GPipe→DualPipeV) hand-rolled limited
Unified runtime spx.run(mesh) → SPMD or MPMD separate code paths separate APIs
Eager modules model(x) + pytree-native functional only functional or ref-tracking
Dispatch overhead ~150 µs ~N/A ~300–2000 µs

Dispatch overhead measured on a tiny 2-layer CPU transformer. See benchmarks.


Quick Start

pip install spectrux

Single-device eager training

import jax.numpy as jnp
import spectrux as spx
from spectrux import nn

class MLP(spx.Module):
    def __init__(self, d, h, o, *, rngs):
        super().__init__()
        self.fc1 = nn.Linear(d, h, rngs=rngs)
        self.fc2 = nn.Linear(h, o, rngs=rngs)

    def forward(self, x):
        return self.fc2(nn.gelu(self.fc1(x)))

model = MLP(16, 64, 4, rngs=spx.Rngs(0))

@spx.jit
@spx.value_and_grad
def loss_fn(m, x, y):
    return ((m(x) - y) ** 2).mean()

loss, grads = loss_fn(model, jnp.ones((8, 16)), jnp.zeros((8, 4)))

Marker-based MPMD — split a function into per-rank programs

from spectrux.runtime.mpmd import sxjit, sxstage_iter
from spectrux.pipeline import Std1F1B

# Define a multi-stage forward with explicit stage boundaries
@sxjit(schedule=Std1F1B(microbatches=8))
def pipeline_fwd(model, x):
    x = model.embed(x)                # stage 0
    x = sxstage_iter(x)               # boundary — split here
    x = model.blocks[0](x)            # stage 1
    x = sxstage_iter(x)               # boundary — split here
    x = model.blocks[1](x)            # stage 2
    x = sxstage_iter(x)               # boundary — split here
    return model.head(x)              # stage 3

# sxjit traces the function, splits the jaxpr at sxstage_iter markers,
# and compiles one XLA executable per stage/rank. Each rank only sees
# its own sub-graph — true MPMD, not SPMD.
output = pipeline_fwd(model, x)

MPMD pipeline training — one call, multiple devices

from spectrux.sharding import logical_axis_rules

# Create a 4-stage pipeline mesh
mesh = spx.create_mesh(axis_dims=(2, 1, -1, 1, 1, 1), mpmd_axis="pp")

with logical_axis_rules(FSDP_TP_RULES), mesh:
    # MPMD: each rank gets its own compiled stage
    loss, grads = spx.run(
        model, inputs=x, targets=y,
        mesh=mesh, mode="train", loss_fn=ce_loss, microbatches=8,
    )

# Drop mpmd_axis → same code runs under pure SPMD pjit
# Same model, same script, different mesh — no code changes.

Deferred initialization — infer shapes at runtime

model = nn.Sequential(
    nn.Linear(None, 256, rngs=rngs),   # in_features inferred from first call
    nn.ReLU(),
    nn.Linear(256, 10, rngs=rngs),
)
y = model(jnp.zeros((8, 128)))         # weight shapes resolved here

MPMD Runtime

Spectrux implements true MPMD: each physical rank compiles and executes its own distinct JAX program. This is not SPMD-with-barriers — stages can have different classes, different parameter shapes, and different I/O shapes.

spx.run — unified entry point

spx.run routes to SPMD (pjit) or MPMD (sxcall) based on the mesh:

spx.run(
    model,
    inputs=x,                # microbatched along leading axis
    targets=y,
    mesh=mesh,               # SpxMesh — mpmd_axis decides the path
    mode="train",            # "forward" | "train"
    loss_fn=ce_loss,
    microbatches=8,
)
Mesh type What happens
mpmd_axis=None Pure SPMD — pjit with FSDP/TP via logical axis rules
mpmd_axis="pp" True MPMD — auto-split into stages, per-rank compilation

Lower-level primitives

For full control, drop below spx.run:

Primitive Purpose
sxcall Execute a PipelineSequential under a schedule — Python dispatch loop over pre-built per-rank callables
sxjit Decorator: trace a function, split at sxstage_iter markers, compile one XLA executable per stage/rank
sxgrad / sxvalue_and_grad Schedule-faithful gradients of an sxjit function
treduce Schedule-driven microbatch reduction primitive — binds a body + schedule into the traced jaxpr
sxstage_iter Marker primitive for stage boundaries inside sxjit

Supported schedules

Schedule Type Key trait
GPipe Flat All forwards, then all backwards. Simple, high memory.
Std1F1B Flat Standard 1-forward-1-backward. Peak memory O(n_stages).
ZeroBubbleH1 Flat Splits BWD into input-grad + weight-grad; weight-grad fills bubble slots.
InterleavedH1 Virtual Each rank owns v non-contiguous stages. Bubble shrinks by v.
DualPipeV Virtual V-shaped bidirectional pipeline (DeepSeek-style).
KimiK2 Virtual Interleaved with extra warmup (Moonshot K2 design).

Heterogeneous stages

No same-shape constraint. Stages can be completely different:

model = PipelineSequential(
    EmbedStage(vocab, d, rngs=rngs),       # (B, S) int → (B, S, d)
    BlockStage(d, rngs=rngs),               # (B, S, d) → (B, S, d)
    BlockStage(d, rngs=rngs),               # (B, S, d) → (B, S, d)
    HeadStage(d, vocab, rngs=rngs),        # (B, S, d) → (B, S, vocab)
)

Auto-splitting is available for homogeneous stacks: spx.run detects model.blocks: ModuleList and slices it evenly across pipeline stages.


Installation

pip install spectrux

# Optional extras
pip install "spectrux[contrib]"   # optax integration
pip install "spectrux[cuda]"      # CUDA jaxlib
pip install "spectrux[tpu]"       # TPU jaxlib

From source:

uv sync --extra dev --extra test --extra contrib

Requires Python 3.11+ and JAX >= 0.9.2.


Features

MPMD Pipeline Parallelism

True multi-program multi-data execution with per-rank compilation, schedule-faithful dispatch, and heterogeneous stage support.

Module-aware JAX transforms

Transform What it does
spx.jit mutable= selector declares writable collections
spx.grad / spx.value_and_grad wrt= selector picks the differentiated subset
spx.vmap module states passed with in_axes=None automatically
spx.scan / spx.remat_scan module-aware loops, optionally checkpointed
spx.remat gradient checkpointing
spx.cond / spx.switch / spx.while_loop / spx.fori_loop control flow over module state

Selector DSL

One composable predicate for every "subset of the model" API:

spx.grad(loss, wrt="params")                     # by collection name
spx.grad(loss, wrt=nn.LoraParameter)             # by Variable class
spx.grad(loss, wrt=spx.path_contains("attn"))    # by path glob

sel = spx.select().at_instances_of(nn.Linear).of_type(spx.Parameter) - spx.path_contains("head")
trainable, frozen = sel.partition_state(model, state)

Sharding & SPMD

Annotate variables with logical axis names and let the mesh decide:

w = spx.Parameter(
    jnp.zeros((256, 256)),
    sharding=spx.Sharding(("data", "model")),
    axis_names=("in", "out"),
)

LoRA Fine-Tuning

base = nn.Linear(768, 768, rngs=spx.Rngs(0))
model = nn.wrap_lora(base, rank=8, alpha=16, rngs=spx.Rngs(1))

@spx.jit
@spx.grad(wrt="lora")
def step(m, x, y):
    return ((m(x) - y) ** 2).mean()

FP8 Training

Delayed-scaling per-tensor FP8 with rolling amax history:

@spx.jit(mutable="fp8_meta")
def step(m, x, y):
    def loss(m, x, y):
        return ((m(x) - y) ** 2).mean()
    return spx.grad(loss)(m, x, y)

Explicit graph / state seam

gdef, state = spx.export(model)          # immutable GraphDef + State dict
model2 = spx.bind(gdef, state)           # reconstruct, skips __init__
spx.update(model, state)                 # in-place state patch
clone = spx.clone(model)                 # deep-copy via export+bind

Inspection

spx.inspect.summary(model, jnp.zeros((1, 128)))
spx.inspect.count_params(model)
spx.inspect.count_bytes(model)
spx.inspect.tabulate(model, example_input)

Examples

See examples/ for runnable scripts:

Folder Topic
01_basics/ Modules, training loops, export/bind, optimizers
02_implementation_guide/ Llama 3, Qwen 2, GPT-2, ViT, custom transformer
03_transformations/ jit, grad, vmap, remat, scan, fori_loop
04_surgery/ Selectors, LoRA injection, FP8, freezing, swapping
05_shardings/ FSDP, TP, hybrid sharding, logical axis rules
06_spmd_scheduled/ Pipeline runtime with all schedules
07_mpmd/ Real MPMD pipeline via spx.run — train, forward, decode, 3-D mesh
python -m examples.01_basics.02_training_loop
python -m examples.02_implementation_guide.01_llama3
python -m examples.07_mpmd.01_train_homogeneous

Design

  1. True MPMD firstsxcall compiles one XLA program per physical rank. Stages can differ in class, shape, and parameters. No SPMD-same-shape constraint.
  2. Unified runtimespx.run dispatches on the mesh. Same model, same script; change the mesh and you change the parallelism strategy.
  3. Schedule-faithful execution — the dispatch loop walks the schedule grid exactly as planned. No hidden reordering, no implicit fusing.
  4. Modules are JAX pytrees — flatten/unflatten through export/bind. jax.jit, jax.tree.map, jax.value_and_grad work directly.
  5. State lives in Variable cellsParameter, Buffer, custom subclasses. Each tags its collection (params, buffers, lora, fp8_meta, ...).
  6. One filter DSL everywhereSelector serves grad(wrt=...), jit(mutable=...), Optimizer(wrt=...), freeze(...), iter_variables(select=...).

Benchmarks

python -m benchmarks.bench --cases all --device cpu
python -m benchmarks.bench --cases large --device tpu     # 1.21B transformer

Results land in benchmarks/results/latest.{json,md} with per-case spectrux_ms / nnx_ms ratios.

On a tiny CPU dispatch-bound benchmark (2-layer / d=64 / batch-2 transformer), spectrux runs at 1.83× the speed of flax.nnx; on d=48 it hits 2.0×. On compute-bound workloads (TPU 8B) the Python gap shrinks but stays positive.


Testing

pytest tests/ -q
pytest tests/test_conformance.py

Status

v0.0.1 — alpha. API may still move; pin the version if you depend on behavioural stability.


License

Apache-2.0.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

spectrux-0.0.1.tar.gz (229.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

spectrux-0.0.1-py3-none-any.whl (321.9 kB view details)

Uploaded Python 3

File details

Details for the file spectrux-0.0.1.tar.gz.

File metadata

  • Download URL: spectrux-0.0.1.tar.gz
  • Upload date:
  • Size: 229.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.11.1 {"installer":{"name":"uv","version":"0.11.1","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"22.04","id":"jammy","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for spectrux-0.0.1.tar.gz
Algorithm Hash digest
SHA256 167a033b8b442ff5383a646afc836c74868a573ce8f72c6edf8e76c9b31b7f52
MD5 27de18b1f7188341c8015ead89a6ead8
BLAKE2b-256 4a2b4c18468ae86a01cee58df7b54d140d77dd3b120cb351574fa4ae4990c23a

See more details on using hashes here.

File details

Details for the file spectrux-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: spectrux-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 321.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.11.1 {"installer":{"name":"uv","version":"0.11.1","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"22.04","id":"jammy","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for spectrux-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 d0dcacb256ad6dd330c37d46b1c1cf1070cc1f8cdedcc8b95d2c54b424704df2
MD5 12d0fd85525bbfb2b6a7e6e4cbbe837c
BLAKE2b-256 1cbb3daa32248c89cfcec3d8ab10a9f1f83614cd4f21ba14c1add119fd7722ff

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page