Skip to main content

Dataloader for jax

Project description

Jax-Dataloader

Python CI status Docs pypi GitHub License

Overview

jax_dataloader provides a high-level pytorch-like dataloader API for jax. It supports

A minimum jax-dataloader example:

import jax_dataloader as jdl

dataloader = jdl.DataLoader(
    dataset, # Can be a jdl.Dataset or pytorch or huggingface dataset
    backend='jax', # Use 'jax' for loading data (also supports `pytorch`)
)

batch = next(iter(dataloader)) # iterate next batch

Installation

The latest jax-dataloader release can directly be installed from PyPI:

pip install jax-dataloader

or install directly from the repository:

pip install git+https://github.com/BirkhoffG/jax-dataloader.git

Note

We will only install jax-related dependencies. If you wish to use integration of pytorch or huggingface datasets, you should try to manually install them, or run pip install jax-dataloader[all] for installing all the dependencies.

Usage

jax_dataloader.core.DataLoader follows similar API as the pytorch dataloader.

  • The dataset argument takes jax_dataloader.core.Dataset or torch.utils.data.Dataset or (the huggingface) datasets.Dataset as an input from which to load the data.
  • The backend argument takes "jax" or"pytorch" as an input, which specifies which backend dataloader to use batches.
import jax_dataloader as jdl
import jax.numpy as jnp

Using ArrayDataset

The jax_dataloader.core.ArrayDataset is an easy way to wrap multiple jax.numpy.array into one Dataset. For example, we can create an ArrayDataset as follows:

# Create features `X` and labels `y`
X = jnp.arange(100).reshape(10, 10)
y = jnp.arange(10)
# Create an `ArrayDataset`
arr_ds = jdl.ArrayDataset(X, y)

This arr_ds can be loaded by both "jax" and "pytorch" dataloaders.

# Create a `DataLoader` from the `ArrayDataset` via jax backend
dataloader = jdl.DataLoader(arr_ds, 'jax', batch_size=5, shuffle=True)
# Or we can use the pytorch backend
dataloader = jdl.DataLoader(arr_ds, 'pytorch', batch_size=5, shuffle=True)

Using Pytorch Datasets

The pytorch Dataset and its ecosystems (e.g., torchvision, torchtext, torchaudio) supports many built-in datasets. jax_dataloader supports directly passing the pytorch Dataset.

Note

Unfortuantely, the pytorch Dataset can only work with backend=pytorch. See the belowing example.

from torchvision.datasets import MNIST
import numpy as np

We load the MNIST dataset from torchvision. The ToNumpy object transforms images to numpy.array.

class ToNumpy(object):
  def __call__(self, pic):
    return np.array(pic, dtype=float)
pt_ds = MNIST('/tmp/mnist/', download=True, transform=ToNumpy(), train=False)

This pt_ds can only be loaded via "pytorch" dataloaders.

dataloader = jdl.DataLoader(pt_ds, 'pytorch', batch_size=5, shuffle=True)

Using Huggingface Datasets

The huggingface datasets is a morden library for downloading, pre-processing, and sharing datasets. jax_dataloader supports directly passing the huggingface datasets.

from datasets import load_dataset

For example, We load the "squad" dataset from datasets:

hf_ds = load_dataset("squad")

This hf_ds can be loaded via "jax" and "pytorch" dataloaders.

# Create a `DataLoader` from the `datasets.Dataset` via jax backend
dataloader = jdl.DataLoader(hf_ds['train'], 'jax', batch_size=5, shuffle=True)
# Or we can use the pytorch backend
dataloader = jdl.DataLoader(hf_ds['train'], 'pytorch', batch_size=5, shuffle=True)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax-dataloader-0.0.4.tar.gz (16.0 kB view details)

Uploaded Source

Built Distribution

jax_dataloader-0.0.4-py3-none-any.whl (15.5 kB view details)

Uploaded Python 3

File details

Details for the file jax-dataloader-0.0.4.tar.gz.

File metadata

  • Download URL: jax-dataloader-0.0.4.tar.gz
  • Upload date:
  • Size: 16.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.8.16

File hashes

Hashes for jax-dataloader-0.0.4.tar.gz
Algorithm Hash digest
SHA256 67b1e2e61271eb394448bbf412b8a1f50f2cbc8f694fa7115973afd618322ba0
MD5 2da62dcf9f89a6e1fa371b4e4a717c49
BLAKE2b-256 812ba94b81323058358e54ad4cba9fec76e83c47e387f664997e53b1233a6772

See more details on using hashes here.

File details

Details for the file jax_dataloader-0.0.4-py3-none-any.whl.

File metadata

File hashes

Hashes for jax_dataloader-0.0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 b30a8781f8c1709a46fb108caaeb7a8313c670c9ee1db590eb9004eea45fcf04
MD5 d725a34c31a1bb5fc208dbbdebddfd93
BLAKE2b-256 d269a454b11f49a79b69f5a8596efd5dcd84dda08b41e4179eabce684618dbb2

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page