Skip to main content

Dataloader for jax

Project description

Dataloader for JAX

Python CI status Docs pypi GitHub License Downloads

Overview

jax_dataloader brings pytorch-like dataloader API to jax. It supports

A minimum jax-dataloader example:

import jax_dataloader as jdl

dataloader = jdl.DataLoader(
    dataset, # Can be a jdl.Dataset or pytorch or huggingface or tensorflow dataset
    backend='jax', # Use 'jax' backend for loading data
    batch_size=32, # Batch size 
    shuffle=True, # Shuffle the dataloader every iteration or not
    drop_last=False, # Drop the last batch or not
)

batch = next(iter(dataloader)) # iterate next batch

Installation

The latest jax-dataloader release can directly be installed from PyPI:

pip install jax-dataloader

or install directly from the repository:

pip install git+https://github.com/BirkhoffG/jax-dataloader.git

Note

We keep jax-dataloader’s dependencies minimum, which only install jax and plum-dispatch (for backend dispatching) when installing. If you wish to use integration of pytorch, huggingface datasets, or tensorflow, we highly recommend manually install those dependencies.

You can also run pip install jax-dataloader[all] to install everything (not recommended).

Usage

jax_dataloader.core.DataLoader follows similar API as the pytorch dataloader.

  • The dataset should be an object of the subclass of jax_dataloader.core.Dataset or torch.utils.data.Dataset or (the huggingface) datasets.Dataset or tf.data.Dataset.
  • The backend should be one of "jax" or "pytorch" or "tensorflow". This argument specifies which backend dataloader to load batches.

Note that not every dataset is compatible with every backend. See the compatibility table below:

jdl.Dataset torch_data.Dataset tf.data.Dataset datasets.Dataset
"jax"
"pytorch"
"tensorflow"

Using ArrayDataset

The jax_dataloader.core.ArrayDataset is an easy way to wrap multiple jax.numpy.array into one Dataset. For example, we can create an ArrayDataset as follows:

# Create features `X` and labels `y`
X = jnp.arange(100).reshape(10, 10)
y = jnp.arange(10)
# Create an `ArrayDataset`
arr_ds = jdl.ArrayDataset(X, y)

This arr_ds can be loaded by every backends.

# Create a `DataLoader` from the `ArrayDataset` via jax backend
dataloader = jdl.DataLoader(arr_ds, 'jax', batch_size=5, shuffle=True)
# Or we can use the pytorch backend
dataloader = jdl.DataLoader(arr_ds, 'pytorch', batch_size=5, shuffle=True)
# Or we can use the tensorflow backend
dataloader = jdl.DataLoader(arr_ds, 'tensorflow', batch_size=5, shuffle=True)

Using Huggingface Datasets

The huggingface datasets is a morden library for downloading, pre-processing, and sharing datasets. jax_dataloader supports directly passing the huggingface datasets.

from datasets import load_dataset

For example, We load the "squad" dataset from datasets:

hf_ds = load_dataset("squad")

Then, we can use jax_dataloader to load batches of hf_ds.

# Create a `DataLoader` from the `datasets.Dataset` via jax backend
dataloader = jdl.DataLoader(hf_ds['train'], 'jax', batch_size=5, shuffle=True)
# Or we can use the pytorch backend
dataloader = jdl.DataLoader(hf_ds['train'], 'pytorch', batch_size=5, shuffle=True)
# Or we can use the tensorflow backend
dataloader = jdl.DataLoader(hf_ds['train'], 'tensorflow', batch_size=5, shuffle=True)

Using Pytorch Datasets

The pytorch Dataset and its ecosystems (e.g., torchvision, torchtext, torchaudio) supports many built-in datasets. jax_dataloader supports directly passing the pytorch Dataset.

Note

Unfortuantely, the pytorch Dataset can only work with backend=pytorch. See the belowing example.

from torchvision.datasets import MNIST
import numpy as np

We load the MNIST dataset from torchvision. The ToNumpy object transforms images to numpy.array.

pt_ds = MNIST('/tmp/mnist/', download=True, transform=lambda x: np.array(x, dtype=float), train=False)

This pt_ds can only be loaded via "pytorch" dataloaders.

dataloader = jdl.DataLoader(pt_ds, 'pytorch', batch_size=5, shuffle=True)

Using Tensowflow Datasets

jax_dataloader supports directly passing the tensorflow datasets.

import tensorflow_datasets as tfds
import tensorflow as tf

For instance, we can load the MNIST dataset from tensorflow_datasets

tf_ds = tfds.load('mnist', split='test', as_supervised=True)

and use jax_dataloader for iterating the dataset.

dataloader = jdl.DataLoader(tf_ds, 'tensorflow', batch_size=5, shuffle=True)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax-dataloader-0.1.2.tar.gz (18.6 kB view details)

Uploaded Source

Built Distribution

jax_dataloader-0.1.2-py3-none-any.whl (19.9 kB view details)

Uploaded Python 3

File details

Details for the file jax-dataloader-0.1.2.tar.gz.

File metadata

  • Download URL: jax-dataloader-0.1.2.tar.gz
  • Upload date:
  • Size: 18.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.10.13

File hashes

Hashes for jax-dataloader-0.1.2.tar.gz
Algorithm Hash digest
SHA256 3f77a2f9770c1898ede5b1202749b268c5fe64fbedf5e558bb33a02e5cae5d3d
MD5 27bf48b3e04a885ef8b0d0e150c0de04
BLAKE2b-256 63f6a4abff866c443271f02283033ff28bd7a05480b42f47586ab6fc792138fd

See more details on using hashes here.

File details

Details for the file jax_dataloader-0.1.2-py3-none-any.whl.

File metadata

File hashes

Hashes for jax_dataloader-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 a8904bf7c7f81e847627e66fd7797dc3a400ddd6b5b0cc90232c7ed2bd068357
MD5 64b8dd81b70a92569e27c40796d1e274
BLAKE2b-256 2a53785703960043ff4f132696a2830d07114c53f792b848badd8e024050f058

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page