Datarax: A high-performance, NNX-based data pipeline framework for JAX
Project description
Datarax: A Data Pipeline Framework for JAX
Note: This project is in early development. API may change. Expect breaking changes.
Datarax (Data + Array/JAX) is a high-performance, extensible data pipeline framework specifically engineered for JAX-based machine learning workflows. It simplifies and accelerates the development of efficient and scalable data loading, preprocessing, and augmentation pipelines for JAX, leveraging the full potential of JAX's Just-In-Time (JIT) compilation, automatic differentiation, and hardware acceleration capabilities.
Key Features
- High Performance: Leverages JAX's JIT compilation and XLA backend to achieve near-optimal data processing speeds on CPUs, GPUs, and TPUs.
- JAX-Native Design: All core components and operations are designed with JAX's functional programming paradigm and immutable data structures (PyTrees) in mind.
- Scalability: Supports efficient data loading and processing for large datasets and distributed training scenarios, including multi-host and multi-device setups.
- Extensibility: Easily define and integrate custom data sources, transformations, and augmentation operations.
- Usability: Provides a clear, intuitive Python API and a flexible configuration system (TOML-based) for defining and managing pipelines.
- Determinism: Pipeline runs are deterministic by default, crucial for reproducibility in research and production.
- Caching Optimization: Multiple caching strategies for performance improvement, including function caching, transformer caching, and checkpointing.
- Complete Feature Set: Supports common data pipeline operations including diverse data source handling, advanced transformations, data augmentation, batching, sharding, checkpointing, and caching.
- Ecosystem Integration: Facilitates smooth integration with other JAX libraries like Flax, Optax, and Orbax.
Installation
Install Datarax using pip:
# Basic installation
pip install datarax
# With data loading support (HuggingFace, TFDS, Audio/Image libs)
pip install datarax[data]
# With GPU support (CUDA 12)
pip install datarax[gpu]
# Full installation with all optional dependencies
pip install datarax[all]
macOS / Apple Silicon
Datarax supports macOS on both Intel and Apple Silicon (M1/M2/M3) processors.
# Install for macOS (CPU mode - recommended)
pip install datarax[all-cpu]
# Run with explicit CPU backend
JAX_PLATFORMS=cpu python your_script.py
Metal GPU Acceleration (Experimental): JAX supports Apple's Metal backend for GPU acceleration on M1+ chips:
pip install jax-metal
JAX_PLATFORMS=metal python your_script.py
Note: Metal GPU acceleration is community-tested. CI runs on macOS with CPU only. If you can help validate Metal support, please open an issue with your results.
Known Limitations:
- TensorFlow Datasets: May hang on import on macOS ARM64. Datarax uses lazy imports to handle this.
- JAX Profiler: Tracing is automatically disabled on macOS due to TensorBoard compatibility.
Development Setup
For contributors and developers, Datarax uses uv as its package manager:
# Install uv
pip install uv
# Clone and setup
git clone https://github.com/avitai/datarax.git
cd datarax
# Run automatic setup (creates venv & installs dependencies)
./setup.sh
# Activate the environment
source activate.sh
# Or install manually with uv
uv pip install -e ".[dev]"
See the Development Environment Guide for detailed instructions on:
- Creating a virtual environment with
uv venv - Installing dependencies through
pyproject.toml - Running tests through pytest
- Building and packaging Datarax
Current Status
Datarax is currently in active development with a Flax NNX-based architecture that provides robust state management and checkpointing:
Core Architecture
- NNX Module System: All components built on
flax.nnx.Modulefor robust state management - Integrated Checkpointing: Seamless Orbax integration for stateful pipeline persistence
- Type Safety: Complete type annotations and runtime validation
- Composability: Modular design enabling flexible pipeline construction
Implemented Components
- Core NNX Modules: DataraxModule, OperatorModule, StructuralModule
- Data Sources: MemorySource, TFDSSource, HFSource (inheriting from StructuralModule/DataSourceModule)
- Operators: Element-wise operators, MapOperator (inheriting from OperatorModule)
- DAG Execution Engine:
DAGExecutorandpipelineAPI for constructing flexible, graph-based data processing flows. - Node System: A rich set of nodes for building pipelines:
DataSourceNode: entry point for dataOperatorNode: for transformationsBatchNode&RebatchNode: for batching controlShuffleNode,CacheNode: for data management- Control flow nodes:
Sequential,Parallel,Branch,Merge
- Stateful Components:
- MemorySourceModule/ArrayRecordSourceModule/HFSourceModule for diverse data ingestion
- Range/ShuffleSamplerModule with reproducible random state
- DefaultBatcherModule with buffer state management
- ArraySharderModule for device-aware data distribution
- External Integrations:
- HuggingFace Datasets with stateful iteration
- TensorFlow Datasets with checkpoint support
- Advanced Features:
- Complete caching strategies with state preservation
- Differentiable rebatching (
DifferentiableRebatchImpl) - NNXCheckpointHandler for production-grade checkpointing
Upcoming features include:
- Image transformation library with JAX-native operations
- Advanced sharding strategies for multi-device and multi-host scenarios
- Performance optimization suite with benchmarking tools
- Extended monitoring and metrics capabilities
- Additional external data source integrations
Quick Start
Here's a simple example of using Datarax's DAG-based architecture to create a data pipeline for image classification:
import jax
import jax.numpy as jnp
import numpy as np
from flax import nnx
from datarax import from_source
from datarax.dag.nodes import OperatorNode
from datarax.operators import ElementOperator, ElementOperatorConfig
from datarax.sources import MemorySource, MemorySourceConfig
from datarax.typing import Element
def normalize(element: Element, key: jax.Array | None = None) -> Element:
"""Normalize the image in the element."""
# element.data is a dict-like PyTree containing the actual data arrays
# Return a new Element with updated data (immutable update)
# update_data provides a clean API for partial updates
return element.update_data({"image": element.data["image"] / 255.0})
def apply_augmentation(element: Element, key: jax.Array) -> Element:
"""Apply a simple augmentation (flipping) to the image."""
key1, _ = jax.random.split(key)
flip_horizontal = jax.random.bernoulli(key1, 0.5)
# Use jax.lax.cond for JAX-compatible conditional execution
def flip_image(img):
return jnp.flip(img, axis=1)
def no_flip(img):
return img
new_image = jax.lax.cond(
flip_horizontal,
flip_image,
no_flip,
element.data["image"]
)
return element.update_data({"image": new_image})
# Create some dummy data (28x28 images)
num_samples = 1000
image_shape = (28, 28, 1)
data = {
"image": np.random.randint(0, 255, (num_samples, *image_shape)).astype(np.float32),
"label": np.random.randint(0, 10, (num_samples,)).astype(np.int32),
}
# Create the data source using config-based API
# MemorySource is a standard DataSource implementation for in-memory data
source_config = MemorySourceConfig()
source = MemorySource(source_config, data=data, rngs=nnx.Rngs(0))
# Create operators using the unified ElementOperator API
# Normalizer is deterministic (no random key needed)
normalizer_config = ElementOperatorConfig(stochastic=False)
normalizer = ElementOperator(normalizer_config, fn=normalize, rngs=nnx.Rngs(0))
# Augmenter is stochastic: requires stream_name for proper RNG management
augmenter_config = ElementOperatorConfig(stochastic=True, stream_name="augmentations")
augmenter = ElementOperator(augmenter_config, fn=apply_augmentation, rngs=nnx.Rngs(42))
# Create the data pipeline using the DAG-based API
# from_source() initializes the pipeline with:
# 1. The DataSourceNode (data loading)
# 2. A BatchNode (automatic batching)
# The >> operator chains transformation operators
pipeline = (
from_source(source, batch_size=32)
>> OperatorNode(normalizer)
>> OperatorNode(augmenter)
)
# Alternative: Method Chaining
# You can also build the pipeline using the fluent .add() method:
# pipeline = (
# from_source(source, batch_size=32)
# .add(OperatorNode(normalizer))
# .add(OperatorNode(augmenter))
# )
# Create an iterator and process batches
# The pipeline handles data streaming, batching, state management, and execution
for i, batch in enumerate(pipeline):
if i >= 3:
break
# Get the shape and stats for each component in the batch
# batch['key'] provides direct access to data arrays
image_batch = batch["image"]
label_batch = batch["label"]
print(f"Batch {i}:")
print(f" Image shape: {image_batch.shape}")
print(f" Label batch size: {label_batch.shape[0]}")
print(f" Image min/max: {image_batch.min():.3f}/{image_batch.max():.3f}")
print("Pipeline processing completed!")
Advanced Pipeline
For more complex workflows, Datarax supports branching and parallel execution:
from datarax.dag.nodes import Parallel, Merge, Branch
# Define additional operators
def invert(element, key=None):
return element.update_data({"image": 1.0 - element.data["image"]})
def is_high_contrast(element):
# Condition: check if image variance is high
return jnp.var(element.data["image"]) > 0.1
# Build a complex DAG:
# 1. Source -> Batching
# 2. Parallel: Normal version AND Inverted version
# 3. Merge: Average them (Simple Ensemble)
# 4. Branch: Apply extra noise ONLY if high contrast, otherwise normalize again
complex_pipeline = (
from_source(source, batch_size=32)
>> (OperatorNode(normalizer) | OperatorNode(invert))
>> Merge("mean")
>> Branch(
condition=is_high_contrast,
true_path=OperatorNode(augmenter),
false_path=OperatorNode(normalizer)
)
)
Documentation
For complete documentation, please visit datarax.readthedocs.io.
Testing
Datarax uses a complete test suite with support for both CPU and GPU testing:
- Tests are organized to mirror the
src/dataraxpackage structure for easier navigation - All GitHub CI workflows run tests exclusively on CPU
- Local test runs automatically use GPU when available
- The testing infrastructure handles environment configuration to ensure consistency
To run tests:
# Run all tests using CPU only (most stable)
JAX_PLATFORMS=cpu python -m pytest
# Run specific test module
JAX_PLATFORMS=cpu python -m pytest tests/sources/test_memory_source.py
For more information on the test organization and how to run tests, see the Testing Guide.
License
Datarax is licensed under the MIT License.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file datarax-0.1.0.post1.tar.gz.
File metadata
- Download URL: datarax-0.1.0.post1.tar.gz
- Upload date:
- Size: 210.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4defc8c14b71cb58ec8ffe7e9940773a76c32b94b52e064cfd5cde8c6214b0f0
|
|
| MD5 |
47fea287cc35d7236eb3afe549eb121d
|
|
| BLAKE2b-256 |
1d6616698700cf5fe3a21d89ea48319c02ac0e149e0a1badabba7d246607ac0a
|
Provenance
The following attestation bundles were made for datarax-0.1.0.post1.tar.gz:
Publisher:
publish.yml on avitai/datarax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
datarax-0.1.0.post1.tar.gz -
Subject digest:
4defc8c14b71cb58ec8ffe7e9940773a76c32b94b52e064cfd5cde8c6214b0f0 - Sigstore transparency entry: 845446202
- Sigstore integration time:
-
Permalink:
avitai/datarax@1c3405a792adf73620c0882cbc31816d0fd620b9 -
Branch / Tag:
refs/heads/main - Owner: https://github.com/avitai
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@1c3405a792adf73620c0882cbc31816d0fd620b9 -
Trigger Event:
workflow_dispatch
-
Statement type:
File details
Details for the file datarax-0.1.0.post1-py3-none-any.whl.
File metadata
- Download URL: datarax-0.1.0.post1-py3-none-any.whl
- Upload date:
- Size: 261.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
9fb0beebe403ccf8d5c8f6b10f4805e127dd748968fc8dea60db7c13cd9411b4
|
|
| MD5 |
54eb974f1f0a4bf6c8ef010fa937dc9e
|
|
| BLAKE2b-256 |
a258cbc9d4a3f6b15e5f0a3f4e327102501c0aac70d16a24336c8b593185eaea
|
Provenance
The following attestation bundles were made for datarax-0.1.0.post1-py3-none-any.whl:
Publisher:
publish.yml on avitai/datarax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
datarax-0.1.0.post1-py3-none-any.whl -
Subject digest:
9fb0beebe403ccf8d5c8f6b10f4805e127dd748968fc8dea60db7c13cd9411b4 - Sigstore transparency entry: 845446208
- Sigstore integration time:
-
Permalink:
avitai/datarax@1c3405a792adf73620c0882cbc31816d0fd620b9 -
Branch / Tag:
refs/heads/main - Owner: https://github.com/avitai
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@1c3405a792adf73620c0882cbc31816d0fd620b9 -
Trigger Event:
workflow_dispatch
-
Statement type: