Jittable data loading utilities for JAX.
Project description
Cyreal - Another JAX DataLoader
grainfor the corporations,cyrealfor the people
Pure jax utilities for iterating over finite datasets without ever touching torch or tensorflow. Dataloaders support jax.jit, jax.grad, jax.lax.scan, and other function transformations.
Installation
The only dependency is jax. On GPU machines, install the
appropriate JAX build for your CUDA version.
pip install cyreal
Quick start with MNIST
Write torch-style dataloaders without torch
import jax
import jax.numpy as jnp
from cyreal import (
ArraySampleSource,
BatchTransform,
DataLoader,
DevicePutTransform,
MNISTDataset,
)
train_data = MNISTDataset(split="train").as_array_dict()
pipeline = [
ArraySampleSource(train_data, ordering="shuffle"),
BatchTransform(batch_size=128),
DevicePutTransform(),
]
loader = DataLoader(pipeline=pipeline)
state = loader.init_state(jax.random.PRNGKey(0))
for batch, mask in loader.iterate(state):
... # train your network!
See full training examples:
Scan and Avoid Boilerplate
DataLoader.scan_epoch will run a full pass through the dataset into a single
jax.lax.scan to minimize dispatch overhead. This will jit the body_fn.
def body_fn(model_state, batch, mask):
model_state = update_model(model_state, batch, mask)
return model_state, None
loader_state, model_state, _ = loader.scan_epoch(loader_state, model_state, body_fn)
JIT Capabilities
Do you enjoy premature optimization? Why not jit the entire train epoch?
import jax
import jax.numpy as jnp
from cyreal import (
ArraySampleSource,
BatchTransform,
DataLoader,
DevicePutTransform,
MNISTDataset,
)
train_data = MNISTDataset(split="train").as_array_dict()
pipeline = [
ArraySampleSource(train_data, ordering="shuffle"),
BatchTransform(batch_size=128),
DevicePutTransform(),
]
loader = DataLoader(pipeline)
loader_state = loader.init_state(jax.random.PRNGKey(0))
model_state = model_init()
@jax.jit
def train_epoch(model_state, loader_state):
def body_fn(model_state, batch, mask):
# Update the network using your train fn
new_model_state = model_update(model_state, batch, mask)
return new_model_state, None
loader_state, model_state, _ = loader.scan_epoch(loader_state, model_state, body_fn)
return model_state, loader_state
model_state, loader_state = train_epoch(model_state, loader_state)
Streaming from Disk
Is your dataset enormous? Swap in a disk-backed source.
import jax
from cyreal import (
BatchTransform,
DataLoader,
DevicePutTransform,
MNISTDataset,
)
pipeline = [
MNISTDataset.make_disk_source(split="train", ordering="shuffle", prefetch_size=1024),
BatchTransform(batch_size=128),
DevicePutTransform(),
]
loader = DataLoader(pipeline=pipeline)
state = loader.init_state(jax.random.PRNGKey(0))
for batch, mask in loader.iterate(state):
... # stream without holding the dataset in RAM
For the Dirty and Impure
Want to jit but also log some metrics? Use HostCallbackTransform which utilizes jax.experimental.io_callback under the hood.
import jax.numpy as jnp
import numpy as np
from cyreal import (
ArraySampleSource,
BatchTransform,
DataLoader,
HostCallbackTransform,
MNISTDataset,
)
def model(images):
return jnp.mean(images.astype(jnp.float32), axis=(1, 2, 3))
def cross_entropy(logits, labels):
labels = labels.astype(jnp.float32)
return (logits - labels) ** 2
def log_loss(batch, mask):
logits = model(batch["image"])
loss = jnp.mean(cross_entropy(logits, batch["label"]) * mask[:, None])
print("loss:", float(np.asarray(loss)))
return batch
loader = DataLoader(
pipeline=[
ArraySampleSource(MNISTDataset(split="train").as_array_dict(), ordering="shuffle"),
BatchTransform(batch_size=128),
HostCallbackTransform(fn=log_loss),
],
)
Reinforcement Learning
GymnaxSource streams transitions from any Gymnax environment one instance at a time. Keep the
policy's trainable parameters and recurrent carries inside the policy_state and use the
provided helpers (loader.set_policy_state, GymnaxSource.set_policy_state) to inject that state
before calling next. This design keeps the pipeline ergonomic (one method call) while still
making jax.vmap straightforward for batched rollouts. Your policy_step_fn also receives a
boolean new_episode flag so it can reset its own recurrent state whenever the environment restarts.
import gymnax
import jax
import jax.numpy as jnp
from cyreal import BatchTransform, DataLoader, GymnaxSource
from cyreal.rl import set_loader_policy_state, set_source_policy_state
env = gymnax.environments.classic_control.cartpole.CartPole()
env_params = env.default_params
def policy_step(obs, policy_state, new_episode, key):
del new_episode
logits = obs @ policy_state["params"]
action = jax.random.categorical(key, logits=logits)
return action, policy_state
policy_state = {
"params": jnp.zeros((4, 2)),
"recurrent_state": jnp.zeros((3,)),
}
source = GymnaxSource(
env=env,
env_params=env_params,
policy_step_fn=policy_step,
policy_state_template=policy_state,
steps_per_epoch=16,
)
pipeline = [
source,
BatchTransform(batch_size=16, drop_last=True),
]
loader = DataLoader(pipeline)
state = loader.init_state(jax.random.PRNGKey(0))
state = set_loader_policy_state(state, policy_state)
# Perform one epoch
for batch, mask in loader.iterate(state):
# Update the policy state (parameters) after each epoch
policy_state.update({"params": jnp.ones((4, 2))})
state = set_loader_policy_state(state, policy_state)
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file cyreal-0.1.2.tar.gz.
File metadata
- Download URL: cyreal-0.1.2.tar.gz
- Upload date:
- Size: 28.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3cfb70113877dc72bc38119906789e3ab5d0c5a4d840850bd771dad7ac569992
|
|
| MD5 |
af79e92e11f91aa555b4605bbd88ef52
|
|
| BLAKE2b-256 |
6b3fec65c6f590c6582a80777ad8dad5c8b1b4520317f3b405ac4291d602ab7d
|
Provenance
The following attestation bundles were made for cyreal-0.1.2.tar.gz:
Publisher:
python-publish.yml on smorad/cyreal
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
cyreal-0.1.2.tar.gz -
Subject digest:
3cfb70113877dc72bc38119906789e3ab5d0c5a4d840850bd771dad7ac569992 - Sigstore transparency entry: 764638700
- Sigstore integration time:
-
Permalink:
smorad/cyreal@5e168897fddbb6d993fb8ba6b44051443a3d8223 -
Branch / Tag:
refs/tags/0.1.2 - Owner: https://github.com/smorad
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
python-publish.yml@5e168897fddbb6d993fb8ba6b44051443a3d8223 -
Trigger Event:
release
-
Statement type:
File details
Details for the file cyreal-0.1.2-py3-none-any.whl.
File metadata
- Download URL: cyreal-0.1.2-py3-none-any.whl
- Upload date:
- Size: 31.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
feaa4de6a7c62546701d36b04771e22bac633a3461879b1727b8793fc5edb618
|
|
| MD5 |
8a685c5495f311a52cf0551f579614d5
|
|
| BLAKE2b-256 |
c33140c665d6e5ff7bd1388f4d118f503b92e46766aa808c7587e673f1923d75
|
Provenance
The following attestation bundles were made for cyreal-0.1.2-py3-none-any.whl:
Publisher:
python-publish.yml on smorad/cyreal
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
cyreal-0.1.2-py3-none-any.whl -
Subject digest:
feaa4de6a7c62546701d36b04771e22bac633a3461879b1727b8793fc5edb618 - Sigstore transparency entry: 764638704
- Sigstore integration time:
-
Permalink:
smorad/cyreal@5e168897fddbb6d993fb8ba6b44051443a3d8223 -
Branch / Tag:
refs/tags/0.1.2 - Owner: https://github.com/smorad
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
python-publish.yml@5e168897fddbb6d993fb8ba6b44051443a3d8223 -
Trigger Event:
release
-
Statement type: