Skip to main content

Dataloader for jax

Project description

Dataloader for JAX

Python CI status Docs pypi GitHub License

Overview

jax_dataloader provides a high-level pytorch-like dataloader API for jax. It supports

A minimum jax-dataloader example:

import jax_dataloader as jdl

dataloader = jdl.DataLoader(
    dataset, # Can be a jdl.Dataset or pytorch or huggingface dataset
    backend='jax', # Use 'jax' for loading data (also supports `pytorch`)
)

batch = next(iter(dataloader)) # iterate next batch

Installation

The latest jax-dataloader release can directly be installed from PyPI:

pip install jax-dataloader

or install directly from the repository:

pip install git+https://github.com/BirkhoffG/jax-dataloader.git

Note

We will only install jax-related dependencies. If you wish to use integration of pytorch or huggingface datasets, you should try to manually install them, or run pip install jax-dataloader[all] for installing all the dependencies.

Usage

jax_dataloader.core.DataLoader follows similar API as the pytorch dataloader.

  • The dataset argument takes jax_dataloader.core.Dataset or torch.utils.data.Dataset or (the huggingface) datasets.Dataset as an input from which to load the data.
  • The backend argument takes "jax" or"pytorch" as an input, which specifies which backend dataloader to use batches.
import jax_dataloader as jdl
import jax.numpy as jnp

Using ArrayDataset

The jax_dataloader.core.ArrayDataset is an easy way to wrap multiple jax.numpy.array into one Dataset. For example, we can create an ArrayDataset as follows:

# Create features `X` and labels `y`
X = jnp.arange(100).reshape(10, 10)
y = jnp.arange(10)
# Create an `ArrayDataset`
arr_ds = jdl.ArrayDataset(X, y)

This arr_ds can be loaded by both "jax" and "pytorch" dataloaders.

# Create a `DataLoader` from the `ArrayDataset` via jax backend
dataloader = jdl.DataLoader(arr_ds, 'jax', batch_size=5, shuffle=True)
# Or we can use the pytorch backend
dataloader = jdl.DataLoader(arr_ds, 'pytorch', batch_size=5, shuffle=True)

Using Pytorch Datasets

The pytorch Dataset and its ecosystems (e.g., torchvision, torchtext, torchaudio) supports many built-in datasets. jax_dataloader supports directly passing the pytorch Dataset.

Note

Unfortuantely, the pytorch Dataset can only work with backend=pytorch. See the belowing example.

from torchvision.datasets import MNIST
import numpy as np

We load the MNIST dataset from torchvision. The ToNumpy object transforms images to numpy.array.

class ToNumpy(object):
  def __call__(self, pic):
    return np.array(pic, dtype=float)
pt_ds = MNIST('/tmp/mnist/', download=True, transform=ToNumpy(), train=False)

This pt_ds can only be loaded via "pytorch" dataloaders.

dataloader = jdl.DataLoader(pt_ds, 'pytorch', batch_size=5, shuffle=True)

Using Huggingface Datasets

The huggingface datasets is a morden library for downloading, pre-processing, and sharing datasets. jax_dataloader supports directly passing the huggingface datasets.

from datasets import load_dataset

For example, We load the "squad" dataset from datasets:

hf_ds = load_dataset("squad")

This hf_ds can be loaded via "jax" and "pytorch" dataloaders.

# Create a `DataLoader` from the `datasets.Dataset` via jax backend
dataloader = jdl.DataLoader(hf_ds['train'], 'jax', batch_size=5, shuffle=True)
# Or we can use the pytorch backend
dataloader = jdl.DataLoader(hf_ds['train'], 'pytorch', batch_size=5, shuffle=True)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax-dataloader-0.0.5.tar.gz (16.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_dataloader-0.0.5-py3-none-any.whl (15.5 kB view details)

Uploaded Python 3

File details

Details for the file jax-dataloader-0.0.5.tar.gz.

File metadata

  • Download URL: jax-dataloader-0.0.5.tar.gz
  • Upload date:
  • Size: 16.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.16

File hashes

Hashes for jax-dataloader-0.0.5.tar.gz
Algorithm Hash digest
SHA256 7c6b8092445256fb8698f3aba2de2dd97f208c3fbd0d1a28b6ba186a3889d80c
MD5 4b89862d023c7aee9e76513c14c7e4ee
BLAKE2b-256 c5337a9223195bfe809b50cd9cbdf95e55fab0bd2eb7ecad9419b00bf7434a83

See more details on using hashes here.

File details

Details for the file jax_dataloader-0.0.5-py3-none-any.whl.

File metadata

  • Download URL: jax_dataloader-0.0.5-py3-none-any.whl
  • Upload date:
  • Size: 15.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.16

File hashes

Hashes for jax_dataloader-0.0.5-py3-none-any.whl
Algorithm Hash digest
SHA256 806d84d0e99707fa96801ada338d03fb4e92b245420c899a3d682b7bf573b378
MD5 389a9cfc2fb220b425765b230d748229
BLAKE2b-256 8cf970afeb43cfdb6d2c9869922658a68ac903608c1285f22a503066d3a6cb3a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page