Skip to main content

Jittable data loading utilities for JAX.

Project description

Cyreal - Another JAX DataLoader

grain for the corporations, cyreal for the people

Pure jax utilities for iterating over finite datasets without ever touching torch or tensorflow. Dataloaders support jax.jit, jax.grad, jax.lax.scan, and other function transformations.

Installation

The only dependency is jax. On GPU machines, install the appropriate JAX build for your CUDA version.

pip install cyreal

Quick start with MNIST

Write torch-style dataloaders without torch

import jax
import jax.numpy as jnp

from cyreal import (
  ArraySampleSource,
  BatchTransform,
  DataLoader,
  DevicePutTransform,
  MNISTDataset,
)

train_data = MNISTDataset(split="train").as_array_dict()
pipeline = [
  ArraySampleSource(train_data, ordering="shuffle"),
  BatchTransform(batch_size=128),
  DevicePutTransform(),
]
loader = DataLoader(pipeline=pipeline)
state = loader.init_state(jax.random.PRNGKey(0))

for batch, mask in loader.iterate(state):
  ...  # train your network!

Scan and Avoid Boilerplate

DataLoader.scan_epoch will run a full pass through the dataset into a single jax.lax.scan to minimize dispatch overhead. This will jit the body_fn.

def body_fn(model_state, batch, mask):
  model_state = update_model(model_state, batch, mask)
  return model_state, None

loader_state, model_state, _ = loader.scan_epoch(loader_state, model_state, body_fn)

JIT Capabilities

Do you enjoy premature optimization? Why not jit the entire train epoch?

import jax
import jax.numpy as jnp

from cyreal import (
  ArraySampleSource,
  BatchTransform,
  DataLoader,
  DevicePutTransform,
  MNISTDataset,
)

train_data = MNISTDataset(split="train").as_array_dict()
pipeline = [
  ArraySampleSource(train_data, ordering="shuffle"),
  BatchTransform(batch_size=128),
  DevicePutTransform(),
]
loader = DataLoader(pipeline)
loader_state = loader.init_state(jax.random.PRNGKey(0))
model_state = model_init()

@jax.jit
def train_epoch(model_state, loader_state):
  def body_fn(model_state, batch, mask):
    # Update the network using your train fn
    new_model_state = model_update(model_state, batch, mask)
    return new_model_state, None

  loader_state, model_state, _ = loader.scan_epoch(loader_state, model_state, body_fn)
  return model_state, loader_state

model_state, loader_state = train_epoch(model_state, loader_state)

Streaming from Disk

Is your dataset enormous? Swap in a disk-backed source.

import jax

from cyreal import (
  BatchTransform,
  DataLoader,
  DevicePutTransform,
  MNISTDataset,
)

pipeline = [
  MNISTDataset.make_disk_source(split="train", ordering="shuffle", prefetch_size=1024),
  BatchTransform(batch_size=128),
  DevicePutTransform(),
]

loader = DataLoader(pipeline=pipeline)
state = loader.init_state(jax.random.PRNGKey(0))

for batch, mask in loader.iterate(state):
  ...  # stream without holding the dataset in RAM

For the Dirty and Impure

Want to jit but also log some metrics? Use HostCallbackTransform which utilizes jax.experimental.io_callback under the hood.

import jax.numpy as jnp
import numpy as np

from cyreal import (
  ArraySampleSource,
  BatchTransform,
  DataLoader,
  HostCallbackTransform,
  MNISTDataset,
)

def model(images):
  return jnp.mean(images.astype(jnp.float32), axis=(1, 2, 3))

def cross_entropy(logits, labels):
  labels = labels.astype(jnp.float32)
  return (logits - labels) ** 2

def log_loss(batch, mask):
  logits = model(batch["image"])
  loss = jnp.mean(cross_entropy(logits, batch["label"]) * mask[:, None])
  print("loss:", float(np.asarray(loss)))
  return batch

loader = DataLoader(
  pipeline=[
    ArraySampleSource(MNISTDataset(split="train").as_array_dict(), ordering="shuffle"),
    BatchTransform(batch_size=128),
    HostCallbackTransform(fn=log_loss),
  ],
)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cyreal-0.1.1.tar.gz (20.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

cyreal-0.1.1-py3-none-any.whl (17.7 kB view details)

Uploaded Python 3

File details

Details for the file cyreal-0.1.1.tar.gz.

File metadata

  • Download URL: cyreal-0.1.1.tar.gz
  • Upload date:
  • Size: 20.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for cyreal-0.1.1.tar.gz
Algorithm Hash digest
SHA256 c7f0457ee5905e254c5a1033e9cc44b6de6a3644ab7f6832ed335a1778efb3b8
MD5 40f006cb185cc28d5d43d5d79854d421
BLAKE2b-256 66ffbb0a3d774a743a66cbcee8365a90b4bf743dad372fc8176480933562c426

See more details on using hashes here.

Provenance

The following attestation bundles were made for cyreal-0.1.1.tar.gz:

Publisher: python-publish.yml on smorad/cyreal

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file cyreal-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: cyreal-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 17.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for cyreal-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 9b8ae11f2bff7f708368b3384ade66e7c358fcfd23bbae6cb2255eec587d5343
MD5 4ab36107cbc8d45cecd8a7f15f4d2e18
BLAKE2b-256 1b697e3c9ef8ab021be9abc6e303c69acd9d4f40dc2f5e9b423f761f9d4a045b

See more details on using hashes here.

Provenance

The following attestation bundles were made for cyreal-0.1.1-py3-none-any.whl:

Publisher: python-publish.yml on smorad/cyreal

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page