Skip to main content

Datarax: A high-performance, NNX-based data pipeline framework for JAX

Project description

Datarax: A Data Pipeline Framework for JAX

CI Test Coverage codecov Build Summary

Project Status: Active

Note: This project is in early development. API may change. Expect breaking changes.

Datarax (Data + Array/JAX) is a high-performance, extensible data pipeline framework specifically engineered for JAX-based machine learning workflows. It simplifies and accelerates the development of efficient and scalable data loading, preprocessing, and augmentation pipelines for JAX, leveraging the full potential of JAX's Just-In-Time (JIT) compilation, automatic differentiation, and hardware acceleration capabilities.

Key Features

  • High Performance: Leverages JAX's JIT compilation and XLA backend to achieve near-optimal data processing speeds on CPUs, GPUs, and TPUs.
  • JAX-Native Design: All core components and operations are designed with JAX's functional programming paradigm and immutable data structures (PyTrees) in mind.
  • Scalability: Supports efficient data loading and processing for large datasets and distributed training scenarios, including multi-host and multi-device setups.
  • Extensibility: Easily define and integrate custom data sources, transformations, and augmentation operations.
  • Usability: Provides a clear, intuitive Python API and a flexible configuration system (TOML-based) for defining and managing pipelines.
  • Determinism: Pipeline runs are deterministic by default, crucial for reproducibility in research and production.
  • Caching Optimization: Multiple caching strategies for performance improvement, including function caching, transformer caching, and checkpointing.
  • Complete Feature Set: Supports common data pipeline operations including diverse data source handling, advanced transformations, data augmentation, batching, sharding, checkpointing, and caching.
  • Ecosystem Integration: Facilitates smooth integration with other JAX libraries like Flax, Optax, and Orbax.

Installation

IMPORTANT: Datarax uses uv as its package manager for all installation, development, and deployment tasks.

# Install uv if you don't have it already
pip install uv

# Run the automatic setup script (creates environment & installs dependencies)
./setup.sh

# Activate the environment
source activate.sh

# Install via uv
uv pip install datarax

# Install with optional dependencies
uv pip install datarax[data]     # For data loading (HF, TFDS, Audio/Image libs)
uv pip install datarax[gpu]      # For GPU support (CUDA 12)

# For development
uv pip install datarax[dev]

# For all dependencies
uv pip install datarax[all]

Development Environment

Datarax development requires using uv for all package management operations. See the Development Environment Guide for detailed instructions on setting up a development environment, including:

  • Creating a virtual environment with uv venv
  • Installing dependencies through pyproject.toml
  • Running tests through pytest
  • Building and packaging Datarax

Current Status

Datarax is currently in active development with a Flax NNX-based architecture that provides robust state management and checkpointing:

Core Architecture

  • NNX Module System: All components built on flax.nnx.Module for robust state management
  • Integrated Checkpointing: Seamless Orbax integration for stateful pipeline persistence
  • Type Safety: Complete type annotations and runtime validation
  • Composability: Modular design enabling flexible pipeline construction

Implemented Components

  • Core NNX Modules: DataraxModule, OperatorModule, StructuralModule
  • Data Sources: MemorySource, TFDSSource, HFSource (inheriting from StructuralModule/DataSourceModule)
  • Operators: Element-wise operators, MapOperator (inheriting from OperatorModule)
  • DAG Execution Engine: DAGExecutor and pipeline API for constructing flexible, graph-based data processing flows.
  • Node System: A rich set of nodes for building pipelines:
    • DataSourceNode: entry point for data
    • OperatorNode: for transformations
    • BatchNode & RebatchNode: for batching control
    • ShuffleNode, CacheNode: for data management
    • Control flow nodes: Sequential, Parallel, Branch, Merge
  • Stateful Components:
    • MemorySourceModule/ArrayRecordSourceModule/HFSourceModule for diverse data ingestion
    • Range/ShuffleSamplerModule with reproducible random state
    • DefaultBatcherModule with buffer state management
    • ArraySharderModule for device-aware data distribution
  • External Integrations:
    • HuggingFace Datasets with stateful iteration
    • TensorFlow Datasets with checkpoint support
  • Advanced Features:
    • Complete caching strategies with state preservation
    • Differentiable rebatching (DifferentiableRebatchImpl)
    • NNXCheckpointHandler for production-grade checkpointing

Upcoming features include:

  • Image transformation library with JAX-native operations
  • Advanced sharding strategies for multi-device and multi-host scenarios
  • Performance optimization suite with benchmarking tools
  • Extended monitoring and metrics capabilities
  • Additional external data source integrations

Quick Start

Here's a simple example of using Datarax's DAG-based architecture to create a data pipeline for image classification:

import jax
import jax.numpy as jnp
import numpy as np
from flax import nnx

from datarax import from_source
from datarax.dag.nodes import OperatorNode
from datarax.operators import ElementOperator, ElementOperatorConfig
from datarax.sources import MemorySource, MemorySourceConfig
from datarax.typing import Element


def normalize(element: Element, key: jax.Array | None = None) -> Element:
    """Normalize the image in the element."""
    # element.data is a dict-like PyTree containing the actual data arrays
    # Return a new Element with updated data (immutable update)
    # update_data provides a clean API for partial updates
    return element.update_data({"image": element.data["image"] / 255.0})


def apply_augmentation(element: Element, key: jax.Array) -> Element:
    """Apply a simple augmentation (flipping) to the image."""
    key1, _ = jax.random.split(key)
    flip_horizontal = jax.random.bernoulli(key1, 0.5)

    # Use jax.lax.cond for JAX-compatible conditional execution
    def flip_image(img):
        return jnp.flip(img, axis=1)

    def no_flip(img):
        return img

    new_image = jax.lax.cond(
        flip_horizontal,
        flip_image,
        no_flip,
        element.data["image"]
    )
    return element.update_data({"image": new_image})


# Create some dummy data (28x28 images)
num_samples = 1000
image_shape = (28, 28, 1)
data = {
    "image": np.random.randint(0, 255, (num_samples, *image_shape)).astype(np.float32),
    "label": np.random.randint(0, 10, (num_samples,)).astype(np.int32),
}

# Create the data source using config-based API
# MemorySource is a standard DataSource implementation for in-memory data
source_config = MemorySourceConfig()
source = MemorySource(source_config, data=data, rngs=nnx.Rngs(0))

# Create operators using the unified ElementOperator API
# Normalizer is deterministic (no random key needed)
normalizer_config = ElementOperatorConfig(stochastic=False)
normalizer = ElementOperator(normalizer_config, fn=normalize, rngs=nnx.Rngs(0))

# Augmenter is stochastic: requires stream_name for proper RNG management
augmenter_config = ElementOperatorConfig(stochastic=True, stream_name="augmentations")
augmenter = ElementOperator(augmenter_config, fn=apply_augmentation, rngs=nnx.Rngs(42))

# Create the data pipeline using the DAG-based API
# from_source() initializes the pipeline with:
# 1. The DataSourceNode (data loading)
# 2. A BatchNode (automatic batching)
# The >> operator chains transformation operators
pipeline = (
    from_source(source, batch_size=32)
    >> OperatorNode(normalizer)
    >> OperatorNode(augmenter)
)

# Alternative: Method Chaining
# You can also build the pipeline using the fluent .add() method:
# pipeline = (
#     from_source(source, batch_size=32)
#     .add(OperatorNode(normalizer))
#     .add(OperatorNode(augmenter))
# )

# Create an iterator and process batches
# The pipeline handles data streaming, batching, state management, and execution
for i, batch in enumerate(pipeline):
    if i >= 3:
        break

    # Get the shape and stats for each component in the batch
    # batch['key'] provides direct access to data arrays
    image_batch = batch["image"]
    label_batch = batch["label"]

    print(f"Batch {i}:")
    print(f"  Image shape: {image_batch.shape}")
    print(f"  Label batch size: {label_batch.shape[0]}")
    print(f"  Image min/max: {image_batch.min():.3f}/{image_batch.max():.3f}")

print("Pipeline processing completed!")

Advanced Pipeline

For more complex workflows, Datarax supports branching and parallel execution:

from datarax.dag.nodes import Parallel, Merge, Branch

# Define additional operators
def invert(element, key=None):
    return element.update_data({"image": 1.0 - element.data["image"]})

def is_high_contrast(element):
    # Condition: check if image variance is high
    return jnp.var(element.data["image"]) > 0.1

# Build a complex DAG:
# 1. Source -> Batching
# 2. Parallel: Normal version AND Inverted version
# 3. Merge: Average them (Simple Ensemble)
# 4. Branch: Apply extra noise ONLY if high contrast, otherwise normalize again
complex_pipeline = (
    from_source(source, batch_size=32)
    >> (OperatorNode(normalizer) | OperatorNode(invert))
    >> Merge("mean")
    >> Branch(
           condition=is_high_contrast,
           true_path=OperatorNode(augmenter),
           false_path=OperatorNode(normalizer)
       )
)

Documentation

For complete documentation, please visit datarax.readthedocs.io.

Testing

Datarax uses a complete test suite with support for both CPU and GPU testing:

  • Tests are organized to mirror the src/datarax package structure for easier navigation
  • All GitHub CI workflows run tests exclusively on CPU
  • Local test runs automatically use GPU when available
  • The testing infrastructure handles environment configuration to ensure consistency

To run tests:

# Run all tests using CPU only (most stable)
JAX_PLATFORMS=cpu python -m pytest

# Run specific test module
JAX_PLATFORMS=cpu python -m pytest tests/sources/test_memory_source.py

For more information on the test organization and how to run tests, see the Testing Guide.

License

Datarax is licensed under the MIT License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

datarax-0.1.0.tar.gz (208.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

datarax-0.1.0-py3-none-any.whl (260.3 kB view details)

Uploaded Python 3

File details

Details for the file datarax-0.1.0.tar.gz.

File metadata

  • Download URL: datarax-0.1.0.tar.gz
  • Upload date:
  • Size: 208.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.10

File hashes

Hashes for datarax-0.1.0.tar.gz
Algorithm Hash digest
SHA256 d2674442a1bd25740a83f95508780f42204334f9b76d39d5389a0d9d3e7b607a
MD5 1b6f235978448fabf9702666edeedfcc
BLAKE2b-256 9d95e73a21e0f92a27b1a04e817cc98b7586b6c1fa24663a8edf788bc3ebe1e7

See more details on using hashes here.

File details

Details for the file datarax-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: datarax-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 260.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.10

File hashes

Hashes for datarax-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 9c299a8d2aa65a0eb1c2f488cf300a90911da08628ec97c74d8c92f8a581e7e7
MD5 61824afb9b45303d6ed4ae2f54506dce
BLAKE2b-256 bc547b958a8e17993883e5947fc7eb65ad5ab0f2ed34ffdebfffed6b8beed1a5

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page