Skip to main content

Dataloader for jax

Project description

Dataloader for JAX

Python CI status Docs pypi GitHub License Downloads

Overview

jax_dataloader brings pytorch-like dataloader API to jax. It supports

A minimum jax-dataloader example:

import jax_dataloader as jdl

dataloader = jdl.DataLoader(
    dataset, # Can be a jdl.Dataset or pytorch or huggingface dataset
    backend='jax', # Use 'jax' for loading data (also supports `pytorch`)
)

batch = next(iter(dataloader)) # iterate next batch

Installation

The latest jax-dataloader release can directly be installed from PyPI:

pip install jax-dataloader

or install directly from the repository:

pip install git+https://github.com/BirkhoffG/jax-dataloader.git

Note

We keep jax-dataloader’s dependencies minimum, which only install jax-related dependencies, and plum-dispatch for backend dispatching. If you wish to use integration of pytorch, huggingface datasets, or tensorflow, we recommend manually install those dependencies.

You can also run pip install jax-dataloader[all] to install everything (not recommended).

Usage

jax_dataloader.core.DataLoader follows similar API as the pytorch dataloader.

  • The dataset should be an object of the subclass of jax_dataloader.core.Dataset or torch.utils.data.Dataset or (the huggingface) datasets.Dataset or tf.data.Dataset.
  • The backend should be one of "jax" or "pytorch" or "tensorflow". This argument specifies which backend dataloader to load batches.

Note that not every dataset is compatible with every backend. See the compatibility table below:

jdl.Dataset torch_data.Dataset tf.data.Dataset datasets.Dataset
"jax"
"pytorch"
"tensorflow"

Using ArrayDataset

The jax_dataloader.core.ArrayDataset is an easy way to wrap multiple jax.numpy.array into one Dataset. For example, we can create an ArrayDataset as follows:

# Create features `X` and labels `y`
X = jnp.arange(100).reshape(10, 10)
y = jnp.arange(10)
# Create an `ArrayDataset`
arr_ds = jdl.ArrayDataset(X, y)

This arr_ds can be loaded by every backends.

# Create a `DataLoader` from the `ArrayDataset` via jax backend
dataloader = jdl.DataLoader(arr_ds, 'jax', batch_size=5, shuffle=True)
# Or we can use the pytorch backend
dataloader = jdl.DataLoader(arr_ds, 'pytorch', batch_size=5, shuffle=True)

Using Huggingface Datasets

The huggingface datasets is a morden library for downloading, pre-processing, and sharing datasets. jax_dataloader supports directly passing the huggingface datasets.

from datasets import load_dataset

For example, We load the "squad" dataset from datasets:

hf_ds = load_dataset("squad")

Then, we can use jax_dataloader to load batches of hf_ds.

# Create a `DataLoader` from the `datasets.Dataset` via jax backend
dataloader = jdl.DataLoader(hf_ds['train'], 'jax', batch_size=5, shuffle=True)
# Or we can use the pytorch backend
dataloader = jdl.DataLoader(hf_ds['train'], 'pytorch', batch_size=5, shuffle=True)

Using Pytorch Datasets

The pytorch Dataset and its ecosystems (e.g., torchvision, torchtext, torchaudio) supports many built-in datasets. jax_dataloader supports directly passing the pytorch Dataset.

Note

Unfortuantely, the pytorch Dataset can only work with backend=pytorch. See the belowing example.

from torchvision.datasets import MNIST
import numpy as np

We load the MNIST dataset from torchvision. The ToNumpy object transforms images to numpy.array.

class ToNumpy(object):
  def __call__(self, pic):
    return np.array(pic, dtype=float)
pt_ds = MNIST('/tmp/mnist/', download=True, transform=ToNumpy(), train=False)

This pt_ds can only be loaded via "pytorch" dataloaders.

dataloader = jdl.DataLoader(pt_ds, 'pytorch', batch_size=5, shuffle=True)

Using Tensowflow Datasets

jax_dataloader supports directly passing the tensorflow datasets.

import tensorflow_datasets as tfds
import tensorflow as tf

For instance, we can load the MNIST dataset from tensorflow_datasets

tf_ds = tfds.load('mnist', split='test', as_supervised=True)

and use jax_dataloader for iterating the dataset.

dataloader = jdl.DataLoader(tf_ds, 'tensorflow', batch_size=5, shuffle=True)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax-dataloader-0.1.0.tar.gz (18.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_dataloader-0.1.0-py3-none-any.whl (19.0 kB view details)

Uploaded Python 3

File details

Details for the file jax-dataloader-0.1.0.tar.gz.

File metadata

  • Download URL: jax-dataloader-0.1.0.tar.gz
  • Upload date:
  • Size: 18.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.9.18

File hashes

Hashes for jax-dataloader-0.1.0.tar.gz
Algorithm Hash digest
SHA256 19058ee94fe548951f9f4b846965391d15e8ddcf6f8f1176836efb43d3dbe3d4
MD5 b6477ffb2942c2f8930a36aa0a2d6c36
BLAKE2b-256 1d29f9d014ed4fd923a1efc49ac6a5d164de7f0479995eecbbd93d0fd4bfedc7

See more details on using hashes here.

File details

Details for the file jax_dataloader-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: jax_dataloader-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 19.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.9.18

File hashes

Hashes for jax_dataloader-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 1029e984cface497eb42ac2603f809a4723dbeaa0a7ef765e8108f46d03f0189
MD5 de2681982d644f5fa049332b23ac91e3
BLAKE2b-256 58fbf772b193be9ecf270aff32c931e1720f0225861b972e207ce073fcac59e9

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page