Skip to main content

Dataloader for jax

Project description

Jax-Dataloader

Python CI status Docs pypi GitHub License

Overview

jax_dataloader provides a high-level pytorch-like dataloader API for jax. It supports

A minimum jax_dataloader example:

import jax_dataloader as jdl

dataloader = jdl.DataLoader(
    dataset, # Can be a jdl.Dataset or pytorch or huggingface dataset
    backend='jax', # Use 'jax' for loading data (also supports `pytorch`)
)

batch = next(iter(dataloader)) # iterate next batch

Installation

The latest jax_dataloader release can directly be installed from PyPI:

pip install jax_dataloader

or install directly from the repository:

pip install git+https://github.com/BirkhoffG/jax-dataloader.git

Note

We will only install jax-related dependencies. If you wish to use integration of pytorch or huggingface datasets, you should try to manually install them, or run pip install jax_dataloader[dev] for installing all the dependencies.

Usage

jax_dataloader.core.DataLoader follows similar API as the pytorch dataloader.

  • The dataset argument takes jax_dataloader.core.Dataset or torch.utils.data.Dataset or (the huggingface) datasets.Dataset as an input from which to load the data.
  • The backend argument takes "jax" or"pytorch" as an input, which specifies which backend dataloader to use batches.
import jax_dataloader as jdl
import jax.numpy as jnp

Using ArrayDataset

The jax_dataloader.core.ArrayDataset is an easy way to wrap multiple jax.numpy.array into one Dataset. For example, we can create an ArrayDataset as follows:

# Create features `X` and labels `y`
X = jnp.arange(100).reshape(10, 10)
y = jnp.arange(10)
# Create an `ArrayDataset`
arr_ds = jdl.ArrayDataset(X, y)

This arr_ds can be loaded by both "jax" and "pytorch" dataloaders.

# Create a `DataLoader` from the `ArrayDataset` via jax backend
dataloader = jdl.DataLoader(arr_ds, 'jax', batch_size=5, shuffle=True)
# Or we can use the pytorch backend
dataloader = jdl.DataLoader(arr_ds, 'pytorch', batch_size=5, shuffle=True)

Using Huggingface Datasets

The huggingface datasets is a morden library for downloading, pre-processing, and sharing datasets. jax_dataloader supports directly passing the huggingface datasets.

from datasets import load_dataset

For example, We load the "squad" dataset from datasets:

hf_ds = load_dataset("squad")

This hf_ds can be loaded via "jax" and "pytorch" dataloaders.

# Create a `DataLoader` from the `datasets.Dataset` via jax backend
# TODO: This is currently not working
dataloader = jdl.DataLoader(hf_ds['train'], 'jax', batch_size=5, shuffle=True)
# Or we can use the pytorch backend
dataloader = jdl.DataLoader(hf_ds['train'], 'pytorch', batch_size=5, shuffle=True)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax-dataloader-0.0.2.tar.gz (13.3 kB view details)

Uploaded Source

Built Distribution

jax_dataloader-0.0.2-py3-none-any.whl (11.9 kB view details)

Uploaded Python 3

File details

Details for the file jax-dataloader-0.0.2.tar.gz.

File metadata

  • Download URL: jax-dataloader-0.0.2.tar.gz
  • Upload date:
  • Size: 13.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.7.12

File hashes

Hashes for jax-dataloader-0.0.2.tar.gz
Algorithm Hash digest
SHA256 91e88fa1f0a8ee9cc40c785d91e514f035d04a469ebbcb5a5e0b4db60227691d
MD5 127fc1f905b229cb1c672d5bdca92959
BLAKE2b-256 07f2b47edc730686090b2175c79be1d28f9c91b2fe527a7c9d610d200a8f888b

See more details on using hashes here.

File details

Details for the file jax_dataloader-0.0.2-py3-none-any.whl.

File metadata

File hashes

Hashes for jax_dataloader-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 42db7b7926f00fd4c818675f7e38edf7fe56c9beb712e1f78021173258b1672a
MD5 f016015eaa2a1b0ad2acd06a34fdaf6c
BLAKE2b-256 a4d952a1d6921e522c24f1d96645e7ecfce769cf854c9714ff3dfa24720de69d

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page