Skip to main content

Dataloader for jax

Project description

Dataloader for JAX

Python CI status Docs pypi GitHub License Downloads

Overview | Installation | Usage | Documentation

Overview

jax_dataloader brings pytorch-like dataloader API to jax. It supports

A minimum jax-dataloader example:

import jax_dataloader as jdl

jdl.manual_seed(1234) # Set the global seed to 1234 for reproducibility

dataloader = jdl.DataLoader(
    dataset, # Can be a jdl.Dataset or pytorch or huggingface or tensorflow dataset
    backend='jax', # Use 'jax' backend for loading data
    batch_size=32, # Batch size 
    shuffle=True, # Shuffle the dataloader every iteration or not
    drop_last=False, # Drop the last batch or not
    generator=jdl.Generator() # Control the randomness of this dataloader 
)

batch = next(iter(dataloader)) # iterate next batch

Installation

The latest jax-dataloader release can directly be installed from PyPI:

pip install jax-dataloader

or install directly from the repository:

pip install git+https://github.com/BirkhoffG/jax-dataloader.git

[!NOTE]

We keep jax-dataloader’s dependencies minimum, which only install jax and plum-dispatch (for backend dispatching) when installing. If you wish to use integration of pytorch, huggingface datasets, or tensorflow, we highly recommend manually install those dependencies.

You can also run pip install jax-dataloader[all] to install everything (not recommended).

Usage

jax_dataloader.core.DataLoader follows similar API as the pytorch dataloader.

  • The dataset should be an object of the subclass of jax_dataloader.core.Dataset or torch.utils.data.Dataset or (the huggingface) datasets.Dataset or tf.data.Dataset.
  • The backend should be one of "jax" or "pytorch" or "tensorflow". This argument specifies which backend dataloader to load batches.

Note that not every dataset is compatible with every backend. See the compatibility table below:

jdl.Dataset torch_data.Dataset tf.data.Dataset datasets.Dataset
"jax"
"pytorch"
"tensorflow"

Using ArrayDataset

The jax_dataloader.core.ArrayDataset is an easy way to wrap multiple jax.numpy.array into one Dataset. For example, we can create an ArrayDataset as follows:

# Create features `X` and labels `y`
X = jnp.arange(100).reshape(10, 10)
y = jnp.arange(10)
# Create an `ArrayDataset`
arr_ds = jdl.ArrayDataset(X, y)

This arr_ds can be loaded by every backends.

# Create a `DataLoader` from the `ArrayDataset` via jax backend
dataloader = jdl.DataLoader(arr_ds, 'jax', batch_size=5, shuffle=True)
# Or we can use the pytorch backend
dataloader = jdl.DataLoader(arr_ds, 'pytorch', batch_size=5, shuffle=True)
# Or we can use the tensorflow backend
dataloader = jdl.DataLoader(arr_ds, 'tensorflow', batch_size=5, shuffle=True)

Using Huggingface Datasets

The huggingface datasets is a morden library for downloading, pre-processing, and sharing datasets. jax_dataloader supports directly passing the huggingface datasets.

from datasets import load_dataset

For example, We load the "squad" dataset from datasets:

hf_ds = load_dataset("squad")

Then, we can use jax_dataloader to load batches of hf_ds.

# Create a `DataLoader` from the `datasets.Dataset` via jax backend
dataloader = jdl.DataLoader(hf_ds['train'], 'jax', batch_size=5, shuffle=True)
# Or we can use the pytorch backend
dataloader = jdl.DataLoader(hf_ds['train'], 'pytorch', batch_size=5, shuffle=True)
# Or we can use the tensorflow backend
dataloader = jdl.DataLoader(hf_ds['train'], 'tensorflow', batch_size=5, shuffle=True)

Using Pytorch Datasets

The pytorch Dataset and its ecosystems (e.g., torchvision, torchtext, torchaudio) supports many built-in datasets. jax_dataloader supports directly passing the pytorch Dataset.

[!NOTE]

Unfortuantely, the pytorch Dataset can only work with backend=pytorch. See the belowing example.

from torchvision.datasets import MNIST
import numpy as np

We load the MNIST dataset from torchvision. The ToNumpy object transforms images to numpy.array.

pt_ds = MNIST('/tmp/mnist/', download=True, transform=lambda x: np.array(x, dtype=float), train=False)

This pt_ds can only be loaded via "pytorch" dataloaders.

dataloader = jdl.DataLoader(pt_ds, 'pytorch', batch_size=5, shuffle=True)

Using Tensowflow Datasets

jax_dataloader supports directly passing the tensorflow datasets.

import tensorflow_datasets as tfds
import tensorflow as tf

For instance, we can load the MNIST dataset from tensorflow_datasets

tf_ds = tfds.load('mnist', split='test', as_supervised=True)

and use jax_dataloader for iterating the dataset.

dataloader = jdl.DataLoader(tf_ds, 'tensorflow', batch_size=5, shuffle=True)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_dataloader-0.1.4.tar.gz (21.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_dataloader-0.1.4-py3-none-any.whl (21.9 kB view details)

Uploaded Python 3

File details

Details for the file jax_dataloader-0.1.4.tar.gz.

File metadata

  • Download URL: jax_dataloader-0.1.4.tar.gz
  • Upload date:
  • Size: 21.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.16

File hashes

Hashes for jax_dataloader-0.1.4.tar.gz
Algorithm Hash digest
SHA256 189cf1f8ea4a054563a07730cc312e89fe100ef6e4bd4cf867907cda8d6591ec
MD5 62dff52ab8711051dca7fe92dd2c9c85
BLAKE2b-256 e5e4e58bea9e3df8e9b4bb566831e050e190efe7de513f3eab766d44b1b8745b

See more details on using hashes here.

File details

Details for the file jax_dataloader-0.1.4-py3-none-any.whl.

File metadata

  • Download URL: jax_dataloader-0.1.4-py3-none-any.whl
  • Upload date:
  • Size: 21.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.16

File hashes

Hashes for jax_dataloader-0.1.4-py3-none-any.whl
Algorithm Hash digest
SHA256 b00d1870399b4194923eb52814bd330875a32f99400ea7d48d7be6f9d5ed3341
MD5 ccc979b98fe2fc094413343b4741df44
BLAKE2b-256 862fff102bd1cb9565f64286c1f065e442a88419c3567cea34909841cfb58417

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page