Jittable data loading utilities for JAX.
Project description
Cyreal - Another JAX DataLoader
grainfor the corporations,cyrealfor the people
Pure jax utilities for iterating over finite datasets without ever touching torch or tensorflow. Dataloaders support jax.jit, jax.grad, jax.lax.scan, and other function transformations.
Installation
The only dependency is jax. On GPU machines, install the
appropriate JAX build for your CUDA version.
pip install cyreal
Quick start with MNIST
Write torch-style dataloaders without torch
import jax
import jax.numpy as jnp
from cyreal import (
ArraySampleSource,
BatchTransform,
DataLoader,
DevicePutTransform,
MNISTDataset,
)
train_data = MNISTDataset(split="train").as_array_dict()
pipeline = [
ArraySampleSource(train_data, ordering="shuffle"),
BatchTransform(batch_size=128),
DevicePutTransform(),
]
loader = DataLoader(pipeline=pipeline)
state = loader.init_state(jax.random.PRNGKey(0))
for batch, mask in loader.iterate(state):
... # train your network!
Scan and Avoid Boilerplate
DataLoader.scan_epoch will run a full pass through the dataset into a single
jax.lax.scan to minimize dispatch overhead. This will jit the body_fn.
def body_fn(model_state, batch, mask):
model_state = update_model(model_state, batch, mask)
return model_state, None
loader_state, model_state, _ = loader.scan_epoch(loader_state, model_state, body_fn)
JIT Capabilities
Do you enjoy premature optimization? Why not jit the entire train epoch?
import jax
import jax.numpy as jnp
from cyreal import (
ArraySampleSource,
BatchTransform,
DataLoader,
DevicePutTransform,
MNISTDataset,
)
train_data = MNISTDataset(split="train").as_array_dict()
pipeline = [
ArraySampleSource(train_data, ordering="shuffle"),
BatchTransform(batch_size=128),
DevicePutTransform(),
]
loader = DataLoader(pipeline)
loader_state = loader.init_state(jax.random.PRNGKey(0))
model_state = model_init()
@jax.jit
def train_epoch(model_state, loader_state):
def body_fn(model_state, batch, mask):
# Update the network using your train fn
new_model_state = model_update(model_state, batch, mask)
return new_model_state, None
loader_state, model_state, _ = loader.scan_epoch(loader_state, model_state, body_fn)
return model_state, loader_state
model_state, loader_state = train_epoch(model_state, loader_state)
Streaming from Disk
Is your dataset enormous? Swap in a disk-backed source.
import jax
from cyreal import (
BatchTransform,
DataLoader,
DevicePutTransform,
MNISTDataset,
)
pipeline = [
MNISTDataset.make_disk_source(split="train", ordering="shuffle", prefetch_size=1024),
BatchTransform(batch_size=128),
DevicePutTransform(),
]
loader = DataLoader(pipeline=pipeline)
state = loader.init_state(jax.random.PRNGKey(0))
for batch, mask in loader.iterate(state):
... # stream without holding the dataset in RAM
For the Dirty and Impure
Want to jit but also log some metrics? Use HostCallbackTransform which utilizes jax.experimental.io_callback under the hood.
import jax.numpy as jnp
import numpy as np
from cyreal import (
ArraySampleSource,
BatchTransform,
DataLoader,
HostCallbackTransform,
MNISTDataset,
)
def model(images):
return jnp.mean(images.astype(jnp.float32), axis=(1, 2, 3))
def cross_entropy(logits, labels):
labels = labels.astype(jnp.float32)
return (logits - labels) ** 2
def log_loss(batch, mask):
logits = model(batch["image"])
loss = jnp.mean(cross_entropy(logits, batch["label"]) * mask[:, None])
print("loss:", float(np.asarray(loss)))
return batch
loader = DataLoader(
pipeline=[
ArraySampleSource(MNISTDataset(split="train").as_array_dict(), ordering="shuffle"),
BatchTransform(batch_size=128),
HostCallbackTransform(fn=log_loss),
],
)
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file cyreal-0.1.1.tar.gz.
File metadata
- Download URL: cyreal-0.1.1.tar.gz
- Upload date:
- Size: 20.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c7f0457ee5905e254c5a1033e9cc44b6de6a3644ab7f6832ed335a1778efb3b8
|
|
| MD5 |
40f006cb185cc28d5d43d5d79854d421
|
|
| BLAKE2b-256 |
66ffbb0a3d774a743a66cbcee8365a90b4bf743dad372fc8176480933562c426
|
Provenance
The following attestation bundles were made for cyreal-0.1.1.tar.gz:
Publisher:
python-publish.yml on smorad/cyreal
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
cyreal-0.1.1.tar.gz -
Subject digest:
c7f0457ee5905e254c5a1033e9cc44b6de6a3644ab7f6832ed335a1778efb3b8 - Sigstore transparency entry: 764331743
- Sigstore integration time:
-
Permalink:
smorad/cyreal@154c52a09ddbe0e2b24ceea5810ef6c6402fcb3d -
Branch / Tag:
refs/tags/0.1.1 - Owner: https://github.com/smorad
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
python-publish.yml@154c52a09ddbe0e2b24ceea5810ef6c6402fcb3d -
Trigger Event:
release
-
Statement type:
File details
Details for the file cyreal-0.1.1-py3-none-any.whl.
File metadata
- Download URL: cyreal-0.1.1-py3-none-any.whl
- Upload date:
- Size: 17.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
9b8ae11f2bff7f708368b3384ade66e7c358fcfd23bbae6cb2255eec587d5343
|
|
| MD5 |
4ab36107cbc8d45cecd8a7f15f4d2e18
|
|
| BLAKE2b-256 |
1b697e3c9ef8ab021be9abc6e303c69acd9d4f40dc2f5e9b423f761f9d4a045b
|
Provenance
The following attestation bundles were made for cyreal-0.1.1-py3-none-any.whl:
Publisher:
python-publish.yml on smorad/cyreal
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
cyreal-0.1.1-py3-none-any.whl -
Subject digest:
9b8ae11f2bff7f708368b3384ade66e7c358fcfd23bbae6cb2255eec587d5343 - Sigstore transparency entry: 764331752
- Sigstore integration time:
-
Permalink:
smorad/cyreal@154c52a09ddbe0e2b24ceea5810ef6c6402fcb3d -
Branch / Tag:
refs/tags/0.1.1 - Owner: https://github.com/smorad
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
python-publish.yml@154c52a09ddbe0e2b24ceea5810ef6c6402fcb3d -
Trigger Event:
release
-
Statement type: