Skip to main content

Datarax: A high-performance, NNX-based data pipeline framework for JAX

Project description

Datarax: A Data Pipeline Framework for JAX

CI Test Coverage codecov Build Summary

Project Status: Active


Early Development - API Unstable

Datarax is in early development and undergoing rapid iteration. Breaking changes are expected. Pin to specific commits if stability is required. We recommend waiting for a stable release (v1.0) before using Datarax in production.


Datarax (Data + Array/JAX) is an extensible data pipeline framework built for JAX-based machine learning workflows. It leverages JAX's JIT compilation, automatic differentiation, and hardware acceleration to build data loading, preprocessing, and augmentation pipelines that run on CPUs, GPUs, and TPUs.

Key Features

  • JAX-Native Design: All core components built on JAX's functional paradigm with Flax NNX module system for state management
  • High Performance: JIT-compiled pipelines via XLA, with built-in profiling and roofline analysis
  • DAG Execution Engine: Graph-based pipeline construction with branching, parallel execution, caching, and rebatching nodes
  • Scalability: Multi-device and multi-host data distribution with device mesh sharding
  • Determinism: Reproducible pipelines by default using Grain's Feistel cipher shuffling (O(1) memory)
  • Extensibility: Custom data sources, operators, and augmentation strategies via composable NNX modules
  • Benchmarking Suite: Comparative benchmarks against 12+ frameworks with Calibrax-powered analysis and regression checks
  • Ecosystem Integration: Works with Flax, Optax, Orbax, HuggingFace Datasets, and TensorFlow Datasets

Why Datarax?

JAX has mature libraries for models (Flax), optimizers (Optax), and checkpointing (Orbax), but lacks a dedicated data pipeline framework that operates at the same level of abstraction. Existing options are either framework-agnostic loaders that return NumPy arrays (losing JIT/autodiff benefits) or wrappers around tf.data/PyTorch that introduce cross-framework overhead. Datarax aims to fill this gap. The framework is under active development with ongoing performance optimization — the architecture is functional, but throughput and API surface are still being refined.

JAX-Native from the Ground Up

Every component — sources, operators, batchers, samplers, sharders — is a Flax NNX module. Pipeline state is managed through NNX's variable system, which means operators can hold learnable parameters, be serialized with Orbax, and participate in JAX transformations (jit, vmap, grad) without special handling.

Differentiable Data Pipelines

Because operators are NNX modules, gradients flow through the entire pipeline. This enables approaches that are not possible with standard data loaders:

See the differentiable pipeline examples for details.

DAG Execution Model

Pipelines are directed acyclic graphs, not linear chains. The >> operator composes sequential steps, | creates parallel branches, and control-flow nodes (Branch, Merge, SplitField) handle conditional and multi-path logic. The DAG executor manages scheduling, caching, and rebatching across the graph.

Deterministic Reproducibility

Shuffling uses Grain's Feistel cipher permutation, which generates a full-epoch permutation in O(1) memory without materializing the index array. Combined with explicit RNG key threading through every stochastic operator, pipelines produce identical output given the same seed — across restarts, devices, and host counts.

Built-in Competitive Benchmarking

The benchmarking suite profiles datarax against 12+ frameworks (Grain, tf.data, PyTorch DataLoader, DALI, Ray Data, and others) across standardized scenarios. Results are converted to CalibraX runs for direction-aware metrics, regression gating, and W&B export. This benchmark-driven loop is how datarax tracks progress toward competitive throughput — current results and optimization status are tracked in the benchmarking documentation.

Installation

# Basic installation
uv pip install datarax

# With data loading support (HuggingFace, TFDS, audio/image libs)
uv pip install "datarax[data]"

# With GPU support (CUDA 12)
uv pip install "datarax[gpu]"

# Full development installation
uv pip install "datarax[all]"

macOS / Apple Silicon

# macOS CPU mode (recommended)
uv pip install "datarax[all-cpu]"
JAX_PLATFORMS=cpu python your_script.py

# Metal GPU acceleration (experimental, M1/M2/M3+)
uv pip install jax-metal
JAX_PLATFORMS=metal python your_script.py

Note: Metal GPU acceleration is community-tested. CI runs on macOS with CPU only.

Quick Start

import jax
import jax.numpy as jnp
import numpy as np
from flax import nnx

from datarax import build_source_pipeline
from datarax.dag.nodes import OperatorNode
from datarax.operators import ElementOperator, ElementOperatorConfig
from datarax.sources import MemorySource, MemorySourceConfig
from datarax.typing import Element


def normalize(element: Element, key: jax.Array | None = None) -> Element:
    return element.update_data({"image": element.data["image"] / 255.0})


def augment(element: Element, key: jax.Array) -> Element:
    key1, _ = jax.random.split(key)
    flip = jax.random.bernoulli(key1, 0.5)
    new_image = jax.lax.cond(
        flip, lambda img: jnp.flip(img, axis=1), lambda img: img,
        element.data["image"],
    )
    return element.update_data({"image": new_image})


# Create in-memory data source
data = {
    "image": np.random.randint(0, 255, (1000, 28, 28, 1)).astype(np.float32),
    "label": np.random.randint(0, 10, (1000,)).astype(np.int32),
}
source = MemorySource(MemorySourceConfig(), data=data, rngs=nnx.Rngs(0))

# Build pipeline with DAG-based API
normalizer = ElementOperator(
    ElementOperatorConfig(stochastic=False), fn=normalize, rngs=nnx.Rngs(0),
)
augmenter = ElementOperator(
    ElementOperatorConfig(stochastic=True, stream_name="augmentations"),
    fn=augment, rngs=nnx.Rngs(42),
)

pipeline = (
    build_source_pipeline(source, batch_size=32)
    >> OperatorNode(normalizer)
    >> OperatorNode(augmenter)
)

# Process batches
for i, batch in enumerate(pipeline):
    if i >= 3:
        break
    print(f"Batch {i}: images {batch['image'].shape}, labels {batch['label'].shape}")

Advanced: Branching and Parallel DAGs

from datarax.dag.nodes import OperatorNode, Merge, Branch

# Define additional operators
def invert(element: Element, key=None) -> Element:
    return element.update_data({"image": 1.0 - element.data["image"]})

inverter = ElementOperator(
    ElementOperatorConfig(stochastic=False), fn=invert, rngs=nnx.Rngs(0),
)

def is_high_contrast(element):
    return jnp.var(element.data["image"]) > 0.1

# Build a complex DAG:
# 1. Source -> Batching
# 2. Parallel: normalizer AND inverter (| creates a Parallel node)
# 3. Merge: average the two branches
# 4. Branch: conditional path based on image variance
complex_pipeline = (
    build_source_pipeline(source, batch_size=32)
    >> (OperatorNode(normalizer) | OperatorNode(inverter))
    >> Merge("mean")
    >> Branch(
           condition=is_high_contrast,
           true_path=OperatorNode(augmenter),
           false_path=OperatorNode(normalizer),
       )
)

Architecture

src/datarax/
  core/         # Base modules: DataSourceModule, OperatorModule, Element, Batcher, Sampler, Sharder
  dag/          # DAG executor and node system (source, operator, batch, cache, control flow)
  sources/      # MemorySource, TFDS (eager/streaming), HuggingFace (eager/streaming), ArrayRecord, MixedSource
  operators/    # ElementOperator, MapOperator, CompositeOperator, modality-specific (image, text)
    strategies/ # Sequential, Parallel, Branching, Ensemble, Merging execution strategies
  samplers/     # Sequential, Shuffle (Feistel cipher), Range, EpochAware samplers
  sharding/     # ArraySharder, JaxProcessSharder for multi-device distribution
  distributed/  # DeviceMesh, DataParallel for multi-host training
  batching/     # DefaultBatcher with buffer state management
  checkpoint/   # NNXCheckpointHandler with Orbax integration
  monitoring/   # Pipeline monitor, DAG monitor, reporters
  performance/  # Roofline analysis, XLA optimization utilities
  control/      # Prefetcher for asynchronous data loading
  memory/       # Shared memory manager for multi-process data sharing
  config/       # TOML-based configuration system with schema validation
  cli/          # datarax CLI entry point
  utils/        # PyTree utilities, external integration helpers

Benchmarking

Datarax includes a benchmarking suite for comparison against 12+ data loading frameworks across a range of workload scenarios (vision, NLP, tabular, multimodal, distributed).

# Install benchmark dependencies (adds PyTorch, DALI, Ray, etc.)
uv sync --extra benchmark

# Optional: install CalibraX with W&B support explicitly
uv pip install "calibrax[wandb] @ git+https://github.com/avitai/calibrax.git"

# Run benchmarks locally
uv run python -m benchmarks.runners.full_runner --platform cpu --repetitions 5

# Run on cloud (SkyPilot)
sky launch benchmarks/sky/gpu-benchmark.yaml --env WANDB_API_KEY=$WANDB_API_KEY

Benchmark results are exported to W&B with charts, gap analysis, stability reports, and raw result artifacts. See Benchmarking Guide for methodology and cloud deployment.

Development Setup

Datarax uses uv as its package manager:

# Clone and setup
git clone https://github.com/avitai/datarax.git
cd datarax

# Automatic setup
./setup.sh && source activate.sh

# Or manual install
uv sync --extra dev

Running Tests

# CPU-only (most stable)
JAX_PLATFORMS=cpu uv run pytest

# Include benchmark test suite in the same run
JAX_PLATFORMS=cpu uv run pytest --all-suites

# Specific module
JAX_PLATFORMS=cpu uv run pytest tests/sources/test_memory_source_module.py

Docker

# Build and run
docker build -t datarax:latest .
docker run --rm --gpus all datarax:latest python -c "import datarax, jax; print(jax.devices())"

# Benchmark images
docker build -f benchmarks/docker/Dockerfile.gpu -t datarax-bench:gpu .

See Docker Guide for full details.

Documentation

License

Datarax is licensed under the MIT License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

datarax-0.1.2.post1.tar.gz (241.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

datarax-0.1.2.post1-py3-none-any.whl (298.4 kB view details)

Uploaded Python 3

File details

Details for the file datarax-0.1.2.post1.tar.gz.

File metadata

  • Download URL: datarax-0.1.2.post1.tar.gz
  • Upload date:
  • Size: 241.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for datarax-0.1.2.post1.tar.gz
Algorithm Hash digest
SHA256 f4f3098241c250300d3ee579e5ea3910f768f9e9989c88c8f3d7aefe153ff63c
MD5 65f8e533b1fc36f0f5726e3e0ce77ec1
BLAKE2b-256 11729d3a3bfc8daa2366ae141ff4a6e6e45e329b3154564a16d6482316d0cfcc

See more details on using hashes here.

Provenance

The following attestation bundles were made for datarax-0.1.2.post1.tar.gz:

Publisher: publish.yml on avitai/datarax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file datarax-0.1.2.post1-py3-none-any.whl.

File metadata

  • Download URL: datarax-0.1.2.post1-py3-none-any.whl
  • Upload date:
  • Size: 298.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for datarax-0.1.2.post1-py3-none-any.whl
Algorithm Hash digest
SHA256 49ffd050049c8d162783921f4af049d429e539c89b76b7f36b5e7df316499c00
MD5 a2449aed234f71d80c27873d82fe777e
BLAKE2b-256 3a6d93fb2d07eb84c7abab96a9532ddd29f9cc7bc4cf7f50c6d92c67806e5635

See more details on using hashes here.

Provenance

The following attestation bundles were made for datarax-0.1.2.post1-py3-none-any.whl:

Publisher: publish.yml on avitai/datarax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page