Dataloader for jax
Project description
Dataloader for JAX
Overview
jax_dataloader
brings pytorch-like dataloader API to jax
. It
supports
-
4 datasets to download and pre-process data:
-
3 backends to iteratively load batches:
A minimum jax-dataloader
example:
import jax_dataloader as jdl
dataloader = jdl.DataLoader(
dataset, # Can be a jdl.Dataset or pytorch or huggingface or tensorflow dataset
backend='jax', # Use 'jax' backend for loading data
batch_size=32, # Batch size
shuffle=True, # Shuffle the dataloader every iteration or not
drop_last=False, # Drop the last batch or not
)
batch = next(iter(dataloader)) # iterate next batch
Installation
The latest jax-dataloader
release can directly be installed from PyPI:
pip install jax-dataloader
or install directly from the repository:
pip install git+https://github.com/BirkhoffG/jax-dataloader.git
Note
We keep
jax-dataloader
’s dependencies minimum, which only installjax
andplum-dispatch
(for backend dispatching) when installing. If you wish to use integration ofpytorch
, huggingfacedatasets
, ortensorflow
, we highly recommend manually install those dependencies.You can also run
pip install jax-dataloader[all]
to install everything (not recommended).
Usage
jax_dataloader.core.DataLoader
follows similar API as the pytorch dataloader.
- The
dataset
should be an object of the subclass ofjax_dataloader.core.Dataset
ortorch.utils.data.Dataset
or (the huggingface)datasets.Dataset
ortf.data.Dataset
. - The
backend
should be one of"jax"
or"pytorch"
or"tensorflow"
. This argument specifies which backend dataloader to load batches.
Note that not every dataset is compatible with every backend. See the compatibility table below:
jdl.Dataset |
torch_data.Dataset |
tf.data.Dataset |
datasets.Dataset |
|
---|---|---|---|---|
"jax" |
✅ | ❌ | ❌ | ✅ |
"pytorch" |
✅ | ✅ | ❌ | ✅ |
"tensorflow" |
✅ | ❌ | ✅ | ✅ |
Using ArrayDataset
The jax_dataloader.core.ArrayDataset
is an easy way to wrap multiple
jax.numpy.array
into one Dataset. For example, we can create an
ArrayDataset
as follows:
# Create features `X` and labels `y`
X = jnp.arange(100).reshape(10, 10)
y = jnp.arange(10)
# Create an `ArrayDataset`
arr_ds = jdl.ArrayDataset(X, y)
This arr_ds
can be loaded by every backends.
# Create a `DataLoader` from the `ArrayDataset` via jax backend
dataloader = jdl.DataLoader(arr_ds, 'jax', batch_size=5, shuffle=True)
# Or we can use the pytorch backend
dataloader = jdl.DataLoader(arr_ds, 'pytorch', batch_size=5, shuffle=True)
# Or we can use the tensorflow backend
dataloader = jdl.DataLoader(arr_ds, 'tensorflow', batch_size=5, shuffle=True)
Using Huggingface Datasets
The huggingface datasets is a
morden library for downloading, pre-processing, and sharing datasets.
jax_dataloader
supports directly passing the huggingface datasets.
from datasets import load_dataset
For example, We load the "squad"
dataset from datasets
:
hf_ds = load_dataset("squad")
Then, we can use jax_dataloader
to load batches of hf_ds
.
# Create a `DataLoader` from the `datasets.Dataset` via jax backend
dataloader = jdl.DataLoader(hf_ds['train'], 'jax', batch_size=5, shuffle=True)
# Or we can use the pytorch backend
dataloader = jdl.DataLoader(hf_ds['train'], 'pytorch', batch_size=5, shuffle=True)
# Or we can use the tensorflow backend
dataloader = jdl.DataLoader(hf_ds['train'], 'tensorflow', batch_size=5, shuffle=True)
Using Pytorch Datasets
The pytorch Dataset and its
ecosystems (e.g.,
torchvision,
torchtext,
torchaudio) supports many
built-in datasets. jax_dataloader
supports directly passing the
pytorch Dataset.
Note
Unfortuantely, the pytorch Dataset can only work with
backend=pytorch
. See the belowing example.
from torchvision.datasets import MNIST
import numpy as np
We load the MNIST dataset from torchvision
. The ToNumpy
object
transforms images to numpy.array
.
pt_ds = MNIST('/tmp/mnist/', download=True, transform=lambda x: np.array(x, dtype=float), train=False)
This pt_ds
can only be loaded via "pytorch"
dataloaders.
dataloader = jdl.DataLoader(pt_ds, 'pytorch', batch_size=5, shuffle=True)
Using Tensowflow Datasets
jax_dataloader
supports directly passing the tensorflow
datasets.
import tensorflow_datasets as tfds
import tensorflow as tf
For instance, we can load the MNIST dataset from tensorflow_datasets
tf_ds = tfds.load('mnist', split='test', as_supervised=True)
and use jax_dataloader
for iterating the dataset.
dataloader = jdl.DataLoader(tf_ds, 'tensorflow', batch_size=5, shuffle=True)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file jax-dataloader-0.1.1.tar.gz
.
File metadata
- Download URL: jax-dataloader-0.1.1.tar.gz
- Upload date:
- Size: 18.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.10.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 55e7e33ac2a8402cd97c8e99180c9479a3563684e87eee61dac9574b3ccc09ca |
|
MD5 | 977f094f917cdd5529331529c19be565 |
|
BLAKE2b-256 | c3390555c482de6cc3a5b8ba0a5059dca4c663c499fc0226068665236c1ca649 |
File details
Details for the file jax_dataloader-0.1.1-py3-none-any.whl
.
File metadata
- Download URL: jax_dataloader-0.1.1-py3-none-any.whl
- Upload date:
- Size: 19.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.10.13
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6dcce180756f6e598f1bbf2ab7cfc3a98f37ac297e77c814e4a19f791dd67fbe |
|
MD5 | b043c1fbb126a98f35b2b568d4689c36 |
|
BLAKE2b-256 | 5ae1eb4d18b21c27664cd4b84006c03af09166ddcfecabc475c1e4df1c81bbba |