High-performance, PyTree-native data loading and processing for JAX/Flax
Project description
JaxFlow
JaxFlow is a high-performance, PyTree-native data loading and processing library designed specifically for the JAX and Flax ecosystem.
Unlike generic data loaders, JaxFlow is built from the ground up to handle JAX's specific needs—like handling arbitrary PyTrees, efficient prefetching to devices (GPU/TPU), and seamless integration with jax.jit and jax.pmap.
🚀 Key Features
- PyTree Native: Data loaders yield PyTrees (dicts, tuples, lists, custom classes) directly, ready for
jax.tree_map. - JAX Device Prefetching: Automatically prefetches batches to the target device (GPU/TPU) to minimize host-device transfer bottlenecks.
- Torch-like API: Familiar
DatasetandLoaderAPI for those coming from PyTorch, but optimized for JAX. - Multiprocessing: Robust multiprocessing workers for parallel data loading and augmentation.
- Composability: Flexible
transformsmodule for composing image and data augmentations. - Visualization: Built-in tools in
jaxflow.vizto quickly inspect batches and training curves. - CLI Tools: Includes a command-line interface for benchmarking system performance (
python -m jaxflow.cli benchmark).
📦 Installation
# Install from PyPI
pip install jaxflow
# Install with visualization support
pip install jaxflow[viz]
# Install from source
pip install .
⚡ Quick Start
Here's how to create a simple dataset and iterate over it:
import jax.numpy as jnp
import numpy as np
from jaxflow import Dataset, Loader, transforms
# 1. Define a custom dataset
class RandomDataset(Dataset):
def __init__(self, length=1000):
self.length = length
self.transform = transforms.Compose([
transforms.ToArray(dtype=jnp.float32),
transforms.Normalize(mean=[0.5], std=[0.5])
])
def __len__(self):
return self.length
def __getitem__(self, idx):
# Return a dict (PyTree)
image = np.random.rand(28, 28, 1)
label = np.random.randint(0, 10)
return {
"image": self.transform(image),
"label": label
}
# 2. Create a loader
dataset = RandomDataset()
loader = Loader(
dataset,
batch_size=32,
shuffle=True,
num_workers=2,
drop_last=True
)
# 3. Iterate (batches are automatically prefetched to device if available)
print("Starting training loop...")
for batch in loader:
images = batch["image"] # Shape: (32, 28, 28, 1)
labels = batch["label"] # Shape: (32,)
# Your JAX training step here...
# params = train_step(params, images, labels)
print(f"Batch shape: {images.shape}")
break
🛠️ CLI Usage
JaxFlow comes with a handy CLI to check your environment and run benchmarks.
# Check system info
jaxflow info --json
# Run a matrix multiplication benchmark to test device performance
jaxflow benchmark --device gpu --size 4096 --iters 100
🤝 Contributing
Contributions are welcome! Please feel free to submit a Pull Request.
📄 License
This project is licensed under the MIT License - see the LICENSE file for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jaxflow_lib-0.1.0.tar.gz.
File metadata
- Download URL: jaxflow_lib-0.1.0.tar.gz
- Upload date:
- Size: 25.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c1a301dc8207eecc578dfbac02c8f835339a0c30b12d3354424617af7d340d1e
|
|
| MD5 |
169c14e350bdfb65bd4387554ba79c5e
|
|
| BLAKE2b-256 |
01796ef59a2ad5aa046206094b726fa764b95f607530c82812dbda847005f857
|
File details
Details for the file jaxflow_lib-0.1.0-py3-none-any.whl.
File metadata
- Download URL: jaxflow_lib-0.1.0-py3-none-any.whl
- Upload date:
- Size: 26.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
28ca1d9a8facefbad3d96a5840b531e7334bb36815e8cd4e56f2bc41af0b9c35
|
|
| MD5 |
a31af198e3c73d3528e808dc09b4fabd
|
|
| BLAKE2b-256 |
b5d27d59a1174e3f171b10701fd74538a9c8e7d74b99fe8706314489b338713b
|