Skip to main content

A JAX-only neural-network library with a PyTorch-shaped eager surface and Flax-NNX-style graph/state underneath.

Project description

specflux

True MPMD pipeline parallelism for JAX — with an eager module API.

Quick Start | Installation | MPMD Runtime | Examples | Docs


SpecFlux is a JAX-native neural network library built around true MPMD pipeline parallelism. Each physical rank compiles and runs its own XLA program — no shared shard_map HLO, no SPMD-same-shape constraint. Heterogeneous stages (embed → blocks → head), multiple schedules (GPipe, 1F1B, ZeroBubble, Interleaved, DualPipe), and a unified spx.run() entry point that dispatches to SPMD or MPMD from the same training script.

The module API is eager and debuggable — subclass Module, override forward, call model(x) — but every Module is a JAX pytree, so jax.jit, jax.grad, and jax.tree.map work directly.


Why specflux?

Capability specflux JAX + manual other JAX frameworks
True MPMD built-in (sxcall, sxjit) hand-rolled SPMD-only
Heterogeneous stages native (different class/shape per rank) fragile not supported
Pipeline schedules 9 schedules (GPipe→DualPipeV) hand-rolled limited
Unified runtime spx.run(mesh) → SPMD or MPMD separate code paths separate APIs
Eager modules model(x) + pytree-native functional only functional or ref-tracking
Dispatch overhead ~150 µs ~N/A ~300–2000 µs

Dispatch overhead measured on a tiny 2-layer CPU transformer. See benchmarks.


Quick Start

pip install specflux

Single-device eager training

import jax.numpy as jnp
import specflux as spx
from specflux import nn

class MLP(spx.Module):
    def __init__(self, d, h, o, *, rngs):
        super().__init__()
        self.fc1 = nn.Linear(d, h, rngs=rngs)
        self.fc2 = nn.Linear(h, o, rngs=rngs)

    def forward(self, x):
        return self.fc2(nn.gelu(self.fc1(x)))

model = MLP(16, 64, 4, rngs=spx.Rngs(0))

@spx.jit
@spx.value_and_grad
def loss_fn(m, x, y):
    return ((m(x) - y) ** 2).mean()

loss, grads = loss_fn(model, jnp.ones((8, 16)), jnp.zeros((8, 4)))

Marker-based MPMD — split a function into per-rank programs

from specflux.runtime.mpmd import sxjit, sxstage_iter
from specflux.pipeline import Std1F1B

# Define a multi-stage forward with explicit stage boundaries
@sxjit(schedule=Std1F1B(microbatches=8))
def pipeline_fwd(model, x):
    x = model.embed(x)                # stage 0
    x = sxstage_iter(x)               # boundary — split here
    x = model.blocks[0](x)            # stage 1
    x = sxstage_iter(x)               # boundary — split here
    x = model.blocks[1](x)            # stage 2
    x = sxstage_iter(x)               # boundary — split here
    return model.head(x)              # stage 3

# sxjit traces the function, splits the jaxpr at sxstage_iter markers,
# and compiles one XLA executable per stage/rank. Each rank only sees
# its own sub-graph — true MPMD, not SPMD.
output = pipeline_fwd(model, x)

MPMD pipeline training — one call, multiple devices

from specflux.sharding import logical_axis_rules

# Create a 4-stage pipeline mesh
mesh = spx.create_mesh(axis_dims=(2, 1, -1, 1, 1, 1), mpmd_axis="pp")

with logical_axis_rules(FSDP_TP_RULES), mesh:
    # MPMD: each rank gets its own compiled stage
    loss, grads = spx.run(
        model, inputs=x, targets=y,
        mesh=mesh, mode="train", loss_fn=ce_loss, microbatches=8,
    )

# Drop mpmd_axis → same code runs under pure SPMD pjit
# Same model, same script, different mesh — no code changes.

Deferred initialization — infer shapes at runtime

model = nn.Sequential(
    nn.Linear(None, 256, rngs=rngs),   # in_features inferred from first call
    nn.ReLU(),
    nn.Linear(256, 10, rngs=rngs),
)
y = model(jnp.zeros((8, 128)))         # weight shapes resolved here

MPMD Runtime

SpecFlux implements true MPMD: each physical rank compiles and executes its own distinct JAX program. This is not SPMD-with-barriers — stages can have different classes, different parameter shapes, and different I/O shapes.

spx.run — unified entry point

spx.run routes to SPMD (pjit) or MPMD (sxcall) based on the mesh:

spx.run(
    model,
    inputs=x,                # microbatched along leading axis
    targets=y,
    mesh=mesh,               # SpxMesh — mpmd_axis decides the path
    mode="train",            # "forward" | "train"
    loss_fn=ce_loss,
    microbatches=8,
)
Mesh type What happens
mpmd_axis=None Pure SPMD — pjit with FSDP/TP via logical axis rules
mpmd_axis="pp" True MPMD — auto-split into stages, per-rank compilation

Lower-level primitives

For full control, drop below spx.run:

Primitive Purpose
sxcall Execute a PipelineSequential under a schedule — Python dispatch loop over pre-built per-rank callables
sxjit Decorator: trace a function, split at sxstage_iter markers, compile one XLA executable per stage/rank
sxgrad / sxvalue_and_grad Schedule-faithful gradients of an sxjit function
treduce Schedule-driven microbatch reduction primitive — binds a body + schedule into the traced jaxpr
sxstage_iter Marker primitive for stage boundaries inside sxjit

Supported schedules

Schedule Type Key trait
GPipe Flat All forwards, then all backwards. Simple, high memory.
Std1F1B Flat Standard 1-forward-1-backward. Peak memory O(n_stages).
ZeroBubbleH1 Flat Splits BWD into input-grad + weight-grad; weight-grad fills bubble slots.
InterleavedH1 Virtual Each rank owns v non-contiguous stages. Bubble shrinks by v.
DualPipeV Virtual V-shaped bidirectional pipeline (DeepSeek-style).
KimiK2 Virtual Interleaved with extra warmup (Moonshot K2 design).

Heterogeneous stages

No same-shape constraint. Stages can be completely different:

model = PipelineSequential(
    EmbedStage(vocab, d, rngs=rngs),       # (B, S) int → (B, S, d)
    BlockStage(d, rngs=rngs),               # (B, S, d) → (B, S, d)
    BlockStage(d, rngs=rngs),               # (B, S, d) → (B, S, d)
    HeadStage(d, vocab, rngs=rngs),        # (B, S, d) → (B, S, vocab)
)

Auto-splitting is available for homogeneous stacks: spx.run detects model.blocks: ModuleList and slices it evenly across pipeline stages.


Installation

pip install specflux

# Optional extras
pip install "specflux[contrib]"   # optax integration
pip install "specflux[cuda]"      # CUDA jaxlib
pip install "specflux[tpu]"       # TPU jaxlib

From source:

uv sync --extra dev --extra test --extra contrib

Requires Python 3.11+ and JAX >= 0.9.2.


Features

MPMD Pipeline Parallelism

True multi-program multi-data execution with per-rank compilation, schedule-faithful dispatch, and heterogeneous stage support.

Module-aware JAX transforms

Transform What it does
spx.jit mutable= selector declares writable collections
spx.grad / spx.value_and_grad wrt= selector picks the differentiated subset
spx.vmap module states passed with in_axes=None automatically
spx.scan / spx.remat_scan module-aware loops, optionally checkpointed
spx.remat gradient checkpointing
spx.cond / spx.switch / spx.while_loop / spx.fori_loop control flow over module state

Selector DSL

One composable predicate for every "subset of the model" API:

spx.grad(loss, wrt="params")                     # by collection name
spx.grad(loss, wrt=nn.LoraParameter)             # by Variable class
spx.grad(loss, wrt=spx.path_contains("attn"))    # by path glob

sel = spx.select().at_instances_of(nn.Linear).of_type(spx.Parameter) - spx.path_contains("head")
trainable, frozen = sel.partition_state(model, state)

Sharding & SPMD

Annotate variables with logical axis names and let the mesh decide:

w = spx.Parameter(
    jnp.zeros((256, 256)),
    sharding=spx.Sharding(("data", "model")),
    axis_names=("in", "out"),
)

LoRA Fine-Tuning

base = nn.Linear(768, 768, rngs=spx.Rngs(0))
model = nn.wrap_lora(base, rank=8, alpha=16, rngs=spx.Rngs(1))

@spx.jit
@spx.grad(wrt="lora")
def step(m, x, y):
    return ((m(x) - y) ** 2).mean()

FP8 Training

Delayed-scaling per-tensor FP8 with rolling amax history:

@spx.jit(mutable="fp8_meta")
def step(m, x, y):
    def loss(m, x, y):
        return ((m(x) - y) ** 2).mean()
    return spx.grad(loss)(m, x, y)

Explicit graph / state seam

gdef, state = spx.export(model)          # immutable GraphDef + State dict
model2 = spx.bind(gdef, state)           # reconstruct, skips __init__
spx.update(model, state)                 # in-place state patch
clone = spx.clone(model)                 # deep-copy via export+bind

Inspection

spx.inspect.summary(model, jnp.zeros((1, 128)))
spx.inspect.count_params(model)
spx.inspect.count_bytes(model)
spx.inspect.tabulate(model, example_input)

Examples

See examples/ for runnable scripts:

Folder Topic
01_basics/ Modules, training loops, export/bind, optimizers
02_implementation_guide/ Llama 3, Qwen 2, GPT-2, ViT, custom transformer
03_transformations/ jit, grad, vmap, remat, scan, fori_loop
04_surgery/ Selectors, LoRA injection, FP8, freezing, swapping
05_shardings/ FSDP, TP, hybrid sharding, logical axis rules
06_spmd_scheduled/ Pipeline runtime with all schedules
07_mpmd/ Real MPMD pipeline via spx.run — train, forward, decode, 3-D mesh
python -m examples.01_basics.02_training_loop
python -m examples.02_implementation_guide.01_llama3
python -m examples.07_mpmd.01_train_homogeneous

Design

  1. True MPMD firstsxcall compiles one XLA program per physical rank. Stages can differ in class, shape, and parameters. No SPMD-same-shape constraint.
  2. Unified runtimespx.run dispatches on the mesh. Same model, same script; change the mesh and you change the parallelism strategy.
  3. Schedule-faithful execution — the dispatch loop walks the schedule grid exactly as planned. No hidden reordering, no implicit fusing.
  4. Modules are JAX pytrees — flatten/unflatten through export/bind. jax.jit, jax.tree.map, jax.value_and_grad work directly.
  5. State lives in Variable cellsParameter, Buffer, custom subclasses. Each tags its collection (params, buffers, lora, fp8_meta, ...).
  6. One filter DSL everywhereSelector serves grad(wrt=...), jit(mutable=...), Optimizer(wrt=...), freeze(...), iter_variables(select=...).

Benchmarks

python -m benchmarks.bench --cases all --device cpu
python -m benchmarks.bench --cases large --device tpu     # 1.21B transformer

Results land in benchmarks/results/latest.{json,md} with per-case specflux_ms / nnx_ms ratios.

On a tiny CPU dispatch-bound benchmark (2-layer / d=64 / batch-2 transformer), specflux runs at 1.83× the speed of flax.nnx; on d=48 it hits 2.0×. On compute-bound workloads (TPU 8B) the Python gap shrinks but stays positive.


Testing

pytest tests/ -q
pytest tests/test_conformance.py

Status

v0.0.1 — alpha. API may still move; pin the version if you depend on behavioural stability.


License

Apache-2.0.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

specflux-0.0.1.tar.gz (229.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

specflux-0.0.1-py3-none-any.whl (322.0 kB view details)

Uploaded Python 3

File details

Details for the file specflux-0.0.1.tar.gz.

File metadata

  • Download URL: specflux-0.0.1.tar.gz
  • Upload date:
  • Size: 229.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.11.1 {"installer":{"name":"uv","version":"0.11.1","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"22.04","id":"jammy","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for specflux-0.0.1.tar.gz
Algorithm Hash digest
SHA256 d11bb03985bb6434c759763801ec772b7cce26c692e7f1bc1cc08939333154a1
MD5 02bd4b8197ecb7b08544da7c3dd6ca79
BLAKE2b-256 1d5b2e6a9ea4d4da123edfb1ca2165f250bc1301bdf44ac69b1e8256d1487723

See more details on using hashes here.

File details

Details for the file specflux-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: specflux-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 322.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.11.1 {"installer":{"name":"uv","version":"0.11.1","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Ubuntu","version":"22.04","id":"jammy","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for specflux-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 d5558939ba4497670ecadb75b735df7952358d8597d77e2e9e284a1c62bec46e
MD5 b17a7c519be30a2429125b541a7e74fd
BLAKE2b-256 f158d87087c263c9d783c75d11154403d9012e93442c7dd4d16926c9dbcb9d4c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page