Skip to main content

JAX-inspired dataframe transforms with Bodo-backed execution

Project description

DataJAX (prototype)

DataJAX explores how to bring JAX-style program transforms (djit, vmap, pjit, scan) to tabular workloads by leaning on Bodo's SPMD compiler. The goal is a “JAX for data” experience: trace pandas-like code, optimise it, and run it across a cluster with predictable sharding semantics.

  • Why: JAX offers composable transforms for array workloads. Applying the same abstractions to tabular data unlocks the ability to stage DataFrame pipelines, reason about sharding, and target high-performance distributed runtimes.
  • Scope: This prototype implements a lightweight IR for DataFrame operations, stages that IR into execution plans, and lowers those plans onto pandas or Bodo. It does not yet handle UDF-heavy workloads or production-grade error handling.

Core Idea

DataJAX works in three stages:

  1. Trace: A Frame object wraps a pandas DataFrame and records all operations (filters, joins, aggregations) into a lightweight Intermediate Representation (IR).
  2. Plan: A planner groups the IR into a series of execution stages. It reasons about data sharding and backend capabilities.
  3. Execute: The plan is lowered to a backend. By default, it uses a pandas-based stub for fast iteration. With Bodo installed, it can generate and execute optimised, parallel code.
[pandas code] -> [Frame wrapper] -> [IR graph] -> [Planner] -> [Execution]
                                                               (pandas or Bodo)

Getting Started

python -m venv .venv
source .venv/bin/activate
pip install datajax  # or: pip install -e .[dev] inside this repo
datajax-export-wavespec --help
pytest -q

The published wheel installs the pandas-backed stub by default so you can experiment immediately. Running pip install -e .[dev] from a clone installs developer tooling (pytest, ruff, pyright) in editable mode.

The CLI entry points expose the offline tooling directly:

datajax-export-wavespec --logs my_logs.parquet --key user_id --out wavespec.json
datajax-replay-tuner --trace trace.json --sample sample.parquet --out policy.json

Both commands work in stub mode; provide optional runtime counters via DATAJAX_RUNTIME_METRICS when replaying real traces.

Running With Real Bodo

If you have a licensed Bodo installation and want to exercise the real backend, set up the environment before running tests or benchmarks:

export DATAJAX_USE_BODO_STUB=0
export DATAJAX_ALLOW_BODO_IMPORT=1
export DATAJAX_EXECUTOR=bodo
# Optional: enable native LazyPlan lowering instead of pandas replay
export DATAJAX_NATIVE_BODO=1
# Recommended when running under mpiexec
export BODO_SPAWN_MODE=0

Recommended validation commands:

pytest tests/api/test_djit_pipeline.py -k bodo -vv
pytest tests/runtime/test_bodo_plan.py -vv
pytest tests/runtime/test_mesh_plan.py -vv

To benchmark the native execution path you will typically need to launch Bodo under MPI. For example, on a workstation with two ranks available:

mpiexec -n 2 python benchmarks/feature_pipeline.py --mode native --spmd

Unset DATAJAX_NATIVE_BODO or switch --mode replay if you prefer the pandas replay path compiled through bodo.jit without native LazyPlan lowering.

Example Usage

Here is a simple example of how to use djit to define a sharded, just-in-time compiled feature engineering pipeline.

import pandas as pd
from datajax.api import djit, shard

# Define a function that takes a DataFrame and returns a transformed one
@djit(
    in_shardings=(shard.replicated(),),
    out_shardings=shard.by_key("user_id"),
)
def compute_features(df):
    df["x2"] = df["x"] * 2
    df_agg = df.groupby("user_id").agg(
        total_x=pd.NamedAgg(column="x", aggfunc="sum"),
        mean_x2=pd.NamedAgg(column="x2", aggfunc="mean"),
    )
    return df_agg

# Create a sample DataFrame
df = pd.DataFrame({
    "user_id": [1, 2, 1, 2, 1],
    "x": [0.1, 0.2, 0.3, 0.4, 0.5],
})

# Execute the djit-compiled function
result = compute_features(df)
print(result)

The @djit decorator traces the pandas operations, plans the execution, and runs it on the selected backend. The out_shardings argument ensures the output data is partitioned by user_id.

Current Capabilities

  • IR & Planner: The Frame wrapper traces column arithmetic, filters, joins, and grouped reductions into a compact IR. A stage planner groups these operations and tracks schemas and sharding.
  • Execution Backends:
    • Default (stub): An embedded Bodo stub executes plans using pandas for instant, deterministic results.
    • Real Bodo: Set DATAJAX_USE_BODO_STUB=0 and DATAJAX_ALLOW_BODO_IMPORT=1 to use a real Bodo installation. This requires an MPI-capable environment.
    • Pandas: Set DATAJAX_EXECUTOR=pandas to bypass Bodo entirely.
  • Developer Experience:
    • djit, vmap, pjit, and scan are available under datajax.api.
    • A comprehensive test suite covers tracing, planning, and backend execution.
    • A benchmark (benchmarks/feature_pipeline.py) compares pandas, the Bodo stub, and native Bodo execution.

Next Steps

Our immediate focus is on hardening the prototype and moving towards a polished "JAX for data" experience.

  • Core Planner & Execution:
    • Improve Bodo-native lowering to remove Python UDFs and add a native repartition operator.
    • Enforce Resource meshes for multi-axis data layouts.
    • Implement a planner optimiser for fusion, pushdowns, and cost-based choices.
  • Feature Coverage:
    • Add IR nodes and lowering rules for window functions, multi-aggregation pipelines, and advanced join strategies.
    • Improve I/O for Arrow/Parquet with sharding hints.
  • Developer Experience:
    • Tighten JAX interoperability and data interchange (DLPack).
    • Improve plan introspection, profiling, and error reporting.
    • Expand CI to cover MPI environments and performance regressions.

For more detail, see the contributor guidelines (AGENTS.md), the development plan (docs/development_plan.md), and the offline intelligence guide (docs/offline_intelligence.md).

Repository Layout

datajax/
  api/            # djit/vmap/pjit/scan frontends and sharding descriptors
  frame/          # traced Frame/Series wrappers and IR builders
  ir/             # IR node definitions
  planner/        # Stage planner and executor
  runtime/        # Backend selection and compilation
  io/             # Data loading helpers
tests/            # Pytest suite
docs/             # Roadmap and development notes

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

datajax-0.0.1.tar.gz (51.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

datajax-0.0.1-py3-none-any.whl (45.1 kB view details)

Uploaded Python 3

File details

Details for the file datajax-0.0.1.tar.gz.

File metadata

  • Download URL: datajax-0.0.1.tar.gz
  • Upload date:
  • Size: 51.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.6

File hashes

Hashes for datajax-0.0.1.tar.gz
Algorithm Hash digest
SHA256 e7a82a3785ac8f67dfa88058d79ef39e0c45b5c1131aaef6e0eda984b7dec684
MD5 e916be17f7fa160514da5e769ce7e800
BLAKE2b-256 1bfdeaf9c77a69d591f5c67ea31638c8250f8d10aff2300c32111a06308785bd

See more details on using hashes here.

File details

Details for the file datajax-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: datajax-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 45.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.6

File hashes

Hashes for datajax-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 b982befcb8e2913280133b913ee7309ab02bbe581c5af566329b28b2861491d3
MD5 6d9f4f25925cea9d9ddc87fdf92095e4
BLAKE2b-256 2f9099e908ef69dc490da8e53f4582ccca8a81b0910707a42bb701f1420ab45c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page