Skip to main content

Dataloading for Jax

Project description

Loadax

Loadax is a dataloading library designed for the JAX ecosystem. It provides utilities for feeding data into your training loop without having to worry about batching, shuffling, and other preprocessing steps. Loadax also supports offloading data loading to the background, and prefetching a cache to improve performance, and jax-native distributed data loading.

[!Important] Loadax is currently in early development, and the rest of this README is a working draft.

Installation

pip install loadax

Usage

Data Loading

Loadax provides a simple interface for loading data into your training loop. Here is an example of loading data from a list of items:

from loadax import DataLoader, InMemoryDataset, Batcher

dataset = InMemoryDataset([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
batcher = Batcher(lambda x: x)
loader = DataLoader(batcher).batch_size(2).build(dataset)

for batch in loader:
    print(batch)

# Output:
# [1, 2]
# [3, 4]
# [5, 6]
# [7, 8]
# [9, 10]

A dataloader is a definition of how to load data from a dataset. It itself is stateless enabling you to define mutliple dataloaders for the same dataset, and even multipple iterators for the same dataloader.

dataloader = DataLoader(batcher).batch_size(2).build(dataset)

fast_iterator = iter(dataloader)
slow_iterator = iter(dataloader)

val = next(fast_iterator)
print(val)
# Output: 1

val = next(slow_iterator)
print(val)
# Output: 1

In the above examples we create an object called a batcher. A batcher is an interface that defines how to collate your data into batches. This is useful for when you want to alter the way your data is batched such as stacking into a single array.

Data Prefetching

When training models, it is essential to ensure that you are not blocking the training loop and especially your accelerator(s), with IO bound tasks. Loadax provides a simple interface for prefetching data into a cache using background worker(s).

from loadax import DataLoader, InMemoryDataset, Batcher

dataset = InMemoryDataset([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
batcher = Batcher(lambda x: x)
loader = DataLoader(batcher).batch_size(2).prefetch(3).build(dataset)

for batch in loader:
    print(batch)

# Output:
# [1, 2]
# [3, 4]
# [5, 6]
# [7, 8]
# [9, 10]

In the above example we create a dataloader with a prefetch factor of 3. This means that the loader will prefetch 3 batches ahead of the current index. The future batches are kept in a cache, depending on your configuration can be eagerly loaded into device memory or kept in host memory.

Using Multiple Workers

In the same way that the dataloader can be used to prefetch data, it can also offload the dataloading into multiple background workers. Lets take a look at an example of why you may want to do this.

In the following example we have a dataset that is slow to load an individual item due to some pre-processing. Ignore the details of the MappedDataset as we will get to that later, for now just know that it lazily transforms the data from the source dataset.

from loadax import DataLoader, RangeDataset, MappedDataset, Batcher

def slow_fn(x):
    time.sleep(0.1)
    return x * 2

dataset = MappedDataset(RangeDataset(0, 10), slow_fn)
batcher = Batcher(lambda x: x)
loader = DataLoader(batcher).batch_size(2).workers(2).build(dataset)

for batch in loader:
    print(batch)

# Output:
# [0, 2]
# [4, 6]
# [8, 10]
# [12, 14]
# [16, 18]

In the above example we create a dataloader with 2 workers. This means that the loader will create 2 workers to load the data. The data is loaded in parallel, alowing the background workers to do the slow processing and then the data is batched and ready for consumption.

A important note is that the implementation of the background workers currently leverages the concurrent.futures library, because multiprocessing does not work well with JAX. This means each node is using a single python process and depending on your python version and how IO bound your datset loading is you may rarely see GIL contention.

Distributed Data Loading

Loadax also supports distributed data loading. This means that you can easily shard your dataset across multiple nodes/jax processes and load data in parallel. Loadax will automatically determine which elements to load on each shard within the network ensuring that the data is evenly distributed, and each node only gets the data it needs.

With the inter-node distribution handled for you, it is now trivial to build advanced distributed training loops with paradigms such as model and data parallelism.

from loadax import DataLoader, InMemoryDataset, Batcher
from jax.sharding import Mesh, PartitionSpec, NamedSharding
import jax.numpy as jnp

# Create a mesh across all the jax devices
mesh = Mesh(jax.devices(), ("data", "model"))

# Create a partition spec for the mesh
partition_spec = PartitionSpec("data", "model")

dataset_size = 100
batch_size = 10

# Create dataloader for a jax process
dataset = InMemoryDataset(list(range(dataset_size)))
batcher = Batcher(lambda x: jnp.stack(x))

dataloader = (
    DataLoader(batcher)
        .batch_size(batch_size)
        .workers(2)
        .prefetch(2)
        .shard(mesh, partition_spec)
        .build(dataset)
    )

# Define a simple model function, you can imagine this is some Flax model or something similar, it may even be sharded itself in some other axis such as model parallelism
def simple_model(x, params):
    return x * params

params = jnp.array([2.0])
sharded_params = jax.device_put(params, NamedSharding(mesh, partition_spec))

def compute_loss(batch, predictions):
    # Your loss calculation logic
    return jnp.mean(...)

for batch in dataloader:
    # Distribute the batch across the local devices
    local_batch = jax.device_put(jnp.array(batch), NamedSharding(mesh, sharding_spec))

    # Apply the model and compute the local loss
    predictions = jax.jit(simple_model)(local_batch, sharded_params)
    loss = compute_loss(local_batch, predictions)

    total_loss += jax.lax.pmean(loss, axis_name="model")

The sharding primitives that Loadax provides are powerful as they declare the way data is distributed up front. This enables loadax to be deterministic as is decides which elements to load on each shard, and even which elements to load into each specific batch. This guaranteed determinism enables you to focus on other things rather than ensuring that your dataloading is correct and can be reproduced.

Type Hinting

Another benefit of Loadax is that the underlying shape of your data is passed through all the way into your training loop. This means you can use type hints to ensure that your data is the correct shape.

from loadax import DataLoader, RangeDataset, Batcher

# RangeDataset has a DatasetItem type of Int, this is a generic argument that can be supplied to any dataset
# type. We can look more into this when we get to datasets.
dataset = RangeDataset(0, 10)

# this function is inferred to return an int
def my_fn(x: list[int]) -> int:
    return sum(x)

batcher = Batcher(my_fn)
loader = DataLoader(batcher).batch_size(2).build(dataset)

for batch in loader:
    print(batch)

# Output:
# [1, 2]
# [3, 4]
# [5, 6]
# [7, 8]
# [9, 10]

Because you define the Batcher (or use a predefined one for common operations), the type of the batch can be inferred all the way from the dataset definition.

Datasets

Loadax provides a simple interface for defining your dataset. As long as you can perform indexed access on your data, you can use Loadax to load your data. See the Dataset Protocol for more details.

Additionally, Loadax provides a few common datasets that can be used out of the box. These include:

  • InMemoryDataset
  • RangeDataset
dataset = InMemoryDataset([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
dataset = RangeDataset(0, 10)

Daasets can also be transformed using a variety of transformations. Transformations are lazily applied to the dataset, meaning that they are only applied when the data is actually accessed. Because your dataloader likely is prefetching and using background workers, this should not block your training loop. This also means that you can use jax to jit compile your transformation function.

from loadax import MappedDataset, RangeDataset, ShuffledDataset

def slow_fn(x):
    time.sleep(0.1)
    return x * 2

base_dataset = ShuffledDataset(RangeDataset(0, 10))
dataset = MappedDataset(base_dataset, slow_fn)

When iterating through dataset, the the slow_fn will be applied lazily to the underlying dataset, which in itself is lazily shuffling the range dataset. This Composable pattern allows you to build complex dataloading pipelines.

Dataset Integrations

Loadax has a few common dataset source on the roadmap, including:

  • PolarsDataset
  • SQLiteDataset
  • HuggingFaceDataset

Feel free to open an issue if you have a use case that you would like to see included.

Batchers

Batchers are used to define how to collate your data into batches.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

loadax-0.1.0.tar.gz (22.6 kB view hashes)

Uploaded Source

Built Distribution

loadax-0.1.0-py3-none-any.whl (26.1 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page