Dataloader for jax
Project description
Dataloader for JAX
Overview
jax_dataloader provides a high-level pytorch-like dataloader API for
jax. It supports
-
downloading and pre-processing datasets via huggingface datasets, pytorch Dataset, and tensorflow dataset (forthcoming)
-
iteratively loading batches via (vanillla) jax dataloader, pytorch dataloader, tensorflow (forthcoming), and merlin (forthcoming).
A minimum jax-dataloader example:
import jax_dataloader as jdl
dataloader = jdl.DataLoader(
dataset, # Can be a jdl.Dataset or pytorch or huggingface dataset
backend='jax', # Use 'jax' for loading data (also supports `pytorch`)
)
batch = next(iter(dataloader)) # iterate next batch
Installation
The latest jax-dataloader release can directly be installed from PyPI:
pip install jax-dataloader
or install directly from the repository:
pip install git+https://github.com/BirkhoffG/jax-dataloader.git
Note
We will only install
jax-related dependencies. If you wish to use integration ofpytorchor huggingfacedatasets, you should try to manually install them, or runpip install jax-dataloader[all]for installing all the dependencies.
Usage
jax_dataloader.core.DataLoader
follows similar API as the pytorch dataloader.
- The
datasetargument takesjax_dataloader.core.Datasetortorch.utils.data.Datasetor (the huggingface)datasets.Datasetas an input from which to load the data. - The
backendargument takes"jax"or"pytorch"as an input, which specifies which backend dataloader to use batches.
import jax_dataloader as jdl
import jax.numpy as jnp
Using ArrayDataset
The jax_dataloader.core.ArrayDataset is an easy way to wrap multiple
jax.numpy.array into one Dataset. For example, we can create an
ArrayDataset
as follows:
# Create features `X` and labels `y`
X = jnp.arange(100).reshape(10, 10)
y = jnp.arange(10)
# Create an `ArrayDataset`
arr_ds = jdl.ArrayDataset(X, y)
This arr_ds can be loaded by both "jax" and "pytorch" dataloaders.
# Create a `DataLoader` from the `ArrayDataset` via jax backend
dataloader = jdl.DataLoader(arr_ds, 'jax', batch_size=5, shuffle=True)
# Or we can use the pytorch backend
dataloader = jdl.DataLoader(arr_ds, 'pytorch', batch_size=5, shuffle=True)
Using Pytorch Datasets
The pytorch Dataset and its
ecosystems (e.g.,
torchvision,
torchtext,
torchaudio) supports many
built-in datasets. jax_dataloader supports directly passing the
pytorch Dataset.
Note
Unfortuantely, the pytorch Dataset can only work with
backend=pytorch. See the belowing example.
from torchvision.datasets import MNIST
import numpy as np
We load the MNIST dataset from torchvision. The ToNumpy object
transforms images to numpy.array.
class ToNumpy(object):
def __call__(self, pic):
return np.array(pic, dtype=float)
pt_ds = MNIST('/tmp/mnist/', download=True, transform=ToNumpy(), train=False)
This pt_ds can only be loaded via "pytorch" dataloaders.
dataloader = jdl.DataLoader(pt_ds, 'pytorch', batch_size=5, shuffle=True)
Using Huggingface Datasets
The huggingface datasets is a
morden library for downloading, pre-processing, and sharing datasets.
jax_dataloader supports directly passing the huggingface datasets.
from datasets import load_dataset
For example, We load the "squad" dataset from datasets:
hf_ds = load_dataset("squad")
This hf_ds can be loaded via "jax" and "pytorch" dataloaders.
# Create a `DataLoader` from the `datasets.Dataset` via jax backend
dataloader = jdl.DataLoader(hf_ds['train'], 'jax', batch_size=5, shuffle=True)
# Or we can use the pytorch backend
dataloader = jdl.DataLoader(hf_ds['train'], 'pytorch', batch_size=5, shuffle=True)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jax-dataloader-0.0.5.tar.gz.
File metadata
- Download URL: jax-dataloader-0.0.5.tar.gz
- Upload date:
- Size: 16.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.8.16
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7c6b8092445256fb8698f3aba2de2dd97f208c3fbd0d1a28b6ba186a3889d80c
|
|
| MD5 |
4b89862d023c7aee9e76513c14c7e4ee
|
|
| BLAKE2b-256 |
c5337a9223195bfe809b50cd9cbdf95e55fab0bd2eb7ecad9419b00bf7434a83
|
File details
Details for the file jax_dataloader-0.0.5-py3-none-any.whl.
File metadata
- Download URL: jax_dataloader-0.0.5-py3-none-any.whl
- Upload date:
- Size: 15.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.8.16
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
806d84d0e99707fa96801ada338d03fb4e92b245420c899a3d682b7bf573b378
|
|
| MD5 |
389a9cfc2fb220b425765b230d748229
|
|
| BLAKE2b-256 |
8cf970afeb43cfdb6d2c9869922658a68ac903608c1285f22a503066d3a6cb3a
|