Skip to main content

High-performance, PyTree-native data loading and processing for JAX/Flax

Project description

JaxFlow

PyPI version License: MIT Python 3.8+ Code style: black

JaxFlow is a high-performance, PyTree-native data loading and processing library designed specifically for the JAX and Flax ecosystem.

Unlike generic data loaders, JaxFlow is built from the ground up to handle JAX's specific needs—like handling arbitrary PyTrees, efficient prefetching to devices (GPU/TPU), and seamless integration with jax.jit and jax.pmap.

🚀 Key Features

  • PyTree Native: Data loaders yield PyTrees (dicts, tuples, lists, custom classes) directly, ready for jax.tree_map.
  • JAX Device Prefetching: Automatically prefetches batches to the target device (GPU/TPU) to minimize host-device transfer bottlenecks.
  • Torch-like API: Familiar Dataset and Loader API for those coming from PyTorch, but optimized for JAX.
  • Multiprocessing: Robust multiprocessing workers for parallel data loading and augmentation.
  • Composability: Flexible transforms module for composing image and data augmentations.
  • Visualization: Built-in tools in jaxflow.viz to quickly inspect batches and training curves.
  • CLI Tools: Includes a command-line interface for benchmarking system performance (python -m jaxflow.cli benchmark).

📦 Installation

# Install from PyPI
pip install jaxflow

# Install with visualization support
pip install jaxflow[viz]

# Install from source
pip install .

⚡ Quick Start

Here's how to create a simple dataset and iterate over it:

import jax.numpy as jnp
import numpy as np
from jaxflow import Dataset, Loader, transforms

# 1. Define a custom dataset
class RandomDataset(Dataset):
    def __init__(self, length=1000):
        self.length = length
        self.transform = transforms.Compose([
            transforms.ToArray(dtype=jnp.float32),
            transforms.Normalize(mean=[0.5], std=[0.5])
        ])

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        # Return a dict (PyTree)
        image = np.random.rand(28, 28, 1)
        label = np.random.randint(0, 10)
        return {
            "image": self.transform(image),
            "label": label
        }

# 2. Create a loader
dataset = RandomDataset()
loader = Loader(
    dataset, 
    batch_size=32, 
    shuffle=True, 
    num_workers=2,
    drop_last=True
)

# 3. Iterate (batches are automatically prefetched to device if available)
print("Starting training loop...")
for batch in loader:
    images = batch["image"] # Shape: (32, 28, 28, 1)
    labels = batch["label"] # Shape: (32,)
    
    # Your JAX training step here...
    # params = train_step(params, images, labels)
    
    print(f"Batch shape: {images.shape}")
    break

🛠️ CLI Usage

JaxFlow comes with a handy CLI to check your environment and run benchmarks.

# Check system info
jaxflow info --json

# Run a matrix multiplication benchmark to test device performance
jaxflow benchmark --device gpu --size 4096 --iters 100

🤝 Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

📄 License

This project is licensed under the MIT License - see the LICENSE file for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxflow_lib-0.1.0.tar.gz (25.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jaxflow_lib-0.1.0-py3-none-any.whl (26.2 kB view details)

Uploaded Python 3

File details

Details for the file jaxflow_lib-0.1.0.tar.gz.

File metadata

  • Download URL: jaxflow_lib-0.1.0.tar.gz
  • Upload date:
  • Size: 25.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.3

File hashes

Hashes for jaxflow_lib-0.1.0.tar.gz
Algorithm Hash digest
SHA256 c1a301dc8207eecc578dfbac02c8f835339a0c30b12d3354424617af7d340d1e
MD5 169c14e350bdfb65bd4387554ba79c5e
BLAKE2b-256 01796ef59a2ad5aa046206094b726fa764b95f607530c82812dbda847005f857

See more details on using hashes here.

File details

Details for the file jaxflow_lib-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: jaxflow_lib-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 26.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.3

File hashes

Hashes for jaxflow_lib-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 28ca1d9a8facefbad3d96a5840b531e7334bb36815e8cd4e56f2bc41af0b9c35
MD5 a31af198e3c73d3528e808dc09b4fabd
BLAKE2b-256 b5d27d59a1174e3f171b10701fd74538a9c8e7d74b99fe8706314489b338713b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page