Skip to main content

Datarax: A high-performance, NNX-based data pipeline framework for JAX

Project description

Datarax: A Data Pipeline Framework for JAX

CI Test Coverage codecov Build Summary

Project Status: Active

Note: This project is in early development. API may change. Expect breaking changes.

Datarax (Data + Array/JAX) is a high-performance, extensible data pipeline framework specifically engineered for JAX-based machine learning workflows. It simplifies and accelerates the development of efficient and scalable data loading, preprocessing, and augmentation pipelines for JAX, leveraging the full potential of JAX's Just-In-Time (JIT) compilation, automatic differentiation, and hardware acceleration capabilities.

Key Features

  • High Performance: Leverages JAX's JIT compilation and XLA backend to achieve near-optimal data processing speeds on CPUs, GPUs, and TPUs.
  • JAX-Native Design: All core components and operations are designed with JAX's functional programming paradigm and immutable data structures (PyTrees) in mind.
  • Scalability: Supports efficient data loading and processing for large datasets and distributed training scenarios, including multi-host and multi-device setups.
  • Extensibility: Easily define and integrate custom data sources, transformations, and augmentation operations.
  • Usability: Provides a clear, intuitive Python API and a flexible configuration system (TOML-based) for defining and managing pipelines.
  • Determinism: Pipeline runs are deterministic by default, crucial for reproducibility in research and production.
  • Caching Optimization: Multiple caching strategies for performance improvement, including function caching, transformer caching, and checkpointing.
  • Complete Feature Set: Supports common data pipeline operations including diverse data source handling, advanced transformations, data augmentation, batching, sharding, checkpointing, and caching.
  • Ecosystem Integration: Facilitates smooth integration with other JAX libraries like Flax, Optax, and Orbax.

Installation

Install Datarax using pip:

# Basic installation
pip install datarax

# With data loading support (HuggingFace, TFDS, Audio/Image libs)
pip install datarax[data]

# With GPU support (CUDA 12)
pip install datarax[gpu]

# Full installation with all optional dependencies
pip install datarax[all]

macOS / Apple Silicon

Datarax supports macOS on both Intel and Apple Silicon (M1/M2/M3) processors.

# Install for macOS (CPU mode - recommended)
pip install datarax[all-cpu]

# Run with explicit CPU backend
JAX_PLATFORMS=cpu python your_script.py

Metal GPU Acceleration (Experimental): JAX supports Apple's Metal backend for GPU acceleration on M1+ chips:

pip install jax-metal
JAX_PLATFORMS=metal python your_script.py

Note: Metal GPU acceleration is community-tested. CI runs on macOS with CPU only. If you can help validate Metal support, please open an issue with your results.

Known Limitations:

  • TensorFlow Datasets: May hang on import on macOS ARM64. Datarax uses lazy imports to handle this.
  • JAX Profiler: Tracing is automatically disabled on macOS due to TensorBoard compatibility.

Development Setup

For contributors and developers, Datarax uses uv as its package manager:

# Install uv
pip install uv

# Clone and setup
git clone https://github.com/avitai/datarax.git
cd datarax

# Run automatic setup (creates venv & installs dependencies)
./setup.sh

# Activate the environment
source activate.sh

# Or install manually with uv
uv pip install -e ".[dev]"

See the Development Environment Guide for detailed instructions on:

  • Creating a virtual environment with uv venv
  • Installing dependencies through pyproject.toml
  • Running tests through pytest
  • Building and packaging Datarax

Current Status

Datarax is currently in active development with a Flax NNX-based architecture that provides robust state management and checkpointing:

Core Architecture

  • NNX Module System: All components built on flax.nnx.Module for robust state management
  • Integrated Checkpointing: Seamless Orbax integration for stateful pipeline persistence
  • Type Safety: Complete type annotations and runtime validation
  • Composability: Modular design enabling flexible pipeline construction

Implemented Components

  • Core NNX Modules: DataraxModule, OperatorModule, StructuralModule
  • Data Sources: MemorySource, TFDSSource, HFSource (inheriting from StructuralModule/DataSourceModule)
  • Operators: Element-wise operators, MapOperator (inheriting from OperatorModule)
  • DAG Execution Engine: DAGExecutor and pipeline API for constructing flexible, graph-based data processing flows.
  • Node System: A rich set of nodes for building pipelines:
    • DataSourceNode: entry point for data
    • OperatorNode: for transformations
    • BatchNode & RebatchNode: for batching control
    • ShuffleNode, CacheNode: for data management
    • Control flow nodes: Sequential, Parallel, Branch, Merge
  • Stateful Components:
    • MemorySourceModule/ArrayRecordSourceModule/HFSourceModule for diverse data ingestion
    • Range/ShuffleSamplerModule with reproducible random state
    • DefaultBatcherModule with buffer state management
    • ArraySharderModule for device-aware data distribution
  • External Integrations:
    • HuggingFace Datasets with stateful iteration
    • TensorFlow Datasets with checkpoint support
  • Advanced Features:
    • Complete caching strategies with state preservation
    • Differentiable rebatching (DifferentiableRebatchImpl)
    • NNXCheckpointHandler for production-grade checkpointing

Upcoming features include:

  • Image transformation library with JAX-native operations
  • Advanced sharding strategies for multi-device and multi-host scenarios
  • Performance optimization suite with benchmarking tools
  • Extended monitoring and metrics capabilities
  • Additional external data source integrations

Quick Start

Here's a simple example of using Datarax's DAG-based architecture to create a data pipeline for image classification:

import jax
import jax.numpy as jnp
import numpy as np
from flax import nnx

from datarax import from_source
from datarax.dag.nodes import OperatorNode
from datarax.operators import ElementOperator, ElementOperatorConfig
from datarax.sources import MemorySource, MemorySourceConfig
from datarax.typing import Element


def normalize(element: Element, key: jax.Array | None = None) -> Element:
    """Normalize the image in the element."""
    # element.data is a dict-like PyTree containing the actual data arrays
    # Return a new Element with updated data (immutable update)
    # update_data provides a clean API for partial updates
    return element.update_data({"image": element.data["image"] / 255.0})


def apply_augmentation(element: Element, key: jax.Array) -> Element:
    """Apply a simple augmentation (flipping) to the image."""
    key1, _ = jax.random.split(key)
    flip_horizontal = jax.random.bernoulli(key1, 0.5)

    # Use jax.lax.cond for JAX-compatible conditional execution
    def flip_image(img):
        return jnp.flip(img, axis=1)

    def no_flip(img):
        return img

    new_image = jax.lax.cond(
        flip_horizontal,
        flip_image,
        no_flip,
        element.data["image"]
    )
    return element.update_data({"image": new_image})


# Create some dummy data (28x28 images)
num_samples = 1000
image_shape = (28, 28, 1)
data = {
    "image": np.random.randint(0, 255, (num_samples, *image_shape)).astype(np.float32),
    "label": np.random.randint(0, 10, (num_samples,)).astype(np.int32),
}

# Create the data source using config-based API
# MemorySource is a standard DataSource implementation for in-memory data
source_config = MemorySourceConfig()
source = MemorySource(source_config, data=data, rngs=nnx.Rngs(0))

# Create operators using the unified ElementOperator API
# Normalizer is deterministic (no random key needed)
normalizer_config = ElementOperatorConfig(stochastic=False)
normalizer = ElementOperator(normalizer_config, fn=normalize, rngs=nnx.Rngs(0))

# Augmenter is stochastic: requires stream_name for proper RNG management
augmenter_config = ElementOperatorConfig(stochastic=True, stream_name="augmentations")
augmenter = ElementOperator(augmenter_config, fn=apply_augmentation, rngs=nnx.Rngs(42))

# Create the data pipeline using the DAG-based API
# from_source() initializes the pipeline with:
# 1. The DataSourceNode (data loading)
# 2. A BatchNode (automatic batching)
# The >> operator chains transformation operators
pipeline = (
    from_source(source, batch_size=32)
    >> OperatorNode(normalizer)
    >> OperatorNode(augmenter)
)

# Alternative: Method Chaining
# You can also build the pipeline using the fluent .add() method:
# pipeline = (
#     from_source(source, batch_size=32)
#     .add(OperatorNode(normalizer))
#     .add(OperatorNode(augmenter))
# )

# Create an iterator and process batches
# The pipeline handles data streaming, batching, state management, and execution
for i, batch in enumerate(pipeline):
    if i >= 3:
        break

    # Get the shape and stats for each component in the batch
    # batch['key'] provides direct access to data arrays
    image_batch = batch["image"]
    label_batch = batch["label"]

    print(f"Batch {i}:")
    print(f"  Image shape: {image_batch.shape}")
    print(f"  Label batch size: {label_batch.shape[0]}")
    print(f"  Image min/max: {image_batch.min():.3f}/{image_batch.max():.3f}")

print("Pipeline processing completed!")

Advanced Pipeline

For more complex workflows, Datarax supports branching and parallel execution:

from datarax.dag.nodes import Parallel, Merge, Branch

# Define additional operators
def invert(element, key=None):
    return element.update_data({"image": 1.0 - element.data["image"]})

def is_high_contrast(element):
    # Condition: check if image variance is high
    return jnp.var(element.data["image"]) > 0.1

# Build a complex DAG:
# 1. Source -> Batching
# 2. Parallel: Normal version AND Inverted version
# 3. Merge: Average them (Simple Ensemble)
# 4. Branch: Apply extra noise ONLY if high contrast, otherwise normalize again
complex_pipeline = (
    from_source(source, batch_size=32)
    >> (OperatorNode(normalizer) | OperatorNode(invert))
    >> Merge("mean")
    >> Branch(
           condition=is_high_contrast,
           true_path=OperatorNode(augmenter),
           false_path=OperatorNode(normalizer)
       )
)

Documentation

For complete documentation, please visit datarax.readthedocs.io.

Testing

Datarax uses a complete test suite with support for both CPU and GPU testing:

  • Tests are organized to mirror the src/datarax package structure for easier navigation
  • All GitHub CI workflows run tests exclusively on CPU
  • Local test runs automatically use GPU when available
  • The testing infrastructure handles environment configuration to ensure consistency

To run tests:

# Run all tests using CPU only (most stable)
JAX_PLATFORMS=cpu python -m pytest

# Run specific test module
JAX_PLATFORMS=cpu python -m pytest tests/sources/test_memory_source.py

For more information on the test organization and how to run tests, see the Testing Guide.

License

Datarax is licensed under the MIT License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

datarax-0.1.0.post1.tar.gz (210.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

datarax-0.1.0.post1-py3-none-any.whl (261.7 kB view details)

Uploaded Python 3

File details

Details for the file datarax-0.1.0.post1.tar.gz.

File metadata

  • Download URL: datarax-0.1.0.post1.tar.gz
  • Upload date:
  • Size: 210.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for datarax-0.1.0.post1.tar.gz
Algorithm Hash digest
SHA256 4defc8c14b71cb58ec8ffe7e9940773a76c32b94b52e064cfd5cde8c6214b0f0
MD5 47fea287cc35d7236eb3afe549eb121d
BLAKE2b-256 1d6616698700cf5fe3a21d89ea48319c02ac0e149e0a1badabba7d246607ac0a

See more details on using hashes here.

Provenance

The following attestation bundles were made for datarax-0.1.0.post1.tar.gz:

Publisher: publish.yml on avitai/datarax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file datarax-0.1.0.post1-py3-none-any.whl.

File metadata

  • Download URL: datarax-0.1.0.post1-py3-none-any.whl
  • Upload date:
  • Size: 261.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for datarax-0.1.0.post1-py3-none-any.whl
Algorithm Hash digest
SHA256 9fb0beebe403ccf8d5c8f6b10f4805e127dd748968fc8dea60db7c13cd9411b4
MD5 54eb974f1f0a4bf6c8ef010fa937dc9e
BLAKE2b-256 a258cbc9d4a3f6b15e5f0a3f4e327102501c0aac70d16a24336c8b593185eaea

See more details on using hashes here.

Provenance

The following attestation bundles were made for datarax-0.1.0.post1-py3-none-any.whl:

Publisher: publish.yml on avitai/datarax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page