Skip to main content

Jittable data loading utilities for JAX.

Project description

Cyreal - Another JAX DataLoader

grain for the corporations, cyreal for the people

Pure jax utilities for iterating over finite datasets without ever touching torch or tensorflow. Dataloaders support jax.jit, jax.grad, jax.lax.scan, and other function transformations.

Installation

The only dependency is jax. On GPU machines, install the appropriate JAX build for your CUDA version.

pip install cyreal

Quick start with MNIST

Write torch-style dataloaders without torch

import jax
import jax.numpy as jnp

from cyreal import (
  ArraySampleSource,
  BatchTransform,
  DataLoader,
  DevicePutTransform,
  MNISTDataset,
)

train_data = MNISTDataset(split="train").as_array_dict()
pipeline = [
  ArraySampleSource(train_data, ordering="shuffle"),
  BatchTransform(batch_size=128),
  DevicePutTransform(),
]
loader = DataLoader(pipeline=pipeline)
state = loader.init_state(jax.random.PRNGKey(0))

for batch, mask in loader.iterate(state):
  ...  # train your network!

See full training examples:

Scan and Avoid Boilerplate

DataLoader.scan_epoch will run a full pass through the dataset into a single jax.lax.scan to minimize dispatch overhead. This will jit the body_fn.

def body_fn(model_state, batch, mask):
  model_state = update_model(model_state, batch, mask)
  return model_state, None

loader_state, model_state, _ = loader.scan_epoch(loader_state, model_state, body_fn)

JIT Capabilities

Do you enjoy premature optimization? Why not jit the entire train epoch?

import jax
import jax.numpy as jnp

from cyreal import (
  ArraySampleSource,
  BatchTransform,
  DataLoader,
  DevicePutTransform,
  MNISTDataset,
)

train_data = MNISTDataset(split="train").as_array_dict()
pipeline = [
  ArraySampleSource(train_data, ordering="shuffle"),
  BatchTransform(batch_size=128),
  DevicePutTransform(),
]
loader = DataLoader(pipeline)
loader_state = loader.init_state(jax.random.PRNGKey(0))
model_state = model_init()

@jax.jit
def train_epoch(model_state, loader_state):
  def body_fn(model_state, batch, mask):
    # Update the network using your train fn
    new_model_state = model_update(model_state, batch, mask)
    return new_model_state, None

  loader_state, model_state, _ = loader.scan_epoch(loader_state, model_state, body_fn)
  return model_state, loader_state

model_state, loader_state = train_epoch(model_state, loader_state)

Streaming from Disk

Is your dataset enormous? Swap in a disk-backed source.

import jax

from cyreal import (
  BatchTransform,
  DataLoader,
  DevicePutTransform,
  MNISTDataset,
)

pipeline = [
  MNISTDataset.make_disk_source(split="train", ordering="shuffle", prefetch_size=1024),
  BatchTransform(batch_size=128),
  DevicePutTransform(),
]

loader = DataLoader(pipeline=pipeline)
state = loader.init_state(jax.random.PRNGKey(0))

for batch, mask in loader.iterate(state):
  ...  # stream without holding the dataset in RAM

For the Dirty and Impure

Want to jit but also log some metrics? Use HostCallbackTransform which utilizes jax.experimental.io_callback under the hood.

import jax.numpy as jnp
import numpy as np

from cyreal import (
  ArraySampleSource,
  BatchTransform,
  DataLoader,
  HostCallbackTransform,
  MNISTDataset,
)

def model(images):
  return jnp.mean(images.astype(jnp.float32), axis=(1, 2, 3))

def cross_entropy(logits, labels):
  labels = labels.astype(jnp.float32)
  return (logits - labels) ** 2

def log_loss(batch, mask):
  logits = model(batch["image"])
  loss = jnp.mean(cross_entropy(logits, batch["label"]) * mask[:, None])
  print("loss:", float(np.asarray(loss)))
  return batch

loader = DataLoader(
  pipeline=[
    ArraySampleSource(MNISTDataset(split="train").as_array_dict(), ordering="shuffle"),
    BatchTransform(batch_size=128),
    HostCallbackTransform(fn=log_loss),
  ],
)

Reinforcement Learning

GymnaxSource streams transitions from any Gymnax environment one instance at a time. Keep the policy's trainable parameters and recurrent carries inside the policy_state and use the provided helpers (loader.set_policy_state, GymnaxSource.set_policy_state) to inject that state before calling next. This design keeps the pipeline ergonomic (one method call) while still making jax.vmap straightforward for batched rollouts. Your policy_step_fn also receives a boolean new_episode flag so it can reset its own recurrent state whenever the environment restarts.

import gymnax
import jax
import jax.numpy as jnp

from cyreal import BatchTransform, DataLoader, GymnaxSource
from cyreal.rl import set_loader_policy_state, set_source_policy_state

env = gymnax.environments.classic_control.cartpole.CartPole()
env_params = env.default_params

def policy_step(obs, policy_state, new_episode, key):
    del new_episode
    logits = obs @ policy_state["params"]
    action = jax.random.categorical(key, logits=logits)
    return action, policy_state

policy_state = {
    "params": jnp.zeros((4, 2)),
    "recurrent_state": jnp.zeros((3,)),
}

source = GymnaxSource(
    env=env,
    env_params=env_params,
    policy_step_fn=policy_step,
    policy_state_template=policy_state,
    steps_per_epoch=16,
)
pipeline = [
    source,
    BatchTransform(batch_size=16, drop_last=True),
]
loader = DataLoader(pipeline)
state = loader.init_state(jax.random.PRNGKey(0))
state = set_loader_policy_state(state, policy_state)

# Perform one epoch
for batch, mask in loader.iterate(state):
    # Update the policy state (parameters) after each epoch
    policy_state.update({"params": jnp.ones((4, 2))})
    state = set_loader_policy_state(state, policy_state)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cyreal-0.1.2.tar.gz (28.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

cyreal-0.1.2-py3-none-any.whl (31.8 kB view details)

Uploaded Python 3

File details

Details for the file cyreal-0.1.2.tar.gz.

File metadata

  • Download URL: cyreal-0.1.2.tar.gz
  • Upload date:
  • Size: 28.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for cyreal-0.1.2.tar.gz
Algorithm Hash digest
SHA256 3cfb70113877dc72bc38119906789e3ab5d0c5a4d840850bd771dad7ac569992
MD5 af79e92e11f91aa555b4605bbd88ef52
BLAKE2b-256 6b3fec65c6f590c6582a80777ad8dad5c8b1b4520317f3b405ac4291d602ab7d

See more details on using hashes here.

Provenance

The following attestation bundles were made for cyreal-0.1.2.tar.gz:

Publisher: python-publish.yml on smorad/cyreal

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file cyreal-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: cyreal-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 31.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for cyreal-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 feaa4de6a7c62546701d36b04771e22bac633a3461879b1727b8793fc5edb618
MD5 8a685c5495f311a52cf0551f579614d5
BLAKE2b-256 c33140c665d6e5ff7bd1388f4d118f503b92e46766aa808c7587e673f1923d75

See more details on using hashes here.

Provenance

The following attestation bundles were made for cyreal-0.1.2-py3-none-any.whl:

Publisher: python-publish.yml on smorad/cyreal

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page