Skip to main content

Dataloader for jax

Project description

Jax-Dataloader

Python CI status Docs pypi GitHub License

Overview

jax_dataloader provides a high-level pytorch-like dataloader API for jax. It supports

A minimum jax_dataloader example:

import jax_dataloader as jdl

dataloader = jdl.DataLoader(
    dataset, # Can be a jdl.Dataset or pytorch or huggingface dataset
    backend='jax', # Use 'jax' for loading data (also supports `pytorch`)
)

batch = next(iter(dataloader)) # iterate next batch

Installation

The latest jax_dataloader release can directly be installed from PyPI:

pip install jax_dataloader

or install directly from the repository:

pip install git+https://github.com/BirkhoffG/jax-dataloader.git

Note

We will only install jax-related dependencies. If you wish to use integration of pytorch or huggingface datasets, you should try to manually install them, or run pip install jax_dataloader[dev] for installing all the dependencies.

Usage

jax_dataloader.core.DataLoader follows similar API as the pytorch dataloader.

  • The dataset argument takes jax_dataloader.core.Dataset or torch.utils.data.Dataset or (the huggingface) datasets.Dataset as an input from which to load the data.
  • The backend argument takes "jax" or"pytorch" as an input, which specifies which backend dataloader to use batches.
import jax_dataloader as jdl
import jax.numpy as jnp

Using ArrayDataset

The jax_dataloader.core.ArrayDataset is an easy way to wrap multiple jax.numpy.array into one Dataset. For example, we can create an ArrayDataset as follows:

# Create features `X` and labels `y`
X = jnp.arange(100).reshape(10, 10)
y = jnp.arange(10)
# Create an `ArrayDataset`
arr_ds = jdl.ArrayDataset(X, y)

This arr_ds can be loaded by both "jax" and "pytorch" dataloaders.

# Create a `DataLoader` from the `ArrayDataset` via jax backend
dataloader = jdl.DataLoader(arr_ds, 'jax', batch_size=5, shuffle=True)
# Or we can use the pytorch backend
dataloader = jdl.DataLoader(arr_ds, 'pytorch', batch_size=5, shuffle=True)

Using Pytorch Datasets

The pytorch Dataset and its ecosystems (e.g., torchvision, torchtext, torchaudio) supports many built-in datasets. jax_dataloader supports directly passing the pytorch Dataset.

Note

Unfortuantely, the pytorch Dataset can only work with backend=pytorch. See the belowing example.

from torchvision.datasets import MNIST
import numpy as np

We load the MNIST dataset from torchvision. The ToNumpy object transforms images to numpy.array.

class ToNumpy(object):
  def __call__(self, pic):
    return np.array(pic, dtype=float)
pt_ds = MNIST('/tmp/mnist/', download=True, transform=ToNumpy(), train=False)

This pt_ds can only be loaded via "pytorch" dataloaders.

dataloader = jdl.DataLoader(pt_ds, 'pytorch', batch_size=5, shuffle=True)

Using Huggingface Datasets

The huggingface datasets is a morden library for downloading, pre-processing, and sharing datasets. jax_dataloader supports directly passing the huggingface datasets.

from datasets import load_dataset

For example, We load the "squad" dataset from datasets:

hf_ds = load_dataset("squad")

This hf_ds can be loaded via "jax" and "pytorch" dataloaders.

# Create a `DataLoader` from the `datasets.Dataset` via jax backend
dataloader = jdl.DataLoader(hf_ds['train'], 'jax', batch_size=5, shuffle=True)
# Or we can use the pytorch backend
dataloader = jdl.DataLoader(hf_ds['train'], 'pytorch', batch_size=5, shuffle=True)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax-dataloader-0.0.3.tar.gz (15.9 kB view details)

Uploaded Source

Built Distribution

jax_dataloader-0.0.3-py3-none-any.whl (15.5 kB view details)

Uploaded Python 3

File details

Details for the file jax-dataloader-0.0.3.tar.gz.

File metadata

  • Download URL: jax-dataloader-0.0.3.tar.gz
  • Upload date:
  • Size: 15.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.7.12

File hashes

Hashes for jax-dataloader-0.0.3.tar.gz
Algorithm Hash digest
SHA256 930aa557d606ed2c6c0928dbc398720d561b64aab7541d5edb76d9f55c76dced
MD5 7af079807fac8f1c4d5d4690d384825f
BLAKE2b-256 e73c026c315a5aab7001f14b609e6727bb7f5c682544242d8a8d906c6e97f0af

See more details on using hashes here.

File details

Details for the file jax_dataloader-0.0.3-py3-none-any.whl.

File metadata

File hashes

Hashes for jax_dataloader-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 93af3429db235556ffd250f59e33e1e908bda3680caaef07f915627eafc41b98
MD5 47eb5c010fd620f63b3b22992d1c7cc4
BLAKE2b-256 0532c18bba857cd41393003df4908a96956333eae2b44aff266cc9b390ce970d

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page