Skip to main content

SpectraX: a JAX-only neural-network library with a PyTorch-shaped eager surface and Flax-NNX-style graph/state underneath.

Project description

SpectraX

True MPMD pipeline parallelism for JAX — with an eager module API.

Quick Start | Installation | MPMD Runtime | Benchmark | Examples | Docs


SpectraX is a JAX-native neural network library built around true MPMD pipeline parallelism. Each physical rank compiles and runs its own XLA program — no shared shard_map HLO, no SPMD-same-shape constraint. Heterogeneous stages (embed → blocks → head), multiple schedules (GPipe, 1F1B, ZeroBubble, Interleaved, DualPipe), and a unified spx.run() entry point that dispatches to SPMD or MPMD from the same training script.

The module API is eager and debuggable — subclass Module, override forward, call model(x) — but every Module is a JAX pytree, so jax.jit, jax.grad, and jax.tree.map work directly.


Why SpectraX?

Capability SpectraX JAX + manual other JAX frameworks
True MPMD built-in (sxcall, sxjit) hand-rolled SPMD-only
Heterogeneous stages native (different class/shape per rank) fragile not supported
Pipeline schedules 9 schedules (GPipe→DualPipeV) hand-rolled limited
Unified runtime spx.run(mesh) → SPMD or MPMD separate code paths separate APIs
Eager modules model(x) + pytree-native functional only functional or ref-tracking
Dispatch overhead ~150 µs ~N/A ~300–2000 µs

Dispatch overhead measured on a tiny 2-layer CPU transformer. See benchmark.


Quick Start

pip install spectrax-lib

Single-device eager training

import jax.numpy as jnp
import spectrax as spx
from spectrax import nn

class MLP(spx.Module):
    def __init__(self, d, h, o, *, rngs):
        super().__init__()
        self.fc1 = nn.Linear(d, h, rngs=rngs)
        self.fc2 = nn.Linear(h, o, rngs=rngs)

    def forward(self, x):
        return self.fc2(nn.gelu(self.fc1(x)))

model = MLP(16, 64, 4, rngs=spx.Rngs(0))

@spx.jit
@spx.value_and_grad
def loss_fn(m, x, y):
    return ((m(x) - y) ** 2).mean()

loss, grads = loss_fn(model, jnp.ones((8, 16)), jnp.zeros((8, 4)))

Marker-based MPMD — split a function into per-rank programs

from spectrax.runtime.mpmd import sxjit, sxstage_iter
from spectrax.pipeline import Std1F1B

# Define a multi-stage forward with explicit stage boundaries
@sxjit(schedule=Std1F1B(microbatches=8))
def pipeline_fwd(model, x):
    x = model.embed(x)                # stage 0
    x = sxstage_iter(x)               # boundary — split here
    x = model.blocks[0](x)            # stage 1
    x = sxstage_iter(x)               # boundary — split here
    x = model.blocks[1](x)            # stage 2
    x = sxstage_iter(x)               # boundary — split here
    return model.head(x)              # stage 3

# sxjit traces the function, splits the jaxpr at sxstage_iter markers,
# and compiles one XLA executable per stage/rank. Each rank only sees
# its own sub-graph — true MPMD, not SPMD.
output = pipeline_fwd(model, x)

MPMD pipeline training — one call, multiple devices

from spectrax.sharding import logical_axis_rules

# Create a 4-stage pipeline mesh
mesh = spx.create_mesh(axis_dims=(2, 1, -1, 1, 1, 1), mpmd_axis="pp")

with logical_axis_rules(FSDP_TP_RULES), mesh:
    # MPMD: each rank gets its own compiled stage
    loss, grads = spx.run(
        model, inputs=x, targets=y,
        mesh=mesh, mode="train", loss_fn=ce_loss, microbatches=8,
    )

# Drop mpmd_axis → same code runs under pure SPMD pjit
# Same model, same script, different mesh — no code changes.

Deferred initialization — infer shapes at runtime

model = nn.Sequential(
    nn.Linear(None, 256, rngs=rngs),   # in_features inferred from first call
    nn.ReLU(),
    nn.Linear(256, 10, rngs=rngs),
)
y = model(jnp.zeros((8, 128)))         # weight shapes resolved here

MPMD Runtime

SpectraX implements true MPMD: each physical rank compiles and executes its own distinct JAX program. This is not SPMD-with-barriers — stages can have different classes, different parameter shapes, and different I/O shapes.

spx.run — unified entry point

spx.run routes to SPMD (pjit) or MPMD (sxcall) based on the mesh:

spx.run(
    model,
    inputs=x,                # microbatched along leading axis
    targets=y,
    mesh=mesh,               # SpxMesh — mpmd_axis decides the path
    mode="train",            # "forward" | "train"
    loss_fn=ce_loss,
    microbatches=8,
)
Mesh type What happens
mpmd_axis=None Pure SPMD — pjit with FSDP/TP via logical axis rules
mpmd_axis="pp" True MPMD — auto-split into stages, per-rank compilation

Lower-level primitives

For full control, drop below spx.run:

Primitive Purpose
sxcall Execute a PipelineSequential under a schedule — Python dispatch loop over pre-built per-rank callables
sxjit Decorator: trace a function, split at sxstage_iter markers, compile one XLA executable per stage/rank
sxgrad / sxvalue_and_grad Schedule-faithful gradients of an sxjit function
treduce Schedule-driven microbatch reduction primitive — binds a body + schedule into the traced jaxpr
sxstage_iter Marker primitive for stage boundaries inside sxjit

Supported schedules

Schedule Type Key trait
GPipe Flat All forwards, then all backwards. Simple, high memory.
Std1F1B Flat Standard 1-forward-1-backward. Peak memory O(n_stages).
ZeroBubbleH1 Flat Splits BWD into input-grad + weight-grad; weight-grad fills bubble slots.
InterleavedH1 Virtual Each rank owns v non-contiguous stages. Bubble shrinks by v.
DualPipeV Virtual V-shaped bidirectional pipeline (DeepSeek-style).
KimiK2 Virtual Interleaved with extra warmup (Moonshot K2 design).

Heterogeneous stages

No same-shape constraint. Stages can be completely different:

model = PipelineSequential(
    EmbedStage(vocab, d, rngs=rngs),       # (B, S) int → (B, S, d)
    BlockStage(d, rngs=rngs),               # (B, S, d) → (B, S, d)
    BlockStage(d, rngs=rngs),               # (B, S, d) → (B, S, d)
    HeadStage(d, vocab, rngs=rngs),        # (B, S, d) → (B, S, vocab)
)

Auto-splitting is available for homogeneous stacks: spx.run detects model.blocks: ModuleList and slices it evenly across pipeline stages.


Installation

pip install spectrax-lib

# Optional extras
pip install "spectrax-lib[contrib]"   # optax integration
pip install "spectrax-lib[cuda]"      # CUDA jaxlib
pip install "spectrax-lib[tpu]"       # TPU jaxlib

From source:

uv sync --extra dev --extra test --extra contrib

Requires Python 3.11+ and JAX >= 0.9.2.


Features

MPMD Pipeline Parallelism

True multi-program multi-data execution with per-rank compilation, schedule-faithful dispatch, and heterogeneous stage support.

Module-aware JAX transforms

Transform What it does
spx.jit mutable= selector declares writable collections
spx.grad / spx.value_and_grad wrt= selector picks the differentiated subset
spx.vmap module states passed with in_axes=None automatically
spx.scan / spx.remat_scan module-aware loops, optionally checkpointed
spx.remat gradient checkpointing
spx.cond / spx.switch / spx.while_loop / spx.fori_loop control flow over module state

Selector DSL

One composable predicate for every "subset of the model" API:

spx.grad(loss, wrt="parameters")                     # by collection name
spx.grad(loss, wrt=nn.LoraParameter)             # by Variable class
spx.grad(loss, wrt=spx.path_contains("attn"))    # by path glob

sel = spx.select().at_instances_of(nn.Linear).of_type(spx.Parameter) - spx.path_contains("head")
trainable, frozen = sel.partition_state(model, state)

Sharding & SPMD

Annotate variables with logical axis names and let the mesh decide:

w = spx.Parameter(
    jnp.zeros((256, 256)),
    sharding=spx.Sharding(("data", "model")),
    axis_names=("in", "out"),
)

LoRA Fine-Tuning

base = nn.Linear(768, 768, rngs=spx.Rngs(0))
model = nn.wrap_lora(base, rank=8, alpha=16, rngs=spx.Rngs(1))

@spx.jit
@spx.grad(wrt="lora")
def step(m, x, y):
    return ((m(x) - y) ** 2).mean()

FP8 Training

Delayed-scaling per-tensor FP8 with rolling amax history:

@spx.jit(mutable="fp8_meta")
def step(m, x, y):
    def loss(m, x, y):
        return ((m(x) - y) ** 2).mean()
    return spx.grad(loss)(m, x, y)

Explicit graph / state seam

gdef, state = spx.export(model)          # immutable GraphDef + State dict
model2 = spx.bind(gdef, state)           # reconstruct, skips __init__
spx.update(model, state)                 # in-place state patch
clone = spx.clone(model)                 # deep-copy via export+bind

Inspection

spx.inspect.summary(model, jnp.zeros((1, 128)))
spx.inspect.count_parameters(model)
spx.inspect.count_bytes(model)
spx.inspect.tabulate(model, example_input)

Examples

See examples/ for runnable scripts:

Folder Topic
01_basics/ Modules, training loops, export/bind, optimizers
02_implementation_guide/ Llama 3, Qwen 2, GPT-2, ViT, custom transformer
03_transformations/ jit, grad, vmap, remat, scan, fori_loop
04_surgery/ Selectors, LoRA injection, FP8, freezing, swapping
05_shardings/ FSDP, TP, hybrid sharding, logical axis rules
06_spmd_scheduled/ Pipeline runtime with all schedules
07_mpmd/ Real MPMD pipeline via spx.run — train, forward, decode, 3-D mesh
python -m examples.01_basics.02_training_loop
python -m examples.02_implementation_guide.01_llama3
python -m examples.07_mpmd.01_train_homogeneous

Design

  1. True MPMD firstsxcall compiles one XLA program per physical rank. Stages can differ in class, shape, and parameters. No SPMD-same-shape constraint.
  2. Unified runtimespx.run dispatches on the mesh. Same model, same script; change the mesh and you change the parallelism strategy.
  3. Schedule-faithful execution — the dispatch loop walks the schedule grid exactly as planned. No hidden reordering, no implicit fusing.
  4. Modules are JAX pytrees — flatten/unflatten through export/bind. jax.jit, jax.tree.map, jax.value_and_grad work directly.
  5. State lives in Variable cellsParameter, Buffer, custom subclasses. Each tags its collection (parameters, buffers, lora, fp8_meta, ...).
  6. One filter DSL everywhereSelector serves grad(wrt=...), jit(mutable=...), Optimizer(wrt=...), freeze(...), iter_variables(select=...).

Benchmark

python -m benchmarks.bench --cases all --device cpu
python -m benchmarks.bench --cases large --device tpu     # 1.21B transformer
python -m benchmarks.llama_transforms_3way --preset small --device tpu --plots

The CPU dispatch benchmark writes benchmarks/results/latest.{json,md} with per-case spectrax_ms / nnx_ms ratios. On a tiny CPU dispatch-bound benchmark (2-layer / d=64 / batch-2 transformer), SpectraX runs at 1.83× the speed of flax.nnx; on d=48 it hits 2.0×. On compute-bound workloads (TPU 8B) the Python gap shrinks but stays positive.

The plots below are from the TPU small three-way Llama transform benchmark, which compares raw JAX, Flax NNX, and SpectraX across the transform surface used by real training code: jit, grad, value_and_grad, jvp, vjp, vmap, control-flow primitives, scan, and remat. These are runtime ratios against raw JAX, so 1.0 means parity, lower is faster, and higher is slower.

TPU small runtime ratio vs raw JAX

Compile latency is tracked separately because first-call XLA costs can dominate short benchmark cases and tell a different story from steady-state runtime.

TPU small compile time


Testing

pytest tests/ -q
pytest tests/test_conformance.py

Status

v0.0.1 — alpha. API may still move; pin the version if you depend on behavioural stability.


License

Apache-2.0.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

spectrax_lib-0.0.2.tar.gz (274.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

spectrax_lib-0.0.2-py3-none-any.whl (365.8 kB view details)

Uploaded Python 3

File details

Details for the file spectrax_lib-0.0.2.tar.gz.

File metadata

  • Download URL: spectrax_lib-0.0.2.tar.gz
  • Upload date:
  • Size: 274.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.11.1 {"installer":{"name":"uv","version":"0.11.1","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"22.04","id":"jammy","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for spectrax_lib-0.0.2.tar.gz
Algorithm Hash digest
SHA256 f061c9558b5bc7fa02b9c64456a52cce7b2eb0918fbbe7c5d48ac6343add04c1
MD5 ca11d9d057927be90fc06951a0b9a6e5
BLAKE2b-256 592ccd54bd14de56bf6979aec6d2e105464069ebe72d8d16644f5248c66003db

See more details on using hashes here.

File details

Details for the file spectrax_lib-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: spectrax_lib-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 365.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.11.1 {"installer":{"name":"uv","version":"0.11.1","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"22.04","id":"jammy","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for spectrax_lib-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 18e260660acca8e1e8be6656d3f563af0007960cfec514fe19ce2a204af89ac8
MD5 93fbd41433dcc4b0290e726175289b3c
BLAKE2b-256 01239d209bd88c838f9c1c3f9da71a208deb83ff81930a7bc66e20784a232fdf

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page